1 module golem.nn;
2 
3 import golem.tensor;
4 import golem.random;
5 
6 import mir.ndslice;
7 
8 import std.meta;
9 
10 enum hasParameters(T) = __traits(compiles, { auto ps = T.init.parameters; });
11 
12 
13 class Linear(T, size_t InputDim, size_t OutputDim, UseGradient useGradient = UseGradient.yes)
14 {
15     Tensor!(T, [InputDim, OutputDim], useGradient) weights;
16     Tensor!(T, [OutputDim], useGradient) bias;
17 
18     alias parameters = AliasSeq!(weights, bias);
19 
20     this()
21     {
22         weights = uniform!(T, [InputDim, OutputDim], useGradient);
23         bias = uniform!(T, [OutputDim], useGradient);
24     }
25 
26     this(T initial)
27     {
28         weights = new Tensor!(T, [InputDim, OutputDim], useGradient)(initial);
29         bias = new Tensor!(T, [OutputDim], useGradient)(initial);
30     }
31 
32     this(T initialWeight, T initialBias)
33     {
34         weights = new Tensor!(T, [InputDim, OutputDim], useGradient)(initialWeight);
35         bias = new Tensor!(T, [OutputDim], useGradient)(initialBias);
36     }
37 
38     auto opCall(size_t[2] Shape, UseGradient useGradient)(Tensor!(T, Shape, useGradient) x)
39             if (Shape[1] == InputDim)
40     {
41         import golem.math : linear;
42 
43         return linear(x, this.weights, this.bias);
44     }
45 
46     static if (useGradient)
47     {
48         void resetGrads()
49         {
50             weights.resetGrads();
51             bias.resetGrads();
52         }
53     }
54 }
55 
56 unittest
57 {
58     import golem.math : flatten;
59 
60     auto fc1 = new Linear!(float, 2 * 2, 1);
61 
62     Tensor!(float, [3, 2, 2]) x = new Tensor!(float, [3, 2, 2])(0.0f);
63     Tensor!(float, [3, 4]) b = flatten(x);
64     Tensor!(float, [3, 1]) y = fc1(b);
65     assert(y.value.shape == [3, 1]);
66 }
67 
68 unittest
69 {
70     auto fc = new Linear!(float, 2, 1);
71     static assert(hasParameters!(typeof(fc)));
72 }
73 
74 unittest
75 {
76     auto fc0 = new Linear!(float, 2, 1)(-1);
77     assert(fc0.weights.value[0, 0] == -1);
78     assert(fc0.weights.value[1, 0] == -1);
79     assert(fc0.bias.value[0] == -1);
80 
81     auto fc1 = new Linear!(float, 2, 1)(0, 1);
82     assert(fc1.weights.value[0, 0] == 0);
83     assert(fc1.weights.value[1, 0] == 0);
84     assert(fc1.bias.value[0] == 1);
85 
86     auto fc2 = new Linear!(float, 2, 1)(1, 0);
87     assert(fc2.weights.value[0, 0] == 1);
88     assert(fc2.weights.value[1, 0] == 1);
89     assert(fc2.bias.value[0] == 0);
90 }
91 
92 unittest
93 {
94     auto fc1 = new Linear!(float, 2, 1, UseGradient.yes);
95     auto fc2 = new Linear!(float, 2, 1, UseGradient.no);
96 
97     auto x = new Tensor!(float, [2, 2], UseGradient.yes)(1.0f);
98     auto y = new Tensor!(float, [2, 2], UseGradient.no)(1.0f);
99 
100     auto a = fc1(x);
101     auto b = fc1(y);
102     auto c = fc2(x);
103     auto d = fc2(y);
104 
105     static assert(canBackward!(typeof(a)));
106     static assert(canBackward!(typeof(b)));
107     static assert(canBackward!(typeof(c)));
108     static assert(!canBackward!(typeof(d)));
109 }
110 
111 class BatchNorm(T, size_t[] Shape, UseGradient useGrad = UseGradient.yes)
112 {
113 	Tensor!(T, Shape, UseGradient.no) mean;
114 	Tensor!(T, Shape, UseGradient.no) var;
115 	Tensor!(T, Shape, useGrad) factor;
116 	Tensor!(T, Shape, useGrad) offset;
117 
118     Tensor!(T, Shape, UseGradient.no) tempMean;
119 	Tensor!(T, Shape, UseGradient.no) tempVar;
120 	Tensor!(T, Shape, UseGradient.no) temps;
121 	T momentum = 0.9;
122 
123 	alias parameters = AliasSeq!(mean, var, factor, offset);
124 
125 	this()
126 	{
127 		mean = zeros!(T, Shape);
128 		var = zeros!(T, Shape);
129 		factor = ones!(T, Shape, useGrad);
130 		offset = zeros!(T, Shape, useGrad);
131         tempMean = zeros!(T, Shape);
132 		tempVar = zeros!(T, Shape);
133 		temps = zeros!(T, Shape);
134 	}
135 
136 	auto opCall(size_t[] ShapeX, UseGradient useGradX)(Tensor!(T, ShapeX, useGradX) x, bool isTrain)
137 	{
138 		static assert(ShapeX[1 .. $] == Shape);
139 
140         assert(x.shape[0] != 0);
141 
142 		import std.math : sqrt;
143         import golem.math : broadcastOp;
144 
145         enum eps = 1e-7;
146 
147 		if (isTrain && x.shape[0] > 1)
148 		{
149 			import mir.math.sum : mirsum = sum;
150 			import mir.ndslice : transposed;
151 			import golem.util : expandIndex;
152 
153             auto tm = tempMean.value;
154             tm.flattened[] = 0;
155 			foreach (t; x.value.ipack!1)
156 			{
157                 tm[] += t[];
158 			}
159             tm[] /= x.shape[0];
160 
161             auto tv = tempVar.value;
162 			tv.flattened[] = 0;
163 			foreach (t; x.value.ipack!1)
164 			{
165 				tv[] += (t[] - tm[]).map!(a => a * a);
166 			}
167             tv[] /= x.shape[0];
168 
169             this.mean.value[] = momentum * this.mean.value[] + (1 - momentum) * tm[];
170 			this.var.value[] = momentum * this.var.value[] + (1 - momentum) * tv[];
171 
172             this.temps.value[] = this.var.value.map!(a => sqrt(a + eps));
173             tempVar.value[] = tempVar.value.map!(a => sqrt(a + eps));
174             return broadcastOp!"+"(broadcastOp!"*"(broadcastOp!"-"(x, this.tempMean), factor / this.tempVar), offset);
175 		}
176 
177         this.temps.value[] = this.var.value.map!(a => sqrt(a + eps));
178 		return broadcastOp!"+"(broadcastOp!"*"(broadcastOp!"-"(x, this.mean), factor / this.temps), offset);
179 	}
180 }
181 
182 unittest
183 {
184     auto x = tensor!([0, 2, 2])([
185         1.0f, 2.0f, 3.0f, 4.0f,
186         2.0f, 3.0f, 4.0f, 5.0f,
187         3.0f, 4.0f, 5.0f, 6.0f,
188         4.0f, 5.0f, 6.0f, 7.0f,
189     ]);
190 
191     auto bn = new BatchNorm!(float, [2, 2]);
192     auto y = bn(x, true);
193 
194     import std.math : isClose;
195 
196     assert(bn.mean.value[0, 0].isClose(0.25f));
197     assert(bn.mean.value[0, 1].isClose(0.35f));
198     assert(bn.mean.value[1, 0].isClose(0.45f));
199     assert(bn.mean.value[1, 1].isClose(0.55f));
200     
201     assert(bn.var.value[0, 0].isClose(0.125f));
202     assert(bn.var.value[0, 1].isClose(0.125f));
203     assert(bn.var.value[1, 0].isClose(0.125f));
204     assert(bn.var.value[1, 1].isClose(0.125f));
205 
206     import std.math : sqrt;
207     import std.conv : text;
208 
209     assert(y.value[0, 0, 0].isClose((1.0f - 2.5f) / sqrt(1.25f)), text(y.value[0, 0, 0]));
210     assert(y.value[0, 0, 1].isClose((2.0f - 3.5f) / sqrt(1.25f)), text(y.value[0, 0, 1]));
211     assert(y.value[0, 1, 0].isClose((3.0f - 4.5f) / sqrt(1.25f)), text(y.value[0, 1, 0]));
212     assert(y.value[0, 1, 1].isClose((4.0f - 5.5f) / sqrt(1.25f)), text(y.value[0, 1, 1]));
213 }
214 
215 unittest
216 {
217     auto x = tensor!([0, 2, 2])([
218         1.0f, 2.0f, 3.0f, 4.0f,
219         2.0f, 3.0f, 4.0f, 5.0f,
220     ]);
221 
222     auto bn = new BatchNorm!(float, [2, 2]);
223     auto y = bn(x, true);
224 
225     import std.conv : text;
226 
227     assert(x.grads[] == [[[0.0f, 0.0f], [0.0f, 0.0f]], [[0.0f, 0.0f], [0.0f, 0.0f]]], text(x.grads));
228     y.backward();
229 
230     import std.math : isClose;
231 
232     assert(x.grads[0, 0, 0].isClose(2.0f));
233     assert(x.grads[0, 0, 1].isClose(2.0f));
234     assert(x.grads[0, 1, 0].isClose(2.0f));
235     assert(x.grads[0, 1, 1].isClose(2.0f));
236     assert(x.grads[1, 0, 0].isClose(2.0f));
237     assert(x.grads[1, 0, 1].isClose(2.0f));
238     assert(x.grads[1, 1, 0].isClose(2.0f));
239     assert(x.grads[1, 1, 1].isClose(2.0f));
240 }
241 
242 struct Activation(alias f)
243 {
244     import std.functional : unaryFun;
245 
246     alias fun = unaryFun!f;
247 
248     auto opCall(T)(T x)
249     {
250         return fun(x);
251     }
252 }
253 
254 unittest
255 {
256     import golem.math : sigmoid, tanh;
257 
258     Activation!sigmoid f1;
259     Activation!tanh f2;
260 
261     auto x = tensor!([2, 2])([1.0f, 2.0f, 3.0f, 4.0f]);
262     
263     auto a = f1(x);
264     auto b = f2(x);
265     auto c = sigmoid(x);
266     auto d = tanh(x);
267 
268     assert(a.value == c.value);
269     assert(b.value == d.value);
270 }
271 
272 
273 class Sequence(Ts...)
274 {
275     Ts layers;
276 
277     private alias isNetModule(alias m) = hasParameters!(typeof(m));
278 
279     static if (Filter!(isNetModule, AliasSeq!(layers)).length > 0)
280     {
281         alias parameters = Filter!(isNetModule, AliasSeq!(layers));
282 
283         this()
284         {
285             foreach (ref p; parameters)
286                 p = new typeof(p);
287         }
288     }
289 
290 
291     auto opCall(T)(T x)
292     {
293         return opCall!0(x);
294     }
295 
296     private auto opCall(size_t n, T)(T x)
297     {
298         static if (n == Ts.length)
299         {
300             return x;
301         }
302         else
303         {
304             return opCall!(n + 1)(layers[n](x));
305         }
306     }
307 }
308 
309 unittest
310 {
311     import golem.math : sigmoid;
312 
313     auto net = new Sequence!(
314         Linear!(float, 2, 2),
315         Activation!sigmoid,
316         Linear!(float, 2, 2),
317         Activation!sigmoid,
318         Linear!(float, 2, 1),
319         Activation!sigmoid,
320     );
321 
322     static assert(hasParameters!(typeof(net)));
323 
324     auto x = tensor!([0, 2])([1.0f, 2.0f, 3.0f, 4.0f]);
325     auto y = net(x);
326 }
327 
328 unittest
329 {
330     import golem.math : sigmoid;
331 
332     auto net = new Sequence!(
333         Activation!sigmoid,
334         Activation!sigmoid,
335         Activation!sigmoid,
336     );
337 
338     static assert(!hasParameters!(typeof(net)));
339 }
340 
341 class LiftPool2D(T, size_t Height, size_t Width, UseGradient useGradient = UseGradient.yes)
342 {
343     enum HalfH = Height / 2;
344     enum HalfW = Width / 2;
345 
346     Tensor!(T, [HalfW, HalfW], useGradient) predictW;
347     Tensor!(T, [HalfW, HalfW], useGradient) updateW;
348     Tensor!(T, [HalfH, HalfH], useGradient) predictH;
349     Tensor!(T, [HalfH, HalfH], useGradient) updateH;
350 
351     alias parameters = AliasSeq!(predictW, updateW, predictH, updateH);
352     this()
353     {
354         import mir.ndslice : diagonal;
355 
356         // Haar wavelet
357         predictW = zeros!(T, [HalfW, HalfW], useGradient)();
358         predictW.value.diagonal[] = T(1);
359         updateW = zeros!(T, [HalfW, HalfW], useGradient)();
360         updateW.value.diagonal[] = T(0.5);
361 
362         predictH = zeros!(T, [HalfH, HalfH], useGradient)();
363         predictH.value.diagonal[] = T(1);
364         updateH = zeros!(T, [HalfH, HalfH], useGradient)();
365         updateH.value.diagonal[] = T(0.5);
366     }
367 
368     auto liftUp(U)(U x)
369     if (isTensor!U && U.staticShape.length == 4 && U.staticShape[2] == Height && U.staticShape[3] == Width)
370     {
371         import std.typecons : tuple;
372         import golem.math : splitEvenOdd2D, concat2D, projection1D;
373 
374         auto xw = splitEvenOdd2D!3(x);
375         auto xw_predict = projection1D!3(xw[0], predictW);
376         auto xw_d = xw[1] - xw_predict;
377         auto xw_c = xw[0] + projection1D!3(xw_d, updateW);
378         auto hidden = concat2D(xw_c, xw_d);
379 
380         auto xh = splitEvenOdd2D!2(hidden);
381         auto xh_predict = projection1D!2(xh[0], predictH);
382         auto xh_d = xh[1] - xh_predict;
383         auto xh_c = xh[0] + projection1D!2(xh_d, updateH);
384         auto output = concat2D(xh_c, xh_d);
385 
386         return tuple(output, tuple(xw[1], xw_predict), tuple(xh[1], xh_predict));
387     }
388 
389 }
390 
391 unittest
392 {
393     auto lift = new LiftPool2D!(double, 4, 4);
394     auto images = tensor!([1, 1, 4, 4])([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]);
395 
396     auto y = lift.liftUp(images);
397 
398     assert(y[0].shape == [1, 4, 2, 2]);
399 
400     assert(y[0].value[0, 0, 0, 0] == 3.5);
401     assert(y[0].value[0, 0, 0, 1] == 5.5);
402     assert(y[0].value[0, 0, 1, 0] == 11.5);
403     assert(y[0].value[0, 0, 1, 1] == 13.5);
404     assert(y[0].value[0, 1, 0, 0] == 1);
405     assert(y[0].value[0, 1, 0, 1] == 1);
406     assert(y[0].value[0, 1, 1, 0] == 1);
407     assert(y[0].value[0, 1, 1, 1] == 1);
408     assert(y[0].value[0, 2, 0, 0] == 4);
409     assert(y[0].value[0, 2, 0, 1] == 4);
410     assert(y[0].value[0, 2, 1, 0] == 4);
411     assert(y[0].value[0, 2, 1, 1] == 4);
412     assert(y[0].value[0, 3, 0, 0] == 0);
413     assert(y[0].value[0, 3, 0, 1] == 0);
414     assert(y[0].value[0, 3, 1, 0] == 0);
415     assert(y[0].value[0, 3, 1, 1] == 0);
416 
417     y[0].backward();
418 }
419 
420 
421 class Conv2D(T, size_t C_in, size_t C_out, size_t[] kernelSize, UseGradient useGrad = UseGradient.yes)
422 {
423     mixin Conv2DImpl!(T, C_in, C_out, kernelSize, [0, 0], useGrad);
424 }
425 
426 class Conv2D(T, size_t C_in, size_t C_out, size_t[] kernelSize, size_t[] padding, UseGradient useGrad = UseGradient.yes)
427 {
428     mixin Conv2DImpl!(T, C_in, C_out, kernelSize, padding, useGrad);
429 }
430 
431 unittest
432 {
433     import golem.random : uniform;
434 
435     auto images = uniform!(float, [1, 1, 28, 28]);
436     auto conv1 = new Conv2D!(float, 1, 2, [3, 3]);
437     auto y = conv1(images);
438     assert(y.shape == [1, 2, 26, 26]);
439     y.backward();
440 }
441 
442 unittest
443 {
444     import golem.random : uniform;
445 
446     auto images = uniform!(float, [1, 1, 28, 28]);
447     auto conv1 = new Conv2D!(float, 1, 2, [3, 3], [1, 1]);
448     auto y = conv1(images);
449     assert(y.shape == [1, 2, 28, 28]);
450     y.backward();
451 }
452 
453 private mixin template Conv2DImpl(T, size_t C_in, size_t C_out, size_t[] kernelSize, size_t[] padding, UseGradient useGrad)
454 {
455     enum size_t[] WeightShape = [C_out, C_in, kernelSize[0], kernelSize[1]];
456     enum size_t[] BiasShape = [C_out];
457 
458     Tensor!(T, WeightShape, useGrad) weights;
459     Tensor!(T, BiasShape, useGrad) bias;
460 
461     alias parameters = AliasSeq!(weights, bias);
462     this()
463     {
464         import std.math : sqrt;
465         import golem.random : uniform;
466 
467         weights = uniform!(T, WeightShape, useGrad)();
468         bias = uniform!(T, BiasShape, useGrad)();
469     }
470 
471     auto opCall(U)(U x)
472     if (isTensor!U && U.staticShape.length == 4 && U.staticShape[1] == C_in)
473     {
474         import golem.math : conv2D;
475 
476         return conv2D!(padding)(x, weights, bias);
477     }
478 }
479 
480 
481 class Perceptron(T, alias activateFn, size_t InputDim, size_t HiddenDim, size_t OutputDim, UseGradient useGrad = UseGradient.yes)
482 {
483     Linear!(T, InputDim, HiddenDim, useGrad) fc1;
484     Linear!(T, HiddenDim, OutputDim, useGrad) fc2;
485 
486     alias parameters = AliasSeq!(fc1, fc2);
487 
488     this()
489     {
490         foreach (ref p; parameters)
491             p = new typeof(p);
492     }
493 
494     auto opCall(U)(U x)
495     {
496         return fc2(activateFn(fc1(x)));
497     }
498 }
499 
500 unittest
501 {
502     import golem.math : sigmoid;
503 
504     auto model = new Perceptron!(float, sigmoid, 2, 2, 1);
505     
506     auto x = tensor!([0, 2])([1.0f, 2.0f, 3.0f, 4.0f]);
507     auto y = model(x);
508     static assert(isTensor!(typeof(y), float, [0, 1]));
509 
510     auto z = tensor!([0, 2], UseGradient.no)([1.0f, 2.0f, 3.0f, 4.0f]);
511     auto w = model(z);
512     static assert(isTensor!(typeof(w), float, [0, 1]));
513 }
514 
515 unittest
516 {
517     import golem.math : sigmoid;
518 
519     auto model = new Perceptron!(float, sigmoid, 2, 2, 1, UseGradient.no);
520 
521     auto x = tensor!([0, 2])([1.0f, 2.0f, 3.0f, 4.0f]);
522     auto y = model(x);
523     static assert(isTensor!(typeof(y), float, [0, 1]));
524 
525     auto z = tensor!([0, 2], UseGradient.no)([1.0f, 2.0f, 3.0f, 4.0f]);
526     auto w = model(z);
527     static assert(isTensor!(typeof(w), float, [0, 1], UseGradient.no));
528 }