1 module golem.models.linear;
2 
3 import golem.math : sigmoid;
4 import golem.tensor : Tensor, UseGradient;
5 import golem.nn : Linear;
6 import golem.optimizer : Adam, SGD, AdaBelief, createOptimizer;
7 import golem.trainer :  EarlyStopping;
8 import std.typecons : Tuple, tuple;
9 import std.meta : AliasSeq;
10 import std.array : array;
11 import mir.ndslice : ndarray;
12 
13 struct LogisticFitOptions
14 {
15     size_t maxEpoch = 500;
16     float penaltyWeightsDecay = 1e-3;
17 }
18 
19 class LogisticRegression(T, size_t InputDim, size_t OutputDim, UseGradient useGrad = UseGradient
20         .yes)
21 {
22     private alias InputTensor = Tensor!(T, [0, InputDim]);
23     private alias OutputTensor = Tensor!(T, [0, OutputDim]);
24 
25     Linear!(T, InputDim, OutputDim, useGrad) weights;
26 
27     alias parameters = AliasSeq!(weights);
28 
29     this()
30     {
31         weights = new typeof(weights)(T(0));
32     }
33 
34     void fit(in Tuple!(T[], T[])[] train, in Tuple!(T[], T[])[] test, LogisticFitOptions options = LogisticFitOptions
35             .init)
36     {
37         import std.stdio;
38 
39         auto optimizer = createOptimizer!AdaBelief(weights);
40         optimizer.config.weightDecay = options.penaltyWeightsDecay;
41         auto stopper = new EarlyStopping!T();
42 
43         auto dataset_train = makeTensors(train);
44         auto dataset_test = makeTensors(test);
45         foreach (epoch; 0 .. options.maxEpoch)
46         {
47             auto y_train = forward(dataset_train[0]);
48             auto loss_train = calculateLoss(y_train, dataset_train[1]);
49 
50             optimizer.resetGrads();
51             loss_train.backward();
52             optimizer.trainStep();
53 
54             auto y_test = forward(dataset_test[0]);
55             auto loss_test = calculateLoss(y_test, dataset_test[1]);
56 
57             if (stopper.shouldStop(loss_test))
58                 break;
59         }
60     }
61 
62     void save(string modelDirPath)
63     {
64         import golem : ModelArchiver;
65 
66         auto archiver = new ModelArchiver(modelDirPath);
67         archiver.save(weights);
68     }
69 
70     void load(string modelDirPath)
71     {
72         import golem : ModelArchiver;
73 
74         auto archiver = new ModelArchiver(modelDirPath);
75         archiver.load(weights);
76     }
77 
78     T[] predict(T[] input)
79     in
80     {
81         assert(input.length == InputDim);
82     }
83     do
84     {
85         auto inputTensor = new Tensor!(T, [1, InputDim], UseGradient.no)(input);
86         auto output = forward(inputTensor);
87 
88         return output.value[0].array();
89     }
90 
91     T[][] predict(T[][] inputs)
92     in
93     {
94         assert(inputs.length % InputDim == 0);
95     }
96     do
97     {
98         auto inputTensor = new Tensor!(T, [0, InputDim], UseGradient.no)(inputs);
99         auto output = forward(inputTensor);
100 
101         return output.value.ndarray();
102     }
103 
104     private auto makeTensors(in Tuple!(T[], T[])[] dataset)
105     {
106         import std.array : appender;
107 
108         immutable size = dataset.length;
109         auto inputBuf = new T[size * InputDim];
110         auto labelBuf = new T[size * OutputDim];
111         for (size_t i = 0, inputPos = 0, labelPos = 0; i < size; i++, inputPos += InputDim, labelPos += OutputDim)
112         {
113             auto temp = dataset[i];
114             inputBuf[inputPos .. inputPos + InputDim] = temp[0];
115             labelBuf[labelPos .. labelPos + OutputDim] = temp[1];
116         }
117 
118         auto inputTensor = new InputTensor(inputBuf);
119         auto labelTensor = new OutputTensor(labelBuf);
120 
121         return tuple(inputTensor, labelTensor);
122     }
123 
124     private auto forward(U)(U x)
125     {
126         return sigmoid(weights(x));
127     }
128 
129     private auto calculateLoss(U, V)(U output, V label)
130     {
131         import golem : sum;
132 
133         auto temp = label - output;
134         static if (OutputDim == 1)
135             return sum(temp * temp);
136         else
137             return sum(sum(temp * temp));
138     }
139 }
140 
141 unittest
142 {
143     Tuple!(float[], float[])[] dataset = [
144         tuple([20.0f, 0.0f, 1], [0.0f]),
145         tuple([24.0f, 0.5f, 1], [0.0f]),
146         tuple([30.0f, 0.0f, 0], [0.0f]),
147         tuple([36.0f, 1.0f, 0], [0.0f]),
148         tuple([48.0f, 1.0f, 1], [0.0f]),
149         tuple([58.0f, 1.0f, 1], [1.0f]),
150         tuple([55.0f, 1.5f, 1], [1.0f]),
151         tuple([18.0f, 0.5f, 0], [0.0f]),
152         tuple([24.0f, 0.5f, 0], [0.0f]),
153         tuple([30.0f, 0.0f, 1], [0.0f]),
154         tuple([34.0f, 1.0f, 1], [0.0f]),
155         tuple([50.0f, 1.5f, 0], [0.0f]),
156         tuple([64.0f, 1.0f, 0], [1.0f]),
157         tuple([57.0f, 2.0f, 1], [1.0f]),
158     ];
159 
160     auto model = new LogisticRegression!(float, 3, 1);
161 
162     model.fit(dataset, dataset);
163 
164     assert(model.predict(dataset[0][0])[0] < 0.5);
165     assert(model.predict(dataset[5][0])[0] > 0.5);
166 }
167 
168 unittest
169 {
170     Tuple!(float[], float[])[] dataset = [
171         tuple([20.0f, 0.0f, 1], [0.0f, 0]),
172         tuple([24.0f, 0.5f, 1], [0.0f, 0]),
173         tuple([30.0f, 0.0f, 0], [0.0f, 0]),
174         tuple([36.0f, 1.0f, 0], [0.0f, 1]),
175         tuple([48.0f, 1.0f, 1], [0.0f, 0]),
176         tuple([58.0f, 1.0f, 1], [1.0f, 0]),
177         tuple([55.0f, 1.5f, 1], [1.0f, 0]),
178         tuple([18.0f, 0.5f, 0], [0.0f, 0]),
179         tuple([24.0f, 0.5f, 0], [0.0f, 0]),
180         tuple([30.0f, 0.0f, 1], [0.0f, 0]),
181         tuple([34.0f, 1.0f, 1], [0.0f, 0]),
182         tuple([50.0f, 1.5f, 0], [0.0f, 1]),
183         tuple([64.0f, 1.0f, 0], [1.0f, 1]),
184         tuple([57.0f, 2.0f, 1], [1.0f, 0]),
185     ];
186 
187     auto model = new LogisticRegression!(float, 3, 2);
188 
189     model.fit(dataset, dataset);
190 
191     assert(model.predict(dataset[0][0])[0] < 0.5);
192     assert(model.predict(dataset[5][0])[0] > 0.5);
193     assert(model.predict(dataset[11][0])[1] > 0.5);
194 }