1 module golem.math;
2 
3 import golem.tensor;
4 import golem.util;
5 
6 import mir.ndslice;
7 
8 import std.typecons : No, tuple;
9 
10 version (all) // exp
11 {
12     Tensor!(T, Shape, useGradient) exp(T, size_t[] Shape, UseGradient useGradient)(Tensor!(T, Shape, useGradient) x)
13     {
14         import std.math : stdexp = exp;
15 
16         auto y = slice(x.value.map!(a => stdexp(a)));
17 
18         static if (canBackward!(typeof(x)))
19         {
20             x.usedCount++;
21 
22             return new Tensor!(T, Shape)(y, (Slice!(T*, Shape.length) grads) {
23                 x.backward(y * grads);
24             });
25         }
26         else
27         {
28             return new Tensor!(T, Shape, No.gradient)(y);
29         }
30     }
31 
32     unittest
33     {
34         auto x = tensor!([2])([-1.0f, 1.0f]);
35         auto y = exp(x);
36 
37         import std.math : stdexp = exp, isClose;
38 
39         assert(y.value[0].isClose(stdexp(-1.0f)));
40         assert(y.value[1].isClose(stdexp(1.0f)));
41 
42         y.resetGrads();
43         y.backward();
44 
45         import std : format;
46 
47         assert(x.grads[0].isClose(y.value[0]), "%s : %s".format(x.grads[0], y.value[0]));
48         assert(x.grads[1].isClose(y.value[1]), "%s : %s".format(x.grads[1], y.value[1]));
49     }
50 
51     unittest
52     {
53         auto x = tensor!([2, 2, 2])([
54                 0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f
55                 ]);
56 
57         auto y = flatten(x);
58 
59         assert(y.shape == [2, 4]);
60 
61         auto z = exp(y);
62         z.backward();
63 
64         int err;
65         assert(x.grads == z.value.reshape([2, 2, 2], err));
66     }
67 
68     unittest
69     {
70         auto x = tensor!([2, 2], No.gradient)([1.0, 2.0, 3.0, 4.0]);
71         auto y = exp(x);
72         
73         import std.math : stdexp = exp, isClose;
74 
75         assert(y.value[0, 0].isClose(stdexp(1.0f)));
76         assert(y.value[0, 1].isClose(stdexp(2.0f)));
77         assert(y.value[1, 0].isClose(stdexp(3.0f)));
78         assert(y.value[1, 1].isClose(stdexp(4.0f)));
79     }
80 }
81 
82 version (all) // log
83 {
84     Tensor!(T, Shape, useGradient) log(T, size_t[] Shape, UseGradient useGradient)(Tensor!(T, Shape, useGradient) x)
85     {
86         import std.math : stdlog = log;
87         import mir.ndslice : slice, map;
88 
89         auto y = slice(x.value.map!(a => T(stdlog(a))));
90 
91         static if (useGradient)
92         {
93             x.usedCount++;
94 
95             alias Return = typeof(return);
96             alias Value = Return.Value;
97             return new Return(y, (Value grads) {
98                 x.backward((ref xGrads) { xGrads[] += grads[] / x.value[]; });
99             });
100         }
101         else
102         {
103             return new typeof(return)(y);
104         }
105     }
106 
107     unittest
108     {
109         auto x = tensor!([0, 2])([
110             [1.0, 2.0],
111             [3.0, 4.0],
112         ]);
113         auto y = log(x);
114 
115         import std.math : stdlog = log, isClose;
116 
117         assert(y.value[0, 0].isClose(stdlog(1.0)));
118         assert(y.value[0, 1].isClose(stdlog(2.0)));
119         assert(y.value[1, 0].isClose(stdlog(3.0)));
120         assert(y.value[1, 1].isClose(stdlog(4.0)));
121 
122         y.backward();
123 
124         assert(x.grads[0, 0].isClose(1.0 / 1.0));
125         assert(x.grads[0, 1].isClose(1.0 / 2.0));
126         assert(x.grads[1, 0].isClose(1.0 / 3.0));
127         assert(x.grads[1, 1].isClose(1.0 / 4.0));
128     }
129 }
130 
131 version (all) // sigmoid
132 {
133     Tensor!(T, Shape, useGradient) sigmoid(T, size_t[] Shape, UseGradient useGradient)(Tensor!(T, Shape, useGradient) x)
134     {
135         import std.math : exp;
136 
137         auto y = x.value.map!(a => T(1) / (T(1) + exp(-a))).slice();
138 
139         alias Return = typeof(return);
140         alias Value = Return.Value;
141 
142         static if (canBackward!(typeof(x)))
143         {
144             x.usedCount++;
145 
146             return new Tensor!(T, Shape)(y, (Value grads) {
147                 x.backward(y * (T(1) - y) * grads);
148             });
149         }
150         else
151         {
152             return new Tensor!(T, Shape, No.gradient)(y);
153         }
154     }
155 
156     unittest
157     {
158         auto x = tensor!([3, 1])([-1.0f, 0.0f, 1.0f]);
159         auto y = sigmoid(x);
160 
161         import std.format : format;
162         import std.math : exp, isClose;
163 
164         assert(y.value[0, 0].isClose(1.0f / (1.0f + exp(+1.0f))), "%s".format(y.value));
165         assert(y.value[1, 0].isClose(1.0f / (1.0f + exp(0.0f))), "%s".format(y.value));
166         assert(y.value[2, 0].isClose(1.0f / (1.0f + exp(-1.0f))), "%s".format(y.value));
167 
168         y.backward();
169 
170         assert(x.grads[0, 0].isClose(y.value[0, 0] * (1.0 - y.value[0, 0])),
171                 "%s".format(x.grads));
172         assert(x.grads[1, 0].isClose(y.value[1, 0] * (1.0 - y.value[1, 0])),
173                 "%s".format(x.grads));
174         assert(x.grads[2, 0].isClose(y.value[2, 0] * (1.0 - y.value[2, 0])),
175                 "%s".format(x.grads));
176     }
177 
178     unittest
179     {
180         auto x = tensor!([3, 1], No.gradient)([-1.0f, 0.0f, 1.0f]);
181         auto y = sigmoid(x);
182         
183         import std.format : format;
184         import std.math : exp, isClose;
185 
186         assert(y.value[0, 0].isClose(1.0f / (1.0f + exp(+1.0f))), "%s".format(y.value));
187         assert(y.value[1, 0].isClose(1.0f / (1.0f + exp(0.0f))), "%s".format(y.value));
188         assert(y.value[2, 0].isClose(1.0f / (1.0f + exp(-1.0f))), "%s".format(y.value));
189     }
190 }
191 
192 version (all) // tanh
193 {
194     Tensor!(T, Shape, useGradient) tanh(T, size_t[] Shape, UseGradient useGradient)(Tensor!(T, Shape, useGradient) x)
195     {
196         import std.math : stdtanh = tanh;
197 
198         auto y = slice(x.value.map!(a => stdtanh(a)));
199 
200         static if (canBackward!(typeof(x)))
201         {
202             x.usedCount++;
203 
204             return new typeof(return)(y, (Slice!(T*, Shape.length) grads) {
205                 x.backward((1 - y * y) * grads);
206             });
207         }
208         else
209         {
210             return new typeof(return)(y);
211         }
212     }
213 
214     unittest
215     {
216         auto x = tensor!([2])([-1.0f, 1.0f]);
217         auto y = tanh(x);
218 
219         import std.math : stdtanh = tanh, isClose;
220 
221         assert(y.value[0].isClose(stdtanh(-1.0f)));
222         assert(y.value[1].isClose(stdtanh(1.0f)));
223 
224         y.resetGrads();
225         y.backward();
226 
227         import std : format;
228 
229         assert(x.grads[0].isClose(1 - y.value[0] ^^ 2),
230                 "%s : %s".format(x.grads[0], y.value[0]));
231         assert(x.grads[1].isClose(1 - y.value[1] ^^ 2),
232                 "%s : %s".format(x.grads[1], y.value[1]));
233     }
234     
235     unittest
236     {
237         auto x = tensor!([2], No.gradient)([-1.0f, 1.0f]);
238         auto y = tanh(x);
239 
240         import std.math : stdtanh = tanh, isClose;
241 
242         assert(y.value[0].isClose(stdtanh(-1.0f)));
243         assert(y.value[1].isClose(stdtanh(1.0f)));
244     }
245 }
246 
247 version (all) // sinh
248 {
249     Tensor!(T, Shape, useGradient) sinh(T, size_t[] Shape, UseGradient useGradient)(Tensor!(T, Shape, useGradient) x)
250     {
251         import std.math : stdsinh = sinh, stdcosh = cosh;
252 
253         auto y = slice(x.value.map!(a => stdsinh(a)));
254 
255         static if (canBackward!(typeof(x)))
256         {
257             x.usedCount++;
258 
259             return new typeof(return)(y, (Slice!(T*, Shape.length) grads) {
260                 x.backward(x.value.map!stdcosh * grads);
261             });
262         }
263         else
264         {
265             return new typeof(return)(y);
266         }
267     }
268 
269     unittest
270     {
271         auto x = tensor!([2])([-1.0f, 1.0f]);
272         auto y = sinh(x);
273 
274         import std.math : stdsinh = sinh, stdcosh = cosh, isClose;
275 
276         assert(y.value[0].isClose(stdsinh(-1.0f)));
277         assert(y.value[1].isClose(stdsinh(1.0f)));
278 
279         y.resetGrads();
280         y.backward();
281 
282         import std : format;
283 
284         assert(x.grads[0].isClose(stdcosh(-1.0f)),
285                 "%s : %s".format(x.grads[0], y.value[0]));
286         assert(x.grads[1].isClose(stdcosh(1.0f)),
287                 "%s : %s".format(x.grads[1], y.value[1]));
288     }
289     
290     unittest
291     {
292         auto x = tensor!([2], No.gradient)([-1.0f, 1.0f]);
293         auto y = sinh(x);
294 
295         import std.math : stdsinh = sinh, isClose;
296 
297         assert(y.value[0].isClose(stdsinh(-1.0f)));
298         assert(y.value[1].isClose(stdsinh(1.0f)));
299     }
300 }
301 
302 version (all) // asinh
303 {
304     Tensor!(T, Shape, useGradient) asinh(T, size_t[] Shape, UseGradient useGradient)(Tensor!(T, Shape, useGradient) x)
305     {
306         import std.math : stdasinh = asinh, stdsqrt = sqrt;
307 
308         auto y = slice(x.value.map!(a => stdasinh(a)));
309 
310         static if (canBackward!(typeof(x)))
311         {
312             x.usedCount++;
313 
314             return new typeof(return)(y, (Slice!(T*, Shape.length) grads) {
315                 x.backward(x.value.map!(a => T(1) / stdsqrt(a * a + T(1))) * grads);
316             });
317         }
318         else
319         {
320             return new typeof(return)(y);
321         }
322     }
323 
324     unittest
325     {
326         auto x = tensor!([2])([-1.0f, 1.0f]);
327         auto y = asinh(x);
328 
329         import std.math : stdasinh = asinh, stdsqrt = sqrt, isClose;
330 
331         assert(y.value[0].isClose(stdasinh(-1.0f)));
332         assert(y.value[1].isClose(stdasinh(1.0f)));
333 
334         y.resetGrads();
335         y.backward();
336 
337         import std : format;
338 
339         assert(x.grads[0].isClose(1 / stdsqrt(-1.0f * -1.0f + 1)),
340                 "%s : %s".format(x.grads[0], y.value[0]));
341         assert(x.grads[1].isClose(1 / stdsqrt(1.0f * 1.0f + 1)),
342                 "%s : %s".format(x.grads[1], y.value[1]));
343     }
344     
345     unittest
346     {
347         auto x = tensor!([2], No.gradient)([-1.0f, 1.0f]);
348         auto y = asinh(x);
349 
350         import std.math : stdasinh = asinh, isClose;
351 
352         assert(y.value[0].isClose(stdasinh(-1.0f)));
353         assert(y.value[1].isClose(stdasinh(1.0f)));
354     }
355 }
356 
357 version (all) // cosh
358 {
359     Tensor!(T, Shape, useGradient) cosh(T, size_t[] Shape, UseGradient useGradient)(Tensor!(T, Shape, useGradient) x)
360     {
361         import std.math : stdcosh = cosh, stdsinh = sinh;
362 
363         auto y = slice(x.value.map!(a => stdcosh(a)));
364 
365         static if (canBackward!(typeof(x)))
366         {
367             x.usedCount++;
368 
369             return new typeof(return)(y, (Slice!(T*, Shape.length) grads) {
370                 x.backward(x.value.map!stdsinh * grads);
371             });
372         }
373         else
374         {
375             return new typeof(return)(y);
376         }
377     }
378 
379     unittest
380     {
381         auto x = tensor!([2])([-1.0f, 1.0f]);
382         auto y = cosh(x);
383 
384         import std.math : stdcosh = cosh, stdsinh = sinh, isClose;
385 
386         assert(y.value[0].isClose(stdcosh(-1.0f)));
387         assert(y.value[1].isClose(stdcosh(1.0f)));
388 
389         y.resetGrads();
390         y.backward();
391 
392         import std : format;
393 
394         assert(x.grads[0].isClose(stdsinh(-1.0f)),
395                 "%s : %s".format(x.grads[0], y.value[0]));
396         assert(x.grads[1].isClose(stdsinh(1.0f)),
397                 "%s : %s".format(x.grads[1], y.value[1]));
398     }
399     
400     unittest
401     {
402         auto x = tensor!([2], No.gradient)([-1.0f, 1.0f]);
403         auto y = cosh(x);
404 
405         import std.math : stdcosh = cosh, isClose;
406 
407         assert(y.value[0].isClose(stdcosh(-1.0f)));
408         assert(y.value[1].isClose(stdcosh(1.0f)));
409     }
410 }
411 
412 version (all) // acosh
413 {
414     Tensor!(T, Shape, useGradient) acosh(T, size_t[] Shape, UseGradient useGradient)(Tensor!(T, Shape, useGradient) x)
415     {
416         import std.math : stdacosh = acosh, stdsqrt = sqrt;
417 
418         auto y = slice(x.value.map!(a => stdacosh(a)));
419 
420         static if (canBackward!(typeof(x)))
421         {
422             x.usedCount++;
423 
424             return new typeof(return)(y, (Slice!(T*, Shape.length) grads) {
425                 x.backward(x.value.map!(a => 1 / (stdsqrt(a - 1) * stdsqrt(a + 1))) * grads);
426             });
427         }
428         else
429         {
430             return new typeof(return)(y);
431         }
432     }
433 
434     unittest
435     {
436         auto x = tensor!([2])([2.0f, 3.0f]);
437         auto y = acosh(x);
438 
439         import std.math : stdacosh = acosh, stdsqrt = sqrt, isClose;
440 
441         assert(y.value[0].isClose(stdacosh(2.0f)));
442         assert(y.value[1].isClose(stdacosh(3.0f)));
443 
444         y.resetGrads();
445         y.backward();
446 
447         import std : format;
448 
449         assert(x.grads[0].isClose(1 / (stdsqrt(2.0f - 1) * stdsqrt(2.0f + 1))),
450                 "%s : %s".format(x.grads[0], y.value[0]));
451         assert(x.grads[1].isClose(1 / (stdsqrt(3.0f - 1) * stdsqrt(3.0f + 1))),
452                 "%s : %s".format(x.grads[1], y.value[1]));
453     }
454     
455     unittest
456     {
457         auto x = tensor!([2], No.gradient)([2.0f, 3.0f]);
458         auto y = acosh(x);
459 
460         import std.math : stdacosh = acosh, isClose;
461 
462         assert(y.value[0].isClose(stdacosh(2.0f)));
463         assert(y.value[1].isClose(stdacosh(3.0f)));
464     }
465 }
466 
467 version (all) // softplus
468 {
469     Tensor!(T, Shape, useGradient) softplus(T, size_t[] Shape, UseGradient useGradient)(Tensor!(T, Shape, useGradient) x)
470     {
471         import std.math : stdlog = log, stdexp = exp;
472 
473         auto y = slice(x.value.map!(a => T(stdlog(1 + stdexp(a)))));
474 
475         static if (canBackward!(typeof(x)))
476         {
477             x.usedCount++;
478 
479             return new typeof(return)(y, (Slice!(T*, Shape.length) grads) {
480                 x.backward(x.value.map!(a => (stdexp(a) / (stdexp(a) + 1))) * grads);
481             });
482         }
483         else
484         {
485             return new typeof(return)(y);
486         }
487     }
488 
489     unittest
490     {
491         auto x = tensor!([2])([-1.0f, 1.0f]);
492         auto y = softplus(x);
493 
494         import std.math : stdlog = log, stdexp = exp, isClose;
495 
496         assert(y.value[0].isClose(stdlog(1 + stdexp(-1.0f))));
497         assert(y.value[1].isClose(stdlog(1 + stdexp(1.0f))));
498 
499         y.resetGrads();
500         y.backward();
501 
502         import std : format;
503 
504         assert(x.grads[0].isClose(stdexp(-1.0f) / (stdexp(-1.0f) + 1)),
505                 "%s : %s".format(x.grads[0], y.value[0]));
506         assert(x.grads[1].isClose(stdexp(1.0f) / (stdexp(1.0f) + 1)),
507                 "%s : %s".format(x.grads[1], y.value[1]));
508     }
509     
510     unittest
511     {
512         auto x = tensor!([2], No.gradient)([-1.0f, 1.0f]);
513         auto y = softplus(x);
514 
515         import std.math : stdlog = log, stdexp = exp, isClose;
516 
517         assert(y.value[0].isClose(stdlog(1 + stdexp(-1.0f))));
518         assert(y.value[1].isClose(stdlog(1 + stdexp(1.0f))));
519     }
520 }
521 
522 version (all) // relu
523 {
524     Tensor!(T, Shape, useGradient) relu(T, size_t[] Shape, UseGradient useGradient)(Tensor!(T, Shape, useGradient) x)
525     {
526         import std.algorithm : max;
527 
528         auto y = slice(x.value.map!(a => max(T(0), a)));
529 
530         alias Return = typeof(return);
531         alias Value = Return.Value;
532 
533         static if (canBackward!(typeof(x)))
534         {
535             x.usedCount++;
536 
537             return new typeof(return)(y, (Value grad) {
538                 x.backward(grad * x.value.map!(a => T(a > 0 ? 1 : 0)));
539             });
540         }
541         else
542         {
543             return new typeof(return)(y);
544         }
545     }
546 
547     unittest
548     {
549         auto x = tensor!([2, 3])([-1.0, 0.0, 1.0, -2.0, 0.0, 2.0]);
550 
551         auto y = relu(x);
552 
553         y.backward();
554 
555         assert(y.value[0, 0] == 0);
556         assert(y.value[0, 1] == 0);
557         assert(y.value[0, 2] == 1.0);
558         assert(y.value[1, 0] == 0);
559         assert(y.value[1, 1] == 0);
560         assert(y.value[1, 2] == 2.0);
561 
562         assert(x.grads[0, 0] == 0);
563         assert(x.grads[0, 1] == 0);
564         assert(x.grads[0, 2] == 1.0);
565         assert(x.grads[1, 0] == 0);
566         assert(x.grads[1, 1] == 0);
567         assert(x.grads[1, 2] == 1.0);
568     }
569     
570     unittest
571     {
572         auto x = tensor!([2, 3], No.gradient)([-1.0, 0.0, 1.0, -2.0, 0.0, 2.0]);
573         auto y = relu(x);
574 
575         assert(y.value[0, 0] == 0);
576         assert(y.value[0, 1] == 0);
577         assert(y.value[0, 2] == 1.0);
578         assert(y.value[1, 0] == 0);
579         assert(y.value[1, 1] == 0);
580         assert(y.value[1, 2] == 2.0);
581     }
582 }
583 
584 version (all) // leakyRelu
585 {
586     Tensor!(T, Shape, useGradient) leakyRelu(T, size_t[] Shape, UseGradient useGradient)(Tensor!(T, Shape, useGradient) x, T a = 0.01)
587     {
588         import std.algorithm : max;
589 
590         auto y = slice(x.value.map!(t => t > 0 ? t : a * t));
591 
592         alias Return = typeof(return);
593         alias Value = Return.Value;
594 
595         static if (canBackward!(typeof(x)))
596         {
597             x.usedCount++;
598 
599             return new typeof(return)(y, (Value grad) {
600                 x.backward(grad * x.value.map!(t => T(t > 0 ? 1 : a)));
601             });
602         }
603         else
604         {
605             return new typeof(return)(y);
606         }
607     }
608 
609     unittest
610     {
611         auto x = tensor!([2, 3])([-1.0, 0.0, 1.0, -2.0, 0.0, 2.0]);
612 
613         auto y = leakyRelu(x, 0.02);
614 
615         y.backward();
616 
617         assert(y.value[0, 0] == -0.02);
618         assert(y.value[0, 1] == 0);
619         assert(y.value[0, 2] == 1.0);
620         assert(y.value[1, 0] == -0.04);
621         assert(y.value[1, 1] == 0);
622         assert(y.value[1, 2] == 2.0);
623 
624         assert(x.grads[0, 0] == 0.02);
625         assert(x.grads[0, 1] == 0.02);
626         assert(x.grads[0, 2] == 1.0);
627         assert(x.grads[1, 0] == 0.02);
628         assert(x.grads[1, 1] == 0.02);
629         assert(x.grads[1, 2] == 1.0);
630     }
631     
632     unittest
633     {
634         auto x = tensor!([2, 3], No.gradient)([-1.0, 0.0, 1.0, -2.0, 0.0, 2.0]);
635         auto y = leakyRelu(x, 0.2);
636 
637         assert(y.value[0, 0] == -0.2);
638         assert(y.value[0, 1] == 0);
639         assert(y.value[0, 2] == 1.0);
640         assert(y.value[1, 0] == -0.4);
641         assert(y.value[1, 1] == 0);
642         assert(y.value[1, 2] == 2.0);
643     }
644 }
645 
646 version (all) // linear
647 {
648     // dfmt off
649     Tensor!(T, [ShapeX[0], ShapeW[1]], useGradX | useGradW | useGradB) linear(
650         T,
651         size_t[2] ShapeX, UseGradient useGradX,
652         size_t[2] ShapeW, UseGradient useGradW,
653         size_t[1] ShapeB, UseGradient useGradB
654     )(
655         Tensor!(T, ShapeX, useGradX) x,
656         Tensor!(T, ShapeW, useGradW) W,
657         Tensor!(T, ShapeB, useGradB) B
658     )
659     // dfmt on
660     {
661         static assert(ShapeX[1] == ShapeW[0]);
662         static assert(ShapeW[1] == ShapeB[0]);
663 
664         enum OutputDim = ShapeW[1];
665 
666         const batchSize = x.value.shape[0];
667         auto result = uninitSlice!T([batchSize, OutputDim]);
668         foreach (i; 0 .. batchSize)
669         {
670             result[i, 0 .. $] = B.value[];
671         }
672 
673         import mir.blas : gemm;
674 
675         gemm(T(1), x.value, W.value, T(1), result);
676 
677         alias Return = typeof(return);
678         alias Value = Return.Value;
679 
680         static if (useGradW | useGradX | useGradB)
681         {
682             static if (canBackward!(typeof(W))) W.usedCount++;
683             static if (canBackward!(typeof(x))) x.usedCount++;
684             static if (canBackward!(typeof(B))) B.usedCount++;
685 
686             return new Return(result, (Value grad) {
687                 static if (canBackward!(typeof(W))) 
688                 {
689                     W.backward((ref wGrads) {
690                         gemm(T(1), x.value.transposed, grad, T(1), wGrads);
691                     });
692                 }
693                 static if (canBackward!(typeof(x))) 
694                 {
695                     x.backward((ref xGrads) {
696                         gemm(T(1), grad, W.value.transposed, T(1), xGrads);
697                     });
698                 }
699                 static if (canBackward!(typeof(B))) 
700                 {
701                     B.backward((ref bGrads) {
702                         foreach (i; 0 .. batchSize)
703                         {
704                             bGrads[] += grad[i, 0 .. $];
705                         }
706                     });
707                 }
708             });
709         }
710         else
711         {
712             return new Return(result);
713         }
714     }
715 
716     unittest
717     {
718         // static
719         Tensor!(float, [2, 2]) w = tensor!([2, 2])([1.0f, 2.0f, 3.0f, 4.0f]);
720         Tensor!(float, [2]) b = tensor!([2])([100.0f, 200.0f]);
721         // dynamic batchSize
722         Tensor!(float, [0, 2]) x = tensor!([0, 2])([1.0f, 2.0f]);
723 
724         // result
725         Tensor!(float, [0, 2]) z = linear(x, w, b);
726 
727         assert(z.value[0, 0] == 1 * 1 + 2 * 3 + 100);
728         assert(z.value[0, 1] == 1 * 2 + 2 * 4 + 200);
729 
730         z.backward();
731     }
732 
733     unittest
734     {
735         // miss input dim
736         Tensor!(float, [15, 3]) weights;
737         Tensor!(float, [3]) bias;
738 
739         Tensor!(float, [0, 4, 4]) x;
740 
741         // compile error
742         static assert(!__traits(compiles, {
743                 auto y = x.flatten().linear(weights, bias);
744             }));
745     }
746 
747     unittest
748     {
749         Tensor!(float, [16, 3]) weights;
750         Tensor!(float, [3]) bias;
751 
752         Tensor!(float, [0, 4, 4]) x;
753 
754         static assert(__traits(compiles, {
755                 auto y = x.flatten().linear(weights, bias);
756             }));
757     }
758 
759     unittest
760     {
761         auto w = tensor!([2, 3])([1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f]);
762         auto x = tensor!([0, 2])([1.0f, 2.0f, 3.0f, 4.0f]);
763         auto b = tensor!([3])([100.0f, 200.0f, 300.0f]);
764 
765         auto z = linear(x, w, b);
766 
767         assert(z.value[0] == [1 * 1 + 2 * 4 + 100, 1 * 2 + 2 * 5 + 200, 1 * 3 + 2 * 6 + 300]);
768         assert(z.value[1] == [3 * 1 + 4 * 4 + 100, 3 * 2 + 4 * 5 + 200, 3 * 3 + 4 * 6 + 300]);
769     }
770     
771     unittest
772     {
773         auto w = tensor!([2, 3], No.gradient)([1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f]);
774         auto x = tensor!([0, 2])([1.0f, 2.0f, 3.0f, 4.0f]);
775         auto b = tensor!([3])([100.0f, 200.0f, 300.0f]);
776 
777         auto z = linear(x, w, b);
778 
779         assert(z.value[0] == [1 * 1 + 2 * 4 + 100, 1 * 2 + 2 * 5 + 200, 1 * 3 + 2 * 6 + 300]);
780         assert(z.value[1] == [3 * 1 + 4 * 4 + 100, 3 * 2 + 4 * 5 + 200, 3 * 3 + 4 * 6 + 300]);
781     }
782     
783     unittest
784     {
785         auto w = tensor!([2, 3])([1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f]);
786         auto x = tensor!([0, 2], No.gradient)([1.0f, 2.0f, 3.0f, 4.0f]);
787         auto b = tensor!([3])([100.0f, 200.0f, 300.0f]);
788 
789         auto z = linear(x, w, b);
790 
791         assert(z.value[0] == [1 * 1 + 2 * 4 + 100, 1 * 2 + 2 * 5 + 200, 1 * 3 + 2 * 6 + 300]);
792         assert(z.value[1] == [3 * 1 + 4 * 4 + 100, 3 * 2 + 4 * 5 + 200, 3 * 3 + 4 * 6 + 300]);
793     }
794     
795     unittest
796     {
797         auto w = tensor!([2, 3])([1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f]);
798         auto x = tensor!([0, 2])([1.0f, 2.0f, 3.0f, 4.0f]);
799         auto b = tensor!([3], No.gradient)([100.0f, 200.0f, 300.0f]);
800 
801         auto z = linear(x, w, b);
802 
803         assert(z.value[0] == [1 * 1 + 2 * 4 + 100, 1 * 2 + 2 * 5 + 200, 1 * 3 + 2 * 6 + 300]);
804         assert(z.value[1] == [3 * 1 + 4 * 4 + 100, 3 * 2 + 4 * 5 + 200, 3 * 3 + 4 * 6 + 300]);
805     }
806 
807     unittest
808     {
809         auto w = tensor!([2, 3], No.gradient)([1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f]);
810         auto x = tensor!([0, 2], No.gradient)([1.0f, 2.0f, 3.0f, 4.0f]);
811         auto b = tensor!([3], No.gradient)([100.0f, 200.0f, 300.0f]);
812         
813         auto z = linear(x, w, b);
814         static assert(!canBackward!(typeof(z)));
815 
816         assert(z.value[0] == [1 * 1 + 2 * 4 + 100, 1 * 2 + 2 * 5 + 200, 1 * 3 + 2 * 6 + 300]);
817         assert(z.value[1] == [3 * 1 + 4 * 4 + 100, 3 * 2 + 4 * 5 + 200, 3 * 3 + 4 * 6 + 300]);
818     }
819 }
820 
821 version (all) // sum
822 {
823     Tensor!(T, [1], useGradient) sum(alias mode = "fast", T, size_t[] Shape, UseGradient useGradient)(Tensor!(T, Shape, useGradient) x)
824             if (Shape.length == 2 && Shape[1] == 1)
825     {
826         import mir.math.sum : mirsum = sum;
827 
828         auto y = slice!T([1], mirsum!mode(x.value));
829 
830         alias Return = typeof(return);
831         alias Value = Return.Value;
832 
833         static if (canBackward!(typeof(x)))
834         {
835             x.usedCount++;
836 
837             return new Return(y, (Value grad) {
838                 x.backward((ref xGrads) { xGrads[] += grad; });
839             });
840         }
841         else
842         {
843             return new Return(y);
844         }
845     }
846 
847     unittest
848     {
849         auto x = tensor!([4, 1])([1.0f, 2.0f, 3.0f, 4.0f]);
850         auto s = sum(x);
851         assert(s.value[0] == 10.0f);
852 
853         assert(x.grads == [[0.0f], [0.0f], [0.0f], [0.0f]]);
854         s.backward();
855         assert(x.grads == [[1.0f], [1.0f], [1.0f], [1.0f]]);
856     }
857     
858     unittest
859     {
860         auto x = tensor!([4, 1], No.gradient)([1.0f, 2.0f, 3.0f, 4.0f]);
861         auto s = sum(x);
862         assert(s.value[0] == 10.0f);
863     }
864 
865     Tensor!(T, [Shape[0], 1], useGradient) sum(alias mode = "fast", T, size_t[] Shape, UseGradient useGradient)(Tensor!(T, Shape, useGradient) x)
866             if ((Shape.length == 2 && Shape[1] != 1) || (Shape.length > 2))
867     {
868         import mir.math.sum : mirsum = sum;
869 
870         const batchSize = x.value.shape[0];
871         auto y = uninitSlice!T([batchSize, 1]);
872         foreach (i; 0 .. batchSize)
873         {
874             y[i, 0] = mirsum!mode(x.value[i]);
875         }
876 
877         alias Return = typeof(return);
878         alias Value = Return.Value;
879 
880         static if (canBackward!(typeof(x)))
881         {
882             x.usedCount++;
883 
884             return new Return(y, (Value grad) {
885                 x.backward((ref xGrads) {
886                     foreach (i; 0 .. xGrads.shape[0])
887                     {
888                         xGrads[i].flattened[] = grad[i, 0];
889                     }
890                 });
891             });
892         }
893         else
894         {
895             return new Return(y);
896         }
897     }
898 
899     unittest
900     {
901         import std.format : format;
902 
903         auto x = tensor!([0, 4])([0.5, 1.0, 1.5, 2.0]);
904         auto y = sum(x);
905 
906         assert(y.staticShape == [0, 1]);
907         assert(y.value[0, 0] == 5.0);
908 
909         assert(x.grads == [[0, 0, 0, 0]], "%s".format(x.grads));
910         y.backward();
911         assert(x.grads == [[1, 1, 1, 1]], "%s".format(x.grads));
912     }
913 
914     unittest
915     {
916         auto x = tensor!([2, 2, 2])([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
917         Tensor!(double, [2, 1]) y = sum(x);
918 
919         assert(y.value[0, 0] == 10.0);
920         assert(y.value[1, 0] == 26.0);
921 
922         y.backward();
923         assert(x.grads[0, 0, 0] == 1.0);
924         assert(x.grads[0, 0, 1] == 1.0);
925         assert(x.grads[0, 1, 0] == 1.0);
926         assert(x.grads[0, 1, 1] == 1.0);
927         assert(x.grads[1, 0, 0] == 1.0);
928         assert(x.grads[1, 0, 1] == 1.0);
929         assert(x.grads[1, 1, 0] == 1.0);
930         assert(x.grads[1, 1, 1] == 1.0);
931     }
932     
933     unittest
934     {
935         auto x = tensor!([2, 2, 2], No.gradient)([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
936         Tensor!(double, [2, 1], No.gradient) y = sum(x);
937 
938         assert(y.value[0, 0] == 10.0);
939         assert(y.value[1, 0] == 26.0);
940     }
941 }
942 
943 version (all) // mean
944 {
945     Tensor!(T, [1], useGradient) mean(alias mode = "fast", T, size_t[] Shape, UseGradient useGradient)(Tensor!(T, Shape, useGradient) x)
946             if (Shape.length == 2 && Shape[1] == 1)
947     {
948         import mir.math.sum : mirsum = sum;
949 
950         const n = elementSize(x.value.shape);
951         auto y = slice!T([1], mirsum!mode(x.value) / n);
952 
953         alias Return = typeof(return);
954         alias Value = Return.Value;
955 
956         static if (canBackward!(typeof(x)))
957         {
958             x.usedCount++;
959 
960             return new Return(y, (Value grad) {
961                 x.backward((ref xGrads) { xGrads[] += grad / n; });
962             });
963         }
964         else
965         {
966             return new Return(y);
967         }
968     }
969 
970     unittest
971     {
972         auto x = tensor!([4, 1])([1.0f, 2.0f, 3.0f, 4.0f]);
973         auto s = mean(x);
974         assert(s.value[0] == 2.5f);
975 
976         assert(x.grads == [[0.0f], [0.0f], [0.0f], [0.0f]]);
977         s.backward();
978         assert(x.grads == [[0.25f], [0.25f], [0.25f], [0.25f]]);
979     }
980     
981     unittest
982     {
983         auto x = tensor!([4, 1], No.gradient)([1.0f, 2.0f, 3.0f, 4.0f]);
984         auto s = mean(x);
985         assert(s.value[0] == 2.5f);
986     }
987 
988     Tensor!(T, [Shape[0], 1], useGradient) mean(alias mode = "fast", T, size_t[] Shape, UseGradient useGradient)(Tensor!(T, Shape, useGradient) x)
989             if ((Shape.length == 2 && Shape[1] != 1) || (Shape.length > 2))
990     {
991         import mir.math.sum : mirsum = sum;
992 
993         const batchSize = x.value.shape[0];
994         const n = elementSize(x.value.shape[1 .. $]);
995         auto y = uninitSlice!T([batchSize, 1]);
996         foreach (i; 0 .. batchSize)
997         {
998             y[i, 0] = mirsum!mode(x.value[i]) / n;
999         }
1000 
1001         alias Return = typeof(return);
1002         alias Value = Return.Value;
1003 
1004         static if (canBackward!(typeof(x)))
1005         {
1006             x.usedCount++;
1007 
1008             return new Return(y, (Value grad) {
1009                 x.backward((ref xGrads) {
1010                     foreach (i; 0 .. xGrads.shape[0])
1011                     {
1012                         xGrads[i].flattened[] = grad[i, 0] / n;
1013                     }
1014                 });
1015             });
1016         }
1017         else
1018         {
1019             return new Return(y);
1020         }
1021     }
1022 
1023     unittest
1024     {
1025         import std.format : format;
1026 
1027         auto x = tensor!([0, 4])([0.5, 1.0, 1.5, 2.0]);
1028         auto y = mean(x);
1029 
1030         assert(y.staticShape == [0, 1]);
1031         assert(y.value[0, 0] == 1.25);
1032 
1033         assert(x.grads == [[0, 0, 0, 0]], "%s".format(x.grads));
1034         y.backward();
1035         assert(x.grads == [[0.25, 0.25, 0.25, 0.25]], "%s".format(x.grads));
1036     }
1037 
1038     unittest
1039     {
1040         auto x = tensor!([2, 2, 2])([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1041         Tensor!(double, [2, 1]) y = mean(x);
1042 
1043         assert(y.value[0, 0] == 2.5);
1044         assert(y.value[1, 0] == 6.5);
1045 
1046         y.backward();
1047         assert(x.grads[0, 0, 0] == 0.25);
1048         assert(x.grads[0, 0, 1] == 0.25);
1049         assert(x.grads[0, 1, 0] == 0.25);
1050         assert(x.grads[0, 1, 1] == 0.25);
1051         assert(x.grads[1, 0, 0] == 0.25);
1052         assert(x.grads[1, 0, 1] == 0.25);
1053         assert(x.grads[1, 1, 0] == 0.25);
1054         assert(x.grads[1, 1, 1] == 0.25);
1055     }
1056     
1057     unittest
1058     {
1059         auto x = tensor!([2, 2, 2], No.gradient)([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1060         Tensor!(double, [2, 1], No.gradient) y = mean(x);
1061 
1062         assert(y.value[0, 0] == 2.5);
1063         assert(y.value[1, 0] == 6.5);
1064     }
1065 }
1066 
1067 version (all) // flatten
1068 {
1069     Tensor!(T, [Shape[0], elementSize(Shape[1 .. $])], useGradient) flatten(T, size_t[] Shape, UseGradient useGradient)(Tensor!(T, Shape, useGradient) x)
1070     {
1071         int err;
1072         auto y = x.value.reshape([x.value.shape[0], -1], err);
1073         assert(err == 0);
1074 
1075         static if (canBackward!(typeof(x)))
1076         {
1077             x.usedCount++;
1078 
1079             alias Value = typeof(return).Value;
1080             return new typeof(return)(y, (Value grad) {
1081                 int err;
1082                 auto reshaped = grad.reshape([
1083                         grad.shape[0], expandShape!(Shape[1 .. $])
1084                     ], err);
1085                 assert(err == 0);
1086                 x.backward(reshaped);
1087             });
1088         }
1089         else
1090         {
1091             return new typeof(return)(y);
1092         }
1093     }
1094 
1095     unittest
1096     {
1097         auto x = tensor!([2, 2, 2])([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1098         Tensor!(double, [2, 4]) y = flatten(x);
1099         assert(y.staticShape == [2, 4]);
1100         assert(y.value[0, 0] == 1.0);
1101         assert(y.value[0, 1] == 2.0);
1102         assert(y.value[0, 2] == 3.0);
1103         assert(y.value[0, 3] == 4.0);
1104         assert(y.value[1, 0] == 5.0);
1105         assert(y.value[1, 1] == 6.0);
1106         assert(y.value[1, 2] == 7.0);
1107         assert(y.value[1, 3] == 8.0);
1108     }
1109     
1110     unittest
1111     {
1112         auto x = tensor!([2, 2, 2], No.gradient)([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1113         Tensor!(double, [2, 4], No.gradient) y = flatten(x);
1114     }
1115 }
1116 
1117 version (all) // softmax
1118 {
1119     Tensor!(T, Shape, useGrad) softmax(T, size_t[] Shape, UseGradient useGrad)(Tensor!(T, Shape, useGrad) x)
1120         if (Shape.length == 2)
1121     {
1122         import std.math : stdexp = exp;
1123 
1124         static if (Shape[0] == 0)
1125             const batchSize = x.shape[0];
1126         else
1127             enum batchSize = Shape[0];
1128 
1129         enum dim = Shape[1];
1130         auto y = uninitSlice!T(batchSize, dim);
1131         
1132         const expx = slice(x.value.map!stdexp);
1133         T[dim] temp;
1134         foreach (i; 0 .. batchSize)
1135         {
1136             auto s = T(0);
1137             foreach (j; 0 .. dim)
1138             {
1139                 temp[j] = expx[i, j];
1140                 s += temp[j];
1141             }
1142             foreach (j; 0 .. dim)
1143             {
1144                 y[i, j] = temp[j] / s;
1145             }
1146         }
1147 
1148         static if (useGrad)
1149         {
1150             x.usedCount++;
1151             return new Tensor!(T, Shape)(y, (grads) {
1152                 x.backward((ref xGrads) {
1153                     foreach (i; 0 .. batchSize)
1154                     {
1155                         import mir.math.sum : mirsum = sum;
1156 
1157                         const s = mirsum!"fast"(expx[i, 0 .. dim]);
1158                         const is2 = T(1) / (s * s);
1159                         foreach (j; 0 .. dim)
1160                         {
1161                             const a = grads[i, j];
1162                             auto d = T(0);
1163                             foreach (k; 0 .. dim)
1164                             {
1165                                 if (k == j) continue;
1166                                 d += (a - grads[i, k]) * expx[i, k];
1167                             }
1168                             
1169                             xGrads[i, j] = is2 * expx[i, j] * d;
1170                         }
1171                     }
1172                 });
1173             });
1174         }
1175         else
1176         {
1177             return new Tensor!(T, Shape, UseGradient.no)(y);
1178         }
1179     }
1180 
1181     unittest
1182     {
1183         auto x = tensor!([0, 3])([[1.0, 2.0, 3.0]]);
1184         auto y = softmax(x);
1185         auto z = tensor!([0, 3], UseGradient.no)([[1.0, 0.0, 0.0]]);
1186 
1187         import mir.math.sum : mirsum = sum;
1188         import std.math : isClose;
1189 
1190         assert(mirsum(y.value).isClose(1));
1191 
1192         auto t = z - y;
1193         auto loss = mean(t * t);
1194 
1195         loss.backward();
1196     }
1197 }
1198 
1199 version (all) // softmaxCrossEntropy
1200 {
1201     Tensor!(T, [Shape1[0], 1], useGrad) softmaxCrossEntropy(T, size_t[] Shape1, size_t[] Shape2, UseGradient useGrad)(Tensor!(T, Shape1, useGrad) x, Tensor!(T, Shape2, UseGradient.no) y)
1202     if (Shape1.length == 2 && Shape2.length == 2 && Shape1[1] == Shape2[1])
1203     {
1204         static assert(Shape1[0] == 0 || Shape2[0] == 0 || Shape1[0] == Shape2[0]);
1205         assert(x.shape[0] == y.shape[0]);
1206 
1207         import mir.ndslice : map, zip;
1208         import mir.math.sum : sum;
1209         import std.math : exp, log;
1210 
1211         const c = T(1) / x.shape[1];
1212 
1213         auto t = x.value.ipack!1.map!(r => sum!"fast"(r.map!exp));
1214 
1215         int err;
1216         auto z = (c * (t.map!(a => cast(T) log(a)) - (x.value * y.value).ipack!1.map!(a => sum!"fast"(a))))
1217             .fuse()
1218             .reshape([x.shape[0], 1], err);
1219 
1220         static if (useGrad)
1221         {
1222             x.usedCount++;
1223         }
1224         alias Return = typeof(return);
1225         alias Value = Return.Value;
1226 
1227         static if (useGrad)
1228         {
1229             return new Return(z, (Value grads) {
1230                 x.backward((ref xGrads) {
1231                     immutable p = T(1) / xGrads.shape[1];
1232                     foreach (i; 0 .. xGrads.shape[0])
1233                     {
1234                         xGrads[i][] += p * (x.value[i].map!exp / t[i] - y.value[i][]) * grads[i, 0];
1235                     }
1236                 });
1237             });
1238         }
1239         else
1240         {
1241             return new Return(z);
1242         }
1243     }
1244 
1245     unittest
1246     {
1247         auto x = tensor!([0, 3])([
1248             [0.0, 0.0, 0.0],
1249             [0.0, 0.0, 1.0],
1250             [0.0, 1.0, 0.0],
1251             [1.0, 0.0, 0.0],
1252         ]);
1253         auto y = tensor!([0, 3], UseGradient.no)([
1254             [0.0, 0.0, 1.0],
1255             [0.0, 0.0, 1.0],
1256             [0.0, 1.0, 0.0],
1257             [1.0, 0.0, 0.0],
1258         ]);
1259 
1260         auto loss = softmaxCrossEntropy(x, y);
1261         assert(loss.shape == [4, 1]);
1262 
1263         import std.math : isClose, E;
1264         import std.format : format;
1265 
1266         assert(loss.value[0, 0].isClose(0.3662040962), format!"%.10f"(loss.value[0, 0]));
1267         assert(loss.value[1, 0].isClose(0.1838149046), format!"%.10f"(loss.value[1, 0]));
1268         assert(loss.value[2, 0].isClose(0.1838149046), format!"%.10f"(loss.value[2, 0]));
1269         assert(loss.value[3, 0].isClose(0.1838149046), format!"%.10f"(loss.value[3, 0]));
1270 
1271         loss.backward();
1272 
1273         enum double g1 = 1.0 / (6.0 + 3 * E);
1274         enum double g2 = -2.0 / (6.0 + 3 * E);
1275 
1276         import std.conv : text;
1277 
1278         assert(x.grads[0, 0].isClose(1.0 / 9), text(x.grads));
1279         assert(x.grads[0, 1].isClose(1.0 / 9), text(x.grads));
1280         assert(x.grads[0, 2].isClose(-2.0 / 9), text(x.grads));
1281         assert(x.grads[1, 0].isClose(g1), text(x.grads));
1282         assert(x.grads[1, 1].isClose(g1), text(x.grads));
1283         assert(x.grads[1, 2].isClose(g2), text(x.grads));
1284         assert(x.grads[2, 0].isClose(g1), text(x.grads));
1285         assert(x.grads[2, 1].isClose(g2), text(x.grads));
1286         assert(x.grads[2, 2].isClose(g1), text(x.grads));
1287         assert(x.grads[3, 0].isClose(g2), text(x.grads));
1288         assert(x.grads[3, 1].isClose(g1), text(x.grads));
1289         assert(x.grads[3, 2].isClose(g1), text(x.grads));
1290     }
1291 
1292     unittest
1293     {
1294         auto x = tensor!([0, 3], UseGradient.no)([
1295             [0.0, 0.0, 0.0],
1296             [0.0, 0.0, 1.0],
1297             [0.0, 1.0, 0.0],
1298             [1.0, 0.0, 0.0],
1299         ]);
1300         auto y = tensor!([0, 3], UseGradient.no)([
1301             [0.0, 0.0, 1.0],
1302             [0.0, 0.0, 1.0],
1303             [0.0, 1.0, 0.0],
1304             [1.0, 0.0, 0.0],
1305         ]);
1306 
1307         auto loss = softmaxCrossEntropy(x, y);
1308         assert(loss.shape == [4, 1]);
1309 
1310         import std.math : isClose, E;
1311         import std.format : format;
1312 
1313         assert(loss.value[0, 0].isClose(0.3662040962), format!"%.10f"(loss.value[0, 0]));
1314         assert(loss.value[1, 0].isClose(0.1838149046), format!"%.10f"(loss.value[1, 0]));
1315         assert(loss.value[2, 0].isClose(0.1838149046), format!"%.10f"(loss.value[2, 0]));
1316         assert(loss.value[3, 0].isClose(0.1838149046), format!"%.10f"(loss.value[3, 0]));
1317 
1318         static assert(!__traits(compiles, {
1319             loss.backward();
1320         }));
1321     }
1322 }
1323 
1324 version (all) // dropout
1325 {
1326     Tensor!(T, Shape, useGradient) dropout(T, size_t[] Shape, UseGradient useGradient)(Tensor!(T, Shape, useGradient) x, float rate, bool isTrain)
1327     {
1328         import std : roundTo;
1329         import golem.util : elementSize;
1330         import mir.ndslice : flattened;
1331 
1332         enum size = elementSize(Shape[1 .. $]);
1333         const dropSize = roundTo!size_t(size * (1 - rate));
1334 
1335         if (isTrain)
1336         {
1337             auto filter = onesLike(x);
1338             foreach (i; 0 .. x.shape[0])
1339             {
1340                 import std.random : uniform;
1341 
1342                 auto row = filter.value[i].flattened;
1343                 foreach (j; 0 .. dropSize)
1344                 {
1345                     row[uniform(0, size)] = 0;
1346                 }
1347             }
1348             return filter * x;
1349         }
1350         else
1351         {
1352             const p = T(size - dropSize) / size;
1353             return p * x;
1354         }
1355     }
1356 
1357     unittest
1358     {
1359         auto x = tensor!([0, 2, 2])([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1360         auto y = dropout(x, 0.5, true);
1361         auto z = dropout(x, 0.5, false);
1362 
1363         import std.algorithm : count;
1364 
1365         const t = y.value.flattened[].count(0);
1366         assert(t >= 2); // batchSize * 1
1367         assert(t <= 2 * 2); // batchSize * round(4 * 0.5)
1368 
1369         import std.math : round, isClose;
1370 
1371         const a = (4 - round(4 * 0.5)) / 4;
1372         assert(z.value[0, 0, 0].isClose(1.0 * a));
1373         assert(z.value[0, 0, 1].isClose(2.0 * a));
1374         assert(z.value[0, 1, 0].isClose(3.0 * a));
1375         assert(z.value[0, 1, 1].isClose(4.0 * a));
1376         assert(z.value[1, 0, 0].isClose(5.0 * a));
1377         assert(z.value[1, 0, 1].isClose(6.0 * a));
1378         assert(z.value[1, 1, 0].isClose(7.0 * a));
1379         assert(z.value[1, 1, 1].isClose(8.0 * a));
1380     }
1381 }
1382 
1383 version (all) // concat
1384 {
1385     size_t[] makeConcatShape(size_t[] lhs, size_t[] rhs)
1386     in (lhs.length > 0)
1387     in (rhs.length > 0)
1388     in (lhs.length == rhs.length)
1389     {
1390         size_t axis = lhs.length - 1;
1391         foreach (i; 0 .. lhs.length)
1392         {
1393             if (lhs[i] != rhs[i])
1394             {
1395                 axis = i;
1396                 break;
1397             }
1398         }
1399         auto shape = lhs.dup;
1400         shape[axis] += rhs[axis];
1401         return shape;
1402     }
1403 
1404     template ConcatTensor(TensorL, TensorR)
1405     if (isTensor!TensorL && isTensor!TensorR)
1406     {
1407         import std.format : format;
1408 
1409         // dfmt off
1410         static assert(TensorL.staticShape.length == TensorR.staticShape.length,
1411             format!"%s != %s"(TensorL.staticShape, TensorR.staticShape));
1412         // dfmt on
1413 
1414         private alias ElementType = TensorL.ElementType;
1415         private enum Shape = makeConcatShape(TensorL.staticShape, TensorR.staticShape);
1416         private enum useGradient = commonGradientType!(TensorL, TensorR);
1417 
1418         alias ConcatTensor = Tensor!(ElementType, Shape, useGradient);
1419     }
1420 
1421     // Dim: [N, A] + [N, B] => [N, A + B]
1422     auto concat(T, U)(T x, U y)
1423     if (isTensor!T && isTensor!U)
1424     {
1425         import std.format : format;
1426         static assert(T.staticShape.length == 2, format!"Only 2 dimensions are supported at x (%s)"(T.staticShape));
1427         static assert(U.staticShape.length == 2, format!"Only 2 dimensions are supported at y (%s)"(U.staticShape));
1428         static if (T.staticShape[0] != 0 && U.staticShape[0] != 0)
1429         {
1430             static assert(T.staticShape[0] == U.staticShape[0], format!"mismatch batch size (%s != %s)"(T.staticShape, U.staticShape));
1431         }
1432         else
1433         {
1434             assert(x.shape[0] == y.shape[0], format!"mismatch batch size (%s != %s)"(T.staticShape, U.staticShape));
1435         }
1436 
1437         alias Return = ConcatTensor!(T, U);
1438 
1439         const batchSize = x.shape[0];
1440         auto z = uninitSlice!(T.ElementType)(batchSize, x.staticShape[1] + y.staticShape[1]);
1441         foreach (i; 0 .. batchSize)
1442         {
1443             z[i][0 .. x.staticShape[1]] = x.value[i][0 .. $];
1444             z[i][x.staticShape[1] .. $] = y.value[i][0 .. $];
1445         }
1446 
1447         static if (canBackward!(Return))
1448         {
1449             static if (canBackward!T) x.usedCount++;
1450             static if (canBackward!U) y.usedCount++;
1451             return new Return(z, (grads) {
1452                 static if (canBackward!T)
1453                 {
1454                     x.backward((ref xGrads) {
1455                         xGrads[] += grads[0 .. $, 0 .. x.staticShape[1]];
1456                     });
1457                 }
1458                 static if (canBackward!U)
1459                 {
1460                     y.backward((ref yGrads) {
1461                         yGrads[] += grads[0 .. $, x.staticShape[1] .. $];
1462                     });
1463                 }
1464             });
1465         }
1466         else
1467         {
1468             return new Return(z);
1469         }
1470     }
1471 
1472     unittest
1473     {
1474         auto x = tensor!([2, 2])([1.0f, 2.0f, 3.0f, 4.0f]);
1475         auto y = tensor!([2, 1])([10.0f, 20.0f]);
1476 
1477         auto z = concat(x, y);
1478 
1479         assert(z.value[0, 0] == 1.0f);
1480         assert(z.value[0, 1] == 2.0f);
1481         assert(z.value[0, 2] == 10.0f);
1482         assert(z.value[1, 0] == 3.0f);
1483         assert(z.value[1, 1] == 4.0f);
1484         assert(z.value[1, 2] == 20.0f);
1485 
1486         auto a = tensor!([2, 3])([[1.0f, 2.0f, 3.0f], [4.0f, 5.0f, 6.0f]]);
1487         (a * z).backward();
1488 
1489         import std.conv : to;
1490         assert(x.grads[0, 0] == 1.0f, x.grads.to!string());
1491         assert(x.grads[0, 1] == 2.0f, x.grads.to!string());
1492         assert(x.grads[1, 0] == 4.0f, x.grads.to!string());
1493         assert(x.grads[1, 1] == 5.0f, x.grads.to!string());
1494         assert(y.grads[0, 0] == 3.0f, y.grads.to!string());
1495         assert(y.grads[1, 0] == 6.0f, y.grads.to!string());
1496     }
1497     
1498     unittest
1499     {
1500         auto x = tensor!([2, 2], UseGradient.no)([1.0f, 2.0f, 3.0f, 4.0f]);
1501         auto y = tensor!([2, 1])([10.0f, 20.0f]);
1502 
1503         auto z = concat(x, y);
1504 
1505         assert(z.value[0, 0] == 1.0f);
1506         assert(z.value[0, 1] == 2.0f);
1507         assert(z.value[0, 2] == 10.0f);
1508         assert(z.value[1, 0] == 3.0f);
1509         assert(z.value[1, 1] == 4.0f);
1510         assert(z.value[1, 2] == 20.0f);
1511 
1512         auto a = tensor!([2, 3])([[1.0f, 2.0f, 3.0f], [4.0f, 5.0f, 6.0f]]);
1513         (a * z).backward();
1514 
1515         import std.conv : to;
1516         assert(y.grads[0, 0] == 3.0f, y.grads.to!string());
1517         assert(y.grads[1, 0] == 6.0f, y.grads.to!string());
1518     }
1519     
1520     unittest
1521     {
1522         auto x = tensor!([2, 2])([1.0f, 2.0f, 3.0f, 4.0f]);
1523         auto y = tensor!([2, 1], UseGradient.no)([10.0f, 20.0f]);
1524 
1525         auto z = concat(x, y);
1526 
1527         assert(z.value[0, 0] == 1.0f);
1528         assert(z.value[0, 1] == 2.0f);
1529         assert(z.value[0, 2] == 10.0f);
1530         assert(z.value[1, 0] == 3.0f);
1531         assert(z.value[1, 1] == 4.0f);
1532         assert(z.value[1, 2] == 20.0f);
1533 
1534         auto a = tensor!([2, 3])([[1.0f, 2.0f, 3.0f], [4.0f, 5.0f, 6.0f]]);
1535         (a * z).backward();
1536 
1537         import std.conv : to;
1538         assert(x.grads[0, 0] == 1.0f, x.grads.to!string());
1539         assert(x.grads[0, 1] == 2.0f, x.grads.to!string());
1540         assert(x.grads[1, 0] == 4.0f, x.grads.to!string());
1541         assert(x.grads[1, 1] == 5.0f, x.grads.to!string());
1542     }
1543     
1544     unittest
1545     {
1546         auto x = tensor!([2, 2], UseGradient.no)([1.0f, 2.0f, 3.0f, 4.0f]);
1547         auto y = tensor!([2, 1], UseGradient.no)([10.0f, 20.0f]);
1548 
1549         auto z = concat(x, y);
1550         static assert(!canBackward!(typeof(z)));
1551         
1552         assert(z.value[0, 0] == 1.0f);
1553         assert(z.value[0, 1] == 2.0f);
1554         assert(z.value[0, 2] == 10.0f);
1555         assert(z.value[1, 0] == 3.0f);
1556         assert(z.value[1, 1] == 4.0f);
1557         assert(z.value[1, 2] == 20.0f);
1558     }
1559 
1560     unittest
1561     {
1562         auto x = tensor!([1, 1])([10.0f, 20.0f, 30.0f]);
1563         auto y = tensor!([2, 2])([1.0f, 2.0f, 3.0f, 4.0f]);
1564 
1565         // mismatch batch size
1566         static assert(!__traits(compiles, concat(x, y)));
1567     }
1568     
1569     unittest
1570     {
1571         auto x = tensor!([0, 1])([10.0f, 20.0f, 30.0f]);
1572         auto y = tensor!([0, 2])([1.0f, 2.0f, 3.0f, 4.0f]);
1573 
1574         // mismatch batch size
1575         import core.exception : AssertError;
1576         import std.exception : assertThrown;
1577 
1578         assertThrown!AssertError(concat(x, y));
1579     }
1580 
1581     unittest
1582     {
1583         auto x = tensor!([3])([1.0f, 2.0f, 3.0f]);
1584         auto y = tensor!([3, 1])([1.0f, 2.0f, 3.0f]);
1585         auto z = tensor!([3, 1, 1])([1.0f, 2.0f, 3.0f]);
1586         
1587         static assert(!__traits(compiles, concat(x, y)));
1588         static assert(!__traits(compiles, concat(y, x)));
1589         static assert(!__traits(compiles, concat(y, z)));
1590         static assert(!__traits(compiles, concat(z, y)));
1591     }
1592 }
1593 
1594 version (all) // batchSum
1595 {
1596     Tensor!(T, Shape[1 .. $], useGrad) batchSum(T, size_t[] Shape, UseGradient useGrad)(Tensor!(T, Shape, useGrad) x)
1597     {
1598         import mir.math.sum : mirsum = sum;
1599 
1600         auto y = x.value.bringToFront!(Shape.length - 1).pack!1.map!(a => mirsum(a)).slice();
1601 
1602         static if (useGrad)
1603         {
1604             x.usedCount++;
1605             return new typeof(return)(y, (grads) {
1606                 x.backward((ref xGrads) {
1607                     foreach (i; 0 .. x.shape[0])
1608                     {
1609                         xGrads[i][] += grads[];
1610                     }
1611                 });
1612             });
1613         }
1614         else
1615         {
1616             return new typeof(return)(y);
1617         }
1618     }
1619 
1620     unittest
1621     {
1622         auto x = tensor!([0, 2, 2])([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1623         
1624         Tensor!(double, [2, 2]) y = batchSum(x);
1625         assert(y.value == [[6.0, 8.0], [10.0, 12.0]]);
1626 
1627         y.backward();
1628 
1629         assert(x.grads == [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]);
1630     }
1631 
1632     unittest
1633     {
1634         auto x = tensor!([0, 2])([1.0, 2.0, 3.0, 4.0]);
1635         
1636         Tensor!(double, [2]) y = batchSum(x);
1637         assert(y.value == [4.0, 6.0]);
1638 
1639         auto z = y * tensor!([2])([-1.0, 2.0]);
1640         z.backward();
1641 
1642         assert(x.grads == [[-1.0, 2.0], [-1.0, 2.0]]);
1643     }
1644 
1645     unittest
1646     {
1647         auto x = tensor!([0, 2, 2], UseGradient.no)([2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]);
1648         
1649         Tensor!(double, [2, 2], UseGradient.no) y = batchSum(x);
1650         assert(y.value == [[8.0, 10.0], [12.0, 14.0]]);
1651     }
1652 }
1653 
1654 version (all) // boradcastOp
1655 {
1656     template broadcastOp(string op)
1657     if (op == "+" || op == "-")
1658     {
1659         /+
1660         [N, C, W, H] + [C, W, H]
1661         [N, C, W, H] - [C, W, H]
1662         +/
1663         auto broadcastOp(T, size_t[] Shape1, UseGradient useGrad1, size_t[] Shape2, UseGradient useGrad2)(
1664                 Tensor!(T, Shape1, useGrad1) x, Tensor!(T, Shape2, useGrad2) y)
1665             if (Shape1[$ - Shape2.length .. $] == Shape2)
1666         {
1667             enum Dim1 = Shape1.length;
1668             enum Dim2 = Shape2.length;
1669 
1670             static if (op == "+")
1671                 alias binOp = (a, b) => a + b;
1672             else static if (op == "-")
1673                 alias binOp = (a, b) => a - b;
1674             
1675             auto yv = y.value;
1676             auto z = x.value.pack!Dim2.map!(a => binOp(a, yv)).fuse();
1677 
1678             static if (useGrad1 || useGrad2)
1679             {
1680                 static if (useGrad1) x.usedCount++;
1681                 static if (useGrad2) y.usedCount++;
1682                 return new Tensor!(T, Shape1, UseGradient.yes)(z, (grads) {
1683                     static if (useGrad1)
1684                     {
1685                         x.backward(grads);
1686                     }
1687                     static if (useGrad2)
1688                     {
1689                         y.backward((ref yGrads) {
1690                             import mir.math.sum : mirsum = sum;
1691 
1692                             static if (op == "+")
1693                                 yGrads[] += grads.transposed!(expandIndex!(Dim1 - Dim2, Dim1)).ipack!Dim2.map!(a => mirsum(a));
1694                             else
1695                                 yGrads[] -= grads.transposed!(expandIndex!(Dim1 - Dim2, Dim1)).ipack!Dim2.map!(a => mirsum(a));
1696                         });
1697                     }
1698                 });
1699             }
1700             else
1701             {
1702                 return new Tensor!(T, Shape1, UseGradient.no)(z);
1703             }
1704         }
1705     }
1706 
1707     unittest
1708     {
1709         auto x = tensor!([0, 2, 2], UseGradient.no)([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1710         auto y = tensor!([2], UseGradient.no)([10.0, 20.0]);
1711 
1712         auto z1 = broadcastOp!"+"(x, y);
1713         auto z2 = broadcastOp!"-"(x, y);
1714 
1715         assert(z1.value.flattened == [11.0, 22.0, 13.0, 24.0, 15.0, 26.0, 17.0, 28.0]);
1716         assert(z2.value.flattened == [-9.0, -18.0, -7.0, -16.0, -5.0, -14.0, -3.0, -12.0]);
1717     }
1718 
1719     unittest
1720     {
1721         auto x = tensor!([0, 2])([1.0, 2.0, 3.0, 4.0]);
1722         auto y = tensor!([2])([10.0, 20.0]);
1723         auto z = broadcastOp!"+"(x, y);
1724         assert(z.shape == [2, 2]);
1725         assert(z.value == [[11.0, 22.0], [13.0, 24.0]]);
1726 
1727         z.backward();
1728 
1729         import std : text;
1730 
1731         assert(x.grads == [[1.0, 1.0], [1.0, 1.0]], text("x.grads: ", x.grads, " != [[1.0, 1.0], [1.0, 1.0]]"));
1732         assert(y.grads == [2.0, 2.0], text("y.grads: ", y.grads, " != [2.0, 2.0]"));
1733     }
1734 
1735     unittest
1736     {
1737         auto x = tensor!([0, 2])([1.0, 2.0, 3.0, 4.0]);
1738         auto y = tensor!([2])([10.0, 20.0]);
1739         auto z = broadcastOp!"-"(x, y);
1740         assert(z.shape == [2, 2]);
1741         assert(z.value == [[-9.0, -18.0], [-7.0, -16.0]]);
1742 
1743         z.backward();
1744 
1745         import std : text;
1746 
1747         assert(x.grads == [[1.0, 1.0], [1.0, 1.0]], text("x.grads: ", x.grads, " != [[1.0, 1.0], [1.0, 1.0]]"));
1748         assert(y.grads == [-2.0, -2.0], text("y.grads: ", y.grads, " != [-2.0, -2.0]"));
1749     }
1750 
1751     unittest
1752     {
1753         auto x = tensor!([0, 2])([1.0, 2.0, 3.0, 4.0]);
1754         auto y = (1.0 / x.shape[0]) * batchSum(x); // mean
1755         auto z = broadcastOp!"-"(x, y);
1756         assert(z.shape == [2, 2]);
1757         assert(z.value == [[-1.0, -1.0], [1.0, 1.0]]);
1758 
1759         z.backward();
1760 
1761         assert(x.grads == [[0.0, 0.0], [0.0, 0.0]]);
1762     }
1763 
1764     unittest
1765     {
1766         auto x = tensor!([0, 2, 2])([1.0, 2.0, 3.0, 4.0]);
1767         auto y = tensor!([2], UseGradient.no)([10.0, 20.0]);
1768 
1769         auto z = broadcastOp!"+"(x, y);
1770         auto w = broadcastOp!"-"(x, y);
1771         z.backward();
1772         w.backward();
1773     }
1774 
1775     unittest
1776     {
1777         auto x = tensor!([0, 2, 2], UseGradient.no)([1.0, 2.0, 3.0, 4.0]);
1778         auto y = tensor!([2])([10.0, 20.0]);
1779 
1780         auto z = broadcastOp!"+"(x, y);
1781         auto w = broadcastOp!"-"(x, y);
1782         z.backward();
1783         w.backward();
1784     }
1785 
1786     unittest
1787     {
1788         auto x = tensor!([0, 2, 2], UseGradient.no)([1.0, 2.0, 3.0, 4.0]);
1789         auto y = tensor!([2], UseGradient.no)([10.0, 20.0]);
1790 
1791         auto z = broadcastOp!"+"(x, y);
1792         auto w = broadcastOp!"-"(x, y);
1793 
1794         static assert(!canBackward!(typeof(z)));
1795         static assert(!canBackward!(typeof(w)));
1796     }
1797 
1798     template broadcastOp(string op)
1799     if (op == "*")
1800     {
1801         /+
1802         [N, C, W, H] * [C, W, H]
1803         +/
1804         auto broadcastOp(T, size_t[] Shape1, UseGradient useGrad1, size_t[] Shape2, UseGradient useGrad2)(
1805                 Tensor!(T, Shape1, useGrad1) x, Tensor!(T, Shape2, useGrad2) y)
1806             if (Shape1[$ - Shape2.length .. $] == Shape2)
1807         {
1808             enum Dim2 = Shape2.length;
1809 
1810             alias binOp = (a, b) => a * b;
1811             
1812             auto yv = y.value;
1813             auto z = x.value.pack!Dim2.map!(a => binOp(a, yv)).fuse();
1814 
1815             static if (useGrad1 || useGrad2)
1816             {
1817                 static if (useGrad1) x.usedCount++;
1818                 static if (useGrad2) y.usedCount++;
1819                 return new Tensor!(T, Shape1, UseGradient.yes)(z, (grads) {
1820                     static if (useGrad1)
1821                     {
1822                         x.backward((ref xGrads) {
1823                             foreach (ref t; zip(xGrads.pack!Dim2.flattened, grads.pack!Dim2.flattened))
1824                             {
1825                                 t[0][] += t[1][] * yv[];
1826                             }
1827                         });
1828                     }
1829                     static if (useGrad2)
1830                     {
1831                         y.backward((ref yGrads) {
1832                             foreach (ref t; zip(grads.pack!Dim2.flattened, x.value.pack!Dim2.flattened))
1833                             {
1834                                 yGrads[] += t[0] * t[1];
1835                             }
1836                         });
1837                     }
1838                 });
1839             }
1840             else
1841             {
1842                 return new Tensor!(T, Shape1, UseGradient.no)(z);
1843             }
1844         }
1845     }
1846 
1847     unittest
1848     {
1849         auto x = tensor!([0, 2, 2])([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1850         auto y = tensor!([2, 2])([0.2, 0.4, 0.6, 0.8]);
1851 
1852         auto z = broadcastOp!"*"(x, y);
1853 
1854         import std.math : isClose;
1855 
1856         assert(z.value[0, 0, 0].isClose(1.0 * 0.2));
1857         assert(z.value[0, 0, 1].isClose(2.0 * 0.4));
1858         assert(z.value[0, 1, 0].isClose(3.0 * 0.6));
1859         assert(z.value[0, 1, 1].isClose(4.0 * 0.8));
1860         assert(z.value[1, 0, 0].isClose(5.0 * 0.2));
1861         assert(z.value[1, 0, 1].isClose(6.0 * 0.4));
1862         assert(z.value[1, 1, 0].isClose(7.0 * 0.6));
1863         assert(z.value[1, 1, 1].isClose(8.0 * 0.8));
1864 
1865         z.backward();
1866 
1867         assert(x.grads[0, 0, 0].isClose(0.2));
1868         assert(x.grads[0, 0, 1].isClose(0.4));
1869         assert(x.grads[0, 1, 0].isClose(0.6));
1870         assert(x.grads[0, 1, 1].isClose(0.8));
1871         assert(x.grads[1, 0, 0].isClose(0.2));
1872         assert(x.grads[1, 0, 1].isClose(0.4));
1873         assert(x.grads[1, 1, 0].isClose(0.6));
1874         assert(x.grads[1, 1, 1].isClose(0.8));
1875 
1876         assert(y.grads[0, 0] == 1.0 + 5.0);
1877         assert(y.grads[0, 1] == 2.0 + 6.0);
1878         assert(y.grads[1, 0] == 3.0 + 7.0);
1879         assert(y.grads[1, 1] == 4.0 + 8.0);
1880     }
1881 
1882     unittest
1883     {
1884         auto x = tensor!([0, 2, 2])([1.0, 2.0, 3.0, 4.0]);
1885         auto y = tensor!([2], UseGradient.no)([10.0, 20.0]);
1886 
1887         auto z = broadcastOp!"*"(x, y);
1888         z.backward();
1889     }
1890 
1891     unittest
1892     {
1893         auto x = tensor!([0, 2, 2], UseGradient.no)([1.0, 2.0, 3.0, 4.0]);
1894         auto y = tensor!([2])([10.0, 20.0]);
1895 
1896         auto z = broadcastOp!"*"(x, y);
1897         z.backward();
1898     }
1899 
1900     unittest
1901     {
1902         auto x = tensor!([0, 2, 2], UseGradient.no)([1.0, 2.0, 3.0, 4.0]);
1903         auto y = tensor!([2], UseGradient.no)([10.0, 20.0]);
1904 
1905         auto z = broadcastOp!"*"(x, y);
1906 
1907         static assert(!canBackward!(typeof(z)));
1908     }
1909 }
1910 
1911 version (all) // multicastOp
1912 {
1913     template multicastOp(string op)
1914     if (op == "+" || op == "-")
1915     {
1916         /+
1917         [N, C, W, H] + [N, C]
1918         [N, C, W, H] - [N, C]
1919         +/
1920         auto multicastOp(T, size_t[] Shape1, UseGradient useGrad1, size_t[] Shape2, UseGradient useGrad2)(
1921                 Tensor!(T, Shape1, useGrad1) x, Tensor!(T, Shape2, useGrad2) y)
1922             if (Shape1[0 .. trimRightOneDims(Shape2).length] == trimRightOneDims(Shape2))
1923         {
1924             enum Dim2 = trimRightOneDims(Shape2).length;
1925 
1926             auto yv = y.value;
1927 
1928             auto z = slice(x.value);
1929             foreach (t; zip(z.ipack!Dim2.flattened, yv.flattened))
1930             {
1931                 static if (op == "+")
1932                     t[0][] += t[1];
1933                 else static if (op == "-")
1934                     t[0][] -= t[1];
1935             }
1936 
1937             static if (useGrad1 || useGrad2)
1938             {
1939                 static if (useGrad1) x.usedCount++;
1940                 static if (useGrad2) y.usedCount++;
1941                 return new Tensor!(T, Shape1, UseGradient.yes)(z, (grads) {
1942                     static if (useGrad1)
1943                     {
1944                         x.backward(grads);
1945                     }
1946                     static if (useGrad2)
1947                     {
1948                         y.backward((ref yGrads) {
1949                             import mir.math.sum : mirsum = sum;
1950 
1951                             static if (op == "+")
1952                                 yGrads[] += grads.ipack!Dim2.map!(a => mirsum(a));
1953                             else
1954                                 yGrads[] -= grads.ipack!Dim2.map!(a => mirsum(a));
1955                         });
1956                     }
1957                 });
1958             }
1959             else
1960             {
1961                 return new Tensor!(T, Shape1, UseGradient.no)(z);
1962             }
1963         }
1964     }
1965 
1966     unittest
1967     {
1968         auto x = tensor!([0, 2, 2], UseGradient.no)([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1969         auto y = tensor!([0], UseGradient.no)([10.0, 20.0]);
1970 
1971         auto z1 = multicastOp!"+"(x, y);
1972         auto z2 = multicastOp!"-"(x, y);
1973 
1974         assert(z1.value.flattened == [11.0, 12.0, 13.0, 14.0, 25.0, 26.0, 27.0, 28.0]);
1975         assert(z2.value.flattened == [-9.0, -8.0, -7.0, -6.0, -15.0, -14.0, -13.0, -12.0]);
1976         
1977         static assert(!canBackward!(typeof(z1)));
1978         static assert(!canBackward!(typeof(z2)));
1979     }
1980 
1981     unittest
1982     {
1983         auto x = tensor!([0, 2])([1.0, 2.0, 3.0, 4.0]);
1984         auto y = tensor!([0])([10.0, 20.0]);
1985         auto z = multicastOp!"+"(x, y);
1986         assert(z.shape == [2, 2]);
1987         assert(z.value == [[11.0, 12.0], [23.0, 24.0]]);
1988 
1989         z.backward();
1990 
1991         import std : text;
1992 
1993         assert(x.grads == [[1.0, 1.0], [1.0, 1.0]], text("x.grads: ", x.grads, " != [[1.0, 1.0], [1.0, 1.0]]"));
1994         assert(y.grads == [2.0, 2.0], text("y.grads: ", y.grads, " != [2.0, 2.0]"));
1995     }
1996 
1997     unittest
1998     {
1999         auto x = tensor!([0, 2])([1.0, 2.0, 3.0, 4.0]);
2000         auto y = tensor!([0])([10.0, 20.0]);
2001         auto z = multicastOp!"-"(x, y);
2002         assert(z.shape == [2, 2]);
2003         assert(z.value == [[-9.0, -8.0], [-17.0, -16.0]]);
2004 
2005         z.backward();
2006 
2007         import std : text;
2008 
2009         assert(x.grads == [[1.0, 1.0], [1.0, 1.0]], text("x.grads: ", x.grads, " != [[1.0, 1.0], [1.0, 1.0]]"));
2010         assert(y.grads == [-2.0, -2.0], text("y.grads: ", y.grads, " != [-2.0, -2.0]"));
2011     }
2012 
2013     unittest
2014     {
2015         auto x = tensor!([0, 2, 2])([1.0, 2.0, 3.0, 4.0]);
2016         auto y = tensor!([0, 2], UseGradient.no)([10.0, 20.0]);
2017 
2018         auto z = multicastOp!"+"(x, y);
2019         auto w = multicastOp!"-"(x, y);
2020         z.backward();
2021         w.backward();
2022     }
2023 
2024     unittest
2025     {
2026         auto x = tensor!([0, 2, 2], UseGradient.no)([1.0, 2.0, 3.0, 4.0]);
2027         auto y = tensor!([0, 2])([10.0, 20.0]);
2028 
2029         auto z = multicastOp!"+"(x, y);
2030         auto w = multicastOp!"-"(x, y);
2031         z.backward();
2032         w.backward();
2033     }
2034 
2035     unittest
2036     {
2037         // remove the average for each batch
2038         auto x = tensor!([0, 2, 2])([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
2039         auto y = mean(x);
2040 
2041         auto z = multicastOp!"-"(x, y);
2042 
2043         assert(z.value.flattened == [-1.5, -0.5, 0.5, 1.5, -1.5, -0.5, 0.5, 1.5]);
2044     }
2045 
2046 
2047     template multicastOp(string op)
2048     if (op == "*")
2049     {
2050         /+
2051         [N, C, W, H] * [N, C]
2052         +/
2053         auto multicastOp(T, size_t[] Shape1, UseGradient useGrad1, size_t[] Shape2, UseGradient useGrad2)(
2054                 Tensor!(T, Shape1, useGrad1) x, Tensor!(T, Shape2, useGrad2) y)
2055             if (Shape1[0 .. trimRightOneDims(Shape2).length] == trimRightOneDims(Shape2))
2056         {
2057             enum Dim2 = trimRightOneDims(Shape2).length;
2058 
2059             auto yv = y.value;
2060 
2061             auto z = slice(x.value);
2062             foreach (t; zip(z.ipack!Dim2.flattened, yv.flattened))
2063             {
2064                 t[0][] *= t[1];
2065             }
2066 
2067             static if (useGrad1 || useGrad2)
2068             {
2069                 static if (useGrad1) x.usedCount++;
2070                 static if (useGrad2) y.usedCount++;
2071                 return new Tensor!(T, Shape1, UseGradient.yes)(z, (grads) {
2072                     static if (useGrad1)
2073                     {
2074                         x.backward((ref xGrads) {
2075                             foreach (t; zip(xGrads.ipack!Dim2.flattened, grads.ipack!Dim2.flattened, y.value.flattened))
2076                             {
2077                                 t[0][] += t[1][] * t[2];
2078                             }
2079                         });
2080                     }
2081                     static if (useGrad2)
2082                     {
2083                         y.backward((ref yGrads) {
2084                             import mir.math.sum : mirsum = sum;
2085 
2086                             yGrads[] += (grads * x.value).ipack!Dim2.map!(a => mirsum(a));
2087                         });
2088                     }
2089                 });
2090             }
2091             else
2092             {
2093                 return new Tensor!(T, Shape1, UseGradient.no)(z);
2094             }
2095         }
2096     }
2097 
2098     unittest
2099     {
2100         auto x = tensor!([0, 2, 2])([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
2101         auto y = tensor!([0])([2.0, 3.0]);
2102 
2103         auto z = multicastOp!"*"(x, y);
2104 
2105         import std.math : isClose;
2106 
2107         assert(z.value[0, 0, 0] == 1.0 * 2);
2108         assert(z.value[0, 0, 1] == 2.0 * 2);
2109         assert(z.value[0, 1, 0] == 3.0 * 2);
2110         assert(z.value[0, 1, 1] == 4.0 * 2);
2111         assert(z.value[1, 0, 0] == 5.0 * 3);
2112         assert(z.value[1, 0, 1] == 6.0 * 3);
2113         assert(z.value[1, 1, 0] == 7.0 * 3);
2114         assert(z.value[1, 1, 1] == 8.0 * 3);
2115 
2116         z.backward();
2117 
2118         assert(x.grads[0, 0, 0] == 2.0);
2119         assert(x.grads[0, 0, 1] == 2.0);
2120         assert(x.grads[0, 1, 0] == 2.0);
2121         assert(x.grads[0, 1, 1] == 2.0);
2122         assert(x.grads[1, 0, 0] == 3.0);
2123         assert(x.grads[1, 0, 1] == 3.0);
2124         assert(x.grads[1, 1, 0] == 3.0);
2125         assert(x.grads[1, 1, 1] == 3.0);
2126 
2127         assert(y.grads[0] == 1.0 + 2.0 + 3.0 + 4.0);
2128         assert(y.grads[1] == 5.0 + 6.0 + 7.0 + 8.0);
2129     }
2130 
2131     unittest
2132     {
2133         auto x = tensor!([0, 2])([1.0, 2.0, 3.0, 4.0]);
2134         auto y = tensor!([0], UseGradient.no)([10.0, 20.0]);
2135 
2136         auto z = multicastOp!"*"(x, y);
2137         z.backward();
2138     }
2139 
2140     unittest
2141     {
2142         auto x = tensor!([0, 2], UseGradient.no)([1.0, 2.0, 3.0, 4.0]);
2143         auto y = tensor!([0])([10.0, 20.0]);
2144 
2145         auto z = multicastOp!"*"(x, y);
2146         z.backward();
2147     }
2148 
2149     unittest
2150     {
2151         auto x = tensor!([0, 2], UseGradient.no)([1.0, 2.0, 3.0, 4.0]);
2152         auto y = tensor!([0], UseGradient.no)([10.0, 20.0]);
2153 
2154         auto z = multicastOp!"*"(x, y);
2155 
2156         static assert(!canBackward!(typeof(z)));
2157     }
2158 }
2159 
2160 version (all) // splitEvenOdd2D
2161 {
2162     ///
2163     auto splitEvenOdd2D(size_t axis = 2, T, size_t[] Shape, UseGradient useGrad)(
2164             Tensor!(T, Shape, useGrad) images)
2165             if (Shape.length == 4 && (axis == 2 || axis == 3))
2166     {
2167         static if (axis == 2)
2168         {
2169             static assert(Shape[2] % 2 == 0);
2170             enum height = Shape[2] / 2;
2171             enum width = Shape[3];
2172 
2173             auto y1 = images.value[0 .. $, 0 .. $, 0 .. $ - 1, 0 .. $].strided!2(2).slice();
2174             auto y2 = images.value[0 .. $, 0 .. $, 1 .. $, 0 .. $].strided!2(2).slice();
2175         }
2176         else
2177         {
2178             static assert(Shape[3] % 2 == 0);
2179             enum height = Shape[2];
2180             enum width = Shape[3] / 2;
2181 
2182             auto y1 = images.value[0 .. $, 0 .. $, 0 .. $, 0 .. $ - 1].strided!3(2).slice();
2183             auto y2 = images.value[0 .. $, 0 .. $, 0 .. $, 1 .. $].strided!3(2).slice();
2184         }
2185 
2186         enum size_t[] ReturnShape = [Shape[0], Shape[1], height, width];
2187         static if (useGrad)
2188         {
2189             images.usedCount += 2;
2190 
2191             static if (axis == 2)
2192             {
2193                 return tuple(new Tensor!(T, ReturnShape)(y1, (grads) {
2194                         images.backward((ref imagesGrads) {
2195                             imagesGrads[0 .. $, 0 .. $, 0 .. $ - 1, 0 .. $].strided!2(2)[] += grads[];
2196                         });
2197                     }), new Tensor!(T, ReturnShape)(y2, (grads) {
2198                         images.backward((ref imagesGrads) {
2199                             imagesGrads[0 .. $, 0 .. $, 1 .. $, 0 .. $].strided!2(2)[] += grads[];
2200                         });
2201                     }));
2202             }
2203             else
2204             {
2205                 return tuple(new Tensor!(T, ReturnShape)(y1, (grads) {
2206                         images.backward((ref imagesGrads) {
2207                             imagesGrads[0 .. $, 0 .. $, 0 .. $, 0 .. $ - 1].strided!3(2)[] += grads[];
2208                         });
2209                     }), new Tensor!(T, ReturnShape)(y2, (grads) {
2210                         images.backward((ref imagesGrads) {
2211                             imagesGrads[0 .. $, 0 .. $, 0 .. $, 1 .. $].strided!3(2)[] += grads[];
2212                         });
2213                     }));
2214             }
2215         }
2216         else
2217         {
2218             // dfmt off
2219             return tuple(
2220                 new Tensor!(T, ReturnShape, UseGradient.no)(y1),
2221                 new Tensor!(T, ReturnShape, UseGradient.no)(y2)
2222                 );
2223             // dfmt on
2224         }
2225     }
2226 
2227     /// ditto
2228     unittest
2229     {
2230         auto x = tensor!([0, 1, 2, 2])([1.0, 2.0, 3.0, 4.0]);
2231 
2232         auto sh = splitEvenOdd2D(x); // split by height
2233         assert(sh[0].shape == [1, 1, 1, 2]);
2234         assert(sh[0].value == [[[[1.0, 2.0]]]]);
2235         assert(sh[1].shape == [1, 1, 1, 2]);
2236         assert(sh[1].value == [[[[3.0, 4.0]]]]);
2237 
2238         sh[0].backward();
2239         sh[1].backward();
2240 
2241         auto sw = splitEvenOdd2D!3(x); // split by width
2242         assert(sw[0].shape == [1, 1, 2, 1]);
2243         assert(sw[0].value == [[[[1.0], [3.0]]]]);
2244         assert(sw[1].shape == [1, 1, 2, 1]);
2245         assert(sw[1].value == [[[[2.0], [4.0]]]]);
2246 
2247         sw[0].backward();
2248         sw[1].backward();
2249     }
2250 
2251     /// ditto
2252     unittest
2253     {
2254         auto x = tensor!([0, 2, 2, 2], UseGradient.no)([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
2255 
2256         auto sh = splitEvenOdd2D!2(x); // split by height
2257         assert(sh[0].shape == [1, 2, 1, 2]);
2258         assert(sh[0].value == [[[[1.0, 2.0]], [[5.0, 6.0]]]]);
2259         assert(sh[1].shape == [1, 2, 1, 2]);
2260         assert(sh[1].value == [[[[3.0, 4.0]], [[7.0, 8.0]]]]);
2261 
2262         static assert(!canBackward!(typeof(sh)));
2263 
2264         auto sw = splitEvenOdd2D!3(x); // split by width
2265         assert(sw[0].shape == [1, 2, 2, 1]);
2266         assert(sw[0].value == [[[[1.0], [3.0]], [[5.0], [7.0]]]]);
2267         assert(sw[1].shape == [1, 2, 2, 1]);
2268         assert(sw[1].value == [[[[2.0], [4.0]], [[6.0], [8.0]]]]);
2269 
2270         static assert(!canBackward!(typeof(sw)));
2271     }
2272 }
2273 
2274 version (all) // mergeEvenOdd2D
2275 {
2276     auto mergeEvenOdd2D(size_t axis = 2, T, size_t[] Shape, UseGradient useGrad1, UseGradient useGrad2)(
2277             Tensor!(T, Shape, useGrad1) even, Tensor!(T, Shape, useGrad2) odd)
2278             if (Shape.length == 4)
2279     {
2280         static if (axis == 2)
2281         {
2282             enum height = Shape[2] * 2;
2283             enum width = Shape[3];
2284         }
2285         else
2286         {
2287             enum height = Shape[2];
2288             enum width = Shape[3] * 2;
2289         }
2290         enum size_t[] ReturnShape = [Shape[0], Shape[1], height, width];
2291 
2292         static if (Shape[0] == 0)
2293             const batchSize = even.shape[0];
2294         else
2295             enum batchSize = Shape[0];
2296 
2297         auto y = slice!T(batchSize, Shape[1], height, width);
2298         static if (axis == 2)
2299         {
2300             y[0 .. $, 0 .. $, 0 .. $ - 1, 0 .. $].strided!2(2)[] = even.value[];
2301             y[0 .. $, 0 .. $, 1 .. $, 0 .. $].strided!2(2)[] = odd.value[];
2302         }
2303         else
2304         {
2305             y[0 .. $, 0 .. $, 0 .. $, 0 .. $ - 1].strided!3(2)[] = even.value[];
2306             y[0 .. $, 0 .. $, 0 .. $, 1 .. $].strided!3(2)[] = odd.value[];
2307         }
2308 
2309         static if (useGrad1 || useGrad2)
2310         {
2311             static if (useGrad1)
2312                 even.usedCount++;
2313             static if (useGrad2)
2314                 odd.usedCount++;
2315 
2316             return new Tensor!(T, ReturnShape)(y, (grads) {
2317                 static if (useGrad1)
2318                 {
2319                     even.backward((ref evenGrads) {
2320                         static if (axis == 2)
2321                             evenGrads[] = grads[0 .. $, 0 .. $, 0 .. $ - 1, 0 .. $].strided!2(2);
2322                         else
2323                             evenGrads[] = grads[0 .. $, 0 .. $, 0 .. $, 0 .. $ - 1].strided!3(2);
2324                     });
2325                 }
2326                 static if (useGrad2)
2327                 {
2328                     odd.backward((ref oddGrads) {
2329                         static if (axis == 2)
2330                             oddGrads[] = grads[0 .. $, 0 .. $, 1 .. $, 0 .. $].strided!2(2);
2331                         else
2332                             oddGrads[] = grads[0 .. $, 0 .. $, 0 .. $, 1 .. $].strided!3(2);
2333                     });
2334                 }
2335             });
2336         }
2337         else
2338         {
2339             return new Tensor!(T, ReturnShape, UseGradient.no)(y);
2340         }
2341     }
2342 
2343     unittest
2344     {
2345         auto x = tensor!([0, 1, 2, 2])([1.0, 2.0, 3.0, 4.0]);
2346         auto s = splitEvenOdd2D(x);
2347         auto m = mergeEvenOdd2D(s.expand);
2348 
2349         assert(x.value == m.value);
2350 
2351         m.backward();
2352     }
2353 
2354     unittest
2355     {
2356         auto x = tensor!([0, 1, 2, 2])([1.0, 2.0, 3.0, 4.0]);
2357         auto s = splitEvenOdd2D!3(x);
2358         auto m = mergeEvenOdd2D!3(s.expand);
2359 
2360         assert(x.value == m.value);
2361 
2362         m.backward();
2363     }
2364 
2365     unittest
2366     {
2367         auto x = tensor!([0, 1, 2, 2], UseGradient.no)([1.0, 2.0, 3.0, 4.0]);
2368         auto s = splitEvenOdd2D!2(x);
2369         auto m = mergeEvenOdd2D!2(s.expand);
2370 
2371         assert(x.value == m.value);
2372 
2373         static assert(!canBackward!(typeof(m)));
2374     }
2375 
2376     unittest
2377     {
2378         auto x = tensor!([0, 1, 2, 2], UseGradient.no)([1.0, 2.0, 3.0, 4.0]);
2379         auto s = splitEvenOdd2D!3(x);
2380         auto m = mergeEvenOdd2D!3(s.expand);
2381 
2382         assert(x.value == m.value);
2383 
2384         static assert(!canBackward!(typeof(m)));
2385     }
2386 
2387     unittest
2388     {
2389         auto x = tensor!([0, 1, 2, 2])([1.0, 2.0, 3.0, 4.0]);
2390         auto y = tensor!([0, 1, 2, 2], UseGradient.no)([1.0, 2.0, 3.0, 4.0]);
2391         auto m = mergeEvenOdd2D(x, y);
2392 
2393         assert(m.shape == [1, 1, 4, 2]);
2394         assert(m.value[0, 0, 0, 0] == 1.0);
2395         assert(m.value[0, 0, 0, 1] == 2.0);
2396         assert(m.value[0, 0, 1, 0] == 1.0);
2397         assert(m.value[0, 0, 1, 1] == 2.0);
2398         assert(m.value[0, 0, 2, 0] == 3.0);
2399         assert(m.value[0, 0, 2, 1] == 4.0);
2400         assert(m.value[0, 0, 3, 0] == 3.0);
2401         assert(m.value[0, 0, 3, 1] == 4.0);
2402 
2403         m.backward();
2404     }
2405 
2406     unittest
2407     {
2408         auto x = tensor!([0, 1, 2, 2], UseGradient.no)([1.0, 2.0, 3.0, 4.0]);
2409         auto y = tensor!([0, 1, 2, 2], UseGradient.no)([1.0, 2.0, 3.0, 4.0]);
2410         auto m = mergeEvenOdd2D!3(x, y);
2411 
2412         assert(m.shape == [1, 1, 2, 4]);
2413         assert(m.value[0, 0, 0, 0] == 1.0);
2414         assert(m.value[0, 0, 0, 1] == 1.0);
2415         assert(m.value[0, 0, 0, 2] == 2.0);
2416         assert(m.value[0, 0, 0, 3] == 2.0);
2417         assert(m.value[0, 0, 1, 0] == 3.0);
2418         assert(m.value[0, 0, 1, 1] == 3.0);
2419         assert(m.value[0, 0, 1, 2] == 4.0);
2420         assert(m.value[0, 0, 1, 3] == 4.0);
2421     }
2422 }
2423 
2424 version (all) // concat2D
2425 {
2426     auto concat2D(size_t axis = 1, T, U)(T x, U y)
2427     if (isTensor!T && isTensor!U)
2428     {
2429         static assert(axis == 1, "not implement");
2430 
2431         enum S1 = T.staticShape;
2432         enum S2 = U.staticShape;
2433         static assert(S1.length == 4);
2434         static assert(S2.length == 4);
2435         static assert(S1[2 .. 4] == S2[2 .. 4]);
2436         assert(x.shape[0] == y.shape[0]);
2437 
2438         static if (is(T : Tensor!(E, T.staticShape), E))
2439         {
2440             alias ElementType = E;
2441         }
2442         else static if (is(T : Tensor!(E, T.staticShape, UseGradient.no), E))
2443         {
2444             alias ElementType = E;
2445         }
2446         else
2447         {
2448             static assert(false);
2449         }
2450 
2451         enum size_t[4] ReturnShape = [S1[0], S1[1] + S2[1], S1[2], S1[3]];
2452         
2453         auto z = uninitSlice!ElementType(x.shape[0], S1[1] + S2[1], S1[2], S1[3]);
2454         z[0 .. $, 0 .. S1[1], 0 .. $, 0 .. $] = x.value;
2455         z[0 .. $, S1[1] .. $, 0 .. $, 0 .. $] = y.value;
2456 
2457         static if (canBackward!T || canBackward!U)
2458         {
2459             static if (canBackward!T)
2460                 x.usedCount++;
2461             static if (canBackward!U)
2462                 y.usedCount++;
2463 
2464             return new Tensor!(E, ReturnShape)(z, (grads) {
2465                 static if (canBackward!T)
2466                 {
2467                     x.backward((ref xGrads) {
2468                         xGrads[] += grads[0 .. $, 0 .. S1[1], 0 .. $, 0 .. $];
2469                     });
2470                 }
2471                 static if (canBackward!U)
2472                 {
2473                     y.backward((ref yGrads) {
2474                         yGrads[] += grads[0 .. $, S1[1] .. $, 0 .. $, 0 .. $];
2475                     });
2476                 }
2477             });
2478         }
2479         else
2480         {
2481             return new Tensor!(E, ReturnShape, UseGradient.no)(z);
2482         }
2483     }
2484 
2485     unittest
2486     {
2487         auto x = tensor!([0, 1, 2, 2])([1.0, 2.0, 3.0, 4.0]);
2488         auto y = tensor!([0, 2, 2, 2])([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
2489         auto z = concat2D(x, y);
2490 
2491         assert(z.shape == [1, 3, 2, 2]);
2492         assert(z.value[0, 0, 0, 0] == 1.0);
2493         assert(z.value[0, 0, 0, 1] == 2.0);
2494         assert(z.value[0, 0, 1, 0] == 3.0);
2495         assert(z.value[0, 0, 1, 1] == 4.0);
2496         assert(z.value[0, 1, 0, 0] == 1.0);
2497         assert(z.value[0, 1, 0, 1] == 2.0);
2498         assert(z.value[0, 1, 1, 0] == 3.0);
2499         assert(z.value[0, 1, 1, 1] == 4.0);
2500         assert(z.value[0, 2, 0, 0] == 5.0);
2501         assert(z.value[0, 2, 0, 1] == 6.0);
2502         assert(z.value[0, 2, 1, 0] == 7.0);
2503         assert(z.value[0, 2, 1, 1] == 8.0);
2504 
2505         z.backward();
2506     }
2507     
2508     unittest
2509     {
2510         auto x = tensor!([0, 1, 3, 1], UseGradient.no)([1.0, 2.0, 3.0]);
2511         auto y = tensor!([0, 1, 3, 1])([1.0, 2.0, 3.0]);
2512         auto z = concat2D(x, y);
2513 
2514         assert(z.shape == [1, 2, 3, 1]);
2515         assert(z.value[0, 0, 0, 0] == 1.0);
2516         assert(z.value[0, 0, 1, 0] == 2.0);
2517         assert(z.value[0, 0, 2, 0] == 3.0);
2518         assert(z.value[0, 1, 0, 0] == 1.0);
2519         assert(z.value[0, 1, 1, 0] == 2.0);
2520         assert(z.value[0, 1, 2, 0] == 3.0);
2521 
2522         z.backward();
2523     }
2524     
2525     unittest
2526     {
2527         auto x = tensor!([0, 1, 3, 1])([1.0, 2.0, 3.0]);
2528         auto y = tensor!([0, 1, 3, 1], UseGradient.no)([1.0, 2.0, 3.0]);
2529         auto z = concat2D(x, y);
2530 
2531         assert(z.shape == [1, 2, 3, 1]);
2532         assert(z.value[0, 0, 0, 0] == 1.0);
2533         assert(z.value[0, 0, 1, 0] == 2.0);
2534         assert(z.value[0, 0, 2, 0] == 3.0);
2535         assert(z.value[0, 1, 0, 0] == 1.0);
2536         assert(z.value[0, 1, 1, 0] == 2.0);
2537         assert(z.value[0, 1, 2, 0] == 3.0);
2538 
2539         z.backward();
2540     }
2541     
2542     unittest
2543     {
2544         auto x = tensor!([0, 1, 3, 1], UseGradient.no)([1.0, 2.0, 3.0]);
2545         auto y = tensor!([0, 1, 3, 1], UseGradient.no)([1.0, 2.0, 3.0]);
2546         auto z = concat2D(x, y);
2547 
2548         assert(z.shape == [1, 2, 3, 1]);
2549         assert(z.value[0, 0, 0, 0] == 1.0);
2550         assert(z.value[0, 0, 1, 0] == 2.0);
2551         assert(z.value[0, 0, 2, 0] == 3.0);
2552         assert(z.value[0, 1, 0, 0] == 1.0);
2553         assert(z.value[0, 1, 1, 0] == 2.0);
2554         assert(z.value[0, 1, 2, 0] == 3.0);
2555 
2556         static assert(!canBackward!(typeof(z)));
2557     }
2558 }
2559 
2560 version (all) // projection1D
2561 {
2562     auto projection1D(size_t axis, T, size_t[] ShapeW, UseGradient useGradW, size_t[] ShapeX, UseGradient useGradX)(Tensor!(T, ShapeX, useGradX) x, Tensor!(T, ShapeW, useGradW) w)
2563     if ((axis == 2 || axis == 3) && ShapeX.length == 4 && ShapeW.length == 2 && ShapeX[axis] == ShapeW[0])
2564     {
2565         enum H = axis == 2 ? ShapeW[1] : ShapeX[2];
2566         enum W = axis == 3 ? ShapeW[1] : ShapeX[3];
2567         auto y = uninitSlice!T(x.shape[0], x.shape[1], H, W);
2568 
2569         import mir.blas : gemm;
2570 
2571         auto tx = x.value.ipack!2.flattened;
2572         auto ty = y.ipack!2.flattened;
2573         static if (axis == 2)
2574         {
2575             auto tw = w.value.transposed;
2576             foreach (t; zip(tx, ty))
2577             {
2578                 gemm(T(1), tw, t[0], T(0), t[1]);
2579             }
2580         }
2581         else static if (axis == 3)
2582         {
2583             foreach (t; zip(tx, ty))
2584             {
2585                 gemm(T(1), t[0], w.value, T(0), t[1]);
2586             }
2587         }
2588 
2589         enum size_t[4] ReturnShape = [ShapeX[0], ShapeX[1], H, W];
2590         static if (useGradW || useGradX)
2591         {
2592             static if (useGradW)
2593                 w.usedCount++;
2594             static if (useGradX)
2595                 x.usedCount++;
2596 
2597             return new Tensor!(T, ReturnShape)(y, (grads) {
2598                 static if (useGradW)
2599                 {
2600                     w.backward((ref wGrads) {
2601                         auto tx = x.value.ipack!2.flattened;
2602                         auto tg = grads.ipack!2.flattened;
2603                         foreach (t; zip(tx, tg))
2604                         {
2605                             static if (axis == 2)
2606                             {
2607                                 gemm(T(1), t[0], t[1].transposed, T(1), wGrads);
2608                             }
2609                             else static if (axis == 3)
2610                             {
2611                                 gemm(T(1), t[0].transposed, t[1], T(1), wGrads);
2612                             }
2613                         }
2614                     });
2615                 }
2616                 static if (useGradX)
2617                 {
2618                     x.backward((ref xGrads) {
2619                         auto txg = xGrads.ipack!2.flattened;
2620                         auto tg = grads.ipack!2.flattened;
2621                         foreach (t; zip(txg, tg))
2622                         {
2623                             static if (axis == 2)
2624                             {
2625                                 gemm(T(1), w.value, t[1], T(1), t[0]);
2626                             }
2627                             else static if (axis == 3)
2628                             {
2629                                 gemm(T(1), w.value, t[1].transposed, T(1), t[0].transposed);
2630                             }
2631                         }
2632                     });
2633                 }
2634             });
2635         }
2636         else
2637         {
2638             return new Tensor!(T, ReturnShape, UseGradient.no)(y);
2639         }
2640     }
2641 
2642     unittest
2643     {
2644         auto x = tensor!([1, 1, 2, 2])([1.0f, 2.0f, 3.0f, 4.0f]);
2645         auto w = tensor!([2, 3])([1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f]);
2646 
2647         auto y = projection1D!2(x, w);
2648         assert(y.shape == [1, 1, 3, 2]);
2649         assert(y.value[0, 0, 0, 0] == 13);
2650         assert(y.value[0, 0, 0, 1] == 18);
2651         assert(y.value[0, 0, 1, 0] == 17);
2652         assert(y.value[0, 0, 1, 1] == 24);
2653         assert(y.value[0, 0, 2, 0] == 21);
2654         assert(y.value[0, 0, 2, 1] == 30);
2655 
2656         y.backward();
2657 
2658         assert(x.grads[0, 0, 0, 0] == 6);
2659         assert(x.grads[0, 0, 0, 1] == 6);
2660         assert(x.grads[0, 0, 1, 0] == 15);
2661         assert(x.grads[0, 0, 1, 1] == 15);
2662 
2663         assert(w.grads[0, 0] == 3);
2664         assert(w.grads[0, 1] == 3);
2665         assert(w.grads[0, 2] == 3);
2666         assert(w.grads[1, 0] == 7);
2667         assert(w.grads[1, 1] == 7);
2668         assert(w.grads[1, 2] == 7);
2669     }
2670 
2671     unittest
2672     {
2673         auto x = tensor!([1, 1, 2, 3])([1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f]);
2674         auto w = tensor!([2, 3])([1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f]);
2675 
2676         auto y = projection1D!2(x, w);
2677         assert(y.shape == [1, 1, 3, 3]);
2678         assert(y.value[0, 0, 0, 0] == 17);
2679         assert(y.value[0, 0, 0, 1] == 22);
2680         assert(y.value[0, 0, 0, 2] == 27);
2681         assert(y.value[0, 0, 1, 0] == 22);
2682         assert(y.value[0, 0, 1, 1] == 29);
2683         assert(y.value[0, 0, 1, 2] == 36);
2684         assert(y.value[0, 0, 2, 0] == 27);
2685         assert(y.value[0, 0, 2, 1] == 36);
2686         assert(y.value[0, 0, 2, 2] == 45);
2687     }
2688 
2689     unittest
2690     {
2691         auto x = tensor!([1, 1, 2, 2])([1.0f, 2.0f, 3.0f, 4.0f]);
2692         auto w = tensor!([2, 3])([1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f]);
2693 
2694         auto y = projection1D!3(x, w);
2695         assert(y.shape == [1, 1, 2, 3]);
2696         assert(y.value[0, 0, 0, 0] == 9);
2697         assert(y.value[0, 0, 0, 1] == 12);
2698         assert(y.value[0, 0, 0, 2] == 15);
2699         assert(y.value[0, 0, 1, 0] == 19);
2700         assert(y.value[0, 0, 1, 1] == 26);
2701         assert(y.value[0, 0, 1, 2] == 33);
2702 
2703         y.backward();
2704 
2705         assert(x.grads[0, 0, 0, 0] == 6);
2706         assert(x.grads[0, 0, 0, 1] == 15);
2707         assert(x.grads[0, 0, 1, 0] == 6);
2708         assert(x.grads[0, 0, 1, 1] == 15);
2709 
2710         assert(w.grads[0, 0] == 4);
2711         assert(w.grads[0, 1] == 4);
2712         assert(w.grads[0, 2] == 4);
2713         assert(w.grads[1, 0] == 6);
2714         assert(w.grads[1, 1] == 6);
2715         assert(w.grads[1, 2] == 6);
2716     }
2717 
2718     unittest
2719     {
2720         auto x = tensor!([1, 1, 3, 2])([1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f]);
2721         auto w = tensor!([2, 3])([1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f]);
2722 
2723         auto y = projection1D!3(x, w);
2724         assert(y.shape == [1, 1, 3, 3]);
2725         assert(y.value[0, 0, 0, 0] == 9);
2726         assert(y.value[0, 0, 0, 1] == 12);
2727         assert(y.value[0, 0, 0, 2] == 15);
2728         assert(y.value[0, 0, 1, 0] == 19);
2729         assert(y.value[0, 0, 1, 1] == 26);
2730         assert(y.value[0, 0, 1, 2] == 33);
2731         assert(y.value[0, 0, 2, 0] == 29);
2732         assert(y.value[0, 0, 2, 1] == 40);
2733         assert(y.value[0, 0, 2, 2] == 51);
2734     }
2735 }
2736 
2737 version (all) // conv2D
2738 {
2739     size_t[] conv2DShape(size_t[] Shape, size_t channel_out, size_t[] kernel_size, size_t[] padding, size_t[] stride, size_t[] dilation)
2740     {
2741         import std.math : floor;
2742 
2743         const H_in = Shape[2];
2744         const W_in = Shape[3];
2745 
2746         const H_out = cast(size_t) floor(real(H_in + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1);
2747         const W_out = cast(size_t) floor(real(W_in + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1);
2748 
2749         return [Shape[0], channel_out, H_out, W_out];
2750     }
2751 
2752     unittest
2753     {
2754         auto shape = conv2DShape([0, 1, 28, 28], 4, [3, 3], [0, 0], [1, 1], [1, 1]);
2755         assert(shape == [0, 4, 26, 26]);
2756     }
2757 
2758     auto conv2D(
2759         size_t[] padding = [0, 0],
2760         size_t[] stride = [1, 1],
2761         size_t[] dilation = [1, 1],
2762         T,
2763         size_t[] ShapeX, UseGradient useGradX,
2764         size_t[] ShapeW, UseGradient useGradW,
2765         size_t[] ShapeB, UseGradient useGradB
2766     )(Tensor!(T, ShapeX, useGradX) x, Tensor!(T, ShapeW, useGradW) weights, Tensor!(T, ShapeB, useGradB) bias)
2767     {
2768         static assert(padding.length == 2);
2769         static assert(stride.length == 2);
2770         static assert(dilation.length == 2);
2771         static assert(stride == [1, 1], "conv2d : stride is not implemented");
2772         static assert(dilation == [1, 1], "conv2d : dilation is not implemented");
2773 
2774         static assert(ShapeX.length == 4);
2775         static assert(ShapeW.length == 4);
2776         static assert(ShapeB.length == 1);
2777         static assert(ShapeX[1] == ShapeW[1]);
2778         static assert(ShapeW[0] == ShapeB[0]);
2779 
2780         enum ReturnShape = conv2DShape(ShapeX, ShapeB[0], ShapeW[2 .. 4], padding, stride, dilation);
2781 
2782         enum C = ShapeX[1];
2783         enum C_out = ReturnShape[1];
2784         enum TempH = ShapeX[2] + 2 * padding[0];
2785         enum TempW = ShapeX[3] + 2 * padding[1];
2786         enum usePadding = padding[0] != 0 || padding[1] != 0;
2787         static if (usePadding)
2788         {
2789             auto temp = slice!T([C, TempH, TempW], 0);
2790         }
2791         auto y = uninitSlice!T(x.shape[0], ReturnShape[1], ReturnShape[2], ReturnShape[3]);
2792 
2793         // prepare im2col
2794         int err;
2795         auto ty = y.reshape([x.shape[0], ReturnShape[1], ReturnShape[2] * ReturnShape[3]], err);
2796         assert(err == 0);
2797         auto v = uninitSlice!T(ReturnShape[2] * ReturnShape[3], C * ShapeW[2] * ShapeW[3] + 1);
2798         v[0 .. $, $ - 1 .. $] = 1;
2799         auto w = uninitSlice!T(C_out, C * ShapeW[2] * ShapeW[3] + 1);
2800         foreach (i; 0 .. C_out)
2801         {
2802             w[i].flattened[0 .. $ - 1] = weights.value[i].flattened;
2803             w[i].back = bias.value[i];
2804         }
2805 
2806         foreach (i; 0 .. x.shape[0])
2807         {
2808             static if (usePadding)
2809             {
2810                 temp[0 .. $, padding[0] .. $ - padding[0], padding[1] .. $ - padding[1]] = x.value[i];
2811                 auto wins = temp.windows(C, ShapeW[2], ShapeW[3]);
2812             }
2813             else
2814             {
2815                 auto wins = x.value[i].windows(C, ShapeW[2], ShapeW[3]);
2816             }
2817             foreach (t; zip(v.ipack!1, wins.flattened))
2818             {
2819                 t[0].flattened[0 .. $ - 1] = t[1].flattened;
2820             }
2821 
2822             import mir.blas : gemm;
2823 
2824             gemm(T(1), v, w.transposed, T(0), ty[i].transposed);
2825         }
2826         
2827         static if (useGradX || useGradW || useGradB)
2828         {
2829             static if (useGradX)
2830                 x.usedCount++;
2831             static if (useGradW)
2832                 weights.usedCount++;
2833             static if (useGradB)
2834                 bias.usedCount++;
2835 
2836             return new Tensor!(T, ReturnShape)(y, (grads) {
2837                 static if (useGradX)
2838                 {
2839                     x.backward((ref xGrads) {
2840                         foreach (i; 0 .. grads.shape[0])
2841                         {
2842                             static if (usePadding)
2843                             {
2844                                 temp.flattened[] = 0;
2845                                 auto wins = temp.windows(ShapeX[1], ShapeW[2], ShapeW[3]);
2846                             }
2847                             else
2848                             {
2849                                 auto wins = xGrads[i].windows(ShapeX[1], ShapeW[2], ShapeW[3]);
2850                             }
2851                             foreach (h; 0 .. ReturnShape[2])
2852                             {
2853                                 foreach (w; 0 .. ReturnShape[3])
2854                                 {
2855                                     auto tw = wins[0, h, w];
2856                                     auto tg = grads.transposed!(0, 2, 3, 1)[i, h, w];
2857                                     foreach (c; 0 .. ReturnShape[1])
2858                                     {
2859                                         tw[] += weights.value[c] * tg[c];
2860                                     }
2861                                 }
2862                             }
2863                             static if (usePadding)
2864                             {
2865                                 xGrads[i][] += temp[0 .. $, padding[0] .. $ - padding[0], padding[1] .. $ - padding[1]];
2866                             }
2867                         }
2868                     });
2869                 }
2870                 static if (useGradW)
2871                 {
2872                     weights.backward((ref wGrads) {
2873                         static if (usePadding)
2874                         {
2875                             temp.flattened[] = 0;
2876                         }
2877                         foreach (i; 0 .. grads.shape[0])
2878                         {
2879                             static if (usePadding)
2880                             {
2881                                 temp[0 .. $, padding[0] .. $ - padding[0], padding[1] .. $ - padding[1]] = x.value[i];
2882                                 auto wins = temp.windows(C, ShapeW[2], ShapeW[3]);
2883                             }
2884                             else
2885                             {
2886                                 auto wins = x.value[i].windows(C, ShapeW[2], ShapeW[3]);
2887                             }
2888                             foreach (h; 0 .. ReturnShape[2])
2889                             {
2890                                 foreach (w; 0 .. ReturnShape[3])
2891                                 {
2892                                     auto tw = wins[0, h, w];
2893                                     auto tg = grads.transposed!(0, 2, 3)[i, h, w];
2894                                     foreach (c; 0 .. ShapeW[0])
2895                                     {
2896                                         wGrads[c][] += tg[c] * tw;
2897                                     }
2898                                 }
2899                             }
2900                         }
2901                     });
2902                 }
2903                 static if (useGradB)
2904                 {
2905                     bias.backward((ref bGrads) {
2906                         import mir.math.sum : sum;
2907 
2908                         bGrads[] += grads.transposed!1.ipack!1.map!sum;
2909                     });
2910                 }
2911             });
2912         }
2913         else
2914         {
2915             return new Tensor!(T, ReturnShape, UseGradient.no)(y);
2916         }
2917     }
2918 
2919     unittest
2920     {
2921         // dfmt off
2922         auto images = tensor!([0, 1, 5, 5])([
2923              1.0,  2.0,  3.0,  4.0,  5.0,
2924              6.0,  7.0,  8.0,  9.0, 10.0,
2925             11.0, 12.0, 13.0, 14.0, 15.0,
2926             16.0, 17.0, 18.0, 19.0, 20.0,
2927             21.0, 22.0, 23.0, 24.0, 25.0]);
2928         // dfmt on
2929 
2930         // dfmt off
2931         auto weights = tensor!([2, 1, 3, 3])([
2932              1.0,  2.0,  3.0,  4.0,  5.0,  6.0,  7.0,  8.0,  9.0,
2933             10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0
2934             ]);
2935         // dfmt on
2936 
2937         auto bias = tensor!([2])([1.0, 2.0]);
2938 
2939         auto y = conv2D(images, weights, bias);
2940 
2941         assert(y.shape == [1, 2, 3, 3]);
2942         assert(y.value[0, 0, 0, 0] == 412);
2943         assert(y.value[0, 0, 0, 1] == 457);
2944         assert(y.value[0, 0, 0, 2] == 502);
2945         assert(y.value[0, 0, 1, 0] == 637);
2946         assert(y.value[0, 0, 1, 1] == 682);
2947         assert(y.value[0, 0, 1, 2] == 727);
2948         assert(y.value[0, 0, 2, 0] == 862);
2949         assert(y.value[0, 0, 2, 1] == 907);
2950         assert(y.value[0, 0, 2, 2] == 952);
2951 
2952         assert(y.value[0, 1, 0, 0] == 980);
2953         assert(y.value[0, 1, 0, 1] == 1106);
2954         assert(y.value[0, 1, 0, 2] == 1232);
2955         assert(y.value[0, 1, 1, 0] == 1610);
2956         assert(y.value[0, 1, 1, 1] == 1736);
2957         assert(y.value[0, 1, 1, 2] == 1862);
2958         assert(y.value[0, 1, 2, 0] == 2240);
2959         assert(y.value[0, 1, 2, 1] == 2366);
2960         assert(y.value[0, 1, 2, 2] == 2492);
2961 
2962         y.backward();
2963 
2964         assert(images.grads[0, 0, 0, 0] == 11);
2965         assert(images.grads[0, 0, 0, 1] == 24);
2966         assert(images.grads[0, 0, 0, 2] == 39);
2967         assert(images.grads[0, 0, 0, 3] == 28);
2968         assert(images.grads[0, 0, 0, 4] == 15);
2969         assert(images.grads[0, 0, 1, 0] == 28);
2970         assert(images.grads[0, 0, 1, 1] == 60);
2971         assert(images.grads[0, 0, 1, 2] == 96);
2972         assert(images.grads[0, 0, 1, 3] == 68);
2973         assert(images.grads[0, 0, 1, 4] == 36);
2974         assert(images.grads[0, 0, 2, 0] == 51);
2975         assert(images.grads[0, 0, 2, 1] == 108);
2976         assert(images.grads[0, 0, 2, 2] == 171);
2977         assert(images.grads[0, 0, 2, 3] == 120);
2978         assert(images.grads[0, 0, 2, 4] == 63);
2979         assert(images.grads[0, 0, 3, 0] == 40);
2980         assert(images.grads[0, 0, 3, 1] == 84);
2981         assert(images.grads[0, 0, 3, 2] == 132);
2982         assert(images.grads[0, 0, 3, 3] == 92);
2983         assert(images.grads[0, 0, 3, 4] == 48);
2984         assert(images.grads[0, 0, 4, 0] == 23);
2985         assert(images.grads[0, 0, 4, 1] == 48);
2986         assert(images.grads[0, 0, 4, 2] == 75);
2987         assert(images.grads[0, 0, 4, 3] == 52);
2988         assert(images.grads[0, 0, 4, 4] == 27);
2989 
2990         assert(weights.grads[0, 0, 0, 0] == 63);
2991         assert(weights.grads[0, 0, 0, 1] == 72);
2992         assert(weights.grads[0, 0, 0, 2] == 81);
2993         assert(weights.grads[0, 0, 1, 0] == 108);
2994         assert(weights.grads[0, 0, 1, 1] == 117);
2995         assert(weights.grads[0, 0, 1, 2] == 126);
2996         assert(weights.grads[0, 0, 2, 0] == 153);
2997         assert(weights.grads[0, 0, 2, 1] == 162);
2998         assert(weights.grads[0, 0, 2, 2] == 171);
2999         assert(weights.grads[1, 0, 0, 0] == 63);
3000         assert(weights.grads[1, 0, 0, 1] == 72);
3001         assert(weights.grads[1, 0, 0, 2] == 81);
3002         assert(weights.grads[1, 0, 1, 0] == 108);
3003         assert(weights.grads[1, 0, 1, 1] == 117);
3004         assert(weights.grads[1, 0, 1, 2] == 126);
3005         assert(weights.grads[1, 0, 2, 0] == 153);
3006         assert(weights.grads[1, 0, 2, 1] == 162);
3007         assert(weights.grads[1, 0, 2, 2] == 171);
3008 
3009         assert(bias.grads[0] == 9);
3010         assert(bias.grads[1] == 9);
3011     }
3012     
3013     unittest
3014     {
3015         // dfmt off
3016         auto x = tensor!([2, 1, 3, 3])([
3017             -1.0,  0.0,  1.0,
3018             0.0,  1.0,  0.0,
3019             1.0,  0.0, -1.0,
3020             1.0, -1.0, -0.5,
3021             -1.0,  1.0, -1.0,
3022             -0.5, -1.0,  1.0,
3023         ]);
3024 
3025         auto w = tensor!([1, 1, 3, 3])([
3026             -0.5,  -0.5,  0.75,
3027             -0.5,   1.0, -0.5,
3028             0.75, -0.5, -0.5,
3029         ]);
3030 
3031         auto b = tensor!([1])([0.0]);
3032         // dfmt on
3033 
3034         auto y = conv2D!([1, 1])(x, w, b);
3035 
3036         assert(y.shape == [2, 1, 3, 3]);
3037 
3038         assert(y.value[0, 0, 0, 0] == -1.5);
3039         assert(y.value[0, 0, 0, 1] == -0.5);
3040         assert(y.value[0, 0, 0, 2] == 1.75);
3041         assert(y.value[0, 0, 1, 0] == -0.5);
3042         assert(y.value[0, 0, 1, 1] == 3.5);
3043         assert(y.value[0, 0, 1, 2] == -0.5);
3044         assert(y.value[0, 0, 2, 0] == 1.75);
3045         assert(y.value[0, 0, 2, 1] == -0.5);
3046         assert(y.value[0, 0, 2, 2] == -1.5);
3047         
3048         assert(y.value[1, 0, 0, 0] == 1.5);
3049         assert(y.value[1, 0, 0, 1] == -2);
3050         assert(y.value[1, 0, 0, 2] == 1.25);
3051         assert(y.value[1, 0, 1, 0] == -2);
3052         assert(y.value[1, 0, 1, 1] == 1.25);
3053         assert(y.value[1, 0, 1, 2] == -2);
3054         assert(y.value[1, 0, 2, 0] == 1.25);
3055         assert(y.value[1, 0, 2, 1] == -2);
3056         assert(y.value[1, 0, 2, 2] == 1.5);
3057 
3058         y.backward();
3059     }
3060 }