1 module golem.nn; 2 3 import golem.tensor; 4 import golem.random; 5 6 import mir.ndslice; 7 8 import std.meta; 9 10 enum hasParameters(T) = __traits(compiles, { auto ps = T.init.parameters; }); 11 12 13 class Linear(T, size_t InputDim, size_t OutputDim, UseGradient useGradient = UseGradient.yes) 14 { 15 Tensor!(T, [InputDim, OutputDim], useGradient) weights; 16 Tensor!(T, [OutputDim], useGradient) bias; 17 18 alias parameters = AliasSeq!(weights, bias); 19 20 this() 21 { 22 weights = uniform!(T, [InputDim, OutputDim], useGradient); 23 bias = uniform!(T, [OutputDim], useGradient); 24 } 25 26 this(T initial) 27 { 28 weights = new Tensor!(T, [InputDim, OutputDim], useGradient)(initial); 29 bias = new Tensor!(T, [OutputDim], useGradient)(initial); 30 } 31 32 this(T initialWeight, T initialBias) 33 { 34 weights = new Tensor!(T, [InputDim, OutputDim], useGradient)(initialWeight); 35 bias = new Tensor!(T, [OutputDim], useGradient)(initialBias); 36 } 37 38 auto opCall(size_t[2] Shape, UseGradient useGradient)(Tensor!(T, Shape, useGradient) x) 39 if (Shape[1] == InputDim) 40 { 41 import golem.math : linear; 42 43 return linear(x, this.weights, this.bias); 44 } 45 46 static if (useGradient) 47 { 48 void resetGrads() 49 { 50 weights.resetGrads(); 51 bias.resetGrads(); 52 } 53 } 54 } 55 56 unittest 57 { 58 import golem.math : flatten; 59 60 auto fc1 = new Linear!(float, 2 * 2, 1); 61 62 Tensor!(float, [3, 2, 2]) x = new Tensor!(float, [3, 2, 2])(0.0f); 63 Tensor!(float, [3, 4]) b = flatten(x); 64 Tensor!(float, [3, 1]) y = fc1(b); 65 assert(y.value.shape == [3, 1]); 66 } 67 68 unittest 69 { 70 auto fc = new Linear!(float, 2, 1); 71 static assert(hasParameters!(typeof(fc))); 72 } 73 74 unittest 75 { 76 auto fc0 = new Linear!(float, 2, 1)(-1); 77 assert(fc0.weights.value[0, 0] == -1); 78 assert(fc0.weights.value[1, 0] == -1); 79 assert(fc0.bias.value[0] == -1); 80 81 auto fc1 = new Linear!(float, 2, 1)(0, 1); 82 assert(fc1.weights.value[0, 0] == 0); 83 assert(fc1.weights.value[1, 0] == 0); 84 assert(fc1.bias.value[0] == 1); 85 86 auto fc2 = new Linear!(float, 2, 1)(1, 0); 87 assert(fc2.weights.value[0, 0] == 1); 88 assert(fc2.weights.value[1, 0] == 1); 89 assert(fc2.bias.value[0] == 0); 90 } 91 92 unittest 93 { 94 auto fc1 = new Linear!(float, 2, 1, UseGradient.yes); 95 auto fc2 = new Linear!(float, 2, 1, UseGradient.no); 96 97 auto x = new Tensor!(float, [2, 2], UseGradient.yes)(1.0f); 98 auto y = new Tensor!(float, [2, 2], UseGradient.no)(1.0f); 99 100 auto a = fc1(x); 101 auto b = fc1(y); 102 auto c = fc2(x); 103 auto d = fc2(y); 104 105 static assert(canBackward!(typeof(a))); 106 static assert(canBackward!(typeof(b))); 107 static assert(canBackward!(typeof(c))); 108 static assert(!canBackward!(typeof(d))); 109 } 110 111 class BatchNorm(T, size_t[] Shape, UseGradient useGrad = UseGradient.yes) 112 { 113 Tensor!(T, Shape, UseGradient.no) mean; 114 Tensor!(T, Shape, UseGradient.no) var; 115 Tensor!(T, Shape, useGrad) factor; 116 Tensor!(T, Shape, useGrad) offset; 117 118 Tensor!(T, Shape, UseGradient.no) tempMean; 119 Tensor!(T, Shape, UseGradient.no) tempVar; 120 Tensor!(T, Shape, UseGradient.no) temps; 121 T momentum = 0.9; 122 123 alias parameters = AliasSeq!(mean, var, factor, offset); 124 125 this() 126 { 127 mean = zeros!(T, Shape); 128 var = zeros!(T, Shape); 129 factor = ones!(T, Shape, useGrad); 130 offset = zeros!(T, Shape, useGrad); 131 tempMean = zeros!(T, Shape); 132 tempVar = zeros!(T, Shape); 133 temps = zeros!(T, Shape); 134 } 135 136 auto opCall(size_t[] ShapeX, UseGradient useGradX)(Tensor!(T, ShapeX, useGradX) x, bool isTrain) 137 { 138 static assert(ShapeX[1 .. $] == Shape); 139 140 assert(x.shape[0] != 0); 141 142 import std.math : sqrt; 143 import golem.math : broadcastOp; 144 145 enum eps = 1e-7; 146 147 if (isTrain && x.shape[0] > 1) 148 { 149 import mir.math.sum : mirsum = sum; 150 import mir.ndslice : transposed; 151 import golem.util : expandIndex; 152 153 auto tm = tempMean.value; 154 tm.flattened[] = 0; 155 foreach (t; x.value.ipack!1) 156 { 157 tm[] += t[]; 158 } 159 tm[] /= x.shape[0]; 160 161 auto tv = tempVar.value; 162 tv.flattened[] = 0; 163 foreach (t; x.value.ipack!1) 164 { 165 tv[] += (t[] - tm[]).map!(a => a * a); 166 } 167 tv[] /= x.shape[0]; 168 169 this.mean.value[] = momentum * this.mean.value[] + (1 - momentum) * tm[]; 170 this.var.value[] = momentum * this.var.value[] + (1 - momentum) * tv[]; 171 172 this.temps.value[] = this.var.value.map!(a => sqrt(a + eps)); 173 tempVar.value[] = tempVar.value.map!(a => sqrt(a + eps)); 174 return broadcastOp!"+"(broadcastOp!"*"(broadcastOp!"-"(x, this.tempMean), factor / this.tempVar), offset); 175 } 176 177 this.temps.value[] = this.var.value.map!(a => sqrt(a + eps)); 178 return broadcastOp!"+"(broadcastOp!"*"(broadcastOp!"-"(x, this.mean), factor / this.temps), offset); 179 } 180 } 181 182 unittest 183 { 184 auto x = tensor!([0, 2, 2])([ 185 1.0f, 2.0f, 3.0f, 4.0f, 186 2.0f, 3.0f, 4.0f, 5.0f, 187 3.0f, 4.0f, 5.0f, 6.0f, 188 4.0f, 5.0f, 6.0f, 7.0f, 189 ]); 190 191 auto bn = new BatchNorm!(float, [2, 2]); 192 auto y = bn(x, true); 193 194 import std.math : isClose; 195 196 assert(bn.mean.value[0, 0].isClose(0.25f)); 197 assert(bn.mean.value[0, 1].isClose(0.35f)); 198 assert(bn.mean.value[1, 0].isClose(0.45f)); 199 assert(bn.mean.value[1, 1].isClose(0.55f)); 200 201 assert(bn.var.value[0, 0].isClose(0.125f)); 202 assert(bn.var.value[0, 1].isClose(0.125f)); 203 assert(bn.var.value[1, 0].isClose(0.125f)); 204 assert(bn.var.value[1, 1].isClose(0.125f)); 205 206 import std.math : sqrt; 207 import std.conv : text; 208 209 assert(y.value[0, 0, 0].isClose((1.0f - 2.5f) / sqrt(1.25f)), text(y.value[0, 0, 0])); 210 assert(y.value[0, 0, 1].isClose((2.0f - 3.5f) / sqrt(1.25f)), text(y.value[0, 0, 1])); 211 assert(y.value[0, 1, 0].isClose((3.0f - 4.5f) / sqrt(1.25f)), text(y.value[0, 1, 0])); 212 assert(y.value[0, 1, 1].isClose((4.0f - 5.5f) / sqrt(1.25f)), text(y.value[0, 1, 1])); 213 } 214 215 unittest 216 { 217 auto x = tensor!([0, 2, 2])([ 218 1.0f, 2.0f, 3.0f, 4.0f, 219 2.0f, 3.0f, 4.0f, 5.0f, 220 ]); 221 222 auto bn = new BatchNorm!(float, [2, 2]); 223 auto y = bn(x, true); 224 225 import std.conv : text; 226 227 assert(x.grads[] == [[[0.0f, 0.0f], [0.0f, 0.0f]], [[0.0f, 0.0f], [0.0f, 0.0f]]], text(x.grads)); 228 y.backward(); 229 230 import std.math : isClose; 231 232 assert(x.grads[0, 0, 0].isClose(2.0f)); 233 assert(x.grads[0, 0, 1].isClose(2.0f)); 234 assert(x.grads[0, 1, 0].isClose(2.0f)); 235 assert(x.grads[0, 1, 1].isClose(2.0f)); 236 assert(x.grads[1, 0, 0].isClose(2.0f)); 237 assert(x.grads[1, 0, 1].isClose(2.0f)); 238 assert(x.grads[1, 1, 0].isClose(2.0f)); 239 assert(x.grads[1, 1, 1].isClose(2.0f)); 240 } 241 242 struct Activation(alias f) 243 { 244 import std.functional : unaryFun; 245 246 alias fun = unaryFun!f; 247 248 auto opCall(T)(T x) 249 { 250 return fun(x); 251 } 252 } 253 254 unittest 255 { 256 import golem.math : sigmoid, tanh; 257 258 Activation!sigmoid f1; 259 Activation!tanh f2; 260 261 auto x = tensor!([2, 2])([1.0f, 2.0f, 3.0f, 4.0f]); 262 263 auto a = f1(x); 264 auto b = f2(x); 265 auto c = sigmoid(x); 266 auto d = tanh(x); 267 268 assert(a.value == c.value); 269 assert(b.value == d.value); 270 } 271 272 273 class Sequence(Ts...) 274 { 275 Ts layers; 276 277 private alias isNetModule(alias m) = hasParameters!(typeof(m)); 278 279 static if (Filter!(isNetModule, AliasSeq!(layers)).length > 0) 280 { 281 alias parameters = Filter!(isNetModule, AliasSeq!(layers)); 282 283 this() 284 { 285 foreach (ref p; parameters) 286 p = new typeof(p); 287 } 288 } 289 290 291 auto opCall(T)(T x) 292 { 293 return opCall!0(x); 294 } 295 296 private auto opCall(size_t n, T)(T x) 297 { 298 static if (n == Ts.length) 299 { 300 return x; 301 } 302 else 303 { 304 return opCall!(n + 1)(layers[n](x)); 305 } 306 } 307 } 308 309 unittest 310 { 311 import golem.math : sigmoid; 312 313 auto net = new Sequence!( 314 Linear!(float, 2, 2), 315 Activation!sigmoid, 316 Linear!(float, 2, 2), 317 Activation!sigmoid, 318 Linear!(float, 2, 1), 319 Activation!sigmoid, 320 ); 321 322 static assert(hasParameters!(typeof(net))); 323 324 auto x = tensor!([0, 2])([1.0f, 2.0f, 3.0f, 4.0f]); 325 auto y = net(x); 326 } 327 328 unittest 329 { 330 import golem.math : sigmoid; 331 332 auto net = new Sequence!( 333 Activation!sigmoid, 334 Activation!sigmoid, 335 Activation!sigmoid, 336 ); 337 338 static assert(!hasParameters!(typeof(net))); 339 } 340 341 class LiftPool2D(T, size_t Height, size_t Width, UseGradient useGradient = UseGradient.yes) 342 { 343 enum HalfH = Height / 2; 344 enum HalfW = Width / 2; 345 346 Tensor!(T, [HalfW, HalfW], useGradient) predictW; 347 Tensor!(T, [HalfW, HalfW], useGradient) updateW; 348 Tensor!(T, [HalfH, HalfH], useGradient) predictH; 349 Tensor!(T, [HalfH, HalfH], useGradient) updateH; 350 351 alias parameters = AliasSeq!(predictW, updateW, predictH, updateH); 352 this() 353 { 354 import mir.ndslice : diagonal; 355 356 // Haar wavelet 357 predictW = zeros!(T, [HalfW, HalfW], useGradient)(); 358 predictW.value.diagonal[] = T(1); 359 updateW = zeros!(T, [HalfW, HalfW], useGradient)(); 360 updateW.value.diagonal[] = T(0.5); 361 362 predictH = zeros!(T, [HalfH, HalfH], useGradient)(); 363 predictH.value.diagonal[] = T(1); 364 updateH = zeros!(T, [HalfH, HalfH], useGradient)(); 365 updateH.value.diagonal[] = T(0.5); 366 } 367 368 auto liftUp(U)(U x) 369 if (isTensor!U && U.staticShape.length == 4 && U.staticShape[2] == Height && U.staticShape[3] == Width) 370 { 371 import std.typecons : tuple; 372 import golem.math : splitEvenOdd2D, concat2D, projection1D; 373 374 auto xw = splitEvenOdd2D!3(x); 375 auto xw_predict = projection1D!3(xw[0], predictW); 376 auto xw_d = xw[1] - xw_predict; 377 auto xw_c = xw[0] + projection1D!3(xw_d, updateW); 378 auto hidden = concat2D(xw_c, xw_d); 379 380 auto xh = splitEvenOdd2D!2(hidden); 381 auto xh_predict = projection1D!2(xh[0], predictH); 382 auto xh_d = xh[1] - xh_predict; 383 auto xh_c = xh[0] + projection1D!2(xh_d, updateH); 384 auto output = concat2D(xh_c, xh_d); 385 386 return tuple(output, tuple(xw[1], xw_predict), tuple(xh[1], xh_predict)); 387 } 388 389 } 390 391 unittest 392 { 393 auto lift = new LiftPool2D!(double, 4, 4); 394 auto images = tensor!([1, 1, 4, 4])([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]); 395 396 auto y = lift.liftUp(images); 397 398 assert(y[0].shape == [1, 4, 2, 2]); 399 400 assert(y[0].value[0, 0, 0, 0] == 3.5); 401 assert(y[0].value[0, 0, 0, 1] == 5.5); 402 assert(y[0].value[0, 0, 1, 0] == 11.5); 403 assert(y[0].value[0, 0, 1, 1] == 13.5); 404 assert(y[0].value[0, 1, 0, 0] == 1); 405 assert(y[0].value[0, 1, 0, 1] == 1); 406 assert(y[0].value[0, 1, 1, 0] == 1); 407 assert(y[0].value[0, 1, 1, 1] == 1); 408 assert(y[0].value[0, 2, 0, 0] == 4); 409 assert(y[0].value[0, 2, 0, 1] == 4); 410 assert(y[0].value[0, 2, 1, 0] == 4); 411 assert(y[0].value[0, 2, 1, 1] == 4); 412 assert(y[0].value[0, 3, 0, 0] == 0); 413 assert(y[0].value[0, 3, 0, 1] == 0); 414 assert(y[0].value[0, 3, 1, 0] == 0); 415 assert(y[0].value[0, 3, 1, 1] == 0); 416 417 y[0].backward(); 418 } 419 420 421 class Conv2D(T, size_t C_in, size_t C_out, size_t[] kernelSize, UseGradient useGrad = UseGradient.yes) 422 { 423 mixin Conv2DImpl!(T, C_in, C_out, kernelSize, [0, 0], useGrad); 424 } 425 426 class Conv2D(T, size_t C_in, size_t C_out, size_t[] kernelSize, size_t[] padding, UseGradient useGrad = UseGradient.yes) 427 { 428 mixin Conv2DImpl!(T, C_in, C_out, kernelSize, padding, useGrad); 429 } 430 431 unittest 432 { 433 import golem.random : uniform; 434 435 auto images = uniform!(float, [1, 1, 28, 28]); 436 auto conv1 = new Conv2D!(float, 1, 2, [3, 3]); 437 auto y = conv1(images); 438 assert(y.shape == [1, 2, 26, 26]); 439 y.backward(); 440 } 441 442 unittest 443 { 444 import golem.random : uniform; 445 446 auto images = uniform!(float, [1, 1, 28, 28]); 447 auto conv1 = new Conv2D!(float, 1, 2, [3, 3], [1, 1]); 448 auto y = conv1(images); 449 assert(y.shape == [1, 2, 28, 28]); 450 y.backward(); 451 } 452 453 private mixin template Conv2DImpl(T, size_t C_in, size_t C_out, size_t[] kernelSize, size_t[] padding, UseGradient useGrad) 454 { 455 enum size_t[] WeightShape = [C_out, C_in, kernelSize[0], kernelSize[1]]; 456 enum size_t[] BiasShape = [C_out]; 457 458 Tensor!(T, WeightShape, useGrad) weights; 459 Tensor!(T, BiasShape, useGrad) bias; 460 461 alias parameters = AliasSeq!(weights, bias); 462 this() 463 { 464 import std.math : sqrt; 465 import golem.random : uniform; 466 467 weights = uniform!(T, WeightShape, useGrad)(); 468 bias = uniform!(T, BiasShape, useGrad)(); 469 } 470 471 auto opCall(U)(U x) 472 if (isTensor!U && U.staticShape.length == 4 && U.staticShape[1] == C_in) 473 { 474 import golem.math : conv2D; 475 476 return conv2D!(padding)(x, weights, bias); 477 } 478 } 479 480 481 class Perceptron(T, alias activateFn, size_t InputDim, size_t HiddenDim, size_t OutputDim, UseGradient useGrad = UseGradient.yes) 482 { 483 Linear!(T, InputDim, HiddenDim, useGrad) fc1; 484 Linear!(T, HiddenDim, OutputDim, useGrad) fc2; 485 486 alias parameters = AliasSeq!(fc1, fc2); 487 488 this() 489 { 490 foreach (ref p; parameters) 491 p = new typeof(p); 492 } 493 494 auto opCall(U)(U x) 495 { 496 return fc2(activateFn(fc1(x))); 497 } 498 } 499 500 unittest 501 { 502 import golem.math : sigmoid; 503 504 auto model = new Perceptron!(float, sigmoid, 2, 2, 1); 505 506 auto x = tensor!([0, 2])([1.0f, 2.0f, 3.0f, 4.0f]); 507 auto y = model(x); 508 static assert(isTensor!(typeof(y), float, [0, 1])); 509 510 auto z = tensor!([0, 2], UseGradient.no)([1.0f, 2.0f, 3.0f, 4.0f]); 511 auto w = model(z); 512 static assert(isTensor!(typeof(w), float, [0, 1])); 513 } 514 515 unittest 516 { 517 import golem.math : sigmoid; 518 519 auto model = new Perceptron!(float, sigmoid, 2, 2, 1, UseGradient.no); 520 521 auto x = tensor!([0, 2])([1.0f, 2.0f, 3.0f, 4.0f]); 522 auto y = model(x); 523 static assert(isTensor!(typeof(y), float, [0, 1])); 524 525 auto z = tensor!([0, 2], UseGradient.no)([1.0f, 2.0f, 3.0f, 4.0f]); 526 auto w = model(z); 527 static assert(isTensor!(typeof(w), float, [0, 1], UseGradient.no)); 528 }