1 module golem.model;
2 
3 import golem.tensor;
4 import golem.nn;
5 import golem.util;
6 
7 import std.meta;
8 
9 ubyte[] packParameters(Params...)(Params params)
10 {
11     import golem.util : staticIndexOf;
12 
13     enum firstPos = staticIndexOf!(hasParameters, Params);
14 
15     static if (firstPos != -1)
16     {
17         // dfmt off
18         return packParameters(
19             params[0 .. firstPos],
20             params[firstPos].parameters,
21             params[firstPos + 1 .. $]
22         );
23         // dfmt on
24     }
25     else
26     {
27         static if (allSatisfy!(isTensor, Params))
28         {
29             import msgpack : Packer;
30             import mir.ndslice : flattened, ndarray;
31 
32             Packer packer;
33             packer.beginArray(params.length);
34             foreach (p; params)
35             {
36                 packer.pack(p.value.flattened[].ndarray());
37             }
38             return packer.stream.data;
39         }
40         else
41         {
42             static assert(false);
43         }
44     }
45 }
46 
47 void unpackParameters(Params...)(ubyte[] data, ref Params params)
48 {
49     import golem.util : staticIndexOf;
50 
51     enum firstPos = staticIndexOf!(hasParameters, Params);
52 
53     static if (firstPos != -1)
54     {
55         // dfmt off
56         unpackParameters(
57             data,
58             params[0 .. firstPos],
59             params[firstPos].parameters,
60             params[firstPos + 1 .. $]
61         );
62         // dfmt on
63     }
64     else
65     {
66         static if (allSatisfy!(isTensor, Params))
67         {
68             import msgpack : unpack;
69             import mir.ndslice : flattened, ndarray, sliced;
70 
71             auto unpacked = unpack(data);
72             foreach (p; params)
73             {
74                 assert(!unpacked.empty);
75                 auto temp = unpacked.front.as!(typeof(p).ElementType[]);
76                 assert(elementSize(p.shape) == temp.length);
77                 p.value = temp.sliced(p.shape);
78                 unpacked.popFront();
79             }
80         }
81         else
82         {
83             static assert(false);
84         }
85     }
86 }
87 
88 unittest
89 {
90     auto x = tensor!([2, 3])([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]);
91     auto serializedData = packParameters(x);
92 
93     auto y = tensor!([2, 3])([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
94     unpackParameters(serializedData, y);
95 
96     assert(x.value == y.value);
97 }
98 
99 unittest
100 {
101     auto x = tensor!([2, 3])([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]);
102     auto y = tensor!([2, 2])([0.0, 0.0, 0.0, 0.0]);
103 
104     auto serializedData = packParameters(x);
105 
106     try
107         unpackParameters(serializedData, y);
108     catch (Throwable t)
109     {
110         return;
111     }
112     assert(false);
113 }
114 
115 unittest
116 {
117     auto x = tensor!([2, 2, 2])([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]);
118     auto serializedData = packParameters(x);
119 
120     auto y = tensor!([4, 2])([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
121     unpackParameters(serializedData, y);
122 
123     import mir.ndslice : flattened;
124 
125     assert(x.value.flattened[] == y.value.flattened[]);
126 }
127 
128 unittest
129 {
130     import golem.nn : Linear;
131     import std.meta : AliasSeq;
132 
133     class Model
134     {
135         Linear!(float, 2, 2) fc1;
136         Linear!(float, 2, 1) fc2;
137 
138         alias parameters = AliasSeq!(fc1, fc2);
139 
140         this()
141         {
142             foreach (ref p; parameters)
143                 p = new typeof(p);
144         }
145     }
146 
147     auto m1 = new Model;
148     auto serializedData = packParameters(m1);
149 
150     auto m2 = new Model;
151     unpackParameters(serializedData, m2);
152 
153     assert(m1.fc1.weights.value == m2.fc1.weights.value);
154     assert(m1.fc1.bias.value == m2.fc1.bias.value);
155     assert(m1.fc2.weights.value == m2.fc2.weights.value);
156     assert(m1.fc2.bias.value == m2.fc2.bias.value);
157 }
158 
159 class ModelArchiver
160 {
161     string dirPath;
162     string prefix;
163 
164     this(string dirPath = "model_data", string prefix = "model_")
165     {
166         this.dirPath = dirPath;
167         this.prefix = prefix;
168     }
169 
170     void save(T)(T model)
171     {
172         static import std.file;
173 
174         prepare();
175         std.file.write(makeCurrentPath(), packParameters(model));
176     }
177 
178     void load(T)(T model)
179     {
180         static import std.file;
181 
182         if (!std.file.exists(dirPath))
183             return;
184 
185         auto recentPath = findRecentModelPath();
186         if (std.file.exists(recentPath))
187             unpackParameters(cast(ubyte[]) std.file.read(recentPath), model);
188     }
189 
190     protected void prepare()
191     {
192         import std.file : exists, mkdirRecurse;
193 
194         if (!exists(dirPath))
195             mkdirRecurse(dirPath);
196     }
197 
198     protected string makeCurrentPath()
199     {
200         import std.path : buildNormalizedPath;
201         import std.format : format;
202         import std.datetime : Clock, DateTime;
203 
204         const DateTime now = cast(DateTime) Clock.currTime;
205         const name = format!"%s%04d%02d%02d-%02d%02d%02d.dat"(prefix, now.year,
206                 now.month, now.day, now.hour, now.minute, now.second);
207 
208         return buildNormalizedPath(dirPath, name);
209     }
210 
211     protected auto makePattern()
212     {
213         import std.regex : escaper, regex;
214         import std.conv : to;
215 
216         const prefix = escaper(prefix).to!string();
217         const pattern = "^" ~ prefix ~ `(\d{8})-(\d{6}).dat$`;
218         return regex(pattern);
219     }
220 
221     protected string findRecentModelPath()
222     {
223         import std.path : baseName;
224         import std.file : dirEntries, DirEntry, SpanMode;
225         import std.regex : matchFirst;
226         import std.typecons : Tuple, tuple;
227 
228         string recentPath;
229         Tuple!(string, string) latest;
230 
231         const pattern = makePattern();
232         foreach (DirEntry entry; dirEntries(dirPath, SpanMode.shallow))
233         {
234             import std.stdio : writeln;
235 
236             auto name = baseName(entry.name);
237             auto m = matchFirst(name, pattern);
238             if (m)
239             {
240                 auto temp = tuple(m.captures[0], m.captures[1]);
241                 if (recentPath.length == 0 || latest < temp)
242                 {
243                     recentPath = entry.name;
244                     latest = temp;
245                 }
246             }
247         }
248 
249         return recentPath;
250     }
251 }
252 
253 
254 mixin template NetModule()
255 {
256 	mixin(parametersAliasSeqCode!(typeof(this)));
257 
258 	this()
259 	{
260 		foreach (ref p; parameters)
261 		{
262 			p = new typeof(p);
263 		}
264 	}
265 }
266 
267 private template AllParameterMembersOf(T)
268 {
269 	private template isParameterMember(string name)
270 	{
271 		import golem.nn : hasParameters;
272 		import golem.tensor : isTensor;
273 
274 		alias MemberType = typeof(__traits(getMember, T.init, name));
275 
276 		enum isParameterMember = hasParameters!(MemberType);
277 	}
278 
279 	import std.traits : FieldNameTuple;
280 
281 	alias AllParameterMembersOf = Filter!(isParameterMember, FieldNameTuple!T);
282 }
283 
284 string parametersAliasSeqCode(T)()
285 {
286 	enum names = [AllParameterMembersOf!T];
287 
288 	string code = "import std.meta : AliasSeq;\nalias parameters = AliasSeq!(";
289 	foreach (i, name; names)
290 	{
291 		if (i > 0)
292 			code ~= ",";
293 		code ~= name;
294 	}
295 	code ~= ");";
296 
297 	return code;
298 }
299 
300 unittest
301 {
302 	static class Test
303 	{
304 		Linear!(float, 16, 8) fc1;
305 		BatchNorm!(float, [8]) bn1;
306 
307 		mixin NetModule;
308         // alias parameters = AliasSeq!(fc1, bn1);
309         // this() {
310         //     foreach (ref p; parameters)
311         //         p = new typeof(p);
312         // }
313 	}
314 
315 	auto t = new Test;
316 	static assert(t.parameters.length == 2);
317 	assert(t.parameters[0] == t.fc1);
318 	assert(t.parameters[1] == t.bn1);
319 	assert(t.fc1 !is null);
320 	assert(t.bn1 !is null);
321 }