1 module golem.data.common; 2 3 import std.typecons; 4 5 Tuple!(T[], T[])[N] kfold(size_t N, T)(T[] source) 6 out(r) 7 { 8 size_t count; 9 foreach (t; r) 10 { 11 assert(t[0].length + t[1].length == source.length); 12 count += t[1].length; 13 } 14 assert(count == source.length); 15 } 16 do 17 { 18 typeof(return) result; 19 20 immutable len = source.length / N; 21 for (size_t i = 0, pos = 0; i < N - 1; i++, pos += len) 22 { 23 result[i][0] = source[0 .. pos] ~ source[pos + len .. $]; 24 result[i][1] = source[pos .. pos + len]; 25 } 26 result[N - 1][0] = source[0 .. len * (N - 1)]; 27 result[N - 1][1] = source[len * (N - 1) .. $]; 28 29 return result; 30 } 31 32 unittest 33 { 34 auto dataSource = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0]; 35 auto dataLoader = dataSource.kfold!5(); 36 37 assert(dataLoader.length == 5); 38 39 assert(dataLoader[0][0] == [3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0]); 40 assert(dataLoader[0][1] == [1.0, 2.0]); 41 42 assert(dataLoader[1][0] == [1.0, 2.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0]); 43 assert(dataLoader[1][1] == [3.0, 4.0]); 44 45 assert(dataLoader[2][0] == [1.0, 2.0, 3.0, 4.0, 7.0, 8.0, 9.0, 10.0, 11.0]); 46 assert(dataLoader[2][1] == [5.0, 6.0]); 47 48 assert(dataLoader[3][0] == [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 9.0, 10.0, 11.0]); 49 assert(dataLoader[3][1] == [7.0, 8.0]); 50 51 assert(dataLoader[4][0] == [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]); 52 assert(dataLoader[4][1] == [9.0, 10.0, 11.0]); 53 54 import std.parallelism: parallel; 55 56 foreach (dataset; parallel(dataLoader[])) 57 { 58 auto train = dataset[0]; 59 auto test = dataset[1]; 60 } 61 }