1 module golem.util; 2 3 import mir.ndslice; 4 5 template expandShape(size_t[] Shape) 6 { 7 import std.meta : AliasSeq; 8 9 static if (Shape.length == 1) 10 { 11 alias expandShape = AliasSeq!(Shape[0]); 12 } 13 else 14 { 15 alias expandShape = AliasSeq!(Shape[0], expandShape!(Shape[1 .. $])); 16 } 17 } 18 19 size_t[] trimRightOneDims(size_t[] shape) 20 in 21 { 22 assert(shape.length >= 1); 23 } 24 do 25 { 26 if (shape.length == 1) 27 { 28 return shape; 29 } 30 foreach_reverse (i, dim; shape) 31 { 32 if (dim == 1) continue; 33 34 return shape[0 .. i + 1]; 35 } 36 return shape[0 .. 1]; 37 } 38 39 unittest 40 { 41 assert(trimRightOneDims([0, 2, 2, 2]) == [0, 2, 2, 2]); 42 assert(trimRightOneDims([0, 2, 2, 1]) == [0, 2, 2]); 43 assert(trimRightOneDims([0, 2, 1, 1]) == [0, 2]); 44 assert(trimRightOneDims([0, 1, 1, 1]) == [0]); 45 assert(trimRightOneDims([1, 1, 1, 1]) == [1]); 46 assert(trimRightOneDims([0, 1, 1, 2]) == [0, 1, 1, 2]); 47 assert(trimRightOneDims([1, 2, 2, 2]) == [1, 2, 2, 2]); 48 } 49 50 template expandIndex(size_t From, size_t To) 51 if (From <= To) 52 { 53 import std.meta : AliasSeq; 54 55 static if (From == To - 1) 56 alias expandIndex = AliasSeq!(From); 57 else 58 alias expandIndex = AliasSeq!(From, expandIndex!(From + 1, To)); 59 } 60 61 unittest 62 { 63 alias s = expandIndex!(2, 4); 64 static assert(s.length == 2); 65 static assert(s[0] == 2); 66 static assert(s[1] == 3); 67 } 68 69 unittest 70 { 71 alias s = expandIndex!(3, 6); 72 static assert(s.length == 3); 73 static assert(s[0] == 3); 74 static assert(s[1] == 4); 75 static assert(s[2] == 5); 76 } 77 78 size_t elementSize(size_t[] shape) 79 { 80 if (shape[0] == 0) 81 { 82 return elementSize(shape[1 .. $]); 83 } 84 size_t s = 1; 85 foreach (x; shape) 86 { 87 s *= x; 88 } 89 return s; 90 } 91 92 93 package template staticIndexOf(alias F, Ts...) 94 { 95 static if (Ts.length == 0) 96 { 97 enum staticIndexOf = -1; 98 } 99 else 100 { 101 enum staticIndexOf = staticIndexOfImpl!(F, 0, Ts); 102 } 103 } 104 105 package template staticIndexOfImpl(alias F, size_t pos, Ts...) 106 { 107 static if (Ts.length == 0) 108 { 109 enum staticIndexOfImpl = -1; 110 } 111 else 112 { 113 static if (F!(Ts[0])) 114 { 115 enum staticIndexOfImpl = pos; 116 } 117 else 118 { 119 enum staticIndexOfImpl = staticIndexOfImpl!(F, pos + 1, Ts[1 .. $]); 120 } 121 } 122 } 123 124 125 package auto bringToFront(size_t M, T)(T x) 126 if (isSlice!T) 127 { 128 return x.transposed!(expandIndex!(T.N - M, T.N)); 129 } 130 131 unittest 132 { 133 import mir.ndslice; 134 135 auto x = iota(2, 3, 4, 5); 136 auto y = x.bringToFront!2; 137 assert(y.shape == [4, 5, 2, 3]); 138 auto z = x.bringToFront!3; 139 assert(z.shape == [3, 4, 5, 2]); 140 }