1 module golem.trainer;
2 
3 import golem.tensor;
4 
5 ///
6 class EarlyStopping(T)
7 {
8 	immutable size_t patience = 3;
9 
10 	size_t patienceStep = 0;
11 	T minLoss;
12 
13 	this() @safe pure nothrow
14 	{
15 	}
16 
17 	this(size_t patience) @safe pure nothrow
18 	{
19 		this.patience = patience;
20 	}
21 
22     bool shouldStop(Tensor!(T, [1]) loss) @safe @nogc nothrow
23     {
24         return shouldStop(loss.value[0]);
25     }
26 
27     bool shouldStop(Tensor!(T, [1], UseGradient.no) loss) @safe @nogc nothrow
28     {
29         return shouldStop(loss.value[0]);
30     }
31 
32 	bool shouldStop(T loss) @safe @nogc nothrow
33 	{
34         import std.math : isNaN;
35 
36 		if (isNaN(minLoss) || loss < minLoss)
37 		{
38 			minLoss = loss;
39 			patienceStep = 0;
40 			return false;
41 		}
42 
43 		if (++patienceStep >= patience)
44 		{
45 			return true;
46 		}
47 
48 		return false;
49 	}
50 }
51 
52 /// ditto
53 unittest
54 {
55 	auto es = new EarlyStopping!float;
56 
57     assert(!es.shouldStop(1.0f));
58     assert(!es.shouldStop(0.9f));
59     assert(!es.shouldStop(0.6f));
60     assert(!es.shouldStop(0.7f));
61     assert(!es.shouldStop(0.7f));
62     assert(es.shouldStop(0.7f));
63 }
64 
65 /// ditto
66 unittest
67 {
68     auto es = new EarlyStopping!float(2);
69 
70     assert(!es.shouldStop(tensor!([1])([1.0f])));
71     assert(!es.shouldStop(tensor!([1])([0.8f])));
72     assert(!es.shouldStop(tensor!([1])([1.0f])));
73     assert(es.shouldStop(tensor!([1])([1.0f])));
74 }