Skip to content

Commit 693b699

Browse files
committed
Add split_tensor_along method
- Add `TrySplitShapeAlong` and `TrySplitTensorAlong`. - Minor linting and docs fix. TODO - Check if the tape should be returned. If not, it can be removed from the interface. - Add cuda kernel. - Consider a different interface, where it could get split in more than two tensors - possibly stated on a vec. In this way it could get closer to the pytorch interface (chunks).
1 parent 4722a99 commit 693b699

File tree

9 files changed

+604
-8
lines changed

9 files changed

+604
-8
lines changed

Diff for: dfdx-core/src/tensor_ops/concat_tensor_along/cpu_kernel.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ impl<E: Dtype> super::ConcatAlongKernel<E> for Cpu {
2626
let buf = std::sync::Arc::get_mut(&mut c.data).unwrap();
2727
while i < n {
2828
for _ in 0..a_n {
29-
buf[i] = a.data[a_idx.next().unwrap()];
29+
(*buf)[i] = a.data[a_idx.next().unwrap()];
3030
i += 1;
3131
}
3232
for _ in 0..b_n {
33-
buf[i] = b.data[b_idx.next().unwrap()];
33+
(*buf)[i] = b.data[b_idx.next().unwrap()];
3434
i += 1;
3535
}
3636
}
@@ -59,11 +59,11 @@ impl<E: Dtype> super::ConcatAlongKernel<E> for Cpu {
5959
let n = grad_out.len();
6060
while i < n {
6161
for _ in 0..a_n {
62-
grad_a[a_idx.next().unwrap()] += grad_out[i];
62+
(*grad_a)[a_idx.next().unwrap()] += grad_out[i];
6363
i += 1;
6464
}
6565
for _ in 0..b_n {
66-
grad_b[b_idx.next().unwrap()] += grad_out[i];
66+
(*grad_b)[b_idx.next().unwrap()] += grad_out[i];
6767
i += 1;
6868
}
6969
}

Diff for: dfdx-core/src/tensor_ops/concat_tensor_along/mod.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ mod webgpu_kernel;
1919
/// # let dev: Cpu = Default::default();
2020
/// let a: Tensor<Rank2<3, 4>, f32, _> = dev.zeros();
2121
/// let b: Tensor<Rank2<3, 4>, f32, _> = dev.zeros();
22-
/// let _: Tensor<Rank2<6, 4>, f32, _> = (a, b).concat_along(Axis::<0>);
22+
/// let _: Tensor<Rank2<6, 4>, f32, _> = (a, b).concat_tensor_along(Axis::<0>);
2323
/// ```
2424
///
2525
/// Along Axis 1:
@@ -28,7 +28,7 @@ mod webgpu_kernel;
2828
/// # let dev: Cpu = Default::default();
2929
/// let a: Tensor<Rank2<3, 4>, f32, _> = dev.zeros();
3030
/// let b: Tensor<Rank2<3, 4>, f32, _> = dev.zeros();
31-
/// let _: Tensor<Rank2<3, 8>, f32, _> = (a, b).concat_along(Axis::<1>);
31+
/// let _: Tensor<Rank2<3, 8>, f32, _> = (a, b).concat_tensor_along(Axis::<1>);
3232
/// ```
3333
///
3434
/// # [usize] dims
@@ -38,7 +38,7 @@ mod webgpu_kernel;
3838
/// # let dev: Cpu = Default::default();
3939
/// let a: Tensor<(usize, Const<3>), f32, _> = dev.zeros_like(&(2, Const));
4040
/// let b: Tensor<(usize, Const<3>), f32, _> = dev.zeros_like(&(4, Const));
41-
/// let _: Tensor<Rank2<6, 3>, f32, _> = (a, b).concat_along(Axis::<0>).realize();
41+
/// let _: Tensor<Rank2<6, 3>, f32, _> = (a, b).concat_tensor_along(Axis::<0>).realize();
4242
/// ```
4343
///
4444
/// Along Axis 1:
@@ -47,7 +47,7 @@ mod webgpu_kernel;
4747
/// # let dev: Cpu = Default::default();
4848
/// let a: Tensor<(Const<2>, usize), f32, _> = dev.zeros_like(&(Const, 2));
4949
/// let b: Tensor<(Const<2>, usize), f32, _> = dev.zeros_like(&(Const, 4));
50-
/// let _: Tensor<Rank2<2, 6>, f32, _> = (a, b).concat_along(Axis::<1>).realize();
50+
/// let _: Tensor<Rank2<2, 6>, f32, _> = (a, b).concat_tensor_along(Axis::<1>).realize();
5151
/// ```
5252
pub trait TryConcatTensorAlong<Ax>: Sized {
5353
type Output;

Diff for: dfdx-core/src/tensor_ops/mod.rs

+4
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,8 @@ mod sigmoid;
200200
mod sin;
201201
mod slice;
202202
mod softmax;
203+
mod split_shape_along;
204+
mod split_tensor_along;
203205
mod sqrt;
204206
mod square;
205207
mod stack;
@@ -267,6 +269,8 @@ pub use sigmoid::sigmoid;
267269
pub use sin::sin;
268270
pub use slice::slice;
269271
pub use softmax::softmax;
272+
pub use split_shape_along::TrySplitShapeAlong;
273+
pub use split_tensor_along::TrySplitTensorAlong;
270274
pub use sqrt::sqrt;
271275
pub use square::square;
272276
pub use stack::{AddDim, TryStack};

Diff for: dfdx-core/src/tensor_ops/split_shape_along/mod.rs

+158
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
use crate::{shapes::*, tensor::*};
2+
3+
/// Split a shape in two along a given axis.
4+
///
5+
/// # [Const] dims **requires nightly**
6+
///
7+
/// Along Axis 0:
8+
/// ```ignore
9+
/// # use dfdx_core::prelude::*;
10+
/// # let dev: Cpu = Default::default();
11+
/// let (a, b): (Rank2<3, 3>, Rank2<4, 3>) = (Const::<7>, Const::<3>).split_shape_along(Axis::<0>, Const::<3>, Const::<4>);
12+
/// ```
13+
///
14+
/// Along Axis 1:
15+
/// ```ignore
16+
/// # use dfdx_core::prelude::*;
17+
/// # let dev: Cpu = Default::default();
18+
/// let (a, b): (Rank2<7, 2>, Rank2<7, 1>) = (Const::<7>, Const::<3>).split_shape_along(Axis::<1>, Const::<2>, Const::<1>);
19+
/// ```
20+
///
21+
/// # [usize] dims
22+
/// Along Axis 0:
23+
/// ```rust
24+
/// # use dfdx_core::prelude::*;
25+
/// # let dev: Cpu = Default::default();
26+
/// let (a, b) = (7, Const::<3>).split_shape_along(Axis::<0>, 3, 4);
27+
/// assert_eq!(a, (3, Const::<3>));
28+
/// assert_eq!(b, (4, Const::<3>));
29+
/// ```
30+
///
31+
/// Along Axis 1:
32+
/// ```rust
33+
/// # use dfdx_core::prelude::*;
34+
/// # let dev: Cpu = Default::default();
35+
/// let (a, b) = (Const::<7>, 3).split_shape_along(Axis::<1>, 2, 1);
36+
/// assert_eq!(a, (Const::<7>, 2));
37+
/// assert_eq!(b, (Const::<7>, 1));
38+
/// ```
39+
pub trait TrySplitShapeAlong<Ax, A: Dim, B: Dim>: Shape {
40+
type Output;
41+
42+
/// Splits self along the given axis.
43+
fn split_shape_along(self, ax: Ax, a: A, b: B) -> Self::Output {
44+
self.try_split_shape_along(ax, a, b).unwrap()
45+
}
46+
/// Fallibly splits self along the given axis.
47+
fn try_split_shape_along(self, ax: Ax, a: A, b: B) -> Result<Self::Output, Error>;
48+
}
49+
50+
macro_rules! impl_split {
51+
($Ax:expr, $NumDims:expr, [$($Head:tt),*], [$($Tail:tt),*]) => {
52+
impl<A: Dim, B: Dim, AB:Dim, $($Head: Dim, )* $($Tail: Dim, )*> TrySplitShapeAlong<Axis<$Ax>, A, B>
53+
for
54+
(
55+
$($Head, )*
56+
AB,
57+
$($Tail, )*
58+
)
59+
where
60+
($($Head, )* A, $($Tail, )*): Shape<Concrete = <Self as Shape>::Concrete>,
61+
($($Head, )* B, $($Tail, )*): Shape<Concrete = <Self as Shape>::Concrete>,
62+
{
63+
type Output =
64+
(
65+
($($Head, )* A, $($Tail, )*),
66+
($($Head, )* B, $($Tail, )*),
67+
);
68+
69+
fn try_split_shape_along(self, _: Axis<$Ax>, a: A, b: B) -> Result<Self::Output, Error> {
70+
let dims = self.concrete();
71+
let mut lhs_dims = dims;
72+
let mut rhs_dims = dims;
73+
lhs_dims[$Ax] = a.size();
74+
rhs_dims[$Ax] = b.size();
75+
assert_eq!(dims[$Ax], lhs_dims[$Ax] + rhs_dims[$Ax]);
76+
77+
Ok((
78+
<($($Head, )* A, $($Tail, )*)>::from_concrete(&lhs_dims).unwrap(),
79+
<($($Head, )* B, $($Tail, )*)>::from_concrete(&rhs_dims).unwrap(),
80+
))
81+
}
82+
}
83+
};
84+
}
85+
86+
impl_split!(0, 1, [], []);
87+
impl_split!(0, 2, [], [D1]);
88+
impl_split!(0, 3, [], [D1, D2]);
89+
impl_split!(0, 4, [], [D1, D2, D3]);
90+
impl_split!(0, 5, [], [D1, D2, D3, D4]);
91+
impl_split!(0, 6, [], [D1, D2, D3, D4, D5]);
92+
93+
impl_split!(1, 2, [D0], []);
94+
impl_split!(1, 3, [D0], [D2]);
95+
impl_split!(1, 4, [D0], [D2, D3]);
96+
impl_split!(1, 5, [D0], [D2, D3, D4]);
97+
impl_split!(1, 6, [D0], [D2, D3, D4, D5]);
98+
99+
impl_split!(2, 3, [D0, D1], []);
100+
impl_split!(2, 4, [D0, D1], [D3]);
101+
impl_split!(2, 5, [D0, D1], [D3, D4]);
102+
impl_split!(2, 6, [D0, D1], [D3, D4, D5]);
103+
104+
impl_split!(3, 4, [D0, D1, D2], []);
105+
impl_split!(3, 5, [D0, D1, D2], [D4]);
106+
impl_split!(3, 6, [D0, D1, D2], [D4, D5]);
107+
108+
impl_split!(4, 5, [D0, D1, D2, D3], []);
109+
impl_split!(4, 6, [D0, D1, D2, D3], [D5]);
110+
111+
impl_split!(5, 6, [D0, D1, D2, D3, D4], []);
112+
113+
#[cfg(test)]
114+
mod tests {
115+
use super::*;
116+
117+
#[test]
118+
fn test_split_shape() {
119+
let a: (usize, Const<5>) = (5, Const);
120+
let b: (usize, Const<5>) = (3, Const);
121+
assert_eq!(
122+
(8, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0),
123+
(a, b)
124+
);
125+
126+
let a: (Const<5>, Const<5>) = (Const, Const);
127+
let b: (usize, Const<5>) = (3, Const);
128+
assert_eq!(
129+
(8, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0),
130+
(a, b)
131+
);
132+
133+
let a: (usize, Const<5>) = (5, Const);
134+
let b: (Const<3>, Const<5>) = (Const, Const);
135+
assert_eq!(
136+
(8, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0),
137+
(a, b)
138+
);
139+
140+
#[cfg(feature = "nightly")]
141+
{
142+
let a: (Const<5>, Const<5>) = (Const, Const);
143+
let b: (Const<3>, Const<5>) = (Const, Const);
144+
assert_eq!(
145+
(Const::<8>, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0),
146+
(a, b)
147+
);
148+
}
149+
}
150+
151+
#[test]
152+
#[should_panic = "left: 8\n right: 7"]
153+
fn test_split_shape_fails() {
154+
let a: (usize, Const<5>) = (4, Const);
155+
let b: (usize, Const<5>) = (3, Const);
156+
(8, Const::<5>).split_shape_along(Axis::<0>, a.0, b.0);
157+
}
158+
}
+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
use super::AorB;
2+
use crate::{
3+
shapes::*,
4+
tensor::{cpu::NdIndex, *},
5+
};
6+
7+
impl<E: Dtype> super::SplitAlongKernel<E> for Cpu {
8+
fn forward<AB: Shape, A: Shape, B: Shape>(
9+
&self,
10+
ax: usize,
11+
ab: &Tensor<AB, E, Self>,
12+
a: &mut Tensor<A, E, Self>,
13+
b: &mut Tensor<B, E, Self>,
14+
) -> Result<(), Error> {
15+
let mut a_n = 1;
16+
let mut b_n = 1;
17+
{
18+
let a_idx = NdIndex::new(a.shape, a.strides);
19+
let b_idx = NdIndex::new(b.shape, b.strides);
20+
for i in ax..A::NUM_DIMS {
21+
a_n *= a_idx.shape[i];
22+
b_n *= b_idx.shape[i];
23+
}
24+
}
25+
26+
let n_ab = ab.data.len();
27+
28+
let buf_a = std::sync::Arc::get_mut(&mut a.data).unwrap();
29+
let buf_b = std::sync::Arc::get_mut(&mut b.data).unwrap();
30+
31+
let mut i = 0;
32+
let mut k = 0;
33+
let mut ab_idx = NdIndex::new(ab.shape, ab.strides);
34+
while i < n_ab {
35+
for j in 0..a_n {
36+
(*buf_a)[j + k * a_n] = ab.data[ab_idx.next().unwrap()];
37+
i += 1;
38+
}
39+
for j in 0..b_n {
40+
(*buf_b)[j + k * b_n] = ab.data[ab_idx.next().unwrap()];
41+
i += 1;
42+
}
43+
k += 1;
44+
}
45+
Ok(())
46+
}
47+
48+
fn backward<AB: Shape, A: Shape, B: Shape>(
49+
&self,
50+
ax: usize,
51+
ab: &GhostTensor<AB, E, Self>,
52+
grad_ab: &mut Self::Vec,
53+
a: &GhostTensor<A, E, Self>,
54+
b: &GhostTensor<B, E, Self>,
55+
a_or_b: AorB,
56+
grad_out: &Self::Vec,
57+
) -> Result<(), Error> {
58+
let a_idx = NdIndex::new(a.shape, a.strides);
59+
let b_idx = NdIndex::new(b.shape, b.strides);
60+
61+
let mut a_n = 1;
62+
let mut b_n = 1;
63+
for i in ax..A::NUM_DIMS {
64+
a_n *= a_idx.shape[i];
65+
b_n *= b_idx.shape[i];
66+
}
67+
68+
let mut i = 0;
69+
let mut j = 0;
70+
let n = grad_ab.len();
71+
let mut ab_idx = NdIndex::new(ab.shape, ab.strides);
72+
while i + j < n {
73+
match a_or_b {
74+
AorB::A => {
75+
for _ in 0..a_n {
76+
(*grad_ab)[ab_idx.next().unwrap()] = grad_out[i];
77+
i += 1;
78+
}
79+
for _ in 0..b_n {
80+
ab_idx.next().unwrap();
81+
j += 1;
82+
}
83+
}
84+
AorB::B => {
85+
for _ in 0..a_n {
86+
ab_idx.next().unwrap();
87+
j += 1;
88+
}
89+
for _ in 0..b_n {
90+
(*grad_ab)[ab_idx.next().unwrap()] = grad_out[i];
91+
i += 1;
92+
}
93+
}
94+
};
95+
}
96+
97+
Ok(())
98+
}
99+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
use super::AorB;
2+
use crate::{
3+
shapes::*,
4+
tensor::{Cuda, Error, GhostTensor, Tensor},
5+
};
6+
use cudarc::types::CudaTypeName;
7+
8+
impl<E: Dtype + CudaTypeName> super::SplitAlongKernel<E> for Cuda {
9+
fn forward<AB: Shape, A: Shape, B: Shape>(
10+
&self,
11+
_ax: usize,
12+
_ab: &Tensor<AB, E, Self>,
13+
_a: &mut Tensor<A, E, Self>,
14+
_b: &mut Tensor<B, E, Self>,
15+
) -> Result<(), Error> {
16+
todo!()
17+
}
18+
19+
fn backward<AB: Shape, A: Shape, B: Shape>(
20+
&self,
21+
_ax: usize,
22+
_ab: &GhostTensor<AB, E, Self>,
23+
_grad_ab: &mut Self::Vec,
24+
_a: &GhostTensor<A, E, Self>,
25+
_b: &GhostTensor<B, E, Self>,
26+
_a_or_b: AorB,
27+
_grad_out: &Self::Vec,
28+
) -> Result<(), Error> {
29+
todo!()
30+
}
31+
}

0 commit comments

Comments
 (0)