1 module golem.metrics; 2 3 import golem.tensor; 4 5 /// calc accuracy for 1 class (single 0-1 output) 6 float accuracy(T, size_t[] Shape, UseGradient useGrad1, UseGradient useGrad2)(Tensor!(T, Shape, useGrad1) output, Tensor!(T, Shape, useGrad2) label) 7 if (Shape.length == 2 && Shape[1] == 1) 8 in(output.shape[0] == label.shape[0]) 9 { 10 import std.algorithm : maxIndex; 11 12 const batchSize = output.shape[0]; 13 14 size_t trueAnswer; 15 foreach (i; 0 .. batchSize) 16 { 17 const x = output.value[i, 0] > 0.5; 18 const y = label.value[i, 0] > 0.5; 19 if (x == y) ++trueAnswer; 20 } 21 22 return float(trueAnswer) / float(batchSize); 23 } 24 25 /// ditto 26 unittest 27 { 28 auto xt = tensor!([0, 1])([0.1, 0.9, 0.8, 0.2]); 29 auto y0 = tensor!([0, 1])([1.0, 0.0, 0.0, 1.0]); // true: 0 30 auto y1 = tensor!([0, 1])([1.0, 0.0, 1.0, 1.0]); // true: 1 31 auto y2 = tensor!([0, 1])([1.0, 1.0, 1.0, 1.0]); // true: 2 32 auto y3 = tensor!([0, 1])([1.0, 1.0, 1.0, 0.0]); // true: 3 33 auto y4 = tensor!([0, 1])([0.0, 1.0, 1.0, 0.0]); // true: 4 34 35 assert(accuracy(xt, y0) == 0.0f); 36 assert(accuracy(xt, y1) == 0.25f); 37 assert(accuracy(xt, y2) == 0.5f); 38 assert(accuracy(xt, y3) == 0.75f); 39 assert(accuracy(xt, y4) == 1.0f); 40 } 41 42 43 /// calc accuracy for multi class (multiple 0-1 output) 44 float accuracy(T, size_t[] Shape, UseGradient useGrad1, UseGradient useGrad2)(Tensor!(T, Shape, useGrad1) output, Tensor!(T, Shape, useGrad2) label) 45 if (Shape.length == 2 && Shape[1] > 1) 46 in(output.shape[0] == label.shape[0]) 47 { 48 import std.algorithm : maxIndex; 49 50 const batchSize = output.shape[0]; 51 52 size_t trueAnswer; 53 foreach (i; 0 .. batchSize) 54 { 55 const x = maxIndex(output.value[i, 0 .. $]); 56 const y = maxIndex(label.value[i, 0 .. $]); 57 if (x == y) ++trueAnswer; 58 } 59 60 return float(trueAnswer) / float(batchSize); 61 } 62 63 /// ditto 64 unittest 65 { 66 auto xt = tensor!([0, 2])([[0.1, 0.9], [0.8, 0.2], [0.4, 0.6], [0.2, 0.8]]); 67 auto y0 = tensor!([0, 2])([[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 0.0]]); // true: 0 68 auto y1 = tensor!([0, 2])([[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]]); // true: 1 69 auto y2 = tensor!([0, 2])([[1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [0.0, 1.0]]); // true: 2 70 auto y3 = tensor!([0, 2])([[1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]]); // true: 3 71 auto y4 = tensor!([0, 2])([[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]]); // true: 4 72 73 assert(accuracy(xt, y0) == 0.0f); 74 assert(accuracy(xt, y1) == 0.25f); 75 assert(accuracy(xt, y2) == 0.5f); 76 assert(accuracy(xt, y3) == 0.75f); 77 assert(accuracy(xt, y4) == 1.0f); 78 } 79 80 Tensor!(size_t, [Shape[1], Shape[1]], UseGradient.no) confusionMatrix(T, size_t[] Shape, UseGradient useGrad1, UseGradient useGrad2)(Tensor!(T, Shape, useGrad1) x, Tensor!(T, Shape, useGrad2) y) 81 if (Shape.length == 2 && Shape[1] > 1) 82 { 83 assert(x.shape[0] == y.shape[0]); 84 85 import std.algorithm : maxIndex; 86 import mir.ndslice; 87 88 auto result = slice!size_t([Shape[1], Shape[1]], 0); 89 foreach (i; 0 .. x.shape[0]) 90 { 91 const xindex = maxIndex(x.value[i]); 92 const yindex = maxIndex(y.value[i]); 93 result[yindex, xindex]++; 94 } 95 return new Tensor!(size_t, [Shape[1], Shape[1]], UseGradient.no)(result); 96 } 97 98 unittest 99 { 100 auto x = tensor!([0, 4])([ 101 [0.0f, 0.0f, 0.0f, 1.0f], 102 [0.0f, 0.0f, 1.0f, 0.0f], 103 [0.0f, 1.0f, 0.0f, 0.0f], 104 [1.0f, 0.0f, 0.0f, 0.0f], 105 [0.0f, 0.0f, 0.0f, 1.0f], 106 [0.0f, 0.0f, 1.0f, 0.0f], 107 [0.0f, 1.0f, 0.0f, 0.0f], 108 [1.0f, 0.0f, 0.0f, 0.0f], 109 [0.0f, 0.0f, 0.0f, 1.0f], 110 [0.0f, 0.0f, 1.0f, 0.0f], 111 [0.0f, 1.0f, 0.0f, 0.0f], 112 [1.0f, 0.0f, 0.0f, 0.0f], 113 [0.0f, 0.0f, 0.0f, 1.0f], 114 [0.0f, 0.0f, 1.0f, 0.0f], 115 [0.0f, 1.0f, 0.0f, 0.0f], 116 [1.0f, 0.0f, 0.0f, 0.0f], 117 ]); 118 119 auto y = tensor!([0, 4])([ 120 [0.0f, 0.0f, 0.0f, 1.0f], 121 [0.0f, 0.0f, 1.0f, 0.0f], 122 [0.0f, 1.0f, 0.0f, 0.0f], 123 [1.0f, 0.0f, 0.0f, 0.0f], 124 [0.0f, 0.0f, 1.0f, 0.0f], 125 [0.0f, 1.0f, 0.0f, 0.0f], 126 [1.0f, 0.0f, 0.0f, 0.0f], 127 [0.0f, 0.0f, 0.0f, 1.0f], 128 [0.0f, 1.0f, 0.0f, 0.0f], 129 [1.0f, 0.0f, 0.0f, 0.0f], 130 [0.0f, 0.0f, 0.0f, 1.0f], 131 [0.0f, 0.0f, 1.0f, 0.0f], 132 [1.0f, 0.0f, 0.0f, 0.0f], 133 [0.0f, 0.0f, 0.0f, 1.0f], 134 [0.0f, 0.0f, 1.0f, 0.0f], 135 [0.0f, 1.0f, 0.0f, 0.0f], 136 ]); 137 138 auto m = confusionMatrix(x, y); 139 static import numir; 140 141 assert(m.value == numir.ones!size_t(4, 4)); 142 }