1 module golem.optimizer;
2 
3 import golem.tensor;
4 import golem.nn;
5 
6 import numir;
7 
8 import std.meta;
9 
10 @("train XOR")
11 unittest
12 {
13     import golem.tensor;
14     import golem.math;
15     import golem.nn;
16 
17     // dataset
18     auto inputs = tensor!([0, 2])([
19             0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 0.0f, 1.0f, 1.0f,
20             ]);
21     auto labels = tensor!([0, 1])([0.0f, 1.0f, 1.0f, 0.0f,]);
22     inputs.requireGrad = false;
23     labels.requireGrad = false;
24 
25     // model
26     auto fc1 = new Linear!(float, 2, 6);
27     auto fc2 = new Linear!(float, 6, 1);
28 
29     auto optimizer = createOptimizer!SGD(fc1, fc2);
30 
31     auto forward(T)(T x)
32     {
33         auto h = sigmoid(fc1(x));
34         auto o = sigmoid(fc2(h));
35         return o;
36     }
37 
38     // loss
39     auto mse(T)(T output, T labels)
40     {
41         auto t = labels - output;
42         auto t2 = t * t;
43         auto l = sum(t2);
44         return l;
45     }
46 
47     auto lossFirst = mse(forward(inputs), labels);
48     // train
49     foreach (_; 0 .. 10)
50     {
51         auto output = forward(inputs);
52         auto loss = mse(output, labels);
53 
54         optimizer.resetGrads();
55 
56         loss.backward();
57 
58         optimizer.trainStep();
59     }
60     auto lossLast = mse(forward(inputs), labels);
61 
62     assert(lossLast.shape == [1]);
63     assert(lossLast.value[0] < lossFirst.value[0]);
64 }
65 
66 interface Optimizer
67 {
68     void resetGrads();
69 
70     void trainStep();
71 }
72 
73 struct SGDConfig
74 {
75     float learningRate = 0.01;
76     float momentumRate = 0.9;
77     float weightDecay = 0;
78 }
79 
80 class SGD(Params...) : Optimizer
81 {
82     SGDConfig config;
83     Params params;
84     staticMap!(mapValue, Params) diffs;
85 
86     this(Params params)
87     {
88         this.params = params;
89         static foreach (i; 0 .. Params.length)
90         {
91             this.diffs[i] = zeros_like(params[i].value);
92         }
93     }
94 
95     void resetGrads()
96     {
97         foreach (p; params)
98         {
99             p.resetGrads();
100         }
101     }
102 
103     void trainStep()
104     {
105         const learningRate = config.learningRate;
106         const momentumRate = config.momentumRate;
107         const weightDecay = config.weightDecay;
108 
109         if (momentumRate != 0 && weightDecay != 0)
110         {
111             foreach (i, p; params)
112             {
113                 diffs[i][] = momentumRate * diffs[i][] + p.grads[];
114                 p.value[] -= learningRate * diffs[i][] + weightDecay * p.value[];
115             }
116         }
117         else if (momentumRate != 0)
118         {
119             foreach (i, p; params)
120             {
121                 diffs[i][] = momentumRate * diffs[i][] + p.grads[];
122                 p.value[] -= learningRate * diffs[i][];
123             }
124         }
125         else if (weightDecay != 0)
126         {
127             foreach (i, p; params)
128             {
129                 p.value[] -= learningRate * p.grads[] + weightDecay * p.value[];
130             }
131         }
132         else
133         {
134             foreach (i, p; params)
135             {
136                 p.value[] -= learningRate * p.grads[];
137             }
138         }
139     }
140 }
141 
142 struct AdamConfig
143 {
144     float learningRate = 0.001;
145     float beta1 = 0.9;
146     float beta2 = 0.999;
147     float eps = 1e-8;
148     float weightDecay = 0;
149 }
150 
151 class Adam(Params...) : Optimizer
152 {
153     alias Values = staticMap!(mapValue, Params);
154 
155     AdamConfig config;
156     Params params;
157     Values ms;
158     Values vs;
159     size_t trainCount;
160 
161     this(Params params)
162     {
163         this.params = params;
164         static foreach (i; 0 .. Params.length)
165         {
166             this.ms[i] = zeros_like(params[i].value);
167             this.vs[i] = zeros_like(params[i].value);
168         }
169     }
170 
171     void resetGrads()
172     {
173         foreach (p; params)
174         {
175             p.resetGrads();
176         }
177     }
178 
179     void trainStep()
180     {
181         import core.math : sqrt;
182         import mir.ndslice : map;
183 
184         ++trainCount;
185 
186         const learningRate = config.learningRate;
187         const beta1 = config.beta1;
188         const beta1_m = 1.0f - beta1;
189         const c1 = 1.0f / (1.0f - beta1 ^^ trainCount);
190         const beta2 = config.beta2;
191         const beta2_m = 1.0f - beta2;
192         const c2 = 1.0f / (1.0f - beta2 ^^ trainCount);
193         const eps = config.eps;
194         const weightDecay = config.weightDecay;
195 
196         foreach (i, p; params)
197         {
198             this.ms[i][] = beta1 * ms[i][] + beta1_m * p.grads[];
199             this.vs[i][] = beta2 * vs[i][] + beta2_m * (p.grads[] * p.grads[]);
200         }
201 
202         if (weightDecay != 0)
203         {
204             foreach (i, p; params)
205             {
206                 auto mbar = ms[i] * c1;
207                 auto vbar = vs[i] * c2;
208 
209                 p.value[] -= learningRate * mbar[] / vbar[].map!(a => sqrt(a + eps)) + weightDecay * p.value[];
210             }
211         }
212         else
213         {
214             foreach (i, p; params)
215             {
216                 auto mbar = ms[i] * c1;
217                 auto vbar = vs[i] * c2;
218 
219                 p.value[] -= learningRate * mbar[] / vbar[].map!(a => sqrt(a + eps));
220             }
221         }
222     }
223 }
224 
225 
226 class AdaBelief(Params...) : Optimizer
227 {
228     alias Values = staticMap!(mapValue, Params);
229 
230     AdamConfig config;
231     Params params;
232     Values ms;
233     Values vs;
234     size_t trainCount;
235 
236     this(Params params)
237     {
238         this.params = params;
239         static foreach (i; 0 .. Params.length)
240         {
241             this.ms[i] = zeros_like(params[i].value);
242             this.vs[i] = zeros_like(params[i].value);
243         }
244     }
245 
246     void resetGrads()
247     {
248         foreach (p; params)
249         {
250             p.resetGrads();
251         }
252     }
253 
254     void trainStep()
255     {
256         import core.math : sqrt;
257         import mir.ndslice : map;
258 
259         ++trainCount;
260 
261         const learningRate = config.learningRate;
262         const beta1 = config.beta1;
263         const beta1_m = 1.0f - beta1;
264         const c1 = 1.0f / (1.0f - beta1 ^^ trainCount);
265         const beta2 = config.beta2;
266         const beta2_m = 1.0f - beta2;
267         const c2 = 1.0f / (1.0f - beta2 ^^ trainCount);
268         const eps = config.eps;
269         const weightDecay = config.weightDecay;
270 
271         foreach (i, p; params)
272         {
273             this.ms[i][] = beta1 * ms[i][] + beta1_m * p.grads[];
274             this.vs[i][] = beta2 * vs[i][] + beta2_m * (p.grads[] - ms[i][]) ^^ 2;
275         }
276 
277         if (weightDecay != 0)
278         {
279             foreach (i, p; params)
280             {
281                 auto mbar = ms[i] * c1;
282                 auto vbar = vs[i] * c2;
283 
284                 p.value[] -= learningRate * mbar[] / vbar[].map!(a => sqrt(a + eps)) + weightDecay * p.value[];
285             }
286         }
287         else
288         {
289             foreach (i, p; params)
290             {
291                 auto mbar = ms[i] * c1;
292                 auto vbar = vs[i] * c2;
293 
294                 p.value[] -= learningRate * mbar[] / vbar[].map!(a => sqrt(a + eps));
295             }
296         }
297     }
298 }
299 
300 auto createOptimizer(alias Optimizer, Params...)(Params params) if (Params.length > 0)
301 {
302     import golem.util : staticIndexOf;
303 
304     enum firstPos = staticIndexOf!(hasParameters, Params);
305 
306     static if (firstPos != -1)
307     {
308         // dfmt off
309         return createOptimizer!Optimizer(
310             params[0 .. firstPos],
311             params[firstPos].parameters,
312             params[firstPos + 1 .. $]
313         );
314         // dfmt on
315     }
316     else
317     {
318         static if (allSatisfy!(isTensor, Params))
319         {
320             static if (allSatisfy!(canBackward, Params))
321             {
322                 alias OptimizerImpl = Optimizer!(Params);
323                 return new OptimizerImpl(params);
324             }
325             else
326             {
327                 enum trainablePos = staticIndexOf!(canNotBackward, Params);
328 
329                 // dfmt off
330                 return createOptimizer!Optimizer(
331                     params[0 .. trainablePos],
332                     params[trainablePos + 1 .. $]
333                 );
334                 // dfmt on
335             }
336         }
337         else
338         {
339             static assert(false);
340         }
341     }
342 }
343 
344 unittest
345 {
346     class Model
347     {
348         Tensor!(float, [2, 2]) weight;
349 
350         alias parameters = AliasSeq!(weight);
351 
352         this()
353         {
354             weight = tensor!([2, 2])([1.0f, 2.0f, 3.0f, 4.0f]);
355         }
356     }
357 
358     auto model = new Model;
359     auto optimizer = createOptimizer!SGD(model);
360     assert(optimizer !is null);
361 
362     model.weight.grads[] = 1.0f;
363     assert(model.weight.grads == [[1.0f, 1.0f], [1.0f, 1.0f]]);
364     optimizer.resetGrads();
365     assert(model.weight.grads == [[0.0f, 0.0f], [0.0f, 0.0f]]);
366 }
367 
368 unittest
369 {
370     class Model
371     {
372         Tensor!(float, [2, 2]) weight;
373 
374         alias parameters = AliasSeq!(weight);
375 
376         this()
377         {
378             weight = tensor!([2, 2])([1.0f, 2.0f, 3.0f, 4.0f]);
379         }
380     }
381 
382     auto model = new Model;
383     auto optimizer = createOptimizer!Adam(model);
384     assert(optimizer !is null);
385 
386     model.weight.grads[] = 1.0f;
387     assert(model.weight.grads == [[1.0f, 1.0f], [1.0f, 1.0f]]);
388     optimizer.resetGrads();
389     assert(model.weight.grads == [[0.0f, 0.0f], [0.0f, 0.0f]]);
390 }
391 
392 unittest
393 {
394     import golem.nn : Linear;
395 
396     auto fc1 = new Linear!(float, 4, 4);
397     auto fc2 = new Linear!(float, 4, 2);
398 
399     auto optimizer = createOptimizer!SGD(fc1, fc2);
400     assert(optimizer !is null);
401 }
402 
403 unittest
404 {
405     import golem.nn : Linear;
406 
407     auto fc1 = new Linear!(float, 2, 2);
408     auto fc2 = new Linear!(float, 2, 1);
409 
410     auto optimizer = createOptimizer!Adam(fc1, fc2);
411     assert(optimizer !is null);
412 }
413 
414 unittest
415 {
416     import golem.nn : Linear;
417 
418     auto fc1 = new Linear!(float, 2, 2, UseGradient.no);
419     auto fc2 = new Linear!(float, 2, 1);
420 
421     auto optimizer = createOptimizer!Adam(fc1, fc2);
422     assert(optimizer !is null);
423 }
424 
425 unittest
426 {
427     import golem.nn : Linear;
428 
429     auto fc1 = new Linear!(float, 2, 2);
430     auto fc2 = new Linear!(float, 2, 1, UseGradient.no);
431 
432     auto optimizer = createOptimizer!Adam(fc1, fc2);
433     assert(optimizer !is null);
434 }
435 
436 unittest
437 {
438     import golem.nn : Linear, BatchNorm;
439 
440     class Model
441     {
442         Linear!(float, 2, 2) fc1;
443         BatchNorm!(float, [2]) bn1;
444 
445         alias parameters = AliasSeq!(fc1, bn1);
446 
447         this()
448         {
449             foreach (ref p; parameters)
450                 p = new typeof(p);
451         }
452     }
453 
454     auto model = new Model;
455     auto optimizer = createOptimizer!SGD(model);
456     assert(optimizer !is null);
457 }
458 
459 unittest
460 {
461     enum OptimizerKind
462     {
463         SGD,
464         Adam,
465         AdaBelief,
466     }
467 
468     auto fc = new Linear!(float, 2, 1)(0);
469 
470     Optimizer optimizer;
471     OptimizerKind kind = OptimizerKind.Adam;
472 
473     final switch (kind)
474     {
475     case OptimizerKind.SGD:
476         optimizer = createOptimizer!SGD(fc);
477         break;
478     case OptimizerKind.Adam:
479         optimizer = createOptimizer!Adam(fc);
480         break;
481     case OptimizerKind.AdaBelief:
482         optimizer = createOptimizer!AdaBelief(fc);
483         break;
484     }
485 }
486 
487 private alias mapValue(T) = T.Value;
488 
489 private template canNotBackward(T)
490 {
491     enum canNotBackward = !canBackward!(T);
492 }