1 module golem.models.linear; 2 3 import golem.math : sigmoid; 4 import golem.tensor : Tensor, UseGradient; 5 import golem.nn : Linear; 6 import golem.optimizer : Adam, SGD, AdaBelief, createOptimizer; 7 import golem.trainer : EarlyStopping; 8 import std.typecons : Tuple, tuple; 9 import std.meta : AliasSeq; 10 import std.array : array; 11 import mir.ndslice : ndarray; 12 13 struct LogisticFitOptions 14 { 15 size_t maxEpoch = 500; 16 float penaltyWeightsDecay = 1e-3; 17 } 18 19 class LogisticRegression(T, size_t InputDim, size_t OutputDim, UseGradient useGrad = UseGradient 20 .yes) 21 { 22 private alias InputTensor = Tensor!(T, [0, InputDim]); 23 private alias OutputTensor = Tensor!(T, [0, OutputDim]); 24 25 Linear!(T, InputDim, OutputDim, useGrad) weights; 26 27 alias parameters = AliasSeq!(weights); 28 29 this() 30 { 31 weights = new typeof(weights)(T(0)); 32 } 33 34 void fit(in Tuple!(T[], T[])[] train, in Tuple!(T[], T[])[] test, LogisticFitOptions options = LogisticFitOptions 35 .init) 36 { 37 import std.stdio; 38 39 auto optimizer = createOptimizer!AdaBelief(weights); 40 optimizer.config.weightDecay = options.penaltyWeightsDecay; 41 auto stopper = new EarlyStopping!T(); 42 43 auto dataset_train = makeTensors(train); 44 auto dataset_test = makeTensors(test); 45 foreach (epoch; 0 .. options.maxEpoch) 46 { 47 auto y_train = forward(dataset_train[0]); 48 auto loss_train = calculateLoss(y_train, dataset_train[1]); 49 50 optimizer.resetGrads(); 51 loss_train.backward(); 52 optimizer.trainStep(); 53 54 auto y_test = forward(dataset_test[0]); 55 auto loss_test = calculateLoss(y_test, dataset_test[1]); 56 57 if (stopper.shouldStop(loss_test)) 58 break; 59 } 60 } 61 62 void save(string modelDirPath) 63 { 64 import golem : ModelArchiver; 65 66 auto archiver = new ModelArchiver(modelDirPath); 67 archiver.save(weights); 68 } 69 70 void load(string modelDirPath) 71 { 72 import golem : ModelArchiver; 73 74 auto archiver = new ModelArchiver(modelDirPath); 75 archiver.load(weights); 76 } 77 78 T[] predict(T[] input) 79 in 80 { 81 assert(input.length == InputDim); 82 } 83 do 84 { 85 auto inputTensor = new Tensor!(T, [1, InputDim], UseGradient.no)(input); 86 auto output = forward(inputTensor); 87 88 return output.value[0].array(); 89 } 90 91 T[][] predict(T[][] inputs) 92 in 93 { 94 assert(inputs.length % InputDim == 0); 95 } 96 do 97 { 98 auto inputTensor = new Tensor!(T, [0, InputDim], UseGradient.no)(inputs); 99 auto output = forward(inputTensor); 100 101 return output.value.ndarray(); 102 } 103 104 private auto makeTensors(in Tuple!(T[], T[])[] dataset) 105 { 106 import std.array : appender; 107 108 immutable size = dataset.length; 109 auto inputBuf = new T[size * InputDim]; 110 auto labelBuf = new T[size * OutputDim]; 111 for (size_t i = 0, inputPos = 0, labelPos = 0; i < size; i++, inputPos += InputDim, labelPos += OutputDim) 112 { 113 auto temp = dataset[i]; 114 inputBuf[inputPos .. inputPos + InputDim] = temp[0]; 115 labelBuf[labelPos .. labelPos + OutputDim] = temp[1]; 116 } 117 118 auto inputTensor = new InputTensor(inputBuf); 119 auto labelTensor = new OutputTensor(labelBuf); 120 121 return tuple(inputTensor, labelTensor); 122 } 123 124 private auto forward(U)(U x) 125 { 126 return sigmoid(weights(x)); 127 } 128 129 private auto calculateLoss(U, V)(U output, V label) 130 { 131 import golem : sum; 132 133 auto temp = label - output; 134 static if (OutputDim == 1) 135 return sum(temp * temp); 136 else 137 return sum(sum(temp * temp)); 138 } 139 } 140 141 unittest 142 { 143 Tuple!(float[], float[])[] dataset = [ 144 tuple([20.0f, 0.0f, 1], [0.0f]), 145 tuple([24.0f, 0.5f, 1], [0.0f]), 146 tuple([30.0f, 0.0f, 0], [0.0f]), 147 tuple([36.0f, 1.0f, 0], [0.0f]), 148 tuple([48.0f, 1.0f, 1], [0.0f]), 149 tuple([58.0f, 1.0f, 1], [1.0f]), 150 tuple([55.0f, 1.5f, 1], [1.0f]), 151 tuple([18.0f, 0.5f, 0], [0.0f]), 152 tuple([24.0f, 0.5f, 0], [0.0f]), 153 tuple([30.0f, 0.0f, 1], [0.0f]), 154 tuple([34.0f, 1.0f, 1], [0.0f]), 155 tuple([50.0f, 1.5f, 0], [0.0f]), 156 tuple([64.0f, 1.0f, 0], [1.0f]), 157 tuple([57.0f, 2.0f, 1], [1.0f]), 158 ]; 159 160 auto model = new LogisticRegression!(float, 3, 1); 161 162 model.fit(dataset, dataset); 163 164 assert(model.predict(dataset[0][0])[0] < 0.5); 165 assert(model.predict(dataset[5][0])[0] > 0.5); 166 } 167 168 unittest 169 { 170 Tuple!(float[], float[])[] dataset = [ 171 tuple([20.0f, 0.0f, 1], [0.0f, 0]), 172 tuple([24.0f, 0.5f, 1], [0.0f, 0]), 173 tuple([30.0f, 0.0f, 0], [0.0f, 0]), 174 tuple([36.0f, 1.0f, 0], [0.0f, 1]), 175 tuple([48.0f, 1.0f, 1], [0.0f, 0]), 176 tuple([58.0f, 1.0f, 1], [1.0f, 0]), 177 tuple([55.0f, 1.5f, 1], [1.0f, 0]), 178 tuple([18.0f, 0.5f, 0], [0.0f, 0]), 179 tuple([24.0f, 0.5f, 0], [0.0f, 0]), 180 tuple([30.0f, 0.0f, 1], [0.0f, 0]), 181 tuple([34.0f, 1.0f, 1], [0.0f, 0]), 182 tuple([50.0f, 1.5f, 0], [0.0f, 1]), 183 tuple([64.0f, 1.0f, 0], [1.0f, 1]), 184 tuple([57.0f, 2.0f, 1], [1.0f, 0]), 185 ]; 186 187 auto model = new LogisticRegression!(float, 3, 2); 188 189 model.fit(dataset, dataset); 190 191 assert(model.predict(dataset[0][0])[0] < 0.5); 192 assert(model.predict(dataset[5][0])[0] > 0.5); 193 assert(model.predict(dataset[11][0])[1] > 0.5); 194 }