ditto
auto x = tensor!([0, 1, 2, 2])([1.0, 2.0, 3.0, 4.0]); auto sh = splitEvenOdd2D(x); // split by height assert(sh[0].shape == [1, 1, 1, 2]); assert(sh[0].value == [[[[1.0, 2.0]]]]); assert(sh[1].shape == [1, 1, 1, 2]); assert(sh[1].value == [[[[3.0, 4.0]]]]); sh[0].backward(); sh[1].backward(); auto sw = splitEvenOdd2D!3(x); // split by width assert(sw[0].shape == [1, 1, 2, 1]); assert(sw[0].value == [[[[1.0], [3.0]]]]); assert(sw[1].shape == [1, 1, 2, 1]); assert(sw[1].value == [[[[2.0], [4.0]]]]); sw[0].backward(); sw[1].backward();
ditto
auto x = tensor!([0, 2, 2, 2], UseGradient.no)([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]); auto sh = splitEvenOdd2D!2(x); // split by height assert(sh[0].shape == [1, 2, 1, 2]); assert(sh[0].value == [[[[1.0, 2.0]], [[5.0, 6.0]]]]); assert(sh[1].shape == [1, 2, 1, 2]); assert(sh[1].value == [[[[3.0, 4.0]], [[7.0, 8.0]]]]); static assert(!canBackward!(typeof(sh))); auto sw = splitEvenOdd2D!3(x); // split by width assert(sw[0].shape == [1, 2, 2, 1]); assert(sw[0].value == [[[[1.0], [3.0]], [[5.0], [7.0]]]]); assert(sw[1].shape == [1, 2, 2, 1]); assert(sw[1].value == [[[[2.0], [4.0]], [[6.0], [8.0]]]]); static assert(!canBackward!(typeof(sw)));