1 module golem.util;
2 
3 import mir.ndslice;
4 
5 template expandShape(size_t[] Shape)
6 {
7     import std.meta : AliasSeq;
8 
9     static if (Shape.length == 1)
10     {
11         alias expandShape = AliasSeq!(Shape[0]);
12     }
13     else
14     {
15         alias expandShape = AliasSeq!(Shape[0], expandShape!(Shape[1 .. $]));
16     }
17 }
18 
19 size_t[] trimRightOneDims(size_t[] shape)
20 in
21 {
22     assert(shape.length >= 1);
23 }
24 do
25 {
26     if (shape.length == 1)
27     {
28         return shape;
29     }
30     foreach_reverse (i, dim; shape)
31     {
32         if (dim == 1) continue;
33 
34         return shape[0 .. i + 1];
35     }
36     return shape[0 .. 1];
37 }
38 
39 unittest
40 {
41     assert(trimRightOneDims([0, 2, 2, 2]) == [0, 2, 2, 2]);
42     assert(trimRightOneDims([0, 2, 2, 1]) == [0, 2, 2]);
43     assert(trimRightOneDims([0, 2, 1, 1]) == [0, 2]);
44     assert(trimRightOneDims([0, 1, 1, 1]) == [0]);
45     assert(trimRightOneDims([1, 1, 1, 1]) == [1]);
46     assert(trimRightOneDims([0, 1, 1, 2]) == [0, 1, 1, 2]);
47     assert(trimRightOneDims([1, 2, 2, 2]) == [1, 2, 2, 2]);
48 }
49 
50 template expandIndex(size_t From, size_t To)
51 if (From <= To)
52 {
53     import std.meta : AliasSeq;
54 
55     static if (From == To - 1)
56         alias expandIndex = AliasSeq!(From);
57     else
58         alias expandIndex = AliasSeq!(From, expandIndex!(From + 1, To));
59 }
60 
61 unittest
62 {
63     alias s = expandIndex!(2, 4);
64     static assert(s.length == 2);
65     static assert(s[0] == 2);
66     static assert(s[1] == 3);
67 }
68 
69 unittest
70 {
71     alias s = expandIndex!(3, 6);
72     static assert(s.length == 3);
73     static assert(s[0] == 3);
74     static assert(s[1] == 4);
75     static assert(s[2] == 5);
76 }
77 
78 size_t elementSize(size_t[] shape)
79 {
80     if (shape[0] == 0)
81     {
82         return elementSize(shape[1 .. $]);
83     }
84     size_t s = 1;
85     foreach (x; shape)
86     {
87         s *= x;
88     }
89     return s;
90 }
91 
92 
93 package template staticIndexOf(alias F, Ts...)
94 {
95     static if (Ts.length == 0)
96     {
97         enum staticIndexOf = -1;
98     }
99     else
100     {
101         enum staticIndexOf = staticIndexOfImpl!(F, 0, Ts);
102     }
103 }
104 
105 package template staticIndexOfImpl(alias F, size_t pos, Ts...)
106 {
107     static if (Ts.length == 0)
108     {
109         enum staticIndexOfImpl = -1;
110     }
111     else
112     {
113         static if (F!(Ts[0]))
114         {
115             enum staticIndexOfImpl = pos;
116         }
117         else
118         {
119             enum staticIndexOfImpl = staticIndexOfImpl!(F, pos + 1, Ts[1 .. $]);
120         }
121     }
122 }
123 
124 
125 package auto bringToFront(size_t M, T)(T x)
126 if (isSlice!T)
127 {
128     return x.transposed!(expandIndex!(T.N - M, T.N));
129 }
130 
131 unittest
132 {
133     import mir.ndslice;
134 
135     auto x = iota(2, 3, 4, 5);
136     auto y = x.bringToFront!2;
137     assert(y.shape == [4, 5, 2, 3]);
138     auto z = x.bringToFront!3;
139     assert(z.shape == [3, 4, 5, 2]);
140 }