1 module golem.trainer; 2 3 import golem.tensor; 4 5 /// 6 class EarlyStopping(T) 7 { 8 immutable size_t patience = 3; 9 10 size_t patienceStep = 0; 11 T minLoss; 12 13 this() @safe pure nothrow 14 { 15 } 16 17 this(size_t patience) @safe pure nothrow 18 { 19 this.patience = patience; 20 } 21 22 bool shouldStop(Tensor!(T, [1]) loss) @safe @nogc nothrow 23 { 24 return shouldStop(loss.value[0]); 25 } 26 27 bool shouldStop(Tensor!(T, [1], UseGradient.no) loss) @safe @nogc nothrow 28 { 29 return shouldStop(loss.value[0]); 30 } 31 32 bool shouldStop(T loss) @safe @nogc nothrow 33 { 34 import std.math : isNaN; 35 36 if (isNaN(minLoss) || loss < minLoss) 37 { 38 minLoss = loss; 39 patienceStep = 0; 40 return false; 41 } 42 43 if (++patienceStep >= patience) 44 { 45 return true; 46 } 47 48 return false; 49 } 50 } 51 52 /// ditto 53 unittest 54 { 55 auto es = new EarlyStopping!float; 56 57 assert(!es.shouldStop(1.0f)); 58 assert(!es.shouldStop(0.9f)); 59 assert(!es.shouldStop(0.6f)); 60 assert(!es.shouldStop(0.7f)); 61 assert(!es.shouldStop(0.7f)); 62 assert(es.shouldStop(0.7f)); 63 } 64 65 /// ditto 66 unittest 67 { 68 auto es = new EarlyStopping!float(2); 69 70 assert(!es.shouldStop(tensor!([1])([1.0f]))); 71 assert(!es.shouldStop(tensor!([1])([0.8f]))); 72 assert(!es.shouldStop(tensor!([1])([1.0f]))); 73 assert(es.shouldStop(tensor!([1])([1.0f]))); 74 }