1 module golem.tensor;
2 
3 import golem.util;
4 
5 import mir.ndslice;
6 static import numir;
7 
8 import std.typecons : Flag, Yes, No;
9 
10 alias UseGradient = Flag!"gradient";
11 
12 enum isTensor(T, U, size_t[] Shape, UseGradient hasGradient) = is(T == Tensor!(U, Shape, hasGradient));
13 
14 enum isTensor(T) = is(T == Tensor!(U, Shape, flag), U, size_t[] Shape, UseGradient flag);
15 
16 enum isTensor(T, UseGradient hasGradient) = is(T == Tensor!(U, Shape, hasGradient), U, size_t[] Shape);
17 
18 enum isTensor(T, U, size_t[] Shape) = isTensor!(T, U, Shape, Yes.gradient) || isTensor!(T, U, Shape, No.gradient);
19 
20 enum canBackward(T) = isTensor!(T, Yes.gradient);
21 
22 bool testCompatibleStaticShape(size_t[] lhsShape, size_t[] rhsShape) @safe @nogc pure nothrow
23 {
24     assert(lhsShape.length > 0);
25     assert(rhsShape.length > 0);
26 
27     if (lhsShape.length != rhsShape.length)
28         return false;
29     
30     foreach (i; 0 .. lhsShape.length)
31     {
32         if (i == 0 && (lhsShape[0] == 0 || rhsShape[0] == 0))
33             continue;
34         
35         if (lhsShape[i] != rhsShape[i])
36             return false;
37     }
38     return true;
39 }
40 
41 unittest
42 {
43     enum testShape = testCompatibleStaticShape([2, 2], [2, 2]);
44     static assert(testShape);
45 
46     static assert(testCompatibleStaticShape([2, 3], [2, 3]));
47     static assert(testCompatibleStaticShape([2, 3], [0, 3]));
48     static assert(testCompatibleStaticShape([0, 3], [0, 3]));
49     static assert(testCompatibleStaticShape([0, 3], [2, 3]));
50     static assert(testCompatibleStaticShape([2, 3, 28, 28], [2, 3, 28, 28]));
51 
52     static assert(!testCompatibleStaticShape([2, 3], [2, 3, 3]));
53     static assert(!testCompatibleStaticShape([2, 3, 3], [2, 3]));
54     static assert(!testCompatibleStaticShape([1, 3], [2, 3]));
55     static assert(!testCompatibleStaticShape([2, 4], [2, 3]));
56 }
57 
58 
59 template commonGradientType(T1, T2)
60     if (isTensor!T1 && isTensor!T2)
61 {
62     static if (canBackward!T1 || canBackward!T2)
63     {
64         enum commonGradientType = UseGradient.yes;
65     }
66     else
67     {
68         enum commonGradientType = UseGradient.no;
69     }
70 }
71 
72 class Tensor(T, size_t[] Shape, UseGradient hasGradient = UseGradient.yes)
73 {
74     alias ElementType = T;
75 
76     enum staticShape = Shape;
77 
78     static if (Shape[0] != 0)
79     {
80         alias shape = staticShape;
81     }
82     else
83     {
84         alias shape = runtimeShape;
85     }
86 
87     alias Value = Slice!(T*, Shape.length);
88 
89     Value value;
90     static if (hasGradient)
91     {
92         Value grads;
93 
94         bool requireGrad = true;
95         size_t usedCount;
96         size_t backwardCount;
97         void delegate(Value grads) backwardFn;
98     }
99 
100     static if (Shape[0] != 0)
101     {
102         this(T init)
103         {
104             this(slice!T(Shape, init));
105         }
106     }
107 
108     this(RoR)(RoR data)
109     {
110         this(fuse(data));
111     }
112 
113     this(T[] data)
114     {
115         static if (Shape[0] == 0)
116         {
117             static if (Shape.length == 1)
118             {
119                 const batchSize = data.length;
120                 auto value = data.sliced(batchSize);
121             }
122             else
123             {
124                 const batchSize = data.length / elementSize(Shape[1 .. $]);
125                 import std.format : format;
126                 assert(batchSize * elementSize(Shape[1 .. $]) == data.length, format!"The number of elements in the data must match the shape of the tensor. Shape = %s, length=%s)"(Shape, data.length));
127                 auto value = data.sliced([
128                         batchSize, expandShape!(Shape[1 .. $])
129                         ]);
130             }
131         }
132         else
133         {
134             auto value = data.sliced(Shape);
135         }
136 
137         this(value);
138     }
139 
140     static if (hasGradient)
141     {
142         this(Value value)
143         {
144             this(value, null);
145         }
146 
147         this(Value value, void delegate(Value grad) gradFn)
148         {
149             this(value, numir.zeros_like(value), gradFn);
150         }
151 
152         private this(Value value, Value grads, void delegate(Value grad) gradFn)
153         {
154             this.value = value;
155             this.grads = grads;
156             this.backwardFn = gradFn;
157         }
158     }
159     else
160     {
161         this(Value value)
162         {
163             this.value = value;
164         }
165     }
166 
167     size_t[Shape.length] runtimeShape() const pure nothrow @safe @nogc
168     {
169         return this.value.shape;
170     }
171 
172     static if (hasGradient)
173     {
174         void resetGrads()
175         {
176             grads[] = T(0);
177         }
178 
179         void backward()(void delegate(ref Value grads) update)
180         {
181             if (requireGrad)
182             {
183                 update(this.grads);
184                 ++backwardCount;
185 
186                 if (backwardCount == usedCount)
187                 {
188                     if (this.backwardFn)
189                         this.backwardFn(this.grads);
190                     this.usedCount = 0;
191                     this.backwardCount = 0;
192                 }
193             }
194         }
195 
196         void backward(U)(U grads)
197         {
198             if (requireGrad)
199             {
200                 import std.format : format;
201 
202                 assert(this.grads.shape == grads.shape,
203                         "%s != %s".format(this.grads.shape, grads.shape));
204                 this.grads[] += grads;
205                 ++backwardCount;
206 
207                 if (backwardCount == usedCount)
208                 {
209                     if (this.backwardFn)
210                         this.backwardFn(this.grads);
211                     this.usedCount = 0;
212                     this.backwardCount = 0;
213                 }
214             }
215         }
216 
217         void backward()
218         {
219             if (requireGrad)
220             {
221                 this.grads[] = T(1);
222                 if (this.backwardFn)
223                     this.backwardFn(this.grads);
224                 this.usedCount = 0;
225                 this.backwardCount = 0;
226             }
227         }
228     }
229 
230     Tensor!(T, Shape, commonGradientType!(typeof(this), RTensor)) opBinary(string op : "+", RTensor)(RTensor rhs)
231         if (isTensor!RTensor)
232     {
233         import std.format : format;
234         static assert(testCompatibleStaticShape(Shape, RTensor.staticShape), format!`Mismatch static shape %s != %s`(Shape, RTensor.staticShape));
235         assert(testCompatibleStaticShape(shape, rhs.shape), format!"Mismatch runtime shape %s != %s"(shape, rhs.shape));
236 
237         auto y = slice(this.value + rhs.value);
238 
239         static if (canBackward!(typeof(this))) this.usedCount++;
240         static if (canBackward!(typeof(rhs))) rhs.usedCount++;
241 
242         static if (canBackward!(typeof(this)) || canBackward!(typeof(rhs)))
243         {
244             return new Tensor!(T, Shape)(y, (Value grads) {
245                 static if (canBackward!(typeof(this))) this.backward(grads);
246                 static if (canBackward!(typeof(rhs))) rhs.backward(grads);
247             });
248         }
249         else
250         {
251             return new Tensor!(T, Shape, No.gradient)(y);
252         }
253     }
254     
255     Tensor!(T, Shape, hasGradient) opBinary(string op : "+")(T rhs)
256     {
257         auto y = slice(this.value[] + rhs);
258 
259         static if (canBackward!(typeof(this)))
260         {
261             this.usedCount++;
262             return new Tensor!(T, Shape)(y, (Value grads) {
263                 this.backward(grads);
264             });
265         }
266         else
267         {
268             return new Tensor!(T, Shape, No.gradient)(y);
269         }
270     }
271 
272     Tensor!(T, Shape, hasGradient) opBinaryRight(string op : "+")(T lhs)
273     {
274         auto y = slice(lhs + this.value[]);
275 
276         static if (canBackward!(typeof(this)))
277         {
278             this.usedCount++;
279             return new Tensor!(T, Shape)(y, (Value grads) {
280                 this.backward(grads);
281             });
282         }
283         else
284         {
285             return new Tensor!(T, Shape, No.gradient)(y);
286         }
287     }
288 
289     Tensor!(T, Shape, commonGradientType!(typeof(this), RTensor)) opBinary(string op : "-", RTensor)(RTensor rhs)
290     {
291         import std.format : format;
292         static assert(testCompatibleStaticShape(Shape, RTensor.staticShape), format!`Mismatch static shape %s != %s`(Shape, RTensor.staticShape));
293         assert(testCompatibleStaticShape(shape, rhs.shape), format!"Mismatch runtime shape %s != %s"(shape, rhs.shape));
294 
295         auto y = slice(this.value - rhs.value);
296 
297         static if (canBackward!(typeof(this))) this.usedCount++;
298         static if (canBackward!(typeof(rhs))) rhs.usedCount++;
299 
300         static if (canBackward!(typeof(this)) || canBackward!(typeof(rhs)))
301         {
302             return new Tensor!(T, Shape)(y, (Value grads) {
303                 static if (canBackward!(typeof(this))) this.backward((ref xGrads) { xGrads[] += grads[]; });
304                 static if (canBackward!(typeof(rhs))) rhs.backward((ref yGrads) { yGrads[] -= grads[]; });
305             });
306         }
307         else
308         {
309             return new Tensor!(T, Shape, No.gradient)(y);
310         }
311     }
312 
313     Tensor!(T, Shape, hasGradient) opBinary(string op : "-")(T rhs)
314     {
315         auto y = slice(this.value[] - rhs);
316 
317         static if (canBackward!(typeof(this)))
318         {
319             this.usedCount++;
320             return new Tensor!(T, Shape)(y, (Value grads) {
321                 this.backward(grads);
322             });
323         }
324         else
325         {
326             return new Tensor!(T, Shape, No.gradient)(y);
327         }
328     }
329     
330     Tensor!(T, Shape, hasGradient) opBinaryRight(string op : "-")(T rhs)
331     {
332         auto y = slice(rhs - this.value[]);
333 
334         static if (canBackward!(typeof(this)))
335         {
336             this.usedCount++;
337             return new Tensor!(T, Shape)(y, (Value grads) {
338                 this.backward((ref yGrads) { yGrads[] -= grads[]; });
339             });
340         }
341         else
342         {
343             return new Tensor!(T, Shape, No.gradient)(y);
344         }
345     }
346 
347     Tensor!(T, Shape, commonGradientType!(typeof(this), RTensor)) opBinary(string op : "*", RTensor)(RTensor rhs)
348     if (isTensor!RTensor)
349     {
350         import std.format : format;
351         static assert(testCompatibleStaticShape(Shape, RTensor.staticShape), format!`Mismatch static shape %s != %s`(Shape, RTensor.staticShape));
352         assert(testCompatibleStaticShape(shape, rhs.shape), format!"Mismatch runtime shape %s != %s"(shape, rhs.shape));
353 
354         static if (is(typeof(this) == typeof(rhs)))
355         {
356             if (this is rhs)
357             {
358                 auto y = slice(this.value * this.value);
359                 static if (canBackward!(typeof(this)))
360                 {
361                     this.usedCount++;
362                     return new Tensor!(T, Shape)(y, (Value grads) {
363                         this.backward(2 * this.value * grads);
364                     });
365                 }
366                 else
367                 {
368                     return new Tensor!(T, Shape, No.gradient)(y);
369                 }
370             }
371             else
372             {
373                 auto y = slice(this.value * rhs.value);
374                 static if (canBackward!(typeof(this))) this.usedCount++;
375                 static if (canBackward!(typeof(rhs))) rhs.usedCount++;
376 
377                 static if (canBackward!(typeof(this)) || canBackward!(typeof(rhs)))
378                 {
379                     return new Tensor!(T, Shape)(y, (Value grads) {
380                         static if (canBackward!(typeof(this))) this.backward(rhs.value * grads);
381                         static if (canBackward!(typeof(rhs))) rhs.backward(this.value * grads);
382                     });
383                 }
384                 else
385                 {
386                     return new Tensor!(T, Shape, No.gradient)(y);
387                 }
388             }
389         }
390         else
391         {
392             auto y = slice(this.value * rhs.value);
393             static if (canBackward!(typeof(this))) this.usedCount++;
394             static if (canBackward!(typeof(rhs))) rhs.usedCount++;
395 
396             static if (canBackward!(typeof(this)) || canBackward!(typeof(rhs)))
397             {
398                 return new Tensor!(T, Shape)(y, (Value grads) {
399                     static if (canBackward!(typeof(this))) this.backward(rhs.value * grads);
400                     static if (canBackward!(typeof(rhs))) rhs.backward(this.value * grads);
401                 });
402             }
403             else
404             {
405                 return new Tensor!(T, Shape, No.gradient)(y);
406             }
407         }
408     }
409 
410     Tensor!(T, Shape, hasGradient) opBinary(string op : "*")(T rhs)
411     {
412         auto y = slice(this.value[] * rhs);
413 
414         static if (canBackward!(typeof(this)))
415         {
416             this.usedCount++;
417 
418             return new typeof(this)(y, (grads) {
419                 this.backward(grads[] * rhs);
420             });
421         }
422         else
423         {
424             return new typeof(this)(y);
425         }
426     }
427 
428     
429     Tensor!(T, Shape, hasGradient) opBinaryRight(string op : "*")(T lhs)
430     {
431         auto y = slice(lhs * this.value[]);
432 
433         static if (canBackward!(typeof(this)))
434         {
435             this.usedCount++;
436 
437             return new typeof(this)(y, (grads) {
438                 this.backward(lhs * grads[]);
439             });
440         }
441         else
442         {
443             return new typeof(this)(y);
444         }
445     }
446 
447 
448     Tensor!(T, Shape, commonGradientType!(typeof(this), RTensor)) opBinary(string op : "/", RTensor)(RTensor rhs)
449         if (isTensor!RTensor)
450     {
451         import std.format : format;
452         static assert(testCompatibleStaticShape(Shape, RTensor.staticShape), format!`Mismatch static shape %s != %s`(Shape, RTensor.staticShape));
453         assert(testCompatibleStaticShape(shape, rhs.shape), format!"Mismatch runtime shape %s != %s"(shape, rhs.shape));
454 
455         auto y = slice(this.value / rhs.value);
456 
457         static if (canBackward!(typeof(this))) this.usedCount++;
458         static if (canBackward!(typeof(rhs))) rhs.usedCount++;
459 
460         static if (canBackward!(typeof(this)) || canBackward!(typeof(rhs)))
461         {
462             return new Tensor!(T, Shape)(y, (Value grads) {
463                 static if (canBackward!(typeof(this))) this.backward(grads[] / rhs.value[]);
464                 static if (canBackward!(typeof(rhs))) rhs.backward(-grads[] * this.value[] / (rhs.value[] * rhs.value[]));
465             });
466         }
467         else
468         {
469             return new Tensor!(T, Shape, No.gradient)(y);
470         }
471     }
472 
473     Tensor!(T, Shape, hasGradient) opUnary(string op : "-")()
474     {
475         auto y = slice(-this.value[]);
476 
477         static if (hasGradient)
478         {
479             this.usedCount++;
480 
481             return new Tensor!(T, Shape)(y, (Value grads) {
482                 this.backward(-grads[]);
483             });
484         }
485         else
486         {
487             return new Tensor!(T, Shape, No.gradient)(y);
488         }
489     }
490 
491     invariant()
492     {
493         import std.format : format;
494         foreach (i; 0 .. Shape.length)
495         {
496             if (Shape[i] != 0)
497             {
498                 assert(Shape[i] == value.shape[i], format!"size mismatched at shape[%d] (%s and %s)"(i, Shape, value.shape));
499                 static if (hasGradient)
500                     assert(Shape[i] == grads.shape[i], format!"size mismatched at shape[%d] (%s and %s)"(i, Shape, value.shape));
501             }
502             else
503             {
504                 assert(value.shape[i] > 0);
505                 static if (hasGradient)
506                     assert(grads.shape[i] > 0);
507             }
508         }
509     }
510 }
511 
512 template tensor(size_t[] Shape, UseGradient useGradient = Yes.gradient)
513 {
514     import std.traits : isNumeric;
515 
516     Tensor!(T, Shape, useGradient) tensor(T)(T[] data)
517     if (isNumeric!T)
518     in(data.length > 0)
519     {
520         return new Tensor!(T, Shape, useGradient)(data);
521     }
522 
523     Tensor!(DeepElementType!T, Shape, useGradient) tensor(T)(T[] data)
524     if (!isNumeric!T)
525     in(data.length > 0)
526     {
527         return new Tensor!(DeepElementType!T, Shape, useGradient)(data);
528     }
529 }
530 
531 unittest
532 {
533     Tensor!(float, [2, 2]) t = tensor!([2, 2])([0.0f, 0.1f, 0.2f, 0.3f]);
534 
535     assert(t !is null);
536     assert(t.staticShape == [2, 2]);
537     assert(t.runtimeShape == [2, 2]);
538     assert(t.shape == [2, 2]);
539 
540     static assert(isTensor!(typeof(t)));
541 }
542 
543 unittest
544 {
545     Tensor!(float, [0, 2]) t = tensor!([0, 2])([1.0f, 2.0f, 3.0f, 4.0f]);
546 
547     assert(t !is null);
548     assert(t.staticShape == [0, 2]);
549     assert(t.runtimeShape == [2, 2]);
550     assert(t.shape == [2, 2]);
551 }
552 
553 unittest
554 {
555     Tensor!(double, [2, 2]) t = tensor!([2, 2])([[1.0, 2.0], [3.0, 4.0]]);
556 
557     assert(t !is null);
558     assert(t.staticShape == [2, 2]);
559     assert(t.runtimeShape == [2, 2]);
560     assert(t.shape == [2, 2]);
561 }
562 
563 unittest
564 {
565     Tensor!(double, [0, 2]) t = tensor!([0, 2])([[1.0, 2.0], [3.0, 4.0]]);
566 
567     assert(t !is null);
568     assert(t.staticShape == [0, 2]);
569     assert(t.runtimeShape == [2, 2]);
570     assert(t.shape == [2, 2]);
571 }
572 
573 
574 unittest
575 {
576     auto x = tensor!([2, 2])([0.0f, 1.0f, 2.0f, 3.0f]);
577     auto y = tensor!([2, 2])([1.0f, 2.0f, 3.0f, 4.0f]);
578 
579     auto z = x + y;
580     assert(z.value[0, 0] == 1.0f);
581     assert(z.value[0, 1] == 3.0f);
582     assert(z.value[1, 0] == 5.0f);
583     assert(z.value[1, 1] == 7.0f);
584 }
585 
586 unittest
587 {
588     auto a = tensor!([2, 2])([1, 2, 3, 4]);
589     auto b = tensor!([3, 2])([1, 2, 3, 4, 5, 6]);
590     auto c = tensor!([0, 2])([1, 2]);
591     auto d = tensor!([0, 3])([1, 2, 3]);
592 
593     // dfmt off
594     static assert(!__traits(compiles, { auto z = a + b; }));
595     static assert( __traits(compiles, { auto z = a + c; }));
596     static assert(!__traits(compiles, { auto z = a + d; }));
597     static assert(!__traits(compiles, { auto z = c + d; }));
598     
599     static assert(!__traits(compiles, { auto z = a - b; }));
600     static assert( __traits(compiles, { auto z = a - c; }));
601     static assert(!__traits(compiles, { auto z = a - d; }));
602     static assert(!__traits(compiles, { auto z = c - d; }));
603     
604     static assert(!__traits(compiles, { auto z = a * b; }));
605     static assert( __traits(compiles, { auto z = a * c; }));
606     static assert(!__traits(compiles, { auto z = a * d; }));
607     static assert(!__traits(compiles, { auto z = c * d; }));
608     
609     static assert(!__traits(compiles, { auto z = a / b; }));
610     static assert( __traits(compiles, { auto z = a / c; }));
611     static assert(!__traits(compiles, { auto z = a / d; }));
612     static assert(!__traits(compiles, { auto z = c / d; }));
613     // dfmt on
614 
615     import core.exception : AssertError;
616     import std.exception : assertThrown;
617 
618     assertThrown!AssertError(a + c, "Mismatch runtime shape [2, 2] != [1, 2]");
619     assertThrown!AssertError(a - c, "Mismatch runtime shape [2, 2] != [1, 2]");
620     assertThrown!AssertError(a * c, "Mismatch runtime shape [2, 2] != [1, 2]");
621     assertThrown!AssertError(a / c, "Mismatch runtime shape [2, 2] != [1, 2]");
622 }
623 
624 unittest
625 {
626     auto x = tensor!([2, 2])([0.0f, 1.0f, 2.0f, 3.0f]);
627     auto y = tensor!([2, 2])([1.0f, 2.0f, 3.0f, 4.0f]);
628 
629     auto z = x - y;
630     assert(z.value[0, 0] == -1.0f);
631     assert(z.value[0, 1] == -1.0f);
632     assert(z.value[1, 0] == -1.0f);
633     assert(z.value[1, 1] == -1.0f);
634 }
635 
636 unittest
637 {
638     auto x = tensor!([2, 2])([
639         [1.0, 2.0],
640         [3.0, 4.0],
641     ]);
642     auto y = -x;
643 
644     assert(y.value[0, 0] == -1.0);
645     assert(y.value[0, 1] == -2.0);
646     assert(y.value[1, 0] == -3.0);
647     assert(y.value[1, 1] == -4.0);
648 }
649 
650 unittest
651 {
652     auto x = tensor!([0, 2], No.gradient)([
653         [1.0, 2.0],
654         [3.0, 4.0],
655     ]);
656     auto y = -x;
657 
658     assert(y.value[0, 0] == -1.0);
659     assert(y.value[0, 1] == -2.0);
660     assert(y.value[1, 0] == -3.0);
661     assert(y.value[1, 1] == -4.0);
662 }
663 
664 unittest
665 {
666     auto a = tensor!([2, 2])([1.0, 2, 3, 4]);
667     auto b = tensor!([0, 2])([1.0f, 2, 3, 4]);
668     auto c = tensor!([2, 2], UseGradient.no)([10, 20, 30, 40]);
669 
670     auto x = a + 0.5;
671     auto y = b + 0.25f;
672     auto z = c + 2;
673 
674     assert(x.value[] == [[1.5, 2.5], [3.5, 4.5]]);
675     assert(y.value[] == [[1.25f, 2.25f], [3.25f, 4.25f]]);
676     assert(z.value[] == [[12, 22], [32, 42]]);
677 
678     assert(x.grads[] == [[0.0, 0.0], [0.0, 0.0]]);
679     assert(y.grads[] == [[0.0f, 0.0f], [0.0f, 0.0f]]);
680     x.backward();
681     y.backward();
682     assert(x.grads[] == [[1.0, 1.0], [1.0, 1.0]]);
683     assert(y.grads[] == [[1.0f, 1.0f], [1.0f, 1.0f]]);
684 
685     static assert(!__traits(compiles, z.backward()));
686 }
687 
688 unittest
689 {
690     auto a = tensor!([2, 2])([1.0, 2, 3, 4]);
691     auto b = tensor!([0, 2])([1.0f, 2, 3, 4]);
692     auto c = tensor!([2, 2], UseGradient.no)([10, 20, 30, 40]);
693 
694     auto x = 0.5 + a;
695     auto y = 0.25f + b;
696     auto z = 2 + c;
697 
698     assert(x.value[] == [[1.5, 2.5], [3.5, 4.5]]);
699     assert(y.value[] == [[1.25f, 2.25f], [3.25f, 4.25f]]);
700     assert(z.value[] == [[12, 22], [32, 42]]);
701 
702     assert(x.grads[] == [[0.0, 0.0], [0.0, 0.0]]);
703     assert(y.grads[] == [[0.0f, 0.0f], [0.0f, 0.0f]]);
704     x.backward();
705     y.backward();
706     assert(x.grads[] == [[1.0, 1.0], [1.0, 1.0]]);
707     assert(y.grads[] == [[1.0f, 1.0f], [1.0f, 1.0f]]);
708 
709     static assert(!__traits(compiles, z.backward()));
710 }
711 
712 unittest
713 {
714     auto a = tensor!([2, 2])([1.0, 2, 3, 4]);
715     auto b = tensor!([0, 2])([1.0f, 2, 3, 4]);
716     auto c = tensor!([2, 2], UseGradient.no)([10, 20, 30, 40]);
717 
718     auto x = a - 0.5;
719     auto y = b - 0.25f;
720     auto z = c - 2;
721 
722     assert(x.value[] == [[0.5, 1.5], [2.5, 3.5]]);
723     assert(y.value[] == [[0.75f, 1.75f], [2.75f, 3.75f]]);
724     assert(z.value[] == [[8, 18], [28, 38]]);
725 
726     assert(x.grads[] == [[0.0, 0.0], [0.0, 0.0]]);
727     assert(y.grads[] == [[0.0f, 0.0f], [0.0f, 0.0f]]);
728     x.backward();
729     y.backward();
730     assert(x.grads[] == [[1.0, 1.0], [1.0, 1.0]]);
731     assert(y.grads[] == [[1.0f, 1.0f], [1.0f, 1.0f]]);
732 
733     static assert(!__traits(compiles, z.backward()));
734 }
735 
736 unittest
737 {
738     auto a = tensor!([2, 2])([1.0, 2, 3, 4]);
739     auto b = tensor!([0, 2])([1.0f, 2, 3, 4]);
740     auto c = tensor!([2, 2], UseGradient.no)([10, 20, 30, 40]);
741 
742     auto x = 0.5 - a;
743     auto y = 0.25f - b;
744     auto z = 2 - c;
745 
746     assert(x.value[] == [[-0.5, -1.5], [-2.5, -3.5]]);
747     assert(y.value[] == [[-0.75f, -1.75f], [-2.75f, -3.75f]]);
748     assert(z.value[] == [[-8, -18], [-28, -38]]);
749 
750     assert(x.grads[] == [[0.0, 0.0], [0.0, 0.0]]);
751     assert(y.grads[] == [[0.0f, 0.0f], [0.0f, 0.0f]]);
752     x.backward();
753     y.backward();
754     assert(x.grads[] == [[1.0, 1.0], [1.0, 1.0]]);
755     assert(a.grads[] == [[-1.0, -1.0], [-1.0, -1.0]]);
756     assert(y.grads[] == [[1.0f, 1.0f], [1.0f, 1.0f]]);
757     assert(b.grads[] == [[-1.0f, -1.0f], [-1.0f, -1.0f]]);
758 
759     static assert(!__traits(compiles, z.backward()));
760 }
761 
762 unittest
763 {
764     auto x = tensor!([2, 2, 2])([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]);
765     auto y = 3.0 * x;
766     auto z = x * 0.5;
767 
768     import std.math : isClose;
769 
770     assert(y.value[0, 0, 0].isClose(0.1 * 3.0));
771     assert(y.value[0, 0, 1].isClose(0.2 * 3.0));
772     assert(y.value[0, 1, 0].isClose(0.3 * 3.0));
773     assert(y.value[0, 1, 1].isClose(0.4 * 3.0));
774     assert(y.value[1, 0, 0].isClose(0.5 * 3.0));
775     assert(y.value[1, 0, 1].isClose(0.6 * 3.0));
776     assert(y.value[1, 1, 0].isClose(0.7 * 3.0));
777     assert(y.value[1, 1, 1].isClose(0.8 * 3.0));
778     
779     assert(z.value[0, 0, 0].isClose(0.1 * 0.5));
780     assert(z.value[0, 0, 1].isClose(0.2 * 0.5));
781     assert(z.value[0, 1, 0].isClose(0.3 * 0.5));
782     assert(z.value[0, 1, 1].isClose(0.4 * 0.5));
783     assert(z.value[1, 0, 0].isClose(0.5 * 0.5));
784     assert(z.value[1, 0, 1].isClose(0.6 * 0.5));
785     assert(z.value[1, 1, 0].isClose(0.7 * 0.5));
786     assert(z.value[1, 1, 1].isClose(0.8 * 0.5));
787 
788     y.backward();
789     z.backward();
790 
791     assert(x.grads.flattened[] == [3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5]);
792 }
793 
794 unittest
795 {
796     auto x = tensor!([2, 2])([-1.0, 0.0, 1.0, 2.0]);
797     auto y = tensor!([2, 2])([-2.0, 3.0, 4.0, 5.0]);
798     auto z = x / y;
799 
800     assert(z.value[0, 0] == 0.5);
801     assert(z.value[0, 1] == 0.0);
802     assert(z.value[1, 0] == 0.25);
803     assert(z.value[1, 1] == 0.4);
804 
805     z.backward();
806 
807     assert(x.grads[0, 0] == 1.0 / -2.0);
808     assert(x.grads[0, 1] == 1.0 / 3.0);
809     assert(x.grads[1, 0] == 1.0 / 4.0);
810     assert(x.grads[1, 1] == 1.0 / 5.0);
811     
812     assert(y.grads[0, 0] == 1.0 / -2.0 / -2.0);
813     assert(y.grads[0, 1] == -0.0 / 3.0 / 3.0);
814     assert(y.grads[1, 0] == -1.0 / 4.0 / 4.0);
815     assert(y.grads[1, 1] == -2.0 / 5.0 / 5.0);
816 }
817 
818 unittest
819 {
820     auto x = tensor!([2, 2])([-0.5f, 0.5f, 0.0f, 1.0f]);
821     auto y = tensor!([2, 2])([0.5f, 0.5f, 0.5f, 0.5f]);
822 
823     auto t = tensor!([2, 2])([0.2f, 0.2f, 0.2f, 0.2f]);
824 
825     // forward
826     auto z = x * y;
827 
828     // loss
829     auto h = t - z;
830     auto loss = h * h;
831 
832     // backward
833     loss.resetGrads();
834     loss.backward();
835 
836     // train
837     x.value[] -= 0.1 * x.grads[];
838     y.value[] -= 0.1 * y.grads[];
839 
840     auto z2 = x * y;
841     auto h2 = t - z2;
842     auto loss2 = h2 * h2;
843     auto s = slice(loss2.value.flattened[] - loss.value.flattened[]);
844     foreach (i; 0 .. 4)
845     {
846         assert(s[i] < 0);
847     }
848 }
849 
850 
851 unittest
852 {
853     Tensor!(int, [2, 2], Yes.gradient) a = tensor!([2, 2])([1, 2, 3, 4]);
854     Tensor!(int, [2, 2], No.gradient) b = tensor!([2, 2], No.gradient)([1, 2, 3, 4]);
855 
856     auto x = a + b;
857     auto y = a - b;
858     auto z = a * b;
859     auto w = a / b;
860 }
861 
862 unittest
863 {
864     Tensor!(int, [2, 2], No.gradient) a = tensor!([2, 2], No.gradient)([1, 2, 3, 4]);
865     Tensor!(int, [2, 2], No.gradient) b = tensor!([2, 2], No.gradient)([1, 2, 3, 4]);
866 
867     auto x = a + b;
868     auto y = a - b;
869     auto z = a * b;
870     auto w = a / b;
871 }
872 
873 ///
874 Tensor!(T, Shape, useGrad) zeros(T, size_t[] Shape, UseGradient useGrad = UseGradient.no)()
875 if (Shape[0] != 0)
876 {
877     return new typeof(return)(numir.zeros!T(Shape));
878 }
879 
880 ///ditto
881 unittest
882 {
883     auto z = zeros!(float, [2, 2]);
884     assert(z.shape == [2, 2]);
885     assert(z.value[0, 0] == 0);
886     assert(z.value[0, 1] == 0);
887     assert(z.value[1, 0] == 0);
888     assert(z.value[1, 1] == 0);
889 }
890 
891 ///ditto
892 unittest
893 {
894     auto z = zeros!(float, [2, 2], UseGradient.yes);
895     assert(z.shape == [2, 2]);
896     assert(z.value[0, 0] == 0);
897     assert(z.value[0, 1] == 0);
898     assert(z.value[1, 0] == 0);
899     assert(z.value[1, 1] == 0);
900 }
901 
902 ///
903 Tensor!(T, Shape, useGrad) zeros(T, size_t[] Shape, UseGradient useGrad = UseGradient.no)(size_t batchSize)
904 if (Shape[0] == 0)
905 {
906     return new typeof(return)(numir.zeros!T([batchSize, expandShape!(Shape[1 .. $])]));
907 }
908 
909 ///ditto
910 unittest
911 {
912     auto z = zeros!(float, [0, 2])(2);
913     assert(z.shape == [2, 2]);
914     assert(z.value[0, 0] == 0);
915     assert(z.value[0, 1] == 0);
916     assert(z.value[1, 0] == 0);
917     assert(z.value[1, 1] == 0);
918 }
919 
920 ///ditto
921 unittest
922 {
923     auto z = zeros!(float, [0, 2], UseGradient.yes)(2);
924     assert(z.shape == [2, 2]);
925     assert(z.value[0, 0] == 0);
926     assert(z.value[0, 1] == 0);
927     assert(z.value[1, 0] == 0);
928     assert(z.value[1, 1] == 0);
929 }
930 
931 ///
932 Tensor!(T, Shape, UseGradient.no) zerosLike(T, size_t[] Shape, UseGradient useGradient)(Tensor!(T, Shape, useGradient) x)
933 {
934     static if (x.staticShape[0] == 0)
935     {
936         return zeros!(T, Shape)(x.shape[0]);
937     }
938     else
939     {
940         return zeros!(T, Shape)();
941     }
942 }
943 
944 ///ditto
945 Tensor!(T, Shape, useGrad) zerosLike(UseGradient useGrad, T, size_t[] Shape, UseGradient useGradient)(Tensor!(T, Shape, useGradient) x)
946 {
947     static if (x.staticShape[0] == 0)
948     {
949         return zeros!(T, Shape, useGrad)(x.shape[0]);
950     }
951     else
952     {
953         return zeros!(T, Shape, useGrad)();
954     }
955 }
956 
957 ///ditto
958 unittest
959 {
960     auto x = tensor!([2, 2])([1.0f, 2.0f, 3.0f, 4.0f]);
961     auto x1 = zerosLike(x);
962 
963     assert(x.shape == x1.shape);
964     assert(x1.value == zeros!(float, [2, 2]).value);
965     static assert(!canBackward!(typeof(x1)));
966 }
967 
968 ///ditto
969 unittest
970 {
971     auto x = tensor!([2, 3])([1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f]);
972     auto x1 = zerosLike(x);
973 
974     assert(x.shape == x1.shape);
975     assert(x1.value == zeros!(float, [2, 3]).value);
976     static assert(!canBackward!(typeof(x1)));
977 }
978 
979 ///ditto
980 unittest
981 {
982     auto x = tensor!([2, 3])([1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f]);
983     auto x1 = zerosLike!(UseGradient.yes)(x);
984 
985     static assert(canBackward!(typeof(x1)));
986 }
987 
988 
989 ///
990 Tensor!(T, Shape, useGrad) ones(T, size_t[] Shape, UseGradient useGrad = UseGradient.no)()
991 if (Shape[0] != 0)
992 {
993     return new typeof(return)(numir.ones!T(Shape));
994 }
995 
996 ///ditto
997 Tensor!(T, Shape, useGrad) ones(T, size_t[] Shape, UseGradient useGrad = UseGradient.no)(size_t batchSize)
998 if (Shape[0] == 0)
999 {
1000     return new typeof(return)(numir.ones!T([batchSize, expandShape!(Shape[1 .. $])]));
1001 }
1002 
1003 ///ditto
1004 unittest
1005 {
1006     auto o = ones!(float, [2, 2]);
1007     assert(!canBackward!(typeof(o)));
1008     assert(o.shape == [2, 2]);
1009     assert(o.value[0, 0] == 1);
1010     assert(o.value[0, 1] == 1);
1011     assert(o.value[1, 0] == 1);
1012     assert(o.value[1, 1] == 1);
1013 }
1014 
1015 ///ditto
1016 unittest
1017 {
1018     auto o = ones!(float, [0, 2])(2);
1019     assert(!canBackward!(typeof(o)));
1020     assert(o.shape == [2, 2]);
1021     assert(o.value[0, 0] == 1);
1022     assert(o.value[0, 1] == 1);
1023     assert(o.value[1, 0] == 1);
1024     assert(o.value[1, 1] == 1);
1025 }
1026 
1027 ///ditto
1028 unittest
1029 {
1030     auto o = ones!(float, [2, 3], UseGradient.yes);
1031     static assert(canBackward!(typeof(o)));
1032 }
1033 
1034 ///
1035 Tensor!(T, Shape, UseGradient.no) onesLike(T, size_t[] Shape, UseGradient useGradient)(Tensor!(T, Shape, useGradient) x)
1036 {
1037     static if (x.staticShape[0] == 0)
1038     {
1039         return ones!(T, Shape)(x.shape[0]);
1040     }
1041     else
1042     {
1043         return ones!(T, Shape)();
1044     }
1045 }
1046 
1047 ///ditto
1048 Tensor!(T, Shape, useGrad) onesLike(UseGradient useGrad, T, size_t[] Shape, UseGradient useGradient)(Tensor!(T, Shape, useGradient) x)
1049 {
1050     static if (x.staticShape[0] == 0)
1051     {
1052         return ones!(T, Shape, useGrad)(x.shape[0]);
1053     }
1054     else
1055     {
1056         return ones!(T, Shape, useGrad)();
1057     }
1058 }
1059 
1060 ///ditto
1061 unittest
1062 {
1063     auto x = tensor!([2, 2])([1.0f, 2.0f, 3.0f, 4.0f]);
1064     auto x1 = onesLike(x);
1065 
1066     assert(x.shape == x1.shape);
1067     assert(x1.value == ones!(float, [2, 2]).value);
1068 }
1069 
1070 ///ditto
1071 unittest
1072 {
1073     auto x = tensor!([2, 3])([1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f]);
1074     auto x1 = onesLike(x);
1075 
1076     assert(x.shape == x1.shape);
1077     assert(x1.value == ones!(float, [2, 3]).value);
1078 }
1079 
1080 ///ditto
1081 unittest
1082 {
1083     auto x = tensor!([2, 3])([1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f]);
1084     auto x1 = onesLike!(UseGradient.yes)(x);
1085 
1086     static assert(canBackward!(typeof(x1)));
1087 }
1088 
1089 ///
1090 Tensor!(T, TargetShape, useGrad) reshape(size_t[] TargetShape, T, size_t[] Shape, UseGradient useGrad)(Tensor!(T, Shape, useGrad) x)
1091 {
1092     static if (Shape[0] == 0 && TargetShape[0] == 0)
1093     {
1094         const batchSize = x.shape[0];
1095         const ptrdiff_t[TargetShape.length] runtimeTargetShape = [batchSize, expandShape!(TargetShape[1 .. $])];
1096     }
1097     else
1098     {
1099         static if (Shape[0] != 0)
1100             enum batchSize = Shape[0];
1101         else
1102             enum batchSize = TargetShape[0];
1103         enum ptrdiff_t[TargetShape.length] runtimeTargetShape = [batchSize, expandShape!(TargetShape[1 .. $])];
1104     }
1105 
1106     import mir.ndslice : reshape;
1107 
1108 	int err;
1109 	auto yValue = reshape(x.value, runtimeTargetShape, err);
1110 
1111     static if (useGrad)
1112     {
1113         x.usedCount++;
1114         return new Tensor!(T, TargetShape)(yValue, (grads) {
1115             x.backward((ref xGrads) {
1116                 xGrads.flattened[] += grads.flattened[];
1117             });
1118         });
1119     }
1120     else
1121     {
1122         return new Tensor!(T, TargetShape, useGrad)(yValue);
1123     }
1124 }
1125 
1126 /// ditto
1127 unittest
1128 {
1129     auto x = tensor!([0, 2, 2])([1, 2, 3, 4]);
1130     auto y = x.reshape!([0, 4]);
1131 
1132     assert(y.shape == [1, 4]);
1133     assert(y.value[0, 0] == 1);
1134     assert(y.value[0, 1] == 2);
1135     assert(y.value[0, 2] == 3);
1136     assert(y.value[0, 3] == 4);
1137 
1138     assert(x.grads == [[[0, 0], [0, 0]]]);
1139     y.backward();
1140     assert(x.grads == [[[1, 1], [1, 1]]]);
1141 }
1142 
1143 /// ditto
1144 unittest
1145 {
1146     auto x = tensor!([0, 4], UseGradient.no)([1, 2, 3, 4]);
1147     auto y = x.reshape!([1, 2, 2]);
1148 
1149     assert(y.shape == [1, 2, 2]);
1150     assert(y.value[0, 0, 0] == 1);
1151     assert(y.value[0, 0, 1] == 2);
1152     assert(y.value[0, 1, 0] == 3);
1153     assert(y.value[0, 1, 1] == 4);
1154 
1155     static assert(!canBackward!(typeof(y)));
1156 }