1 module golem.data.common;
2 
3 import std.typecons;
4 
5 Tuple!(T[], T[])[N] kfold(size_t N, T)(T[] source)
6 out(r)
7 {
8     size_t count;
9     foreach (t; r)
10     {
11         assert(t[0].length + t[1].length == source.length);
12         count += t[1].length;
13     }
14     assert(count == source.length);
15 }
16 do
17 {
18     typeof(return) result;
19 
20     immutable len = source.length / N;
21     for (size_t i = 0, pos = 0; i < N - 1; i++, pos += len)
22     {
23         result[i][0] = source[0 .. pos] ~ source[pos + len .. $];
24         result[i][1] = source[pos .. pos + len];
25     }
26     result[N - 1][0] = source[0 .. len * (N - 1)];
27     result[N - 1][1] = source[len * (N - 1) .. $];
28 
29     return result;
30 }
31 
32 unittest
33 {
34     auto dataSource = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0];
35     auto dataLoader = dataSource.kfold!5();
36 
37     assert(dataLoader.length == 5);
38 
39     assert(dataLoader[0][0] == [3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0]);
40     assert(dataLoader[0][1] == [1.0, 2.0]);
41     
42     assert(dataLoader[1][0] == [1.0, 2.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0]);
43     assert(dataLoader[1][1] == [3.0, 4.0]);
44     
45     assert(dataLoader[2][0] == [1.0, 2.0, 3.0, 4.0, 7.0, 8.0, 9.0, 10.0, 11.0]);
46     assert(dataLoader[2][1] == [5.0, 6.0]);
47     
48     assert(dataLoader[3][0] == [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 9.0, 10.0, 11.0]);
49     assert(dataLoader[3][1] == [7.0, 8.0]);
50     
51     assert(dataLoader[4][0] == [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
52     assert(dataLoader[4][1] == [9.0, 10.0, 11.0]);
53 
54     import std.parallelism: parallel;
55 
56     foreach (dataset; parallel(dataLoader[]))
57     {
58         auto train = dataset[0];
59         auto test = dataset[1];
60     }
61 }