1 module golem.model; 2 3 import golem.tensor; 4 import golem.nn; 5 import golem.util; 6 7 import std.meta; 8 9 ubyte[] packParameters(Params...)(Params params) 10 { 11 import golem.util : staticIndexOf; 12 13 enum firstPos = staticIndexOf!(hasParameters, Params); 14 15 static if (firstPos != -1) 16 { 17 // dfmt off 18 return packParameters( 19 params[0 .. firstPos], 20 params[firstPos].parameters, 21 params[firstPos + 1 .. $] 22 ); 23 // dfmt on 24 } 25 else 26 { 27 static if (allSatisfy!(isTensor, Params)) 28 { 29 import msgpack : Packer; 30 import mir.ndslice : flattened, ndarray; 31 32 Packer packer; 33 packer.beginArray(params.length); 34 foreach (p; params) 35 { 36 packer.pack(p.value.flattened[].ndarray()); 37 } 38 return packer.stream.data; 39 } 40 else 41 { 42 static assert(false); 43 } 44 } 45 } 46 47 void unpackParameters(Params...)(ubyte[] data, ref Params params) 48 { 49 import golem.util : staticIndexOf; 50 51 enum firstPos = staticIndexOf!(hasParameters, Params); 52 53 static if (firstPos != -1) 54 { 55 // dfmt off 56 unpackParameters( 57 data, 58 params[0 .. firstPos], 59 params[firstPos].parameters, 60 params[firstPos + 1 .. $] 61 ); 62 // dfmt on 63 } 64 else 65 { 66 static if (allSatisfy!(isTensor, Params)) 67 { 68 import msgpack : unpack; 69 import mir.ndslice : flattened, ndarray, sliced; 70 71 auto unpacked = unpack(data); 72 foreach (p; params) 73 { 74 assert(!unpacked.empty); 75 auto temp = unpacked.front.as!(typeof(p).ElementType[]); 76 assert(elementSize(p.shape) == temp.length); 77 p.value = temp.sliced(p.shape); 78 unpacked.popFront(); 79 } 80 } 81 else 82 { 83 static assert(false); 84 } 85 } 86 } 87 88 unittest 89 { 90 auto x = tensor!([2, 3])([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]); 91 auto serializedData = packParameters(x); 92 93 auto y = tensor!([2, 3])([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]); 94 unpackParameters(serializedData, y); 95 96 assert(x.value == y.value); 97 } 98 99 unittest 100 { 101 auto x = tensor!([2, 3])([0.1, 0.2, 0.3, 0.4, 0.5, 0.6]); 102 auto y = tensor!([2, 2])([0.0, 0.0, 0.0, 0.0]); 103 104 auto serializedData = packParameters(x); 105 106 try 107 unpackParameters(serializedData, y); 108 catch (Throwable t) 109 { 110 return; 111 } 112 assert(false); 113 } 114 115 unittest 116 { 117 auto x = tensor!([2, 2, 2])([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]); 118 auto serializedData = packParameters(x); 119 120 auto y = tensor!([4, 2])([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]); 121 unpackParameters(serializedData, y); 122 123 import mir.ndslice : flattened; 124 125 assert(x.value.flattened[] == y.value.flattened[]); 126 } 127 128 unittest 129 { 130 import golem.nn : Linear; 131 import std.meta : AliasSeq; 132 133 class Model 134 { 135 Linear!(float, 2, 2) fc1; 136 Linear!(float, 2, 1) fc2; 137 138 alias parameters = AliasSeq!(fc1, fc2); 139 140 this() 141 { 142 foreach (ref p; parameters) 143 p = new typeof(p); 144 } 145 } 146 147 auto m1 = new Model; 148 auto serializedData = packParameters(m1); 149 150 auto m2 = new Model; 151 unpackParameters(serializedData, m2); 152 153 assert(m1.fc1.weights.value == m2.fc1.weights.value); 154 assert(m1.fc1.bias.value == m2.fc1.bias.value); 155 assert(m1.fc2.weights.value == m2.fc2.weights.value); 156 assert(m1.fc2.bias.value == m2.fc2.bias.value); 157 } 158 159 class ModelArchiver 160 { 161 string dirPath; 162 string prefix; 163 164 this(string dirPath = "model_data", string prefix = "model_") 165 { 166 this.dirPath = dirPath; 167 this.prefix = prefix; 168 } 169 170 void save(T)(T model) 171 { 172 static import std.file; 173 174 prepare(); 175 std.file.write(makeCurrentPath(), packParameters(model)); 176 } 177 178 void load(T)(T model) 179 { 180 static import std.file; 181 182 if (!std.file.exists(dirPath)) 183 return; 184 185 auto recentPath = findRecentModelPath(); 186 if (std.file.exists(recentPath)) 187 unpackParameters(cast(ubyte[]) std.file.read(recentPath), model); 188 } 189 190 protected void prepare() 191 { 192 import std.file : exists, mkdirRecurse; 193 194 if (!exists(dirPath)) 195 mkdirRecurse(dirPath); 196 } 197 198 protected string makeCurrentPath() 199 { 200 import std.path : buildNormalizedPath; 201 import std.format : format; 202 import std.datetime : Clock, DateTime; 203 204 const DateTime now = cast(DateTime) Clock.currTime; 205 const name = format!"%s%04d%02d%02d-%02d%02d%02d.dat"(prefix, now.year, 206 now.month, now.day, now.hour, now.minute, now.second); 207 208 return buildNormalizedPath(dirPath, name); 209 } 210 211 protected auto makePattern() 212 { 213 import std.regex : escaper, regex; 214 import std.conv : to; 215 216 const prefix = escaper(prefix).to!string(); 217 const pattern = "^" ~ prefix ~ `(\d{8})-(\d{6}).dat$`; 218 return regex(pattern); 219 } 220 221 protected string findRecentModelPath() 222 { 223 import std.path : baseName; 224 import std.file : dirEntries, DirEntry, SpanMode; 225 import std.regex : matchFirst; 226 import std.typecons : Tuple, tuple; 227 228 string recentPath; 229 Tuple!(string, string) latest; 230 231 const pattern = makePattern(); 232 foreach (DirEntry entry; dirEntries(dirPath, SpanMode.shallow)) 233 { 234 import std.stdio : writeln; 235 236 auto name = baseName(entry.name); 237 auto m = matchFirst(name, pattern); 238 if (m) 239 { 240 auto temp = tuple(m.captures[0], m.captures[1]); 241 if (recentPath.length == 0 || latest < temp) 242 { 243 recentPath = entry.name; 244 latest = temp; 245 } 246 } 247 } 248 249 return recentPath; 250 } 251 } 252 253 254 mixin template NetModule() 255 { 256 mixin(parametersAliasSeqCode!(typeof(this))); 257 258 this() 259 { 260 foreach (ref p; parameters) 261 { 262 p = new typeof(p); 263 } 264 } 265 } 266 267 private template AllParameterMembersOf(T) 268 { 269 private template isParameterMember(string name) 270 { 271 import golem.nn : hasParameters; 272 import golem.tensor : isTensor; 273 274 alias MemberType = typeof(__traits(getMember, T.init, name)); 275 276 enum isParameterMember = hasParameters!(MemberType); 277 } 278 279 import std.traits : FieldNameTuple; 280 281 alias AllParameterMembersOf = Filter!(isParameterMember, FieldNameTuple!T); 282 } 283 284 string parametersAliasSeqCode(T)() 285 { 286 enum names = [AllParameterMembersOf!T]; 287 288 string code = "import std.meta : AliasSeq;\nalias parameters = AliasSeq!("; 289 foreach (i, name; names) 290 { 291 if (i > 0) 292 code ~= ","; 293 code ~= name; 294 } 295 code ~= ");"; 296 297 return code; 298 } 299 300 unittest 301 { 302 static class Test 303 { 304 Linear!(float, 16, 8) fc1; 305 BatchNorm!(float, [8]) bn1; 306 307 mixin NetModule; 308 // alias parameters = AliasSeq!(fc1, bn1); 309 // this() { 310 // foreach (ref p; parameters) 311 // p = new typeof(p); 312 // } 313 } 314 315 auto t = new Test; 316 static assert(t.parameters.length == 2); 317 assert(t.parameters[0] == t.fc1); 318 assert(t.parameters[1] == t.bn1); 319 assert(t.fc1 !is null); 320 assert(t.bn1 !is null); 321 }