diff --git a/Cargo.lock b/Cargo.lock index 16d43547e7..79f29058e6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2305,6 +2305,7 @@ dependencies = [ [[package]] name = "cubek" version = "0.1.0-pre.1" +source = "git+https://github.com/tracel-ai/cubek?rev=8097a621dfd3a6a89f8d0433994e8c1adba377c2#8097a621dfd3a6a89f8d0433994e8c1adba377c2" dependencies = [ "cubecl", "cubek-attention", @@ -2318,6 +2319,7 @@ dependencies = [ [[package]] name = "cubek-attention" version = "0.1.0-pre.1" +source = "git+https://github.com/tracel-ai/cubek?rev=8097a621dfd3a6a89f8d0433994e8c1adba377c2#8097a621dfd3a6a89f8d0433994e8c1adba377c2" dependencies = [ "bytemuck", "cubecl", @@ -2331,6 +2333,7 @@ dependencies = [ [[package]] name = "cubek-convolution" version = "0.1.0-pre.1" +source = "git+https://github.com/tracel-ai/cubek?rev=8097a621dfd3a6a89f8d0433994e8c1adba377c2#8097a621dfd3a6a89f8d0433994e8c1adba377c2" dependencies = [ "bytemuck", "cubecl", @@ -2345,6 +2348,7 @@ dependencies = [ [[package]] name = "cubek-matmul" version = "0.1.0-pre.1" +source = "git+https://github.com/tracel-ai/cubek?rev=8097a621dfd3a6a89f8d0433994e8c1adba377c2#8097a621dfd3a6a89f8d0433994e8c1adba377c2" dependencies = [ "bytemuck", "cubecl", @@ -2356,6 +2360,7 @@ dependencies = [ [[package]] name = "cubek-quant" version = "0.1.0-pre.1" +source = "git+https://github.com/tracel-ai/cubek?rev=8097a621dfd3a6a89f8d0433994e8c1adba377c2#8097a621dfd3a6a89f8d0433994e8c1adba377c2" dependencies = [ "cubecl", "cubecl-common", @@ -2366,6 +2371,7 @@ dependencies = [ [[package]] name = "cubek-random" version = "0.1.0-pre.1" +source = "git+https://github.com/tracel-ai/cubek?rev=8097a621dfd3a6a89f8d0433994e8c1adba377c2#8097a621dfd3a6a89f8d0433994e8c1adba377c2" dependencies = [ "cubecl", "cubecl-common", @@ -2378,6 +2384,7 @@ dependencies = [ [[package]] name = "cubek-reduce" version = "0.1.0-pre.1" +source = "git+https://github.com/tracel-ai/cubek?rev=8097a621dfd3a6a89f8d0433994e8c1adba377c2#8097a621dfd3a6a89f8d0433994e8c1adba377c2" dependencies = [ "cubecl", "half", diff --git a/crates/burn-autodiff/src/ops/module.rs b/crates/burn-autodiff/src/ops/module.rs index 32497c2ce6..7fd0d47463 100644 --- a/crates/burn-autodiff/src/ops/module.rs +++ b/crates/burn-autodiff/src/ops/module.rs @@ -1684,6 +1684,258 @@ impl ModuleOps> for Autodiff, + kernel_size: [usize; 3], + stride: [usize; 3], + padding: [usize; 3], + count_include_pad: bool, + ceil_mode: bool, + ) -> AutodiffTensor { + #[derive(Debug)] + struct AvgPool3D; + + impl Backward for AvgPool3D { + type State = (NodeId, [usize; 3], [usize; 3], [usize; 3], bool, bool); + + fn backward( + self, + ops: Ops, + grads: &mut Gradients, + checkpointer: &mut Checkpointer, + ) { + let [node_parent] = ops.parents; + let grad = grads.consume::(&ops.node); + let (x_state, kernel_size, stride, padding, count_include_pad, ceil_mode) = + ops.state; + let x = checkpointer.retrieve_node_output(x_state); + + if let Some(node) = node_parent { + let grad = B::avg_pool3d_backward( + x, + grad, + kernel_size, + stride, + padding, + count_include_pad, + ceil_mode, + ); + grads.register::(node.id, grad); + } + } + } + + match AvgPool3D + .prepare::([x.node.clone()]) + .compute_bound() + .stateful() + { + OpsKind::Tracked(mut prep) => { + let x_state = prep.checkpoint(&x); + prep.finish( + ( + x_state, + kernel_size, + stride, + padding, + count_include_pad, + ceil_mode, + ), + B::avg_pool3d( + x.primitive.clone(), + kernel_size, + stride, + padding, + count_include_pad, + ceil_mode, + ), + ) + } + OpsKind::UnTracked(prep) => prep.finish(B::avg_pool3d( + x.primitive, + kernel_size, + stride, + padding, + count_include_pad, + ceil_mode, + )), + } + } + + fn avg_pool3d_backward( + _x: AutodiffTensor, + _grad: AutodiffTensor, + _kernel_size: [usize; 3], + _stride: [usize; 3], + _padding: [usize; 3], + _count_include_pad: bool, + _ceil_mode: bool, + ) -> AutodiffTensor { + panic!("Can't differentiate avg pool 3d backward."); + } + + fn max_pool3d( + x: AutodiffTensor, + kernel_size: [usize; 3], + stride: [usize; 3], + padding: [usize; 3], + dilation: [usize; 3], + ceil_mode: bool, + ) -> AutodiffTensor { + match MaxPool3D + .prepare::([x.node.clone()]) + .compute_bound() + .stateful() + { + OpsKind::Tracked(mut prep) => { + let x_state = prep.checkpoint(&x); + let output = B::max_pool3d_with_indices( + x.primitive, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + ); + prep.finish( + ( + x_state, + output.indices, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + ), + output.output, + ) + } + OpsKind::UnTracked(prep) => prep.finish(B::max_pool3d( + x.primitive, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + )), + } + } + + fn max_pool3d_with_indices( + x: AutodiffTensor, + kernel_size: [usize; 3], + stride: [usize; 3], + padding: [usize; 3], + dilation: [usize; 3], + ceil_mode: bool, + ) -> MaxPool3dWithIndices { + match MaxPool3D + .prepare::([x.node.clone()]) + .compute_bound() + .stateful() + { + OpsKind::Tracked(mut prep) => { + let x_state = prep.checkpoint(&x); + + let output = B::max_pool3d_with_indices( + x.primitive, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + ); + + let output_tensor = prep.finish( + ( + x_state, + output.indices.clone(), + kernel_size, + stride, + padding, + dilation, + ceil_mode, + ), + output.output, + ); + + MaxPool3dWithIndices::new(output_tensor, output.indices) + } + OpsKind::UnTracked(prep) => { + let output = B::max_pool3d_with_indices( + x.primitive, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + ); + let output_tensor = prep.finish(output.output); + + MaxPool3dWithIndices::new(output_tensor, output.indices) + } + } + } + + fn max_pool3d_with_indices_backward( + _x: AutodiffTensor, + _kernel_size: [usize; 3], + _stride: [usize; 3], + _padding: [usize; 3], + _dilation: [usize; 3], + _ceil_mode: bool, + _output_grad: AutodiffTensor, + _indices: IntTensor, + ) -> MaxPool3dBackward { + panic!("Can't differentiate max pool3d with indices backward."); + } + + fn adaptive_avg_pool3d(x: AutodiffTensor, output_size: [usize; 3]) -> AutodiffTensor { + #[derive(Debug)] + struct AdaptiveAvgPool3D; + + impl Backward for AdaptiveAvgPool3D { + type State = NodeId; + + fn backward( + self, + ops: Ops, + grads: &mut Gradients, + checkpointer: &mut Checkpointer, + ) { + let [node_parent] = ops.parents; + let grad = grads.consume::(&ops.node); + let state = checkpointer.retrieve_node_output(ops.state); + + if let Some(node) = node_parent { + let grad = B::adaptive_avg_pool3d_backward(state, grad); + grads.register::(node.id, grad); + } + } + } + + match AdaptiveAvgPool3D + .prepare::([x.node.clone()]) + .compute_bound() + .stateful() + { + OpsKind::Tracked(mut prep) => { + let x_state = prep.checkpoint(&x); + prep.finish(x_state, B::adaptive_avg_pool3d(x.primitive, output_size)) + } + OpsKind::UnTracked(prep) => { + prep.finish(B::adaptive_avg_pool3d(x.primitive, output_size)) + } + } + } + + fn adaptive_avg_pool3d_backward( + _x: AutodiffTensor, + _grad: AutodiffTensor, + ) -> as Backend>::FloatTensorPrimitive { + panic!("Can't differentiate adaptive avg pool3d backward."); + } + fn interpolate( x: AutodiffTensor, output_size: [usize; 2], @@ -1814,3 +2066,45 @@ impl Backward for MaxPool2D { } } } + +#[derive(Debug)] +struct MaxPool3D; + +impl Backward for MaxPool3D { + type State = ( + NodeId, + IntTensor, + [usize; 3], + [usize; 3], + [usize; 3], + [usize; 3], + bool, + ); + + fn backward( + self, + ops: Ops, + grads: &mut Gradients, + checkpointer: &mut Checkpointer, + ) { + let [node_parent] = ops.parents; + let grad = grads.consume::(&ops.node); + let (x_state, indices, kernel_size, stride, padding, dilation, ceil_mode) = ops.state; + let x = checkpointer.retrieve_node_output(x_state); + + if let Some(node) = node_parent { + let grad = B::max_pool3d_with_indices_backward( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + grad, + indices, + ); + + grads.register::(node.id, grad.x_grad); + } + } +} diff --git a/crates/burn-backend/src/backend/ops/modules/base.rs b/crates/burn-backend/src/backend/ops/modules/base.rs index de428c2021..46b6e16157 100644 --- a/crates/burn-backend/src/backend/ops/modules/base.rs +++ b/crates/burn-backend/src/backend/ops/modules/base.rs @@ -85,6 +85,23 @@ pub struct MaxPool2dWithIndices { pub indices: IntTensor, } +/// Gradient computed during the backward pass for each tensor used by [max_pool3d](ModuleOps::max_pool3d). +#[derive(new)] +pub struct MaxPool3dBackward { + /// Gradient. + pub x_grad: FloatTensor, +} + +/// Results from [max_pool3d](ModuleOps::max_pool3d_with_indices). +#[derive(new)] +pub struct MaxPool3dWithIndices { + /// The output tensor. + pub output: FloatTensor, + + /// The indices tensor. + pub indices: IntTensor, +} + /// Check that the parameter value is non-zero. // NOTE: for now we keep usize but we could refactor the parameters to hold `NonZeroUsize`. pub(crate) fn check_nonzero(value: usize, msg: &str) -> usize { @@ -901,6 +918,78 @@ pub trait ModuleOps { indices: IntTensor, ) -> MaxPool2dBackward; + /// Three dimensional avg pooling. + /// + /// # Shapes + /// + /// x: [batch_size, channels, depth, height, width], + fn avg_pool3d( + x: FloatTensor, + kernel_size: [usize; 3], + stride: [usize; 3], + padding: [usize; 3], + count_include_pad: bool, + ceil_mode: bool, + ) -> FloatTensor; + /// Backward pass for the [avg pooling 3d](ModuleOps::avg_pool3d) operation. + fn avg_pool3d_backward( + x: FloatTensor, + grad: FloatTensor, + kernel_size: [usize; 3], + stride: [usize; 3], + padding: [usize; 3], + count_include_pad: bool, + ceil_mode: bool, + ) -> FloatTensor; + /// Three dimensional adaptive avg pooling. + /// + /// # Shapes + /// + /// x: [batch_size, channels, depth, height, width], + fn adaptive_avg_pool3d(x: FloatTensor, output_size: [usize; 3]) -> FloatTensor; + /// Backward pass for the [adaptive avg pooling 3d](ModuleOps::adaptive_avg_pool3d) operation. + fn adaptive_avg_pool3d_backward(x: FloatTensor, grad: FloatTensor) -> FloatTensor; + + /// Three dimensional max pooling. + /// + /// # Shapes + /// + /// x: [batch_size, channels, depth, height, width], + fn max_pool3d( + x: FloatTensor, + kernel_size: [usize; 3], + stride: [usize; 3], + padding: [usize; 3], + dilation: [usize; 3], + ceil_mode: bool, + ) -> FloatTensor; + + /// Three dimensional max pooling with indices. + /// + /// # Shapes + /// + /// x: [batch_size, channels, depth, height, width], + fn max_pool3d_with_indices( + x: FloatTensor, + kernel_size: [usize; 3], + stride: [usize; 3], + padding: [usize; 3], + dilation: [usize; 3], + ceil_mode: bool, + ) -> MaxPool3dWithIndices; + /// Backward pass for the [max pooling 3d](ModuleOps::max_pool3d_with_indices) operation. + #[allow(clippy::too_many_arguments)] + fn max_pool3d_with_indices_backward( + x: FloatTensor, + kernel_size: [usize; 3], + stride: [usize; 3], + padding: [usize; 3], + dilation: [usize; 3], + ceil_mode: bool, + output_grad: FloatTensor, + indices: IntTensor, + ) -> MaxPool3dBackward; + /// Down/up samples the input. /// /// # Shapes diff --git a/crates/burn-candle/src/ops/module.rs b/crates/burn-candle/src/ops/module.rs index 283d8004fc..134a52209b 100644 --- a/crates/burn-candle/src/ops/module.rs +++ b/crates/burn-candle/src/ops/module.rs @@ -2,8 +2,8 @@ use burn_backend::{ Shape, ops::{ ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, - InterpolateMode, InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, - UnfoldOptions, + InterpolateMode, InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices, + MaxPool3dBackward, MaxPool3dWithIndices, ModuleOps, UnfoldOptions, }, tensor::{FloatTensor, IntTensor}, }; @@ -284,6 +284,75 @@ impl ModuleOps for Candle, + _kernel_size: [usize; 3], + _stride: [usize; 3], + _padding: [usize; 3], + _count_include_pad: bool, + _ceil_mode: bool, + ) -> FloatTensor { + panic!("avg_pool3d is not supported by Candle") + } + + fn avg_pool3d_backward( + _x: FloatTensor, + _grad: FloatTensor, + _kernel_size: [usize; 3], + _stride: [usize; 3], + _padding: [usize; 3], + _count_include_pad: bool, + _ceil_mode: bool, + ) -> FloatTensor { + panic!("avg_pool3d_backward is not supported by Candle") + } + + fn max_pool3d( + _x: FloatTensor, + _kernel_size: [usize; 3], + _stride: [usize; 3], + _padding: [usize; 3], + _dilation: [usize; 3], + _ceil_mode: bool, + ) -> FloatTensor { + panic!("max_pool3d is not supported by Candle") + } + + fn max_pool3d_with_indices( + _x: FloatTensor, + _kernel_size: [usize; 3], + _stride: [usize; 3], + _padding: [usize; 3], + _dilation: [usize; 3], + _ceil_mode: bool, + ) -> MaxPool3dWithIndices { + panic!("max_pool3d_with_indices is not supported by Candle") + } + + fn max_pool3d_with_indices_backward( + _x: FloatTensor, + _kernel_size: [usize; 3], + _stride: [usize; 3], + _padding: [usize; 3], + _dilation: [usize; 3], + _ceil_mode: bool, + _output_grad: FloatTensor, + _indices: IntTensor, + ) -> MaxPool3dBackward { + panic!("max_pool3d_with_indices_backward is not supported by Candle") + } + + fn adaptive_avg_pool3d(_x: FloatTensor, _output_size: [usize; 3]) -> FloatTensor { + panic!("adaptive_avg_pool3d is not supported by Candle") + } + + fn adaptive_avg_pool3d_backward( + _x: FloatTensor, + _grad: FloatTensor, + ) -> FloatTensor { + panic!("adaptive_avg_pool3d_backward is not supported by Candle") + } + fn interpolate( x: FloatTensor, output_size: [usize; 2], diff --git a/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool3d.rs b/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool3d.rs new file mode 100644 index 0000000000..e76e0d2208 --- /dev/null +++ b/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool3d.rs @@ -0,0 +1,142 @@ +use crate::{ + CubeRuntime, + kernel::into_contiguous, + ops::{ + max_line_size, numeric::empty_device_dtype, permute_ncdhw_to_ndhwc, permute_ndhwc_to_ncdhw, + }, + tensor::CubeTensor, +}; +use burn_backend::Shape; +use cubecl::{calculate_cube_count_elemwise, prelude::*}; + +#[cube(launch)] +fn adaptive_avg_pool3d_direct( + input: &Tensor>, + output: &mut Tensor>, + #[define(E)] _dtype: StorageType, +) { + if ABSOLUTE_POS >= output.len() { + terminate!(); + } + + // Output shape is [batch, out_d, out_h, out_w, channels] in NDHWC format + let (out_d, out_h, out_w, channels) = ( + output.shape(1), + output.shape(2), + output.shape(3), + output.shape(4), + ); + let channel_lines = channels / output.line_size(); + let (in_stride_b, in_stride_d, in_stride_h, in_stride_w, in_stride_c) = ( + input.stride(0), + input.stride(1), + input.stride(2), + input.stride(3), + input.stride(4), + ); + let (in_d, in_h, in_w) = (input.shape(1), input.shape(2), input.shape(3)); + + // Decode position: c, ow, oh, od, b + let c = (ABSOLUTE_POS % channel_lines) * input.line_size(); + let pos = ABSOLUTE_POS / channel_lines; + let ow = pos % out_w; + let pos = pos / out_w; + let oh = pos % out_h; + let pos = pos / out_h; + let od = pos % out_d; + let b = pos / out_d; + + let id_start = start_index(od, out_d, in_d); + let id_end = end_index(od, out_d, in_d); + + let ih_start = start_index(oh, out_h, in_h); + let ih_end = end_index(oh, out_h, in_h); + + let iw_start = start_index(ow, out_w, in_w); + let iw_end = end_index(ow, out_w, in_w); + + let mut sum = Line::empty(input.line_size()).fill(E::from_int(0)); + + let index_input_0 = b * in_stride_b; + let index_input_1 = c * in_stride_c; + + for id in id_start..id_end { + let index_input_2 = id * in_stride_d; + + for ih in ih_start..ih_end { + let index_input_3 = ih * in_stride_h; + + for iw in iw_start..iw_end { + let index_input_4 = iw * in_stride_w; + + let index_input = + index_input_0 + index_input_1 + index_input_2 + index_input_3 + index_input_4; + sum += input[index_input / input.line_size()]; + } + } + } + + let num_id = id_end - id_start; + let num_ih = ih_end - ih_start; + let num_iw = iw_end - iw_start; + + output[ABSOLUTE_POS] = sum / Line::cast_from(num_id * num_ih * num_iw); +} + +#[cube] +fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { + (output_size_index * input_size) / output_size +} + +#[cube] +fn end_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { + let index = (output_size_index + 1) * input_size; + let index = index.div_ceil(output_size); + + if input_size < index { + input_size + } else { + index + } +} + +pub(crate) fn adaptive_avg_pool3d( + input: CubeTensor, + output_size: [usize; 3], +) -> CubeTensor { + let [batch_size, channels, _, _, _] = input.shape.dims(); + + let input = into_contiguous(permute_ncdhw_to_ndhwc(input)); + let line_size = max_line_size(&input); + + let output_shape = Shape::new([ + batch_size, + output_size[0], + output_size[1], + output_size[2], + channels, + ]); + let num_elems: usize = output_shape.num_elements(); + let output = empty_device_dtype( + input.client.clone(), + input.device.clone(), + output_shape, + input.dtype, + ); + + let working_units = num_elems / line_size as usize; + let cube_dim = CubeDim::new(&input.client, working_units); + let cube_count = calculate_cube_count_elemwise(&input.client, working_units, cube_dim); + + adaptive_avg_pool3d_direct::launch( + &input.client, + cube_count, + cube_dim, + input.as_tensor_arg(line_size), + output.as_tensor_arg(line_size), + output.dtype.into(), + ) + .expect("Kernel to never fail"); + + permute_ndhwc_to_ncdhw(output) +} diff --git a/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool3d_backward.rs b/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool3d_backward.rs new file mode 100644 index 0000000000..fe61fa4e20 --- /dev/null +++ b/crates/burn-cubecl/src/kernel/pool/adaptive_avg_pool3d_backward.rs @@ -0,0 +1,143 @@ +use crate::{ + CubeRuntime, + kernel::into_contiguous, + ops::{ + max_line_size, numeric::empty_device_dtype, permute_ncdhw_to_ndhwc, permute_ndhwc_to_ncdhw, + }, + tensor::CubeTensor, +}; +use burn_backend::Shape; +use cubecl::{calculate_cube_count_elemwise, prelude::*}; + +#[cube(launch)] +fn adaptive_avg_pool3d_backward_direct( + grad: &Tensor>, + output: &mut Tensor>, + #[define(E)] _dtype: StorageType, +) { + if ABSOLUTE_POS >= output.len() { + terminate!(); + } + + // Output shape is [batch, depth, height, width, channels] in NDHWC format + let (out_d, out_h, out_w, channels) = ( + output.shape(1), + output.shape(2), + output.shape(3), + output.shape(4), + ); + let channel_lines = channels / grad.line_size(); + let (grad_stride_b, grad_stride_d, grad_stride_h, grad_stride_w, grad_stride_c) = ( + grad.stride(0), + grad.stride(1), + grad.stride(2), + grad.stride(3), + grad.stride(4), + ); + let (grad_d, grad_h, grad_w) = (grad.shape(1), grad.shape(2), grad.shape(3)); + + // Decode position: c, iw, ih, id, b + let c = (ABSOLUTE_POS % channel_lines) * grad.line_size(); + let pos = ABSOLUTE_POS / channel_lines; + let iw = pos % out_w; + let pos = pos / out_w; + let ih = pos % out_h; + let pos = pos / out_h; + let id = pos % out_d; + let b = pos / out_d; + + let od_start = start_index(id, out_d, grad_d); + let od_end = end_index(id, out_d, grad_d); + + let oh_start = start_index(ih, out_h, grad_h); + let oh_end = end_index(ih, out_h, grad_h); + + let ow_start = start_index(iw, out_w, grad_w); + let ow_end = end_index(iw, out_w, grad_w); + + let mut grad_acc = Line::empty(grad.line_size()).fill(E::from_int(0)); + + let index_base = b * grad_stride_b + (c * grad_stride_c); + + for od in od_start..od_end { + let id_start = start_index(od, grad_d, out_d); + let id_end = end_index(od, grad_d, out_d); + + if id >= id_start && id < id_end { + for oh in oh_start..oh_end { + let ih_start = start_index(oh, grad_h, out_h); + let ih_end = end_index(oh, grad_h, out_h); + + if ih >= ih_start && ih < ih_end { + for ow in ow_start..ow_end { + let iw_start = start_index(ow, grad_w, out_w); + let iw_end = end_index(ow, grad_w, out_w); + + if iw >= iw_start && iw < iw_end { + let num_id = id_end - id_start; + let num_ih = ih_end - ih_start; + let num_iw = iw_end - iw_start; + + let index = index_base + + (od * grad_stride_d) + + (oh * grad_stride_h) + + (ow * grad_stride_w); + grad_acc += grad[index / grad.line_size()] + / Line::cast_from(num_id * num_ih * num_iw); + } + } + } + } + } + } + + output[ABSOLUTE_POS] = grad_acc; +} + +#[cube] +fn start_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { + (output_size_index * input_size) / output_size +} + +#[cube] +fn end_index(output_size_index: usize, output_size: usize, input_size: usize) -> usize { + let index = (output_size_index + 1) * input_size; + let index = index.div_ceil(output_size); + + if input_size < index { + input_size + } else { + index + } +} + +pub(crate) fn adaptive_avg_pool3d_backward( + x: CubeTensor, + out_grad: CubeTensor, +) -> CubeTensor { + let [batches, channels, depth, height, width] = x.shape.dims(); + + let out_grad = into_contiguous(permute_ncdhw_to_ndhwc(out_grad)); + let line_size = max_line_size(&out_grad); + + let out_shape = Shape::new([batches, depth, height, width, channels]); + let output = empty_device_dtype(x.client.clone(), x.device.clone(), out_shape, x.dtype); + + let num_elems = output.shape.num_elements(); + + let working_units = num_elems / line_size as usize; + let cube_dim = CubeDim::new(&x.client, working_units); + let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim); + + adaptive_avg_pool3d_backward_direct::launch( + &x.client, + cube_count, + cube_dim, + out_grad.as_tensor_arg(line_size), + output.as_tensor_arg(line_size), + output.dtype.into(), + ) + .expect("Kernel to never fail"); + + permute_ndhwc_to_ncdhw(output) +} diff --git a/crates/burn-cubecl/src/kernel/pool/avg_pool3d.rs b/crates/burn-cubecl/src/kernel/pool/avg_pool3d.rs new file mode 100644 index 0000000000..0c58adb3f9 --- /dev/null +++ b/crates/burn-cubecl/src/kernel/pool/avg_pool3d.rs @@ -0,0 +1,185 @@ +use super::pool3d::{ + Pool3dDirectArgsLaunch, Pool3dDirectStrategy, Pool3dDirectStrategyFamily, pool3d_direct, +}; +use crate::{ + CubeRuntime, + kernel::into_contiguous, + ops::{ + max_line_size, numeric::empty_device_dtype, permute_ncdhw_to_ndhwc, permute_ndhwc_to_ncdhw, + }, + tensor::CubeTensor, +}; +use burn_backend::{Shape, ops::conv::calculate_pool_output_size}; +use cubecl::prelude::*; +use cubecl::{CubeDim, calculate_cube_count_elemwise, prelude::ScalarArg}; + +struct AvgPool3dStrategy; + +impl Pool3dDirectStrategyFamily for AvgPool3dStrategy { + type Indices = (); + type Config = AvgPool3dStrategyConfig; + type Pool3d = Self; +} + +#[derive(CubeType, Debug, PartialEq, Eq, Hash, Clone, Copy)] +pub struct AvgPool3dStrategyConfig { + count_include_pad: bool, + /// Total padded depth (input_depth + 2 * padding_0) + padded_d: u32, + /// Total padded height (input_height + 2 * padding_1) + padded_h: u32, + /// Total padded width (input_width + 2 * padding_2) + padded_w: u32, +} + +#[cube] +impl Pool3dDirectStrategy for AvgPool3dStrategy { + type Accumulator = (Line, u32); + type Config = AvgPool3dStrategyConfig; + type Indices = (); + + fn initialize( + #[comptime] _config: &Self::Config, + #[comptime] line_size: LineSize, + ) -> Self::Accumulator { + let sum = Line::empty(line_size).fill(N::from_int(0)); + // Count will be set dynamically: either by accumulate (count_include_pad=false) + // or by count_position (count_include_pad=true) + let count = 0u32; + + (sum, count) + } + + fn accumulate( + #[comptime] config: &Self::Config, + accumulator: &mut Self::Accumulator, + _index: usize, + result: Line, + ) { + let (sum, count) = accumulator; + + // Only count valid positions when count_include_pad=false + if comptime![!config.count_include_pad] { + *count += 1; + } + + *sum += result; + } + + fn count_position( + #[comptime] config: &Self::Config, + accumulator: &mut Self::Accumulator, + id: u32, + ih: u32, + iw: u32, + ) { + // When count_include_pad=true, count positions within padded bounds + // (excludes ceil_mode extensions beyond the padded input) + if comptime![config.count_include_pad] + && id < config.padded_d + && ih < config.padded_h + && iw < config.padded_w + { + let (_sum, count) = accumulator; + *count += 1; + } + } + + fn store( + #[comptime] _config: &Self::Config, + position: usize, + output: &mut Tensor>, + _output_indices: &mut (), + accumulator: Self::Accumulator, + ) { + let (sum, count) = accumulator; + output[position] = sum / Line::cast_from(count); + } +} + +pub(crate) fn avg_pool3d( + x: CubeTensor, + kernel_size: [usize; 3], + stride: [usize; 3], + padding: [usize; 3], + count_include_pad: bool, + ceil_mode: bool, +) -> CubeTensor { + let [batch_size, channels, in_d, in_h, in_w] = x.shape.dims(); + let dilation = 1; + + let size_0 = calculate_pool_output_size( + kernel_size[0], + stride[0], + padding[0], + dilation, + in_d, + ceil_mode, + ); + let size_1 = calculate_pool_output_size( + kernel_size[1], + stride[1], + padding[1], + dilation, + in_h, + ceil_mode, + ); + let size_2 = calculate_pool_output_size( + kernel_size[2], + stride[2], + padding[2], + dilation, + in_w, + ceil_mode, + ); + + // Padded dimensions (for count_include_pad with ceil_mode) + let padded_0 = in_d + 2 * padding[0]; + let padded_1 = in_h + 2 * padding[1]; + let padded_2 = in_w + 2 * padding[2]; + + let x = into_contiguous(permute_ncdhw_to_ndhwc(x)); + let line_size = max_line_size(&x); + + let shape_out = Shape::new([batch_size, size_0, size_1, size_2, channels]); + let output = empty_device_dtype(x.client.clone(), x.device.clone(), shape_out, x.dtype); + + let working_units = output.shape.num_elements() / line_size as usize; + let cube_dim = CubeDim::new(&x.client, working_units); + let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim); + + pool3d_direct::launch::( + &x.client, + cube_count, + cube_dim, + x.as_tensor_arg(line_size), + output.as_tensor_arg(line_size), + (), + Pool3dDirectArgsLaunch::new( + ScalarArg::new(stride[0] as u32), + ScalarArg::new(stride[1] as u32), + ScalarArg::new(stride[2] as u32), + ScalarArg::new(dilation as u32), + ScalarArg::new(dilation as u32), + ScalarArg::new(dilation as u32), + ScalarArg::new(padding[0] as u32), + ScalarArg::new(padding[1] as u32), + ScalarArg::new(padding[2] as u32), + ), + ( + kernel_size[0] as u32, + kernel_size[1] as u32, + kernel_size[2] as u32, + ), + AvgPool3dStrategyConfig { + count_include_pad, + padded_d: padded_0 as u32, + padded_h: padded_1 as u32, + padded_w: padded_2 as u32, + }, + output.dtype.into(), + ) + .expect("Kernel to never fail"); + + permute_ndhwc_to_ncdhw(output) +} diff --git a/crates/burn-cubecl/src/kernel/pool/avg_pool3d_backward.rs b/crates/burn-cubecl/src/kernel/pool/avg_pool3d_backward.rs new file mode 100644 index 0000000000..2c591c68af --- /dev/null +++ b/crates/burn-cubecl/src/kernel/pool/avg_pool3d_backward.rs @@ -0,0 +1,204 @@ +use crate::{ + CubeRuntime, + ops::{ + max_line_size, numeric::empty_device_dtype, permute_ncdhw_to_ndhwc, permute_ndhwc_to_ncdhw, + }, + tensor::CubeTensor, +}; +use burn_backend::Shape; +use cubecl::{calculate_cube_count_elemwise, prelude::*}; + +use super::max_pool3d_backward::{Pool3dBackwardArgs, Pool3dBackwardArgsLaunch}; + +#[cube(launch_unchecked)] +fn avg_pool3d_backward_kernel( + grad: &Tensor>, + output: &mut Tensor>, + args: &Pool3dBackwardArgs, + #[comptime] kernel_size_0: i32, + #[comptime] kernel_size_1: i32, + #[comptime] kernel_size_2: i32, + #[comptime] count_include_pad: bool, + #[define(E)] _dtype: StorageType, +) { + if ABSOLUTE_POS >= output.len() { + terminate!(); + } + + let line_size = grad.line_size(); + + // Output shape is [batch, depth, height, width, channels] in NDHWC format + let channel_lines = output.shape(4) / line_size; + let channel = (ABSOLUTE_POS % channel_lines) * output.line_size(); + let pos = ABSOLUTE_POS / channel_lines; + let iw = pos as u32 % output.shape(3) as u32; + let pos = pos / output.shape(3); + let ih = pos as u32 % output.shape(2) as u32; + let pos = pos / output.shape(2); + let id = pos as u32 % output.shape(1) as u32; + let batch = pos / output.shape(1); + + let mut grad_acc = Line::empty(grad.line_size()).fill(E::from_int(0)); + + let (od_start, od_end, oh_start, oh_end, ow_start, ow_end) = loop_ranges( + id as i32, + ih as i32, + iw as i32, + grad.shape(1) as u32, + grad.shape(2) as u32, + grad.shape(3) as u32, + args, + kernel_size_0, + kernel_size_1, + kernel_size_2, + ); + + let padding_0 = args.padding_0 as u32; + let padding_1 = args.padding_1 as u32; + let padding_2 = args.padding_2 as u32; + let stride_0 = args.stride_0 as u32; + let stride_1 = args.stride_1 as u32; + let stride_2 = args.stride_2 as u32; + let kernel_size_0 = comptime![kernel_size_0 as u32]; + let kernel_size_1 = comptime![kernel_size_1 as u32]; + let kernel_size_2 = comptime![kernel_size_2 as u32]; + + let index_base = batch * grad.stride(0) + channel * grad.stride(4); + let border_back = output.shape(1) as u32 + padding_0; + let border_bottom = output.shape(2) as u32 + padding_1; + let border_right = output.shape(3) as u32 + padding_2; + let begin_d = id + padding_0; + let begin_h = ih + padding_1; + let begin_w = iw + padding_2; + + for od in od_start..od_end { + let id_start = od * stride_0; + let id_end = Min::min(id_start + kernel_size_0, border_back); + let id_start = Max::max(id_start, padding_0); + + if begin_d >= id_start && id < id_end { + for oh in oh_start..oh_end { + let ih_start = oh * stride_1; + let ih_end = Min::min(ih_start + kernel_size_1, border_bottom); + let ih_start = Max::max(ih_start, padding_1); + + if begin_h >= ih_start && ih < ih_end { + for ow in ow_start..ow_end { + let index = index_base + + od as usize * grad.stride(1) + + oh as usize * grad.stride(2) + + ow as usize * grad.stride(3); + + let iw_start = ow * stride_2; + let iw_end = Min::min(iw_start + kernel_size_2, border_right); + let iw_start = Max::max(iw_start, padding_2); + + if begin_w >= iw_start && iw < iw_end { + if count_include_pad { + grad_acc += grad[index / line_size] + / Line::cast_from( + kernel_size_0 * kernel_size_1 * kernel_size_2, + ); + } else { + let id_diff = id_end - id_start; + let ih_diff = ih_end - ih_start; + let iw_diff = iw_end - iw_start; + let count = Line::cast_from(id_diff * ih_diff * iw_diff); + grad_acc += grad[index / line_size] / count; + } + } + } + } + } + } + } + + output[ABSOLUTE_POS] = grad_acc; +} + +#[cube] +#[allow(clippy::too_many_arguments)] +fn loop_ranges( + id: i32, + ih: i32, + iw: i32, + grad_d: u32, + grad_h: u32, + grad_w: u32, + args: &Pool3dBackwardArgs, + #[comptime] kernel_size_0: i32, + #[comptime] kernel_size_1: i32, + #[comptime] kernel_size_2: i32, +) -> (u32, u32, u32, u32, u32, u32) { + let kms_0 = args.dilation_0 * kernel_size_0 - args.stride_0; + let kms_1 = args.dilation_1 * kernel_size_1 - args.stride_1; + let kms_2 = args.dilation_2 * kernel_size_2 - args.stride_2; + + let od_start = Max::max((id + args.padding_0 - kms_0) / args.stride_0, 0) as u32; + let oh_start = Max::max((ih + args.padding_1 - kms_1) / args.stride_1, 0) as u32; + let ow_start = Max::max((iw + args.padding_2 - kms_2) / args.stride_2, 0) as u32; + + let od_end = Min::min(Max::max(kms_0, 0) as u32 + od_start, grad_d - 1) + 1; + let oh_end = Min::min(Max::max(kms_1, 0) as u32 + oh_start, grad_h - 1) + 1; + let ow_end = Min::min(Max::max(kms_2, 0) as u32 + ow_start, grad_w - 1) + 1; + + (od_start, od_end, oh_start, oh_end, ow_start, ow_end) +} + +pub(crate) fn avg_pool3d_backward( + x: CubeTensor, + grad: CubeTensor, + kernel_size: [usize; 3], + stride: [usize; 3], + padding: [usize; 3], + count_include_pad: bool, + _ceil_mode: bool, +) -> CubeTensor { + let [batches, channels, depth, height, width] = x.shape.dims(); + + let grad = permute_ncdhw_to_ndhwc(grad); + + let line_size = if x.strides[4] == grad.strides[4] { + max_line_size(&x) + } else { + 1 + }; + + let dilation = 1; + + let out_shape = Shape::new([batches, depth, height, width, channels]); + let output = empty_device_dtype(x.client.clone(), x.device.clone(), out_shape, x.dtype); + + let working_units = output.shape.num_elements() / line_size as usize; + let cube_dim = CubeDim::new(&x.client, working_units); + let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim); + + unsafe { + avg_pool3d_backward_kernel::launch_unchecked( + &grad.client, + cube_count, + cube_dim, + grad.as_tensor_arg(line_size), + output.as_tensor_arg(line_size), + Pool3dBackwardArgsLaunch::new( + ScalarArg::new(stride[0] as i32), + ScalarArg::new(stride[1] as i32), + ScalarArg::new(stride[2] as i32), + ScalarArg::new(dilation), + ScalarArg::new(dilation), + ScalarArg::new(dilation), + ScalarArg::new(padding[0] as i32), + ScalarArg::new(padding[1] as i32), + ScalarArg::new(padding[2] as i32), + ), + kernel_size[0] as i32, + kernel_size[1] as i32, + kernel_size[2] as i32, + count_include_pad, + output.dtype.into(), + ) + } + .expect("Kernel to never fail"); + + permute_ndhwc_to_ncdhw(output) +} diff --git a/crates/burn-cubecl/src/kernel/pool/max_pool3d.rs b/crates/burn-cubecl/src/kernel/pool/max_pool3d.rs new file mode 100644 index 0000000000..a6e1eba942 --- /dev/null +++ b/crates/burn-cubecl/src/kernel/pool/max_pool3d.rs @@ -0,0 +1,279 @@ +use super::pool3d::{ + Pool3dDirectArgsLaunch, Pool3dDirectStrategy, Pool3dDirectStrategyFamily, pool3d_direct, +}; +use crate::{ + CubeRuntime, + kernel::into_contiguous, + ops::{ + max_line_size, numeric::empty_device_dtype, permute_ncdhw_to_ndhwc, permute_ndhwc_to_ncdhw, + }, + tensor::CubeTensor, +}; +use burn_backend::{DType, Shape, ops::conv::calculate_pool_output_size}; +use cubecl::{CubeDim, calculate_cube_count_elemwise, prelude::*}; + +struct MaxPool3dStrategy; +struct MaxPool3dWithIndicesStrategy; + +impl Pool3dDirectStrategyFamily for MaxPool3dStrategy { + type Indices = (); + type Config = (); + type Pool3d = Self; +} + +impl Pool3dDirectStrategyFamily for MaxPool3dWithIndicesStrategy { + type Indices = Tensor>; + type Config = (); + type Pool3d = Self; +} + +#[cube] +impl Pool3dDirectStrategy for MaxPool3dStrategy { + type Accumulator = Line; + type Config = (); + type Indices = (); + + fn initialize( + #[comptime] _config: &Self::Config, + #[comptime] line_size: LineSize, + ) -> Self::Accumulator { + Line::empty(line_size).fill(N::min_value()) + } + + fn accumulate( + #[comptime] _config: &Self::Config, + accumulator: &mut Self::Accumulator, + _index: LineSize, + result: Line, + ) { + *accumulator = Max::max(*accumulator, result); + } + + fn count_position( + #[comptime] _config: &Self::Config, + _accumulator: &mut Self::Accumulator, + _id: u32, + _ih: u32, + _iw: u32, + ) { + } + + fn store( + #[comptime] _config: &Self::Config, + position: usize, + output: &mut Tensor>, + _output_indices: &mut (), + accumulator: Self::Accumulator, + ) { + output[position] = accumulator; + } +} + +#[cube] +impl Pool3dDirectStrategy for MaxPool3dWithIndicesStrategy { + type Accumulator = (Line, Line); + type Config = (); + type Indices = Tensor>; + + fn initialize( + #[comptime] _config: &Self::Config, + #[comptime] line_size: LineSize, + ) -> Self::Accumulator { + let val = Line::empty(line_size).fill(N::min_value()); + let idx = Line::empty(line_size).fill(0i32); + (val, idx) + } + + fn accumulate( + #[comptime] _config: &Self::Config, + accumulator: &mut Self::Accumulator, + index: usize, + result: Line, + ) { + let indices = Line::cast_from(index); + accumulator.1 = select_many(result.greater_than(accumulator.0), indices, accumulator.1); + accumulator.0 = Max::max(result, accumulator.0); + } + + fn count_position( + #[comptime] _config: &Self::Config, + _accumulator: &mut Self::Accumulator, + _id: u32, + _ih: u32, + _iw: u32, + ) { + } + + fn store( + #[comptime] _config: &Self::Config, + position: usize, + output: &mut Tensor>, + output_indices: &mut Tensor>, + accumulator: Self::Accumulator, + ) { + output[position] = accumulator.0; + output_indices[position] = accumulator.1; + } +} + +pub(crate) fn max_pool3d( + x: CubeTensor, + kernel_size: [usize; 3], + stride: [usize; 3], + padding: [usize; 3], + dilation: [usize; 3], + ceil_mode: bool, +) -> CubeTensor { + let [batch_size, channels, _, _, _] = x.shape.dims(); + + let size_0 = calculate_pool_output_size( + kernel_size[0], + stride[0], + padding[0], + dilation[0], + x.shape[2], + ceil_mode, + ); + let size_1 = calculate_pool_output_size( + kernel_size[1], + stride[1], + padding[1], + dilation[1], + x.shape[3], + ceil_mode, + ); + let size_2 = calculate_pool_output_size( + kernel_size[2], + stride[2], + padding[2], + dilation[2], + x.shape[4], + ceil_mode, + ); + + let x = into_contiguous(permute_ncdhw_to_ndhwc(x)); + + let line_size = max_line_size(&x); + + let shape_out = Shape::new([batch_size, size_0, size_1, size_2, channels]); + let output = empty_device_dtype(x.client.clone(), x.device.clone(), shape_out, x.dtype); + + let working_units = output.shape.num_elements() / line_size as usize; + let cube_dim = CubeDim::new(&x.client, working_units); + let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim); + + pool3d_direct::launch::( + &x.client, + cube_count, + cube_dim, + x.as_tensor_arg(line_size), + output.as_tensor_arg(line_size), + (), + Pool3dDirectArgsLaunch::new( + ScalarArg::new(stride[0] as u32), + ScalarArg::new(stride[1] as u32), + ScalarArg::new(stride[2] as u32), + ScalarArg::new(dilation[0] as u32), + ScalarArg::new(dilation[1] as u32), + ScalarArg::new(dilation[2] as u32), + ScalarArg::new(padding[0] as u32), + ScalarArg::new(padding[1] as u32), + ScalarArg::new(padding[2] as u32), + ), + ( + kernel_size[0] as u32, + kernel_size[1] as u32, + kernel_size[2] as u32, + ), + (), + output.dtype.into(), + ) + .expect("Kernel to never fail"); + + permute_ndhwc_to_ncdhw(output) +} + +pub(crate) fn max_pool3d_with_indices( + x: CubeTensor, + kernel_size: [usize; 3], + stride: [usize; 3], + padding: [usize; 3], + dilation: [usize; 3], + ceil_mode: bool, + dtype_indices: DType, +) -> (CubeTensor, CubeTensor) { + let [batch_size, channels, _, _, _] = x.shape.dims(); + + let size_0 = calculate_pool_output_size( + kernel_size[0], + stride[0], + padding[0], + dilation[0], + x.shape[2], + ceil_mode, + ); + let size_1 = calculate_pool_output_size( + kernel_size[1], + stride[1], + padding[1], + dilation[1], + x.shape[3], + ceil_mode, + ); + let size_2 = calculate_pool_output_size( + kernel_size[2], + stride[2], + padding[2], + dilation[2], + x.shape[4], + ceil_mode, + ); + + let x = into_contiguous(permute_ncdhw_to_ndhwc(x)); + let line_size = max_line_size(&x); + + let shape_out = Shape::new([batch_size, size_0, size_1, size_2, channels]); + let output = empty_device_dtype( + x.client.clone(), + x.device.clone(), + shape_out.clone(), + x.dtype, + ); + let indices = empty_device_dtype(x.client.clone(), x.device.clone(), shape_out, dtype_indices); + + let working_units = output.shape.num_elements() / line_size as usize; + let cube_dim = CubeDim::new(&x.client, working_units); + let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim); + + pool3d_direct::launch::( + &x.client, + cube_count, + cube_dim, + x.as_tensor_arg(line_size), + output.as_tensor_arg(line_size), + indices.as_tensor_arg(line_size), + Pool3dDirectArgsLaunch::new( + ScalarArg::new(stride[0] as u32), + ScalarArg::new(stride[1] as u32), + ScalarArg::new(stride[2] as u32), + ScalarArg::new(dilation[0] as u32), + ScalarArg::new(dilation[1] as u32), + ScalarArg::new(dilation[2] as u32), + ScalarArg::new(padding[0] as u32), + ScalarArg::new(padding[1] as u32), + ScalarArg::new(padding[2] as u32), + ), + ( + kernel_size[0] as u32, + kernel_size[1] as u32, + kernel_size[2] as u32, + ), + (), + output.dtype.into(), + ) + .expect("Kernel to never fail"); + + let output = permute_ndhwc_to_ncdhw(output); + let indices = permute_ndhwc_to_ncdhw(indices); + (output, indices) +} diff --git a/crates/burn-cubecl/src/kernel/pool/max_pool3d_backward.rs b/crates/burn-cubecl/src/kernel/pool/max_pool3d_backward.rs new file mode 100644 index 0000000000..82d8352cf3 --- /dev/null +++ b/crates/burn-cubecl/src/kernel/pool/max_pool3d_backward.rs @@ -0,0 +1,179 @@ +use crate::{ + CubeRuntime, + kernel::into_contiguous, + ops::{ + max_line_size, numeric::empty_device_dtype, permute_ncdhw_to_ndhwc, permute_ndhwc_to_ncdhw, + }, + tensor::CubeTensor, +}; +use burn_backend::Shape; +use cubecl::{calculate_cube_count_elemwise, prelude::*}; + +#[derive(CubeLaunch, CubeType)] +pub(crate) struct Pool3dBackwardArgs { + pub stride_0: i32, + pub stride_1: i32, + pub stride_2: i32, + pub dilation_0: i32, + pub dilation_1: i32, + pub dilation_2: i32, + pub padding_0: i32, + pub padding_1: i32, + pub padding_2: i32, +} + +#[cube(launch_unchecked)] +fn max_pool3d_with_indices_backward_kernel( + grad: &Tensor>, + indices: &Tensor>, + output: &mut Tensor>, + args: &Pool3dBackwardArgs, + #[comptime] kernel_size_0: i32, + #[comptime] kernel_size_1: i32, + #[comptime] kernel_size_2: i32, + #[define(E, I)] _dtypes: [StorageType; 2], +) { + if ABSOLUTE_POS >= output.len() { + terminate!(); + } + + let line_size = grad.line_size(); + + // Output shape is [batch, depth, height, width, channels] in NDHWC format + let channels = output.shape(4) / line_size; + let channel = (ABSOLUTE_POS % channels) * output.line_size(); + let pos = ABSOLUTE_POS / channels; + let iw = pos % output.shape(3); + let pos = pos / output.shape(3); + let ih = pos % output.shape(2); + let pos = pos / output.shape(2); + let id = pos % output.shape(1); + let batch = pos / output.shape(1); + + let index_current = id * output.shape(2) * output.shape(3) + ih * output.shape(3) + iw; + + let (od_start, od_end, oh_start, oh_end, ow_start, ow_end) = loop_ranges( + id as i32, + ih as i32, + iw as i32, + grad.shape(1) as u32, + grad.shape(2) as u32, + grad.shape(3) as u32, + args, + kernel_size_0, + kernel_size_1, + kernel_size_2, + ); + + let mut grad_acc = Line::empty(grad.line_size()).fill(E::from_int(0)); + + let index_base = batch * grad.stride(0) + channel * grad.stride(4); + + for od in od_start..od_end { + for oh in oh_start..oh_end { + for ow in ow_start..ow_end { + let index = index_base + + od as usize * grad.stride(1) + + oh as usize * grad.stride(2) + + ow as usize * grad.stride(3); + let index_max = Line::::cast_from(indices[index / line_size]); + + grad_acc += select_many( + index_max.equal(Line::cast_from(index_current)), + grad[index / line_size], + Line::new(E::from_int(0)), + ); + } + } + } + + output[ABSOLUTE_POS] = grad_acc; +} + +#[cube] +#[allow(clippy::too_many_arguments)] +fn loop_ranges( + id: i32, + ih: i32, + iw: i32, + grad_d: u32, + grad_h: u32, + grad_w: u32, + args: &Pool3dBackwardArgs, + #[comptime] kernel_size_0: i32, + #[comptime] kernel_size_1: i32, + #[comptime] kernel_size_2: i32, +) -> (u32, u32, u32, u32, u32, u32) { + let kms_0 = args.dilation_0 * kernel_size_0 - args.stride_0; + let kms_1 = args.dilation_1 * kernel_size_1 - args.stride_1; + let kms_2 = args.dilation_2 * kernel_size_2 - args.stride_2; + + let od_start = Max::max((id + args.padding_0 - kms_0) / args.stride_0, 0) as u32; + let oh_start = Max::max((ih + args.padding_1 - kms_1) / args.stride_1, 0) as u32; + let ow_start = Max::max((iw + args.padding_2 - kms_2) / args.stride_2, 0) as u32; + + let od_end = Min::min(Max::max(kms_0, 0) as u32 + od_start, grad_d - 1) + 1; + let oh_end = Min::min(Max::max(kms_1, 0) as u32 + oh_start, grad_h - 1) + 1; + let ow_end = Min::min(Max::max(kms_2, 0) as u32 + ow_start, grad_w - 1) + 1; + + (od_start, od_end, oh_start, oh_end, ow_start, ow_end) +} + +#[allow(clippy::too_many_arguments)] +pub(crate) fn max_pool3d_with_indices_backward( + x: CubeTensor, + grad: CubeTensor, + indices: CubeTensor, + kernel_size: [usize; 3], + stride: [usize; 3], + padding: [usize; 3], + dilation: [usize; 3], + _ceil_mode: bool, +) -> CubeTensor { + let [batches, channels, depth, height, width] = x.shape.dims(); + + let grad = into_contiguous(permute_ncdhw_to_ndhwc(grad)); + let indices = into_contiguous(permute_ncdhw_to_ndhwc(indices)); + + let line_size = if grad.strides[4] == indices.strides[4] { + max_line_size(&grad) + } else { + 1 + }; + + let out_shape = Shape::new([batches, depth, height, width, channels]); + let output = empty_device_dtype(x.client.clone(), x.device.clone(), out_shape, x.dtype); + + let working_units = output.shape.num_elements() / line_size as usize; + let cube_dim = CubeDim::new(&x.client, working_units); + let cube_count = calculate_cube_count_elemwise(&x.client, working_units, cube_dim); + + unsafe { + max_pool3d_with_indices_backward_kernel::launch_unchecked( + &x.client, + cube_count, + cube_dim, + grad.as_tensor_arg(line_size), + indices.as_tensor_arg(line_size), + output.as_tensor_arg(line_size), + Pool3dBackwardArgsLaunch::new( + ScalarArg::new(stride[0] as i32), + ScalarArg::new(stride[1] as i32), + ScalarArg::new(stride[2] as i32), + ScalarArg::new(dilation[0] as i32), + ScalarArg::new(dilation[1] as i32), + ScalarArg::new(dilation[2] as i32), + ScalarArg::new(padding[0] as i32), + ScalarArg::new(padding[1] as i32), + ScalarArg::new(padding[2] as i32), + ), + kernel_size[0] as i32, + kernel_size[1] as i32, + kernel_size[2] as i32, + [x.dtype.into(), indices.dtype.into()], + ) + .expect("Kernel to never fail") + }; + + permute_ndhwc_to_ncdhw(output) +} diff --git a/crates/burn-cubecl/src/kernel/pool/mod.rs b/crates/burn-cubecl/src/kernel/pool/mod.rs index 73b42490ac..abb001d498 100644 --- a/crates/burn-cubecl/src/kernel/pool/mod.rs +++ b/crates/burn-cubecl/src/kernel/pool/mod.rs @@ -1,15 +1,28 @@ mod adaptive_avg_pool2d; mod adaptive_avg_pool2d_backward; +mod adaptive_avg_pool3d; +mod adaptive_avg_pool3d_backward; mod avg_pool2d; mod avg_pool2d_backward; +mod avg_pool3d; +mod avg_pool3d_backward; mod max_pool2d; mod max_pool2d_backward; +mod max_pool3d; +mod max_pool3d_backward; pub(super) mod pool2d; +pub(super) mod pool3d; pub(crate) use adaptive_avg_pool2d::*; pub(crate) use adaptive_avg_pool2d_backward::*; +pub(crate) use adaptive_avg_pool3d::*; +pub(crate) use adaptive_avg_pool3d_backward::*; pub(crate) use avg_pool2d::*; pub(crate) use avg_pool2d_backward::*; +pub(crate) use avg_pool3d::*; +pub(crate) use avg_pool3d_backward::*; pub(crate) use max_pool2d::*; pub(crate) use max_pool2d_backward::*; +pub(crate) use max_pool3d::*; +pub(crate) use max_pool3d_backward::*; diff --git a/crates/burn-cubecl/src/kernel/pool/pool3d.rs b/crates/burn-cubecl/src/kernel/pool/pool3d.rs new file mode 100644 index 0000000000..71879ccbad --- /dev/null +++ b/crates/burn-cubecl/src/kernel/pool/pool3d.rs @@ -0,0 +1,157 @@ +use core::hash::Hash; +use cubecl::prelude::*; + +pub trait Pool3dDirectStrategyFamily: Send + Sync + 'static { + type Indices: LaunchArg; + type Config: CubeType + Clone + Send + Sync + core::fmt::Debug + Hash + core::cmp::Eq; + type Pool3d: Pool3dDirectStrategy; +} + +#[cube] +pub(crate) trait Pool3dDirectStrategy: Send + Sync + 'static { + type Accumulator: CubeType; + type Config: CubeType + Clone + Send + Sync + core::fmt::Debug + Hash + core::cmp::Eq; + + type Indices: LaunchArg; + + fn initialize( + #[comptime] config: &Self::Config, + #[comptime] line_size: LineSize, + ) -> Self::Accumulator; + + fn accumulate( + #[comptime] config: &Self::Config, + accumulator: &mut Self::Accumulator, + index: usize, + result: Line, + ); + + /// Count a position within the kernel window (for avg_pool count_include_pad). + /// Called for each position in the kernel window with the current id/ih/iw coordinates. + /// Only avg_pool uses this; max_pool implements as no-op. + fn count_position( + #[comptime] config: &Self::Config, + accumulator: &mut Self::Accumulator, + id: u32, + ih: u32, + iw: u32, + ); + + fn store( + #[comptime] config: &Self::Config, + position: usize, + output: &mut Tensor>, + output_indices: &mut Self::Indices, + accumulator: Self::Accumulator, + ); +} + +#[derive(CubeLaunch, CubeType)] +pub struct Pool3dDirectArgs { + pub strides_0: u32, + pub strides_1: u32, + pub strides_2: u32, + pub dilation_0: u32, + pub dilation_1: u32, + pub dilation_2: u32, + pub padding_0: u32, + pub padding_1: u32, + pub padding_2: u32, +} + +#[cube(launch)] +pub fn pool3d_direct( + input: &Tensor>, + output: &mut Tensor>, + indices: &mut S::Indices, + args: &Pool3dDirectArgs, + #[comptime] kernel_size: (u32, u32, u32), + #[comptime] config: &S::Config, + #[define(E)] _dtype: StorageType, +) { + if ABSOLUTE_POS >= output.len() { + terminate!(); + } + + // Output shape is [batch, out_d, out_h, out_w, channels] in NDHWC format + let (out_d, out_h, out_w, channels) = ( + output.shape(1), + output.shape(2), + output.shape(3), + output.shape(4), + ); + let channel_lines = channels / input.line_size(); + let (in_stride_b, in_stride_d, in_stride_h, in_stride_w, in_stride_c) = ( + input.stride(0), + input.stride(1), + input.stride(2), + input.stride(3), + input.stride(4), + ); + let (in_d, in_h, in_w) = ( + input.shape(1) as u32, + input.shape(2) as u32, + input.shape(3) as u32, + ); + + // Decode position: c, ow, oh, od, b + let c = (ABSOLUTE_POS % channel_lines) * input.line_size(); + let pos = ABSOLUTE_POS / channel_lines; + let ow = pos as u32 % out_w as u32; + let pos = pos / out_w; + let oh = pos as u32 % out_h as u32; + let pos = pos / out_h; + let od = pos as u32 % out_d as u32; + let b = pos / out_d; + + let mut accumulator = S::Pool3d::::initialize(config, input.line_size()); + + let in_b_off = b * in_stride_b; + let in_c_off = c * in_stride_c; + + let border_back = in_d + args.padding_0; + let border_bottom = in_h + args.padding_1; + let border_right = in_w + args.padding_2; + + for kd in 0..kernel_size.0 { + let id = od * args.strides_0 + kd * args.dilation_0; + let within_padding_d = id >= args.padding_0 && id < border_back; + + for kh in 0..kernel_size.1 { + let ih = oh * args.strides_1 + kh * args.dilation_1; + let within_padding_h = ih >= args.padding_1 && ih < border_bottom; + + for kw in 0..kernel_size.2 { + let iw = ow * args.strides_2 + kw * args.dilation_2; + let within_padding_w = iw >= args.padding_2 && iw < border_right; + + // Let strategy handle position counting (only used by avg_pool) + S::Pool3d::::count_position(config, &mut accumulator, id, ih, iw); + + // Only accumulate values from valid input positions + if within_padding_d && within_padding_h && within_padding_w { + let id_pad = id - args.padding_0; + let ih_pad = ih - args.padding_1; + let iw_pad = iw - args.padding_2; + + let in_d_off = id_pad as usize * in_stride_d; + let in_h_off = ih_pad as usize * in_stride_h; + let in_w_off = iw_pad as usize * in_stride_w; + + let index_input = in_b_off + in_c_off + in_d_off + in_h_off + in_w_off; + + S::Pool3d::::accumulate( + config, + &mut accumulator, + id_pad as usize * in_h as usize * in_w as usize + + ih_pad as usize * in_w as usize + + iw_pad as usize, + input[index_input / input.line_size()], + ); + } + } + } + } + + S::Pool3d::::store(config, ABSOLUTE_POS, output, indices, accumulator); +} diff --git a/crates/burn-cubecl/src/ops/base.rs b/crates/burn-cubecl/src/ops/base.rs index ebddccba86..c539f43716 100644 --- a/crates/burn-cubecl/src/ops/base.rs +++ b/crates/burn-cubecl/src/ops/base.rs @@ -186,6 +186,22 @@ pub fn permute_nhwc_to_nchw_shape(shape: Shape) -> Shape { shape.permute(&dims).expect("Shape permute should succeed") } +/// Convenience wrapper to permute a 5D tensor's dimensions from NCDHW to NDHWC. +/// Internally this delegates to [`permute_nchw_to_nhwc`], which handles the +/// corresponding N-dimensional permutation pattern. +pub fn permute_ncdhw_to_ndhwc(tensor: CubeTensor) -> CubeTensor { + // This is the same as permute_nchw_to_nhwc but more explicit for 5D + permute_nchw_to_nhwc(tensor) +} + +/// Convenience wrapper to permute a 5D tensor's dimensions from NDHWC to NCDHW +/// Internally this delegates to [`permute_nhwc_to_nchw`], which handles the corresponding +/// N-dimensional permutation pattern and supports arbitrary ranks, including 5D. +pub fn permute_ndhwc_to_ncdhw(tensor: CubeTensor) -> CubeTensor { + // This is the same as permute_nhwc_to_nchw but more explicit for 5D + permute_nhwc_to_nchw(tensor) +} + pub(crate) fn expand(tensor: CubeTensor, target_shape: Shape) -> CubeTensor { let ndims_in = tensor.shape.num_dims(); let ndims_out = target_shape.num_dims(); diff --git a/crates/burn-cubecl/src/ops/module.rs b/crates/burn-cubecl/src/ops/module.rs index fb1094a2ca..d44d4b7d32 100644 --- a/crates/burn-cubecl/src/ops/module.rs +++ b/crates/burn-cubecl/src/ops/module.rs @@ -5,7 +5,7 @@ use crate::{ }; use burn_backend::ops::{ ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateOptions, - MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, + MaxPool2dBackward, MaxPool2dWithIndices, MaxPool3dBackward, MaxPool3dWithIndices, ModuleOps, }; use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor}; @@ -276,6 +276,109 @@ where kernel::pool::adaptive_avg_pool2d_backward(x, grad) } + fn avg_pool3d( + x: FloatTensor, + kernel_size: [usize; 3], + stride: [usize; 3], + padding: [usize; 3], + count_include_pad: bool, + ceil_mode: bool, + ) -> FloatTensor { + kernel::pool::avg_pool3d( + x, + kernel_size, + stride, + padding, + count_include_pad, + ceil_mode, + ) + } + + fn avg_pool3d_backward( + x: FloatTensor, + grad: FloatTensor, + kernel_size: [usize; 3], + stride: [usize; 3], + padding: [usize; 3], + count_include_pad: bool, + ceil_mode: bool, + ) -> FloatTensor { + kernel::pool::avg_pool3d_backward( + x, + grad, + kernel_size, + stride, + padding, + count_include_pad, + ceil_mode, + ) + } + + fn max_pool3d( + x: FloatTensor, + kernel_size: [usize; 3], + stride: [usize; 3], + padding: [usize; 3], + dilation: [usize; 3], + ceil_mode: bool, + ) -> FloatTensor { + kernel::pool::max_pool3d(x, kernel_size, stride, padding, dilation, ceil_mode) + } + + fn max_pool3d_with_indices( + x: FloatTensor, + kernel_size: [usize; 3], + stride: [usize; 3], + padding: [usize; 3], + dilation: [usize; 3], + ceil_mode: bool, + ) -> MaxPool3dWithIndices { + let (output, indices) = kernel::pool::max_pool3d_with_indices( + x, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + I::dtype(), + ); + + MaxPool3dWithIndices::new(output, indices) + } + + fn max_pool3d_with_indices_backward( + x: FloatTensor, + kernel_size: [usize; 3], + stride: [usize; 3], + padding: [usize; 3], + dilation: [usize; 3], + ceil_mode: bool, + output_grad: FloatTensor, + indices: IntTensor, + ) -> MaxPool3dBackward { + MaxPool3dBackward::new(kernel::pool::max_pool3d_with_indices_backward( + x, + output_grad, + indices, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + )) + } + + fn adaptive_avg_pool3d(x: FloatTensor, output_size: [usize; 3]) -> FloatTensor { + kernel::pool::adaptive_avg_pool3d(x, output_size) + } + + fn adaptive_avg_pool3d_backward( + x: FloatTensor, + grad: FloatTensor, + ) -> FloatTensor { + kernel::pool::adaptive_avg_pool3d_backward(x, grad) + } + fn interpolate( x: FloatTensor, output_size: [usize; 2], diff --git a/crates/burn-fusion/src/ops/module.rs b/crates/burn-fusion/src/ops/module.rs index d52c1815a0..6ed7449066 100644 --- a/crates/burn-fusion/src/ops/module.rs +++ b/crates/burn-fusion/src/ops/module.rs @@ -7,7 +7,7 @@ use burn_backend::{ ops::{ ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateOptions, MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward, - MaxPool2dWithIndices, ModuleOps, + MaxPool2dWithIndices, MaxPool3dBackward, MaxPool3dWithIndices, ModuleOps, }, tensor::{FloatTensor, IntTensor}, }; @@ -1087,6 +1087,75 @@ impl ModuleOps> for Fusion { .output() } + fn avg_pool3d( + _x: FloatTensor, + _kernel_size: [usize; 3], + _stride: [usize; 3], + _padding: [usize; 3], + _count_include_pad: bool, + _ceil_mode: bool, + ) -> FloatTensor { + unimplemented!("avg_pool3d is not yet implemented for Fusion backend") + } + + fn avg_pool3d_backward( + _x: FloatTensor, + _grad: FloatTensor, + _kernel_size: [usize; 3], + _stride: [usize; 3], + _padding: [usize; 3], + _count_include_pad: bool, + _ceil_mode: bool, + ) -> FloatTensor { + unimplemented!("avg_pool3d_backward is not yet implemented for Fusion backend") + } + + fn max_pool3d( + _x: FloatTensor, + _kernel_size: [usize; 3], + _stride: [usize; 3], + _padding: [usize; 3], + _dilation: [usize; 3], + _ceil_mode: bool, + ) -> FloatTensor { + unimplemented!("max_pool3d is not yet implemented for Fusion backend") + } + + fn max_pool3d_with_indices( + _x: FloatTensor, + _kernel_size: [usize; 3], + _stride: [usize; 3], + _padding: [usize; 3], + _dilation: [usize; 3], + _ceil_mode: bool, + ) -> MaxPool3dWithIndices { + unimplemented!("max_pool3d_with_indices is not yet implemented for Fusion backend") + } + + fn max_pool3d_with_indices_backward( + _x: FloatTensor, + _kernel_size: [usize; 3], + _stride: [usize; 3], + _padding: [usize; 3], + _dilation: [usize; 3], + _ceil_mode: bool, + _output_grad: FloatTensor, + _indices: IntTensor, + ) -> MaxPool3dBackward { + unimplemented!("max_pool3d_with_indices_backward is not yet implemented for Fusion backend") + } + + fn adaptive_avg_pool3d(_x: FloatTensor, _output_size: [usize; 3]) -> FloatTensor { + unimplemented!("adaptive_avg_pool3d is not yet implemented for Fusion backend") + } + + fn adaptive_avg_pool3d_backward( + _x: FloatTensor, + _grad: FloatTensor, + ) -> FloatTensor { + unimplemented!("adaptive_avg_pool3d_backward is not yet implemented for Fusion backend") + } + fn interpolate( x: FloatTensor, output_size: [usize; 2], diff --git a/crates/burn-ndarray/src/ops/module.rs b/crates/burn-ndarray/src/ops/module.rs index b6c4271707..840c977965 100644 --- a/crates/burn-ndarray/src/ops/module.rs +++ b/crates/burn-ndarray/src/ops/module.rs @@ -323,6 +323,77 @@ where } } + fn avg_pool3d( + _x: FloatTensor, + _kernel_size: [usize; 3], + _stride: [usize; 3], + _padding: [usize; 3], + _count_include_pad: bool, + _ceil_mode: bool, + ) -> FloatTensor { + unimplemented!("avg_pool3d is not yet implemented for NdArray backend") + } + + fn avg_pool3d_backward( + _x: FloatTensor, + _grad: FloatTensor, + _kernel_size: [usize; 3], + _stride: [usize; 3], + _padding: [usize; 3], + _count_include_pad: bool, + _ceil_mode: bool, + ) -> FloatTensor { + unimplemented!("avg_pool3d_backward is not yet implemented for NdArray backend") + } + + fn max_pool3d( + _x: FloatTensor, + _kernel_size: [usize; 3], + _stride: [usize; 3], + _padding: [usize; 3], + _dilation: [usize; 3], + _ceil_mode: bool, + ) -> FloatTensor { + unimplemented!("max_pool3d is not yet implemented for NdArray backend") + } + + fn max_pool3d_with_indices( + _x: FloatTensor, + _kernel_size: [usize; 3], + _stride: [usize; 3], + _padding: [usize; 3], + _dilation: [usize; 3], + _ceil_mode: bool, + ) -> MaxPool3dWithIndices { + unimplemented!("max_pool3d_with_indices is not yet implemented for NdArray backend") + } + + fn max_pool3d_with_indices_backward( + _x: FloatTensor, + _kernel_size: [usize; 3], + _stride: [usize; 3], + _padding: [usize; 3], + _dilation: [usize; 3], + _ceil_mode: bool, + _output_grad: FloatTensor, + _indices: NdArrayTensor, + ) -> MaxPool3dBackward { + unimplemented!( + "max_pool3d_with_indices_backward is not yet implemented for NdArray backend" + ) + } + + fn adaptive_avg_pool3d(_x: FloatTensor, _output_size: [usize; 3]) -> FloatTensor { + unimplemented!("adaptive_avg_pool3d is not yet implemented for NdArray backend") + } + + fn adaptive_avg_pool3d_backward( + _x: FloatTensor, + _grad: FloatTensor, + ) -> FloatTensor { + unimplemented!("adaptive_avg_pool3d_backward is not yet implemented for NdArray backend") + } + fn conv3d( x: FloatTensor, weight: FloatTensor, diff --git a/crates/burn-nn/src/modules/pool/adaptive_avg_pool3d.rs b/crates/burn-nn/src/modules/pool/adaptive_avg_pool3d.rs new file mode 100644 index 0000000000..44b1e5c56f --- /dev/null +++ b/crates/burn-nn/src/modules/pool/adaptive_avg_pool3d.rs @@ -0,0 +1,79 @@ +use burn_core as burn; + +use burn::config::Config; +use burn::module::Module; +use burn::module::{Content, DisplaySettings, ModuleDisplay}; +use burn::tensor::Tensor; +use burn::tensor::backend::Backend; + +use burn::tensor::module::adaptive_avg_pool3d; + +/// Configuration to create a [3D adaptive avg pooling](AdaptiveAvgPool3d) layer using the [init function](AdaptiveAvgPool3dConfig::init). +#[derive(Config, Debug)] +pub struct AdaptiveAvgPool3dConfig { + /// The size of the output. + pub output_size: [usize; 3], +} + +/// Applies a 3D adaptive avg pooling over input tensors. +/// +/// Should be created with [AdaptiveAvgPool3dConfig]. +#[derive(Module, Clone, Debug)] +#[module(custom_display)] +pub struct AdaptiveAvgPool3d { + /// The size of the output. + pub output_size: [usize; 3], +} + +impl ModuleDisplay for AdaptiveAvgPool3d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + let output_size = alloc::format!("{:?}", self.output_size); + + content.add("output_size", &output_size).optional() + } +} + +impl AdaptiveAvgPool3dConfig { + /// Initialize a new [adaptive avg pool 3d](AdaptiveAvgPool3d) module. + pub fn init(&self) -> AdaptiveAvgPool3d { + AdaptiveAvgPool3d { + output_size: self.output_size, + } + } +} + +impl AdaptiveAvgPool3d { + /// Applies the forward pass on the input tensor. + /// + /// See [adaptive_avg_pool3d](burn::tensor::module::adaptive_avg_pool3d) for more information. + /// + /// # Shapes + /// + /// - input: `[batch_size, channels, depth_in, height_in, width_in]` + /// - output: `[batch_size, channels, depth_out, height_out, width_out]` + pub fn forward(&self, input: Tensor) -> Tensor { + adaptive_avg_pool3d(input, self.output_size) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn display() { + let config = AdaptiveAvgPool3dConfig::new([3, 3, 3]); + let layer = config.init(); + + assert_eq!( + alloc::format!("{layer}"), + "AdaptiveAvgPool3d {output_size: [3, 3, 3]}" + ); + } +} diff --git a/crates/burn-nn/src/modules/pool/avg_pool3d.rs b/crates/burn-nn/src/modules/pool/avg_pool3d.rs new file mode 100644 index 0000000000..7505949c71 --- /dev/null +++ b/crates/burn-nn/src/modules/pool/avg_pool3d.rs @@ -0,0 +1,162 @@ +use crate::conv::checks::check_same_padding_support; +use burn_core as burn; + +use crate::PaddingConfig3d; +use burn::config::Config; +use burn::module::{Content, DisplaySettings, ModuleDisplay}; +use burn::module::{Ignored, Module}; +use burn::tensor::Tensor; +use burn::tensor::backend::Backend; + +use burn::tensor::module::avg_pool3d; + +/// Configuration to create a [3D avg pooling](AvgPool3d) layer using the [init function](AvgPool3dConfig::init). +#[derive(Config, Debug)] +pub struct AvgPool3dConfig { + /// The size of the kernel. + pub kernel_size: [usize; 3], + /// The strides. + #[config(default = "kernel_size")] + pub strides: [usize; 3], + /// The padding configuration. + /// + /// ### Warning + /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel + /// size is not supported as it will not produce the same output size. + #[config(default = "PaddingConfig3d::Valid")] + pub padding: PaddingConfig3d, + /// If the padding is counted in the denominator when computing the average. + #[config(default = "true")] + pub count_include_pad: bool, + /// If true, use ceiling instead of floor for output size calculation. + #[config(default = "false")] + pub ceil_mode: bool, +} + +/// Applies a 3D avg pooling over input tensors. +/// +/// Should be created with [AvgPool3dConfig](AvgPool3dConfig). +/// +/// # Remarks +/// +/// The zero-padding values will be included in the calculation +/// of the average. This means that the zeros are counted as +/// legitimate values, and they contribute to the denominator +/// when calculating the average. This is equivalent to +/// `torch.nn.AvgPool3d` with `count_include_pad=True`. +#[derive(Module, Clone, Debug)] +#[module(custom_display)] +pub struct AvgPool3d { + /// Stride of the pooling. + pub stride: [usize; 3], + /// Size of the kernel. + pub kernel_size: [usize; 3], + /// Padding configuration. + pub padding: Ignored, + /// If the padding is counted in the denominator when computing the average. + pub count_include_pad: bool, + /// If true, use ceiling instead of floor for output size calculation. + pub ceil_mode: bool, +} + +impl ModuleDisplay for AvgPool3d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("kernel_size", &alloc::format!("{:?}", &self.kernel_size)) + .add("stride", &alloc::format!("{:?}", &self.stride)) + .add("padding", &self.padding) + .add("count_include_pad", &self.count_include_pad) + .add("ceil_mode", &self.ceil_mode) + .optional() + } +} + +impl AvgPool3dConfig { + /// Initialize a new [avg pool 3d](AvgPool3d) module. + pub fn init(&self) -> AvgPool3d { + if self.padding == PaddingConfig3d::Same { + check_same_padding_support(&self.kernel_size); + } + AvgPool3d { + stride: self.strides, + kernel_size: self.kernel_size, + padding: Ignored(self.padding.clone()), + count_include_pad: self.count_include_pad, + ceil_mode: self.ceil_mode, + } + } +} + +impl AvgPool3d { + /// Applies the forward pass on the input tensor. + /// + /// See [avg_pool3d](burn::tensor::module::avg_pool3d) for more information. + /// + /// # Shapes + /// + /// - input: `[batch_size, channels, depth_in, height_in, width_in]` + /// - output: `[batch_size, channels, depth_out, height_out, width_out]` + pub fn forward(&self, input: Tensor) -> Tensor { + let [_batch_size, _channels_in, depth_in, height_in, width_in] = input.dims(); + let padding = self.padding.calculate_padding_3d( + depth_in, + height_in, + width_in, + &self.kernel_size, + &self.stride, + ); + + avg_pool3d( + input, + self.kernel_size, + self.stride, + padding, + self.count_include_pad, + self.ceil_mode, + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rstest::rstest; + + #[test] + #[should_panic = "Same padding with an even kernel size is not supported"] + fn same_with_even_kernel_is_invalid() { + let config = AvgPool3dConfig::new([2, 2, 2]).with_padding(PaddingConfig3d::Same); + let _ = config.init(); + } + + #[test] + fn display() { + let config = AvgPool3dConfig::new([3, 3, 3]); + + let layer = config.init(); + + assert_eq!( + alloc::format!("{layer}"), + "AvgPool3d {kernel_size: [3, 3, 3], stride: [3, 3, 3], padding: Valid, count_include_pad: true, ceil_mode: false}" + ); + } + + #[rstest] + #[case([2, 2, 2])] + #[case([1, 2, 3])] + fn default_strides_match_kernel_size(#[case] kernel_size: [usize; 3]) { + let config = AvgPool3dConfig::new(kernel_size); + + assert_eq!( + config.strides, kernel_size, + "Expected strides ({:?}) to match kernel size ({:?}) in default AvgPool3dConfig::new constructor", + config.strides, config.kernel_size + ); + } +} diff --git a/crates/burn-nn/src/modules/pool/max_pool3d.rs b/crates/burn-nn/src/modules/pool/max_pool3d.rs new file mode 100644 index 0000000000..c1bc75d78d --- /dev/null +++ b/crates/burn-nn/src/modules/pool/max_pool3d.rs @@ -0,0 +1,154 @@ +use crate::conv::checks::check_same_padding_support; +use burn_core as burn; + +use crate::PaddingConfig3d; +use burn::config::Config; +use burn::module::{Content, DisplaySettings, ModuleDisplay}; +use burn::module::{Ignored, Module}; +use burn::tensor::Tensor; +use burn::tensor::backend::Backend; + +use burn::tensor::module::max_pool3d; + +/// Configuration to create a [3D max pooling](MaxPool3d) layer using the [init function](MaxPool3dConfig::init). +#[derive(Debug, Config)] +pub struct MaxPool3dConfig { + /// The size of the kernel. + pub kernel_size: [usize; 3], + /// The strides. + #[config(default = "kernel_size")] + pub strides: [usize; 3], + /// The padding configuration. + /// + /// ### Warning + /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel + /// size is not supported as it will not produce the same output size. + #[config(default = "PaddingConfig3d::Valid")] + pub padding: PaddingConfig3d, + /// The dilation. + #[config(default = "[1, 1, 1]")] + pub dilation: [usize; 3], + /// If true, use ceiling instead of floor for output size calculation. + #[config(default = "false")] + pub ceil_mode: bool, +} + +/// Applies a 3D max pooling over input tensors. +/// +/// Should be created with [MaxPool3dConfig](MaxPool3dConfig). +#[derive(Module, Clone, Debug)] +#[module(custom_display)] +pub struct MaxPool3d { + /// The strides. + pub stride: [usize; 3], + /// The size of the kernel. + pub kernel_size: [usize; 3], + /// The padding configuration. + pub padding: Ignored, + /// The dilation. + pub dilation: [usize; 3], + /// If true, use ceiling instead of floor for output size calculation. + pub ceil_mode: bool, +} + +impl ModuleDisplay for MaxPool3d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("kernel_size", &alloc::format!("{:?}", &self.kernel_size)) + .add("stride", &alloc::format!("{:?}", &self.stride)) + .add("padding", &self.padding) + .add("dilation", &alloc::format!("{:?}", &self.dilation)) + .add("ceil_mode", &self.ceil_mode) + .optional() + } +} + +impl MaxPool3dConfig { + /// Initialize a new [max pool 3d](MaxPool3d) module. + pub fn init(&self) -> MaxPool3d { + if self.padding == PaddingConfig3d::Same { + check_same_padding_support(&self.kernel_size); + } + MaxPool3d { + stride: self.strides, + kernel_size: self.kernel_size, + padding: Ignored(self.padding.clone()), + dilation: self.dilation, + ceil_mode: self.ceil_mode, + } + } +} + +impl MaxPool3d { + /// Applies the forward pass on the input tensor. + /// + /// See [max_pool3d](burn::tensor::module::max_pool3d) for more information. + /// + /// # Shapes + /// + /// - input: `[batch_size, channels, depth_in, height_in, width_in]` + /// - output: `[batch_size, channels, depth_out, height_out, width_out]` + pub fn forward(&self, input: Tensor) -> Tensor { + let [_batch_size, _channels_in, depth_in, height_in, width_in] = input.dims(); + let padding = self.padding.calculate_padding_3d( + depth_in, + height_in, + width_in, + &self.kernel_size, + &self.stride, + ); + + max_pool3d( + input, + self.kernel_size, + self.stride, + padding, + self.dilation, + self.ceil_mode, + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rstest::rstest; + + #[test] + #[should_panic = "Same padding with an even kernel size is not supported"] + fn same_with_even_kernel_is_invalid() { + let config = MaxPool3dConfig::new([2, 2, 2]).with_padding(PaddingConfig3d::Same); + let _ = config.init(); + } + + #[test] + fn display() { + let config = MaxPool3dConfig::new([3, 3, 3]); + + let layer = config.init(); + + assert_eq!( + alloc::format!("{layer}"), + "MaxPool3d {kernel_size: [3, 3, 3], stride: [3, 3, 3], padding: Valid, dilation: [1, 1, 1], ceil_mode: false}" + ); + } + + #[rstest] + #[case([2, 2, 2])] + #[case([1, 2, 3])] + fn default_strides_match_kernel_size(#[case] kernel_size: [usize; 3]) { + let config = MaxPool3dConfig::new(kernel_size); + + assert_eq!( + config.strides, kernel_size, + "Expected strides ({:?}) to match kernel size ({:?}) in default MaxPool3dConfig::new constructor", + config.strides, config.kernel_size + ); + } +} diff --git a/crates/burn-nn/src/modules/pool/mod.rs b/crates/burn-nn/src/modules/pool/mod.rs index 622a4b66f4..4fdecbb023 100644 --- a/crates/burn-nn/src/modules/pool/mod.rs +++ b/crates/burn-nn/src/modules/pool/mod.rs @@ -1,13 +1,19 @@ mod adaptive_avg_pool1d; mod adaptive_avg_pool2d; +mod adaptive_avg_pool3d; mod avg_pool1d; mod avg_pool2d; +mod avg_pool3d; mod max_pool1d; mod max_pool2d; +mod max_pool3d; pub use adaptive_avg_pool1d::*; pub use adaptive_avg_pool2d::*; +pub use adaptive_avg_pool3d::*; pub use avg_pool1d::*; pub use avg_pool2d::*; +pub use avg_pool3d::*; pub use max_pool1d::*; pub use max_pool2d::*; +pub use max_pool3d::*; diff --git a/crates/burn-router/src/ops/module.rs b/crates/burn-router/src/ops/module.rs index 9b2bb0ef64..393ff4cd7b 100644 --- a/crates/burn-router/src/ops/module.rs +++ b/crates/burn-router/src/ops/module.rs @@ -3,7 +3,8 @@ use alloc::boxed::Box; use burn_backend::Element; use burn_backend::ops::{ ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateOptions, - MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, + MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices, + MaxPool3dBackward, MaxPool3dWithIndices, ModuleOps, }; use burn_backend::tensor::{FloatTensor, IntElem, IntTensor}; use burn_ir::{ @@ -483,6 +484,75 @@ impl ModuleOps for BackendRouter { .output() } + fn avg_pool3d( + _x: FloatTensor, + _kernel_size: [usize; 3], + _stride: [usize; 3], + _padding: [usize; 3], + _count_include_pad: bool, + _ceil_mode: bool, + ) -> FloatTensor { + unimplemented!("avg_pool3d is not yet implemented for BackendRouter") + } + + fn avg_pool3d_backward( + _x: FloatTensor, + _grad: FloatTensor, + _kernel_size: [usize; 3], + _stride: [usize; 3], + _padding: [usize; 3], + _count_include_pad: bool, + _ceil_mode: bool, + ) -> FloatTensor { + unimplemented!("avg_pool3d_backward is not yet implemented for BackendRouter") + } + + fn max_pool3d( + _x: FloatTensor, + _kernel_size: [usize; 3], + _stride: [usize; 3], + _padding: [usize; 3], + _dilation: [usize; 3], + _ceil_mode: bool, + ) -> FloatTensor { + unimplemented!("max_pool3d is not yet implemented for BackendRouter") + } + + fn max_pool3d_with_indices( + _x: FloatTensor, + _kernel_size: [usize; 3], + _stride: [usize; 3], + _padding: [usize; 3], + _dilation: [usize; 3], + _ceil_mode: bool, + ) -> MaxPool3dWithIndices { + unimplemented!("max_pool3d_with_indices is not yet implemented for BackendRouter") + } + + fn max_pool3d_with_indices_backward( + _x: FloatTensor, + _kernel_size: [usize; 3], + _stride: [usize; 3], + _padding: [usize; 3], + _dilation: [usize; 3], + _ceil_mode: bool, + _output_grad: FloatTensor, + _indices: IntTensor, + ) -> MaxPool3dBackward { + unimplemented!("max_pool3d_with_indices_backward is not yet implemented for BackendRouter") + } + + fn adaptive_avg_pool3d(_x: FloatTensor, _output_size: [usize; 3]) -> FloatTensor { + unimplemented!("adaptive_avg_pool3d is not yet implemented for BackendRouter") + } + + fn adaptive_avg_pool3d_backward( + _x: FloatTensor, + _grad: FloatTensor, + ) -> FloatTensor { + unimplemented!("adaptive_avg_pool3d_backward is not yet implemented for BackendRouter") + } + fn interpolate( x: FloatTensor, output_size: [usize; 2], diff --git a/crates/burn-tch/src/ops/module.rs b/crates/burn-tch/src/ops/module.rs index 6ff386c3e1..7f02d32331 100644 --- a/crates/burn-tch/src/ops/module.rs +++ b/crates/burn-tch/src/ops/module.rs @@ -4,7 +4,7 @@ use burn_backend::{ ops::{ ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateMode, InterpolateOptions, MaxPool1dWithIndices, MaxPool2dBackward, - MaxPool2dWithIndices, ModuleOps, + MaxPool2dWithIndices, MaxPool3dBackward, MaxPool3dWithIndices, ModuleOps, }, }; @@ -366,6 +366,146 @@ impl ModuleOps for LibTorch { MaxPool2dBackward::new(TchTensor::new(grad)) } + fn avg_pool3d( + x: TchTensor, + kernel_size: [usize; 3], + stride: [usize; 3], + padding: [usize; 3], + count_include_pad: bool, + ceil_mode: bool, + ) -> TchTensor { + let tensor = tch::Tensor::avg_pool3d( + &x.tensor, + [ + kernel_size[0] as i64, + kernel_size[1] as i64, + kernel_size[2] as i64, + ], + [stride[0] as i64, stride[1] as i64, stride[2] as i64], + [padding[0] as i64, padding[1] as i64, padding[2] as i64], + ceil_mode, + count_include_pad, + None, + ); + + TchTensor::new(tensor) + } + + fn avg_pool3d_backward( + x: TchTensor, + grad: TchTensor, + kernel_size: [usize; 3], + stride: [usize; 3], + padding: [usize; 3], + count_include_pad: bool, + ceil_mode: bool, + ) -> TchTensor { + let tensor = tch::Tensor::avg_pool3d_backward( + &x.tensor, + &grad.tensor, + [ + kernel_size[0] as i64, + kernel_size[1] as i64, + kernel_size[2] as i64, + ], + [stride[0] as i64, stride[1] as i64, stride[2] as i64], + [padding[0] as i64, padding[1] as i64, padding[2] as i64], + ceil_mode, + count_include_pad, + None, + ); + + TchTensor::new(tensor) + } + + fn max_pool3d( + x: TchTensor, + kernel_size: [usize; 3], + stride: [usize; 3], + padding: [usize; 3], + dilation: [usize; 3], + ceil_mode: bool, + ) -> TchTensor { + let tensor = tch::Tensor::max_pool3d( + &x.tensor, + [ + kernel_size[0] as i64, + kernel_size[1] as i64, + kernel_size[2] as i64, + ], + [stride[0] as i64, stride[1] as i64, stride[2] as i64], + [padding[0] as i64, padding[1] as i64, padding[2] as i64], + [dilation[0] as i64, dilation[1] as i64, dilation[2] as i64], + ceil_mode, + ); + + TchTensor::new(tensor) + } + + fn max_pool3d_with_indices( + x: TchTensor, + kernel_size: [usize; 3], + stride: [usize; 3], + padding: [usize; 3], + dilation: [usize; 3], + ceil_mode: bool, + ) -> MaxPool3dWithIndices> { + let (tensor, indices) = tch::Tensor::max_pool3d_with_indices( + &x.tensor, + [ + kernel_size[0] as i64, + kernel_size[1] as i64, + kernel_size[2] as i64, + ], + [stride[0] as i64, stride[1] as i64, stride[2] as i64], + [padding[0] as i64, padding[1] as i64, padding[2] as i64], + [dilation[0] as i64, dilation[1] as i64, dilation[2] as i64], + ceil_mode, + ); + + MaxPool3dWithIndices::new(TchTensor::new(tensor), TchTensor::new(indices)) + } + + fn max_pool3d_with_indices_backward( + x: TchTensor, + kernel_size: [usize; 3], + stride: [usize; 3], + padding: [usize; 3], + dilation: [usize; 3], + ceil_mode: bool, + output_grad: TchTensor, + indices: TchTensor, + ) -> MaxPool3dBackward> { + let grad = tch::Tensor::max_pool3d_with_indices_backward( + &x.tensor, + &output_grad.tensor, + [ + kernel_size[0] as i64, + kernel_size[1] as i64, + kernel_size[2] as i64, + ], + [stride[0] as i64, stride[1] as i64, stride[2] as i64], + [padding[0] as i64, padding[1] as i64, padding[2] as i64], + [dilation[0] as i64, dilation[1] as i64, dilation[2] as i64], + ceil_mode, + &indices.tensor, + ); + + MaxPool3dBackward::new(TchTensor::new(grad)) + } + + fn adaptive_avg_pool3d(x: TchTensor, output_size: [usize; 3]) -> TchTensor { + let tensor = tch::Tensor::adaptive_avg_pool3d(&x.tensor, output_size.map(|e| e as i64)); + + TchTensor::new(tensor) + } + + fn adaptive_avg_pool3d_backward(x: TchTensor, grad: TchTensor) -> TchTensor { + let tensor = tch::Tensor::internal_adaptive_avg_pool3d_backward(&x.tensor, &grad.tensor); + + TchTensor::new(tensor) + } + fn adaptive_avg_pool2d(x: TchTensor, output_size: [usize; 2]) -> TchTensor { let tensor = tch::Tensor::adaptive_avg_pool2d(&x.tensor, output_size.map(|e| e as i64)); diff --git a/crates/burn-tensor/src/tensor/module.rs b/crates/burn-tensor/src/tensor/module.rs index a5e41df734..809e8a29ac 100644 --- a/crates/burn-tensor/src/tensor/module.rs +++ b/crates/burn-tensor/src/tensor/module.rs @@ -364,6 +364,88 @@ where ))) } +/// Applies a [3D max pooling](crate::ops::ModuleOps::max_pool3d). +pub fn max_pool3d( + x: Tensor, + kernel_size: [usize; 3], + stride: [usize; 3], + padding: [usize; 3], + dilation: [usize; 3], + ceil_mode: bool, +) -> Tensor +where + B: Backend, +{ + Tensor::new(TensorPrimitive::Float(B::max_pool3d( + x.primitive.tensor(), + kernel_size, + stride, + padding, + dilation, + ceil_mode, + ))) +} + +/// Applies a [3D max pooling with indices](crate::ops::ModuleOps::max_pool3d_with_indices). +pub fn max_pool3d_with_indices( + x: Tensor, + kernel_size: [usize; 3], + stride: [usize; 3], + padding: [usize; 3], + dilation: [usize; 3], + ceil_mode: bool, +) -> (Tensor, Tensor) +where + B: Backend, +{ + let output = B::max_pool3d_with_indices( + x.primitive.tensor(), + kernel_size, + stride, + padding, + dilation, + ceil_mode, + ); + + ( + Tensor::new(TensorPrimitive::Float(output.output)), + Tensor::new(output.indices), + ) +} + +/// Applies a [3D avg pooling](crate::ops::ModuleOps::avg_pool3d). +pub fn avg_pool3d( + x: Tensor, + kernel_size: [usize; 3], + stride: [usize; 3], + padding: [usize; 3], + count_include_pad: bool, + ceil_mode: bool, +) -> Tensor +where + B: Backend, +{ + Tensor::new(TensorPrimitive::Float(B::avg_pool3d( + x.primitive.tensor(), + kernel_size, + stride, + padding, + count_include_pad, + ceil_mode, + ))) +} + +/// Applies a [3D adaptive avg pooling](crate::ops::ModuleOps::adaptive_avg_pool3d). +pub fn adaptive_avg_pool3d(x: Tensor, output_size: [usize; 3]) -> Tensor +where + B: Backend, +{ + Tensor::new(TensorPrimitive::Float(B::adaptive_avg_pool3d( + x.primitive.tensor(), + output_size, + ))) +} + /// Applies a [2D interpolation](crate::ops::ModuleOps::interpolate). pub fn interpolate( x: Tensor,