Skip to content

Commit f5fe978

Browse files
authored
Fix dtype dispatch in cubecl module ops (#3658)
* Add dtype dispatch to cubecl module ops * Fix output option * Fix other int + float combined dispatch * Clippy allow
1 parent 704c51c commit f5fe978

File tree

4 files changed

+201
-69
lines changed

4 files changed

+201
-69
lines changed

crates/burn-cubecl/src/kernel/conv/deform_conv_transpose2d.rs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::marker::PhantomData;
22

33
use burn_tensor::{
44
Shape,
5-
ops::{DeformConv2dBackward, DeformConvOptions, FloatTensorOps as _},
5+
ops::{DeformConvOptions, FloatTensorOps as _},
66
};
77
use cubecl::{
88
AtomicFeature, CubeDim, CubeLaunch, Feature, calculate_cube_count_elemwise,
@@ -30,7 +30,7 @@ use crate::{
3030
use super::{bilinear_interpolate, deform_im2col, index};
3131

3232
/// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions.
33-
#[allow(clippy::single_range_in_vec_init)]
33+
#[allow(clippy::single_range_in_vec_init, clippy::type_complexity)]
3434
pub(crate) fn deform_conv2d_backward<
3535
R: CubeRuntime,
3636
E: FloatElement,
@@ -44,7 +44,16 @@ pub(crate) fn deform_conv2d_backward<
4444
bias: Option<CubeTensor<R>>,
4545
out_grad: CubeTensor<R>,
4646
options: DeformConvOptions<2>,
47-
) -> Result<DeformConv2dBackward<CubeBackend<R, E, I, BT>>, ConvSetupError> {
47+
) -> Result<
48+
(
49+
CubeTensor<R>,
50+
CubeTensor<R>,
51+
CubeTensor<R>,
52+
Option<CubeTensor<R>>,
53+
Option<CubeTensor<R>>,
54+
),
55+
ConvSetupError,
56+
> {
4857
let [_, _, out_h, out_w] = out_grad.shape.dims();
4958
let [_, _, kernel_h, kernel_w] = weight.shape.dims();
5059

@@ -80,7 +89,7 @@ pub(crate) fn deform_conv2d_backward<
8089
(out_h, out_w),
8190
)?;
8291

83-
Ok(DeformConv2dBackward::new(
92+
Ok((
8493
input_gradient,
8594
offset_gradient,
8695
weight_grad,

crates/burn-cubecl/src/ops/float_ops.rs

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,13 @@ where
190190
indices: IntTensor<Self>,
191191
) -> FloatTensor<Self> {
192192
execute_with_dtype!(
193-
float(tensor.dtype),
194-
E,
195-
kernel::gather::<R, E, I>(dim, tensor, indices)
193+
int(indices.dtype),
194+
I,
195+
execute_with_dtype!(
196+
float(tensor.dtype),
197+
E,
198+
kernel::gather::<R, E, I>(dim, tensor, indices)
199+
)
196200
)
197201
}
198202

@@ -203,9 +207,13 @@ where
203207
value: FloatTensor<Self>,
204208
) -> FloatTensor<Self> {
205209
execute_with_dtype!(
206-
float(tensor.dtype, value.dtype),
207-
E,
208-
kernel::scatter::<R, E, I>(dim, tensor, indices, value)
210+
int(indices.dtype),
211+
I,
212+
execute_with_dtype!(
213+
float(tensor.dtype, value.dtype),
214+
E,
215+
kernel::scatter::<R, E, I>(dim, tensor, indices, value)
216+
)
209217
)
210218
}
211219

@@ -215,9 +223,13 @@ where
215223
indices: IntTensor<Self>,
216224
) -> FloatTensor<Self> {
217225
execute_with_dtype!(
218-
float(tensor.dtype),
219-
E,
220-
kernel::select::<R, E, I>(tensor, dim, indices)
226+
int(indices.dtype),
227+
I,
228+
execute_with_dtype!(
229+
float(tensor.dtype),
230+
E,
231+
kernel::select::<R, E, I>(tensor, dim, indices)
232+
)
221233
)
222234
}
223235

@@ -228,9 +240,13 @@ where
228240
value: FloatTensor<Self>,
229241
) -> FloatTensor<Self> {
230242
execute_with_dtype!(
231-
float(tensor.dtype, value.dtype),
232-
E,
233-
kernel::select_assign::<R, E, I>(tensor, dim, indices, value)
243+
int(indices.dtype),
244+
I,
245+
execute_with_dtype!(
246+
float(tensor.dtype, value.dtype),
247+
E,
248+
kernel::select_assign::<R, E, I>(tensor, dim, indices, value)
249+
)
234250
)
235251
}
236252

Lines changed: 120 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::{
22
CubeBackend, CubeRuntime, FloatElement, IntElement,
33
element::BoolElement,
4+
execute_with_dtype,
45
kernel::{
56
self,
67
conv::{ConvStrategy, ConvTranspose2dStrategy},
@@ -25,7 +26,12 @@ where
2526
bias: Option<FloatTensor<Self>>,
2627
options: ConvOptions<1>,
2728
) -> FloatTensor<Self> {
28-
kernel::conv::conv::<R, F, 1>(x, weight, bias, options, ConvStrategy::default()).unwrap()
29+
execute_with_dtype!(
30+
float(x.dtype),
31+
E,
32+
kernel::conv::conv::<R, E, 1>(x, weight, bias, options, ConvStrategy::default())
33+
.unwrap()
34+
)
2935
}
3036

3137
fn conv2d(
@@ -34,7 +40,12 @@ where
3440
bias: Option<FloatTensor<Self>>,
3541
options: ConvOptions<2>,
3642
) -> FloatTensor<Self> {
37-
kernel::conv::conv::<R, F, 2>(x, weight, bias, options, ConvStrategy::default()).unwrap()
43+
execute_with_dtype!(
44+
float(x.dtype),
45+
E,
46+
kernel::conv::conv::<R, E, 2>(x, weight, bias, options, ConvStrategy::default())
47+
.unwrap()
48+
)
3849
}
3950

4051
fn deform_conv2d(
@@ -45,7 +56,11 @@ where
4556
bias: Option<FloatTensor<Self>>,
4657
options: DeformConvOptions<2>,
4758
) -> FloatTensor<Self> {
48-
kernel::conv::deform_conv2d::<R, F>(x, offset, weight, mask, bias, options).unwrap()
59+
execute_with_dtype!(
60+
float(x.dtype),
61+
E,
62+
kernel::conv::deform_conv2d::<R, E>(x, offset, weight, mask, bias, options).unwrap()
63+
)
4964
}
5065

5166
fn deform_conv2d_backward(
@@ -57,16 +72,19 @@ where
5772
output_grad: FloatTensor<Self>,
5873
options: DeformConvOptions<2>,
5974
) -> DeformConv2dBackward<Self> {
60-
kernel::conv::deform_conv2d_backward::<R, F, I, BT>(
61-
x,
62-
offset,
63-
weight,
64-
mask,
65-
bias,
66-
output_grad,
67-
options,
68-
)
69-
.unwrap()
75+
execute_with_dtype!(float(x.dtype), E, {
76+
let (x, o, w, m, b) = kernel::conv::deform_conv2d_backward::<R, E, I, BT>(
77+
x,
78+
offset,
79+
weight,
80+
mask,
81+
bias,
82+
output_grad,
83+
options,
84+
)
85+
.unwrap();
86+
DeformConv2dBackward::new(x, o, w, m, b)
87+
})
7088
}
7189

7290
fn conv3d(
@@ -75,7 +93,11 @@ where
7593
bias: Option<FloatTensor<Self>>,
7694
options: ConvOptions<3>,
7795
) -> FloatTensor<Self> {
78-
kernel::conv::conv::<R, F, 3>(x, weight, bias, options, ConvStrategy::Direct).unwrap()
96+
execute_with_dtype!(
97+
float(x.dtype),
98+
E,
99+
kernel::conv::conv::<R, E, 3>(x, weight, bias, options, ConvStrategy::Direct).unwrap()
100+
)
79101
}
80102

81103
fn conv_transpose2d(
@@ -84,14 +106,18 @@ where
84106
bias: Option<FloatTensor<Self>>,
85107
options: ConvTransposeOptions<2>,
86108
) -> FloatTensor<Self> {
87-
kernel::conv::conv_transpose2d::<R, F, I>(
88-
x,
89-
weight,
90-
bias,
91-
options,
92-
ConvTranspose2dStrategy::default(),
109+
execute_with_dtype!(
110+
float(x.dtype),
111+
E,
112+
kernel::conv::conv_transpose2d::<R, E, I>(
113+
x,
114+
weight,
115+
bias,
116+
options,
117+
ConvTranspose2dStrategy::default(),
118+
)
119+
.unwrap()
93120
)
94-
.unwrap()
95121
}
96122

97123
fn conv_transpose3d(
@@ -100,7 +126,11 @@ where
100126
bias: Option<FloatTensor<Self>>,
101127
options: ConvTransposeOptions<3>,
102128
) -> FloatTensor<Self> {
103-
kernel::conv::conv_transpose3d::<R, F>(x, weight, bias, options)
129+
execute_with_dtype!(
130+
float(x.dtype),
131+
E,
132+
kernel::conv::conv_transpose3d::<R, E>(x, weight, bias, options)
133+
)
104134
}
105135

106136
fn avg_pool2d(
@@ -110,7 +140,11 @@ where
110140
padding: [usize; 2],
111141
count_include_pad: bool,
112142
) -> FloatTensor<Self> {
113-
kernel::pool::avg_pool2d::<R, F>(x, kernel_size, stride, padding, count_include_pad)
143+
execute_with_dtype!(
144+
float(x.dtype),
145+
E,
146+
kernel::pool::avg_pool2d::<R, E>(x, kernel_size, stride, padding, count_include_pad)
147+
)
114148
}
115149

116150
fn avg_pool2d_backward(
@@ -121,13 +155,17 @@ where
121155
padding: [usize; 2],
122156
count_include_pad: bool,
123157
) -> FloatTensor<Self> {
124-
kernel::pool::avg_pool2d_backward::<R, F>(
125-
x,
126-
grad,
127-
kernel_size,
128-
stride,
129-
padding,
130-
count_include_pad,
158+
execute_with_dtype!(
159+
float(x.dtype),
160+
E,
161+
kernel::pool::avg_pool2d_backward::<R, E>(
162+
x,
163+
grad,
164+
kernel_size,
165+
stride,
166+
padding,
167+
count_include_pad,
168+
)
131169
)
132170
}
133171

@@ -138,7 +176,11 @@ where
138176
padding: [usize; 2],
139177
dilation: [usize; 2],
140178
) -> FloatTensor<Self> {
141-
kernel::pool::max_pool2d::<R, F>(x, kernel_size, stride, padding, dilation)
179+
execute_with_dtype!(
180+
float(x.dtype),
181+
E,
182+
kernel::pool::max_pool2d::<R, E>(x, kernel_size, stride, padding, dilation)
183+
)
142184
}
143185

144186
fn max_pool2d_with_indices(
@@ -148,15 +190,17 @@ where
148190
padding: [usize; 2],
149191
dilation: [usize; 2],
150192
) -> MaxPool2dWithIndices<Self> {
151-
let (output, indices) = kernel::pool::max_pool2d_with_indices::<R, F, I>(
152-
x,
153-
kernel_size,
154-
stride,
155-
padding,
156-
dilation,
157-
);
193+
execute_with_dtype!(float(x.dtype), E, {
194+
let (output, indices) = kernel::pool::max_pool2d_with_indices::<R, E, I>(
195+
x,
196+
kernel_size,
197+
stride,
198+
padding,
199+
dilation,
200+
);
158201

159-
MaxPool2dWithIndices::new(output, indices)
202+
MaxPool2dWithIndices::new(output, indices)
203+
})
160204
}
161205

162206
fn max_pool2d_with_indices_backward(
@@ -168,34 +212,54 @@ where
168212
output_grad: FloatTensor<Self>,
169213
indices: IntTensor<Self>,
170214
) -> MaxPool2dBackward<Self> {
171-
MaxPool2dBackward::new(kernel::pool::max_pool2d_with_indices_backward::<R, F, I>(
172-
x,
173-
output_grad,
174-
indices,
175-
kernel_size,
176-
stride,
177-
padding,
178-
dilation,
179-
))
215+
execute_with_dtype!(
216+
int(indices.dtype),
217+
I,
218+
execute_with_dtype!(
219+
float(x.dtype),
220+
E,
221+
MaxPool2dBackward::new(kernel::pool::max_pool2d_with_indices_backward::<R, E, I>(
222+
x,
223+
output_grad,
224+
indices,
225+
kernel_size,
226+
stride,
227+
padding,
228+
dilation,
229+
))
230+
)
231+
)
180232
}
181233

182234
fn adaptive_avg_pool2d(x: FloatTensor<Self>, output_size: [usize; 2]) -> FloatTensor<Self> {
183-
kernel::pool::adaptive_avg_pool2d::<R, F>(x, output_size)
235+
execute_with_dtype!(
236+
float(x.dtype),
237+
E,
238+
kernel::pool::adaptive_avg_pool2d::<R, E>(x, output_size)
239+
)
184240
}
185241

186242
fn adaptive_avg_pool2d_backward(
187243
x: FloatTensor<Self>,
188244
grad: FloatTensor<Self>,
189245
) -> FloatTensor<Self> {
190-
kernel::pool::adaptive_avg_pool2d_backward::<R, F>(x, grad)
246+
execute_with_dtype!(
247+
float(x.dtype),
248+
E,
249+
kernel::pool::adaptive_avg_pool2d_backward::<R, E>(x, grad)
250+
)
191251
}
192252

193253
fn interpolate(
194254
x: FloatTensor<Self>,
195255
output_size: [usize; 2],
196256
options: InterpolateOptions,
197257
) -> FloatTensor<Self> {
198-
kernel::interpolate::interpolate::<R, F>(x, output_size, options)
258+
execute_with_dtype!(
259+
float(x.dtype),
260+
E,
261+
kernel::interpolate::interpolate::<R, E>(x, output_size, options)
262+
)
199263
}
200264

201265
fn interpolate_backward(
@@ -204,6 +268,10 @@ where
204268
output_size: [usize; 2],
205269
options: InterpolateOptions,
206270
) -> FloatTensor<Self> {
207-
kernel::interpolate::interpolate_backward::<R, F>(x, grad, output_size, options)
271+
execute_with_dtype!(
272+
float(x.dtype),
273+
E,
274+
kernel::interpolate::interpolate_backward::<R, E>(x, grad, output_size, options)
275+
)
208276
}
209277
}

0 commit comments

Comments
 (0)