1 module golem.tensor; 2 3 import golem.util; 4 5 import mir.ndslice; 6 static import numir; 7 8 import std.typecons : Flag, Yes, No; 9 10 alias UseGradient = Flag!"gradient"; 11 12 enum isTensor(T, U, size_t[] Shape, UseGradient hasGradient) = is(T == Tensor!(U, Shape, hasGradient)); 13 14 enum isTensor(T) = is(T == Tensor!(U, Shape, flag), U, size_t[] Shape, UseGradient flag); 15 16 enum isTensor(T, UseGradient hasGradient) = is(T == Tensor!(U, Shape, hasGradient), U, size_t[] Shape); 17 18 enum isTensor(T, U, size_t[] Shape) = isTensor!(T, U, Shape, Yes.gradient) || isTensor!(T, U, Shape, No.gradient); 19 20 enum canBackward(T) = isTensor!(T, Yes.gradient); 21 22 bool testCompatibleStaticShape(size_t[] lhsShape, size_t[] rhsShape) @safe @nogc pure nothrow 23 { 24 assert(lhsShape.length > 0); 25 assert(rhsShape.length > 0); 26 27 if (lhsShape.length != rhsShape.length) 28 return false; 29 30 foreach (i; 0 .. lhsShape.length) 31 { 32 if (i == 0 && (lhsShape[0] == 0 || rhsShape[0] == 0)) 33 continue; 34 35 if (lhsShape[i] != rhsShape[i]) 36 return false; 37 } 38 return true; 39 } 40 41 unittest 42 { 43 enum testShape = testCompatibleStaticShape([2, 2], [2, 2]); 44 static assert(testShape); 45 46 static assert(testCompatibleStaticShape([2, 3], [2, 3])); 47 static assert(testCompatibleStaticShape([2, 3], [0, 3])); 48 static assert(testCompatibleStaticShape([0, 3], [0, 3])); 49 static assert(testCompatibleStaticShape([0, 3], [2, 3])); 50 static assert(testCompatibleStaticShape([2, 3, 28, 28], [2, 3, 28, 28])); 51 52 static assert(!testCompatibleStaticShape([2, 3], [2, 3, 3])); 53 static assert(!testCompatibleStaticShape([2, 3, 3], [2, 3])); 54 static assert(!testCompatibleStaticShape([1, 3], [2, 3])); 55 static assert(!testCompatibleStaticShape([2, 4], [2, 3])); 56 } 57 58 59 template commonGradientType(T1, T2) 60 if (isTensor!T1 && isTensor!T2) 61 { 62 static if (canBackward!T1 || canBackward!T2) 63 { 64 enum commonGradientType = UseGradient.yes; 65 } 66 else 67 { 68 enum commonGradientType = UseGradient.no; 69 } 70 } 71 72 class Tensor(T, size_t[] Shape, UseGradient hasGradient = UseGradient.yes) 73 { 74 alias ElementType = T; 75 76 enum staticShape = Shape; 77 78 static if (Shape[0] != 0) 79 { 80 alias shape = staticShape; 81 } 82 else 83 { 84 alias shape = runtimeShape; 85 } 86 87 alias Value = Slice!(T*, Shape.length); 88 89 Value value; 90 static if (hasGradient) 91 { 92 Value grads; 93 94 bool requireGrad = true; 95 size_t usedCount; 96 size_t backwardCount; 97 void delegate(Value grads) backwardFn; 98 } 99 100 static if (Shape[0] != 0) 101 { 102 this(T init) 103 { 104 this(slice!T(Shape, init)); 105 } 106 } 107 108 this(RoR)(RoR data) 109 { 110 this(fuse(data)); 111 } 112 113 this(T[] data) 114 { 115 static if (Shape[0] == 0) 116 { 117 static if (Shape.length == 1) 118 { 119 const batchSize = data.length; 120 auto value = data.sliced(batchSize); 121 } 122 else 123 { 124 const batchSize = data.length / elementSize(Shape[1 .. $]); 125 import std.format : format; 126 assert(batchSize * elementSize(Shape[1 .. $]) == data.length, format!"The number of elements in the data must match the shape of the tensor. Shape = %s, length=%s)"(Shape, data.length)); 127 auto value = data.sliced([ 128 batchSize, expandShape!(Shape[1 .. $]) 129 ]); 130 } 131 } 132 else 133 { 134 auto value = data.sliced(Shape); 135 } 136 137 this(value); 138 } 139 140 static if (hasGradient) 141 { 142 this(Value value) 143 { 144 this(value, null); 145 } 146 147 this(Value value, void delegate(Value grad) gradFn) 148 { 149 this(value, numir.zeros_like(value), gradFn); 150 } 151 152 private this(Value value, Value grads, void delegate(Value grad) gradFn) 153 { 154 this.value = value; 155 this.grads = grads; 156 this.backwardFn = gradFn; 157 } 158 } 159 else 160 { 161 this(Value value) 162 { 163 this.value = value; 164 } 165 } 166 167 size_t[Shape.length] runtimeShape() const pure nothrow @safe @nogc 168 { 169 return this.value.shape; 170 } 171 172 static if (hasGradient) 173 { 174 void resetGrads() 175 { 176 grads[] = T(0); 177 } 178 179 void backward()(void delegate(ref Value grads) update) 180 { 181 if (requireGrad) 182 { 183 update(this.grads); 184 ++backwardCount; 185 186 if (backwardCount == usedCount) 187 { 188 if (this.backwardFn) 189 this.backwardFn(this.grads); 190 this.usedCount = 0; 191 this.backwardCount = 0; 192 } 193 } 194 } 195 196 void backward(U)(U grads) 197 { 198 if (requireGrad) 199 { 200 import std.format : format; 201 202 assert(this.grads.shape == grads.shape, 203 "%s != %s".format(this.grads.shape, grads.shape)); 204 this.grads[] += grads; 205 ++backwardCount; 206 207 if (backwardCount == usedCount) 208 { 209 if (this.backwardFn) 210 this.backwardFn(this.grads); 211 this.usedCount = 0; 212 this.backwardCount = 0; 213 } 214 } 215 } 216 217 void backward() 218 { 219 if (requireGrad) 220 { 221 this.grads[] = T(1); 222 if (this.backwardFn) 223 this.backwardFn(this.grads); 224 this.usedCount = 0; 225 this.backwardCount = 0; 226 } 227 } 228 } 229 230 Tensor!(T, Shape, commonGradientType!(typeof(this), RTensor)) opBinary(string op : "+", RTensor)(RTensor rhs) 231 if (isTensor!RTensor) 232 { 233 import std.format : format; 234 static assert(testCompatibleStaticShape(Shape, RTensor.staticShape), format!`Mismatch static shape %s != %s`(Shape, RTensor.staticShape)); 235 assert(testCompatibleStaticShape(shape, rhs.shape), format!"Mismatch runtime shape %s != %s"(shape, rhs.shape)); 236 237 auto y = slice(this.value + rhs.value); 238 239 static if (canBackward!(typeof(this))) this.usedCount++; 240 static if (canBackward!(typeof(rhs))) rhs.usedCount++; 241 242 static if (canBackward!(typeof(this)) || canBackward!(typeof(rhs))) 243 { 244 return new Tensor!(T, Shape)(y, (Value grads) { 245 static if (canBackward!(typeof(this))) this.backward(grads); 246 static if (canBackward!(typeof(rhs))) rhs.backward(grads); 247 }); 248 } 249 else 250 { 251 return new Tensor!(T, Shape, No.gradient)(y); 252 } 253 } 254 255 Tensor!(T, Shape, hasGradient) opBinary(string op : "+")(T rhs) 256 { 257 auto y = slice(this.value[] + rhs); 258 259 static if (canBackward!(typeof(this))) 260 { 261 this.usedCount++; 262 return new Tensor!(T, Shape)(y, (Value grads) { 263 this.backward(grads); 264 }); 265 } 266 else 267 { 268 return new Tensor!(T, Shape, No.gradient)(y); 269 } 270 } 271 272 Tensor!(T, Shape, hasGradient) opBinaryRight(string op : "+")(T lhs) 273 { 274 auto y = slice(lhs + this.value[]); 275 276 static if (canBackward!(typeof(this))) 277 { 278 this.usedCount++; 279 return new Tensor!(T, Shape)(y, (Value grads) { 280 this.backward(grads); 281 }); 282 } 283 else 284 { 285 return new Tensor!(T, Shape, No.gradient)(y); 286 } 287 } 288 289 Tensor!(T, Shape, commonGradientType!(typeof(this), RTensor)) opBinary(string op : "-", RTensor)(RTensor rhs) 290 { 291 import std.format : format; 292 static assert(testCompatibleStaticShape(Shape, RTensor.staticShape), format!`Mismatch static shape %s != %s`(Shape, RTensor.staticShape)); 293 assert(testCompatibleStaticShape(shape, rhs.shape), format!"Mismatch runtime shape %s != %s"(shape, rhs.shape)); 294 295 auto y = slice(this.value - rhs.value); 296 297 static if (canBackward!(typeof(this))) this.usedCount++; 298 static if (canBackward!(typeof(rhs))) rhs.usedCount++; 299 300 static if (canBackward!(typeof(this)) || canBackward!(typeof(rhs))) 301 { 302 return new Tensor!(T, Shape)(y, (Value grads) { 303 static if (canBackward!(typeof(this))) this.backward((ref xGrads) { xGrads[] += grads[]; }); 304 static if (canBackward!(typeof(rhs))) rhs.backward((ref yGrads) { yGrads[] -= grads[]; }); 305 }); 306 } 307 else 308 { 309 return new Tensor!(T, Shape, No.gradient)(y); 310 } 311 } 312 313 Tensor!(T, Shape, hasGradient) opBinary(string op : "-")(T rhs) 314 { 315 auto y = slice(this.value[] - rhs); 316 317 static if (canBackward!(typeof(this))) 318 { 319 this.usedCount++; 320 return new Tensor!(T, Shape)(y, (Value grads) { 321 this.backward(grads); 322 }); 323 } 324 else 325 { 326 return new Tensor!(T, Shape, No.gradient)(y); 327 } 328 } 329 330 Tensor!(T, Shape, hasGradient) opBinaryRight(string op : "-")(T rhs) 331 { 332 auto y = slice(rhs - this.value[]); 333 334 static if (canBackward!(typeof(this))) 335 { 336 this.usedCount++; 337 return new Tensor!(T, Shape)(y, (Value grads) { 338 this.backward((ref yGrads) { yGrads[] -= grads[]; }); 339 }); 340 } 341 else 342 { 343 return new Tensor!(T, Shape, No.gradient)(y); 344 } 345 } 346 347 Tensor!(T, Shape, commonGradientType!(typeof(this), RTensor)) opBinary(string op : "*", RTensor)(RTensor rhs) 348 if (isTensor!RTensor) 349 { 350 import std.format : format; 351 static assert(testCompatibleStaticShape(Shape, RTensor.staticShape), format!`Mismatch static shape %s != %s`(Shape, RTensor.staticShape)); 352 assert(testCompatibleStaticShape(shape, rhs.shape), format!"Mismatch runtime shape %s != %s"(shape, rhs.shape)); 353 354 static if (is(typeof(this) == typeof(rhs))) 355 { 356 if (this is rhs) 357 { 358 auto y = slice(this.value * this.value); 359 static if (canBackward!(typeof(this))) 360 { 361 this.usedCount++; 362 return new Tensor!(T, Shape)(y, (Value grads) { 363 this.backward(2 * this.value * grads); 364 }); 365 } 366 else 367 { 368 return new Tensor!(T, Shape, No.gradient)(y); 369 } 370 } 371 else 372 { 373 auto y = slice(this.value * rhs.value); 374 static if (canBackward!(typeof(this))) this.usedCount++; 375 static if (canBackward!(typeof(rhs))) rhs.usedCount++; 376 377 static if (canBackward!(typeof(this)) || canBackward!(typeof(rhs))) 378 { 379 return new Tensor!(T, Shape)(y, (Value grads) { 380 static if (canBackward!(typeof(this))) this.backward(rhs.value * grads); 381 static if (canBackward!(typeof(rhs))) rhs.backward(this.value * grads); 382 }); 383 } 384 else 385 { 386 return new Tensor!(T, Shape, No.gradient)(y); 387 } 388 } 389 } 390 else 391 { 392 auto y = slice(this.value * rhs.value); 393 static if (canBackward!(typeof(this))) this.usedCount++; 394 static if (canBackward!(typeof(rhs))) rhs.usedCount++; 395 396 static if (canBackward!(typeof(this)) || canBackward!(typeof(rhs))) 397 { 398 return new Tensor!(T, Shape)(y, (Value grads) { 399 static if (canBackward!(typeof(this))) this.backward(rhs.value * grads); 400 static if (canBackward!(typeof(rhs))) rhs.backward(this.value * grads); 401 }); 402 } 403 else 404 { 405 return new Tensor!(T, Shape, No.gradient)(y); 406 } 407 } 408 } 409 410 Tensor!(T, Shape, hasGradient) opBinary(string op : "*")(T rhs) 411 { 412 auto y = slice(this.value[] * rhs); 413 414 static if (canBackward!(typeof(this))) 415 { 416 this.usedCount++; 417 418 return new typeof(this)(y, (grads) { 419 this.backward(grads[] * rhs); 420 }); 421 } 422 else 423 { 424 return new typeof(this)(y); 425 } 426 } 427 428 429 Tensor!(T, Shape, hasGradient) opBinaryRight(string op : "*")(T lhs) 430 { 431 auto y = slice(lhs * this.value[]); 432 433 static if (canBackward!(typeof(this))) 434 { 435 this.usedCount++; 436 437 return new typeof(this)(y, (grads) { 438 this.backward(lhs * grads[]); 439 }); 440 } 441 else 442 { 443 return new typeof(this)(y); 444 } 445 } 446 447 448 Tensor!(T, Shape, commonGradientType!(typeof(this), RTensor)) opBinary(string op : "/", RTensor)(RTensor rhs) 449 if (isTensor!RTensor) 450 { 451 import std.format : format; 452 static assert(testCompatibleStaticShape(Shape, RTensor.staticShape), format!`Mismatch static shape %s != %s`(Shape, RTensor.staticShape)); 453 assert(testCompatibleStaticShape(shape, rhs.shape), format!"Mismatch runtime shape %s != %s"(shape, rhs.shape)); 454 455 auto y = slice(this.value / rhs.value); 456 457 static if (canBackward!(typeof(this))) this.usedCount++; 458 static if (canBackward!(typeof(rhs))) rhs.usedCount++; 459 460 static if (canBackward!(typeof(this)) || canBackward!(typeof(rhs))) 461 { 462 return new Tensor!(T, Shape)(y, (Value grads) { 463 static if (canBackward!(typeof(this))) this.backward(grads[] / rhs.value[]); 464 static if (canBackward!(typeof(rhs))) rhs.backward(-grads[] * this.value[] / (rhs.value[] * rhs.value[])); 465 }); 466 } 467 else 468 { 469 return new Tensor!(T, Shape, No.gradient)(y); 470 } 471 } 472 473 Tensor!(T, Shape, hasGradient) opUnary(string op : "-")() 474 { 475 auto y = slice(-this.value[]); 476 477 static if (hasGradient) 478 { 479 this.usedCount++; 480 481 return new Tensor!(T, Shape)(y, (Value grads) { 482 this.backward(-grads[]); 483 }); 484 } 485 else 486 { 487 return new Tensor!(T, Shape, No.gradient)(y); 488 } 489 } 490 491 invariant() 492 { 493 import std.format : format; 494 foreach (i; 0 .. Shape.length) 495 { 496 if (Shape[i] != 0) 497 { 498 assert(Shape[i] == value.shape[i], format!"size mismatched at shape[%d] (%s and %s)"(i, Shape, value.shape)); 499 static if (hasGradient) 500 assert(Shape[i] == grads.shape[i], format!"size mismatched at shape[%d] (%s and %s)"(i, Shape, value.shape)); 501 } 502 else 503 { 504 assert(value.shape[i] > 0); 505 static if (hasGradient) 506 assert(grads.shape[i] > 0); 507 } 508 } 509 } 510 } 511 512 template tensor(size_t[] Shape, UseGradient useGradient = Yes.gradient) 513 { 514 import std.traits : isNumeric; 515 516 Tensor!(T, Shape, useGradient) tensor(T)(T[] data) 517 if (isNumeric!T) 518 in(data.length > 0) 519 { 520 return new Tensor!(T, Shape, useGradient)(data); 521 } 522 523 Tensor!(DeepElementType!T, Shape, useGradient) tensor(T)(T[] data) 524 if (!isNumeric!T) 525 in(data.length > 0) 526 { 527 return new Tensor!(DeepElementType!T, Shape, useGradient)(data); 528 } 529 } 530 531 unittest 532 { 533 Tensor!(float, [2, 2]) t = tensor!([2, 2])([0.0f, 0.1f, 0.2f, 0.3f]); 534 535 assert(t !is null); 536 assert(t.staticShape == [2, 2]); 537 assert(t.runtimeShape == [2, 2]); 538 assert(t.shape == [2, 2]); 539 540 static assert(isTensor!(typeof(t))); 541 } 542 543 unittest 544 { 545 Tensor!(float, [0, 2]) t = tensor!([0, 2])([1.0f, 2.0f, 3.0f, 4.0f]); 546 547 assert(t !is null); 548 assert(t.staticShape == [0, 2]); 549 assert(t.runtimeShape == [2, 2]); 550 assert(t.shape == [2, 2]); 551 } 552 553 unittest 554 { 555 Tensor!(double, [2, 2]) t = tensor!([2, 2])([[1.0, 2.0], [3.0, 4.0]]); 556 557 assert(t !is null); 558 assert(t.staticShape == [2, 2]); 559 assert(t.runtimeShape == [2, 2]); 560 assert(t.shape == [2, 2]); 561 } 562 563 unittest 564 { 565 Tensor!(double, [0, 2]) t = tensor!([0, 2])([[1.0, 2.0], [3.0, 4.0]]); 566 567 assert(t !is null); 568 assert(t.staticShape == [0, 2]); 569 assert(t.runtimeShape == [2, 2]); 570 assert(t.shape == [2, 2]); 571 } 572 573 574 unittest 575 { 576 auto x = tensor!([2, 2])([0.0f, 1.0f, 2.0f, 3.0f]); 577 auto y = tensor!([2, 2])([1.0f, 2.0f, 3.0f, 4.0f]); 578 579 auto z = x + y; 580 assert(z.value[0, 0] == 1.0f); 581 assert(z.value[0, 1] == 3.0f); 582 assert(z.value[1, 0] == 5.0f); 583 assert(z.value[1, 1] == 7.0f); 584 } 585 586 unittest 587 { 588 auto a = tensor!([2, 2])([1, 2, 3, 4]); 589 auto b = tensor!([3, 2])([1, 2, 3, 4, 5, 6]); 590 auto c = tensor!([0, 2])([1, 2]); 591 auto d = tensor!([0, 3])([1, 2, 3]); 592 593 // dfmt off 594 static assert(!__traits(compiles, { auto z = a + b; })); 595 static assert( __traits(compiles, { auto z = a + c; })); 596 static assert(!__traits(compiles, { auto z = a + d; })); 597 static assert(!__traits(compiles, { auto z = c + d; })); 598 599 static assert(!__traits(compiles, { auto z = a - b; })); 600 static assert( __traits(compiles, { auto z = a - c; })); 601 static assert(!__traits(compiles, { auto z = a - d; })); 602 static assert(!__traits(compiles, { auto z = c - d; })); 603 604 static assert(!__traits(compiles, { auto z = a * b; })); 605 static assert( __traits(compiles, { auto z = a * c; })); 606 static assert(!__traits(compiles, { auto z = a * d; })); 607 static assert(!__traits(compiles, { auto z = c * d; })); 608 609 static assert(!__traits(compiles, { auto z = a / b; })); 610 static assert( __traits(compiles, { auto z = a / c; })); 611 static assert(!__traits(compiles, { auto z = a / d; })); 612 static assert(!__traits(compiles, { auto z = c / d; })); 613 // dfmt on 614 615 import core.exception : AssertError; 616 import std.exception : assertThrown; 617 618 assertThrown!AssertError(a + c, "Mismatch runtime shape [2, 2] != [1, 2]"); 619 assertThrown!AssertError(a - c, "Mismatch runtime shape [2, 2] != [1, 2]"); 620 assertThrown!AssertError(a * c, "Mismatch runtime shape [2, 2] != [1, 2]"); 621 assertThrown!AssertError(a / c, "Mismatch runtime shape [2, 2] != [1, 2]"); 622 } 623 624 unittest 625 { 626 auto x = tensor!([2, 2])([0.0f, 1.0f, 2.0f, 3.0f]); 627 auto y = tensor!([2, 2])([1.0f, 2.0f, 3.0f, 4.0f]); 628 629 auto z = x - y; 630 assert(z.value[0, 0] == -1.0f); 631 assert(z.value[0, 1] == -1.0f); 632 assert(z.value[1, 0] == -1.0f); 633 assert(z.value[1, 1] == -1.0f); 634 } 635 636 unittest 637 { 638 auto x = tensor!([2, 2])([ 639 [1.0, 2.0], 640 [3.0, 4.0], 641 ]); 642 auto y = -x; 643 644 assert(y.value[0, 0] == -1.0); 645 assert(y.value[0, 1] == -2.0); 646 assert(y.value[1, 0] == -3.0); 647 assert(y.value[1, 1] == -4.0); 648 } 649 650 unittest 651 { 652 auto x = tensor!([0, 2], No.gradient)([ 653 [1.0, 2.0], 654 [3.0, 4.0], 655 ]); 656 auto y = -x; 657 658 assert(y.value[0, 0] == -1.0); 659 assert(y.value[0, 1] == -2.0); 660 assert(y.value[1, 0] == -3.0); 661 assert(y.value[1, 1] == -4.0); 662 } 663 664 unittest 665 { 666 auto a = tensor!([2, 2])([1.0, 2, 3, 4]); 667 auto b = tensor!([0, 2])([1.0f, 2, 3, 4]); 668 auto c = tensor!([2, 2], UseGradient.no)([10, 20, 30, 40]); 669 670 auto x = a + 0.5; 671 auto y = b + 0.25f; 672 auto z = c + 2; 673 674 assert(x.value[] == [[1.5, 2.5], [3.5, 4.5]]); 675 assert(y.value[] == [[1.25f, 2.25f], [3.25f, 4.25f]]); 676 assert(z.value[] == [[12, 22], [32, 42]]); 677 678 assert(x.grads[] == [[0.0, 0.0], [0.0, 0.0]]); 679 assert(y.grads[] == [[0.0f, 0.0f], [0.0f, 0.0f]]); 680 x.backward(); 681 y.backward(); 682 assert(x.grads[] == [[1.0, 1.0], [1.0, 1.0]]); 683 assert(y.grads[] == [[1.0f, 1.0f], [1.0f, 1.0f]]); 684 685 static assert(!__traits(compiles, z.backward())); 686 } 687 688 unittest 689 { 690 auto a = tensor!([2, 2])([1.0, 2, 3, 4]); 691 auto b = tensor!([0, 2])([1.0f, 2, 3, 4]); 692 auto c = tensor!([2, 2], UseGradient.no)([10, 20, 30, 40]); 693 694 auto x = 0.5 + a; 695 auto y = 0.25f + b; 696 auto z = 2 + c; 697 698 assert(x.value[] == [[1.5, 2.5], [3.5, 4.5]]); 699 assert(y.value[] == [[1.25f, 2.25f], [3.25f, 4.25f]]); 700 assert(z.value[] == [[12, 22], [32, 42]]); 701 702 assert(x.grads[] == [[0.0, 0.0], [0.0, 0.0]]); 703 assert(y.grads[] == [[0.0f, 0.0f], [0.0f, 0.0f]]); 704 x.backward(); 705 y.backward(); 706 assert(x.grads[] == [[1.0, 1.0], [1.0, 1.0]]); 707 assert(y.grads[] == [[1.0f, 1.0f], [1.0f, 1.0f]]); 708 709 static assert(!__traits(compiles, z.backward())); 710 } 711 712 unittest 713 { 714 auto a = tensor!([2, 2])([1.0, 2, 3, 4]); 715 auto b = tensor!([0, 2])([1.0f, 2, 3, 4]); 716 auto c = tensor!([2, 2], UseGradient.no)([10, 20, 30, 40]); 717 718 auto x = a - 0.5; 719 auto y = b - 0.25f; 720 auto z = c - 2; 721 722 assert(x.value[] == [[0.5, 1.5], [2.5, 3.5]]); 723 assert(y.value[] == [[0.75f, 1.75f], [2.75f, 3.75f]]); 724 assert(z.value[] == [[8, 18], [28, 38]]); 725 726 assert(x.grads[] == [[0.0, 0.0], [0.0, 0.0]]); 727 assert(y.grads[] == [[0.0f, 0.0f], [0.0f, 0.0f]]); 728 x.backward(); 729 y.backward(); 730 assert(x.grads[] == [[1.0, 1.0], [1.0, 1.0]]); 731 assert(y.grads[] == [[1.0f, 1.0f], [1.0f, 1.0f]]); 732 733 static assert(!__traits(compiles, z.backward())); 734 } 735 736 unittest 737 { 738 auto a = tensor!([2, 2])([1.0, 2, 3, 4]); 739 auto b = tensor!([0, 2])([1.0f, 2, 3, 4]); 740 auto c = tensor!([2, 2], UseGradient.no)([10, 20, 30, 40]); 741 742 auto x = 0.5 - a; 743 auto y = 0.25f - b; 744 auto z = 2 - c; 745 746 assert(x.value[] == [[-0.5, -1.5], [-2.5, -3.5]]); 747 assert(y.value[] == [[-0.75f, -1.75f], [-2.75f, -3.75f]]); 748 assert(z.value[] == [[-8, -18], [-28, -38]]); 749 750 assert(x.grads[] == [[0.0, 0.0], [0.0, 0.0]]); 751 assert(y.grads[] == [[0.0f, 0.0f], [0.0f, 0.0f]]); 752 x.backward(); 753 y.backward(); 754 assert(x.grads[] == [[1.0, 1.0], [1.0, 1.0]]); 755 assert(a.grads[] == [[-1.0, -1.0], [-1.0, -1.0]]); 756 assert(y.grads[] == [[1.0f, 1.0f], [1.0f, 1.0f]]); 757 assert(b.grads[] == [[-1.0f, -1.0f], [-1.0f, -1.0f]]); 758 759 static assert(!__traits(compiles, z.backward())); 760 } 761 762 unittest 763 { 764 auto x = tensor!([2, 2, 2])([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]); 765 auto y = 3.0 * x; 766 auto z = x * 0.5; 767 768 import std.math : isClose; 769 770 assert(y.value[0, 0, 0].isClose(0.1 * 3.0)); 771 assert(y.value[0, 0, 1].isClose(0.2 * 3.0)); 772 assert(y.value[0, 1, 0].isClose(0.3 * 3.0)); 773 assert(y.value[0, 1, 1].isClose(0.4 * 3.0)); 774 assert(y.value[1, 0, 0].isClose(0.5 * 3.0)); 775 assert(y.value[1, 0, 1].isClose(0.6 * 3.0)); 776 assert(y.value[1, 1, 0].isClose(0.7 * 3.0)); 777 assert(y.value[1, 1, 1].isClose(0.8 * 3.0)); 778 779 assert(z.value[0, 0, 0].isClose(0.1 * 0.5)); 780 assert(z.value[0, 0, 1].isClose(0.2 * 0.5)); 781 assert(z.value[0, 1, 0].isClose(0.3 * 0.5)); 782 assert(z.value[0, 1, 1].isClose(0.4 * 0.5)); 783 assert(z.value[1, 0, 0].isClose(0.5 * 0.5)); 784 assert(z.value[1, 0, 1].isClose(0.6 * 0.5)); 785 assert(z.value[1, 1, 0].isClose(0.7 * 0.5)); 786 assert(z.value[1, 1, 1].isClose(0.8 * 0.5)); 787 788 y.backward(); 789 z.backward(); 790 791 assert(x.grads.flattened[] == [3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 3.5]); 792 } 793 794 unittest 795 { 796 auto x = tensor!([2, 2])([-1.0, 0.0, 1.0, 2.0]); 797 auto y = tensor!([2, 2])([-2.0, 3.0, 4.0, 5.0]); 798 auto z = x / y; 799 800 assert(z.value[0, 0] == 0.5); 801 assert(z.value[0, 1] == 0.0); 802 assert(z.value[1, 0] == 0.25); 803 assert(z.value[1, 1] == 0.4); 804 805 z.backward(); 806 807 assert(x.grads[0, 0] == 1.0 / -2.0); 808 assert(x.grads[0, 1] == 1.0 / 3.0); 809 assert(x.grads[1, 0] == 1.0 / 4.0); 810 assert(x.grads[1, 1] == 1.0 / 5.0); 811 812 assert(y.grads[0, 0] == 1.0 / -2.0 / -2.0); 813 assert(y.grads[0, 1] == -0.0 / 3.0 / 3.0); 814 assert(y.grads[1, 0] == -1.0 / 4.0 / 4.0); 815 assert(y.grads[1, 1] == -2.0 / 5.0 / 5.0); 816 } 817 818 unittest 819 { 820 auto x = tensor!([2, 2])([-0.5f, 0.5f, 0.0f, 1.0f]); 821 auto y = tensor!([2, 2])([0.5f, 0.5f, 0.5f, 0.5f]); 822 823 auto t = tensor!([2, 2])([0.2f, 0.2f, 0.2f, 0.2f]); 824 825 // forward 826 auto z = x * y; 827 828 // loss 829 auto h = t - z; 830 auto loss = h * h; 831 832 // backward 833 loss.resetGrads(); 834 loss.backward(); 835 836 // train 837 x.value[] -= 0.1 * x.grads[]; 838 y.value[] -= 0.1 * y.grads[]; 839 840 auto z2 = x * y; 841 auto h2 = t - z2; 842 auto loss2 = h2 * h2; 843 auto s = slice(loss2.value.flattened[] - loss.value.flattened[]); 844 foreach (i; 0 .. 4) 845 { 846 assert(s[i] < 0); 847 } 848 } 849 850 851 unittest 852 { 853 Tensor!(int, [2, 2], Yes.gradient) a = tensor!([2, 2])([1, 2, 3, 4]); 854 Tensor!(int, [2, 2], No.gradient) b = tensor!([2, 2], No.gradient)([1, 2, 3, 4]); 855 856 auto x = a + b; 857 auto y = a - b; 858 auto z = a * b; 859 auto w = a / b; 860 } 861 862 unittest 863 { 864 Tensor!(int, [2, 2], No.gradient) a = tensor!([2, 2], No.gradient)([1, 2, 3, 4]); 865 Tensor!(int, [2, 2], No.gradient) b = tensor!([2, 2], No.gradient)([1, 2, 3, 4]); 866 867 auto x = a + b; 868 auto y = a - b; 869 auto z = a * b; 870 auto w = a / b; 871 } 872 873 /// 874 Tensor!(T, Shape, useGrad) zeros(T, size_t[] Shape, UseGradient useGrad = UseGradient.no)() 875 if (Shape[0] != 0) 876 { 877 return new typeof(return)(numir.zeros!T(Shape)); 878 } 879 880 ///ditto 881 unittest 882 { 883 auto z = zeros!(float, [2, 2]); 884 assert(z.shape == [2, 2]); 885 assert(z.value[0, 0] == 0); 886 assert(z.value[0, 1] == 0); 887 assert(z.value[1, 0] == 0); 888 assert(z.value[1, 1] == 0); 889 } 890 891 ///ditto 892 unittest 893 { 894 auto z = zeros!(float, [2, 2], UseGradient.yes); 895 assert(z.shape == [2, 2]); 896 assert(z.value[0, 0] == 0); 897 assert(z.value[0, 1] == 0); 898 assert(z.value[1, 0] == 0); 899 assert(z.value[1, 1] == 0); 900 } 901 902 /// 903 Tensor!(T, Shape, useGrad) zeros(T, size_t[] Shape, UseGradient useGrad = UseGradient.no)(size_t batchSize) 904 if (Shape[0] == 0) 905 { 906 return new typeof(return)(numir.zeros!T([batchSize, expandShape!(Shape[1 .. $])])); 907 } 908 909 ///ditto 910 unittest 911 { 912 auto z = zeros!(float, [0, 2])(2); 913 assert(z.shape == [2, 2]); 914 assert(z.value[0, 0] == 0); 915 assert(z.value[0, 1] == 0); 916 assert(z.value[1, 0] == 0); 917 assert(z.value[1, 1] == 0); 918 } 919 920 ///ditto 921 unittest 922 { 923 auto z = zeros!(float, [0, 2], UseGradient.yes)(2); 924 assert(z.shape == [2, 2]); 925 assert(z.value[0, 0] == 0); 926 assert(z.value[0, 1] == 0); 927 assert(z.value[1, 0] == 0); 928 assert(z.value[1, 1] == 0); 929 } 930 931 /// 932 Tensor!(T, Shape, UseGradient.no) zerosLike(T, size_t[] Shape, UseGradient useGradient)(Tensor!(T, Shape, useGradient) x) 933 { 934 static if (x.staticShape[0] == 0) 935 { 936 return zeros!(T, Shape)(x.shape[0]); 937 } 938 else 939 { 940 return zeros!(T, Shape)(); 941 } 942 } 943 944 ///ditto 945 Tensor!(T, Shape, useGrad) zerosLike(UseGradient useGrad, T, size_t[] Shape, UseGradient useGradient)(Tensor!(T, Shape, useGradient) x) 946 { 947 static if (x.staticShape[0] == 0) 948 { 949 return zeros!(T, Shape, useGrad)(x.shape[0]); 950 } 951 else 952 { 953 return zeros!(T, Shape, useGrad)(); 954 } 955 } 956 957 ///ditto 958 unittest 959 { 960 auto x = tensor!([2, 2])([1.0f, 2.0f, 3.0f, 4.0f]); 961 auto x1 = zerosLike(x); 962 963 assert(x.shape == x1.shape); 964 assert(x1.value == zeros!(float, [2, 2]).value); 965 static assert(!canBackward!(typeof(x1))); 966 } 967 968 ///ditto 969 unittest 970 { 971 auto x = tensor!([2, 3])([1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f]); 972 auto x1 = zerosLike(x); 973 974 assert(x.shape == x1.shape); 975 assert(x1.value == zeros!(float, [2, 3]).value); 976 static assert(!canBackward!(typeof(x1))); 977 } 978 979 ///ditto 980 unittest 981 { 982 auto x = tensor!([2, 3])([1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f]); 983 auto x1 = zerosLike!(UseGradient.yes)(x); 984 985 static assert(canBackward!(typeof(x1))); 986 } 987 988 989 /// 990 Tensor!(T, Shape, useGrad) ones(T, size_t[] Shape, UseGradient useGrad = UseGradient.no)() 991 if (Shape[0] != 0) 992 { 993 return new typeof(return)(numir.ones!T(Shape)); 994 } 995 996 ///ditto 997 Tensor!(T, Shape, useGrad) ones(T, size_t[] Shape, UseGradient useGrad = UseGradient.no)(size_t batchSize) 998 if (Shape[0] == 0) 999 { 1000 return new typeof(return)(numir.ones!T([batchSize, expandShape!(Shape[1 .. $])])); 1001 } 1002 1003 ///ditto 1004 unittest 1005 { 1006 auto o = ones!(float, [2, 2]); 1007 assert(!canBackward!(typeof(o))); 1008 assert(o.shape == [2, 2]); 1009 assert(o.value[0, 0] == 1); 1010 assert(o.value[0, 1] == 1); 1011 assert(o.value[1, 0] == 1); 1012 assert(o.value[1, 1] == 1); 1013 } 1014 1015 ///ditto 1016 unittest 1017 { 1018 auto o = ones!(float, [0, 2])(2); 1019 assert(!canBackward!(typeof(o))); 1020 assert(o.shape == [2, 2]); 1021 assert(o.value[0, 0] == 1); 1022 assert(o.value[0, 1] == 1); 1023 assert(o.value[1, 0] == 1); 1024 assert(o.value[1, 1] == 1); 1025 } 1026 1027 ///ditto 1028 unittest 1029 { 1030 auto o = ones!(float, [2, 3], UseGradient.yes); 1031 static assert(canBackward!(typeof(o))); 1032 } 1033 1034 /// 1035 Tensor!(T, Shape, UseGradient.no) onesLike(T, size_t[] Shape, UseGradient useGradient)(Tensor!(T, Shape, useGradient) x) 1036 { 1037 static if (x.staticShape[0] == 0) 1038 { 1039 return ones!(T, Shape)(x.shape[0]); 1040 } 1041 else 1042 { 1043 return ones!(T, Shape)(); 1044 } 1045 } 1046 1047 ///ditto 1048 Tensor!(T, Shape, useGrad) onesLike(UseGradient useGrad, T, size_t[] Shape, UseGradient useGradient)(Tensor!(T, Shape, useGradient) x) 1049 { 1050 static if (x.staticShape[0] == 0) 1051 { 1052 return ones!(T, Shape, useGrad)(x.shape[0]); 1053 } 1054 else 1055 { 1056 return ones!(T, Shape, useGrad)(); 1057 } 1058 } 1059 1060 ///ditto 1061 unittest 1062 { 1063 auto x = tensor!([2, 2])([1.0f, 2.0f, 3.0f, 4.0f]); 1064 auto x1 = onesLike(x); 1065 1066 assert(x.shape == x1.shape); 1067 assert(x1.value == ones!(float, [2, 2]).value); 1068 } 1069 1070 ///ditto 1071 unittest 1072 { 1073 auto x = tensor!([2, 3])([1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f]); 1074 auto x1 = onesLike(x); 1075 1076 assert(x.shape == x1.shape); 1077 assert(x1.value == ones!(float, [2, 3]).value); 1078 } 1079 1080 ///ditto 1081 unittest 1082 { 1083 auto x = tensor!([2, 3])([1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f]); 1084 auto x1 = onesLike!(UseGradient.yes)(x); 1085 1086 static assert(canBackward!(typeof(x1))); 1087 } 1088 1089 /// 1090 Tensor!(T, TargetShape, useGrad) reshape(size_t[] TargetShape, T, size_t[] Shape, UseGradient useGrad)(Tensor!(T, Shape, useGrad) x) 1091 { 1092 static if (Shape[0] == 0 && TargetShape[0] == 0) 1093 { 1094 const batchSize = x.shape[0]; 1095 const ptrdiff_t[TargetShape.length] runtimeTargetShape = [batchSize, expandShape!(TargetShape[1 .. $])]; 1096 } 1097 else 1098 { 1099 static if (Shape[0] != 0) 1100 enum batchSize = Shape[0]; 1101 else 1102 enum batchSize = TargetShape[0]; 1103 enum ptrdiff_t[TargetShape.length] runtimeTargetShape = [batchSize, expandShape!(TargetShape[1 .. $])]; 1104 } 1105 1106 import mir.ndslice : reshape; 1107 1108 int err; 1109 auto yValue = reshape(x.value, runtimeTargetShape, err); 1110 1111 static if (useGrad) 1112 { 1113 x.usedCount++; 1114 return new Tensor!(T, TargetShape)(yValue, (grads) { 1115 x.backward((ref xGrads) { 1116 xGrads.flattened[] += grads.flattened[]; 1117 }); 1118 }); 1119 } 1120 else 1121 { 1122 return new Tensor!(T, TargetShape, useGrad)(yValue); 1123 } 1124 } 1125 1126 /// ditto 1127 unittest 1128 { 1129 auto x = tensor!([0, 2, 2])([1, 2, 3, 4]); 1130 auto y = x.reshape!([0, 4]); 1131 1132 assert(y.shape == [1, 4]); 1133 assert(y.value[0, 0] == 1); 1134 assert(y.value[0, 1] == 2); 1135 assert(y.value[0, 2] == 3); 1136 assert(y.value[0, 3] == 4); 1137 1138 assert(x.grads == [[[0, 0], [0, 0]]]); 1139 y.backward(); 1140 assert(x.grads == [[[1, 1], [1, 1]]]); 1141 } 1142 1143 /// ditto 1144 unittest 1145 { 1146 auto x = tensor!([0, 4], UseGradient.no)([1, 2, 3, 4]); 1147 auto y = x.reshape!([1, 2, 2]); 1148 1149 assert(y.shape == [1, 2, 2]); 1150 assert(y.value[0, 0, 0] == 1); 1151 assert(y.value[0, 0, 1] == 2); 1152 assert(y.value[0, 1, 0] == 3); 1153 assert(y.value[0, 1, 1] == 4); 1154 1155 static assert(!canBackward!(typeof(y))); 1156 }