1 module golem.random; 2 3 import golem.tensor; 4 import golem.util; 5 6 import mir.ndslice; 7 8 Tensor!(T, Shape, useGradient) uniform(T, size_t[] Shape, UseGradient useGradient = UseGradient.yes)() if (Shape.length > 0) 9 { 10 import std.random : stduniform = uniform; 11 import std.math : sqrt; 12 13 enum size = elementSize(Shape); 14 enum q = T(0.5) / sqrt(T(size)); 15 16 auto t = new T[size]; 17 foreach (ref x; t) 18 { 19 x = stduniform!"[]"(-q, q); 20 } 21 22 static if (useGradient) 23 { 24 return new Tensor!(T, Shape)(t.sliced(Shape), null); 25 } 26 else 27 { 28 return new Tensor!(T, Shape, UseGradient.no)(t.sliced(Shape)); 29 } 30 } 31 32 unittest 33 { 34 auto x = uniform!(float, [2, 3])(); 35 static assert(canBackward!(typeof(x))); 36 37 assert(x.shape == [2, 3]); 38 } 39 40 unittest 41 { 42 auto x = uniform!(float, [2, 3], UseGradient.no)(); 43 static assert(!canBackward!(typeof(x))); 44 45 assert(x.shape == [2, 3]); 46 } 47 48 Tensor!(T, Shape, useGradient) uniform(T, size_t[] Shape, UseGradient useGradient = UseGradient.yes)(size_t size) if (Shape.length > 0 && Shape[0] == 0) 49 { 50 import std.random : stduniform = uniform; 51 import std.math : sqrt; 52 53 enum esize = elementSize(Shape); 54 const totalSize = size * esize; 55 const q = T(0.5) / sqrt(T(totalSize)); 56 57 auto t = new T[totalSize]; 58 foreach (ref x; t) 59 { 60 x = stduniform!"[]"(-q, q); 61 } 62 63 static if (useGradient) 64 { 65 return new Tensor!(T, Shape)(t.sliced([size, expandShape!(Shape[1 .. $])]), null); 66 } 67 else 68 { 69 return new Tensor!(T, Shape, UseGradient.no)(t.sliced([size, expandShape!(Shape[1 .. $])])); 70 } 71 } 72 73 unittest 74 { 75 auto x = uniform!(float, [0, 4])(3); 76 static assert(canBackward!(typeof(x))); 77 78 assert(x.shape == [3, 4]); 79 } 80 81 unittest 82 { 83 auto x = uniform!(float, [0, 4], UseGradient.no)(3); 84 static assert(!canBackward!(typeof(x))); 85 86 assert(x.shape == [3, 4]); 87 } 88 89 alias randn = uniform; 90 91 92 Tensor!(T, Shape, UseGradient.no) normal(T, size_t[] Shape)(in T location = 0.0, in T scale = 1.0) 93 if (Shape[0] != 0) 94 { 95 import mir.ndslice : diagonal, reshape; 96 import mir.random.variable : normalVar; 97 import mir.random.engine : rne; 98 99 auto result = uninitSlice!T(Shape); 100 auto ngen = normalVar!T(location, scale); 101 foreach (ref x; result.flattened[]) 102 { 103 x = ngen(rne); 104 } 105 106 return new Tensor!(T, Shape, UseGradient.no)(result); 107 } 108 109 Tensor!(T, Shape, UseGradient.no) normal(T, size_t[] Shape)(size_t batchSize, in T location = 0.0, in T scale = 1.0) 110 if (Shape[0] == 0) 111 { 112 assert(batchSize > 0); 113 114 import mir.ndslice : diagonal, reshape; 115 import mir.random.variable : normalVar; 116 import mir.random.engine : rne; 117 118 auto result = uninitSlice!T([batchSize, expandShape!(Shape[1 .. $])]); 119 auto ngen = normalVar!T(location, scale); 120 foreach (ref x; result.flattened[]) 121 { 122 x = ngen(rne); 123 } 124 125 return new Tensor!(T, Shape, UseGradient.no)(result); 126 } 127 128 unittest 129 { 130 auto m = normal!(float, [2, 3]); 131 auto n = normal!(float, [0, 4])(4); 132 auto u = normal!(float, [2, 2])(1.0, 2.0); 133 auto t = normal!(float, [0, 3])(3, 1.0, 2.0); 134 } 135 136 Tensor!(T, Shape, UseGradient.no) normalLike(T, size_t[] Shape, UseGradient useGrad)(Tensor!(T, Shape, useGrad) x, in T location = 0.0, in T scale = 1.0) 137 if (Shape[0] != 0) 138 { 139 return normal!(T, Shape)(location, scale); 140 } 141 142 unittest 143 { 144 auto x = tensor!([2, 2])([0.1, 0.2, 0.3, 0.4]); 145 auto z = normalLike(x, 0, 1); 146 147 static assert(x.staticShape == z.staticShape); 148 assert(x.shape == z.shape); 149 } 150 151 Tensor!(T, Shape, UseGradient.no) normalLike(T, size_t[] Shape, UseGradient useGrad)(Tensor!(T, Shape, useGrad) x, in T location = 0.0, in T scale = 1.0) 152 if (Shape[0] == 0) 153 { 154 return normal!(T, Shape)(x.shape[0], location, scale); 155 } 156 157 unittest 158 { 159 auto x = tensor!([0, 2])([0.1, 0.2, 0.3, 0.4]); 160 auto z = normalLike(x, 0, 1); 161 162 static assert(x.staticShape == z.staticShape); 163 assert(x.shape == z.shape); 164 }