From ce7d866b5d6492b69c1ae90800b21491dfc72aea Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Wed, 9 Oct 2024 15:49:00 -0400 Subject: [PATCH] Refactor Adaptive Avg Pool to CubeCL (#2351) --- .../src/kernel/pool/adaptive_avg_pool2d.rs | 89 ++++- .../pool/adaptive_avg_pool2d_backward.rs | 303 ++++-------------- .../src/kernel/pool/adaptive_pool2d_shader.rs | 229 ------------- crates/burn-jit/src/kernel/pool/mod.rs | 2 - 4 files changed, 153 insertions(+), 470 deletions(-) delete mode 100644 crates/burn-jit/src/kernel/pool/adaptive_pool2d_shader.rs diff --git a/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d.rs b/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d.rs index 4f962a7501..233ccd2707 100644 --- a/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d.rs +++ b/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d.rs @@ -1,8 +1,78 @@ use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; use burn_tensor::Shape; -use cubecl::{CubeCountSettings, Execution}; +use cubecl::{calculate_cube_count_elemwise, prelude::*}; -use super::AdaptivePool2dEagerKernel; +#[cube(launch)] +fn adaptive_avg_pool2d_direct(input: &Tensor, output: &mut Tensor) { + let (output_stride_0, output_stride_1, output_stride_2, output_stride_3) = ( + output.stride(0), + output.stride(1), + output.stride(2), + output.stride(3), + ); + let (output_shape_0, output_shape_1, output_shape_2, output_shape_3) = ( + output.shape(0), + output.shape(1), + output.shape(2), + output.shape(3), + ); + let (input_stride_0, input_stride_1, input_stride_2, input_stride_3) = ( + input.stride(0), + input.stride(1), + input.stride(2), + input.stride(3), + ); + let (input_shape_2, input_shape_3) = (input.shape(2), input.shape(3)); + + let b = (ABSOLUTE_POS / output_stride_0) % output_shape_0; + let c = (ABSOLUTE_POS / output_stride_1) % output_shape_1; + let oh = (ABSOLUTE_POS / output_stride_2) % output_shape_2; + let ow = (ABSOLUTE_POS / output_stride_3) % output_shape_3; + + let ih_start = start_index(oh, output_shape_2, input_shape_2); + let ih_end = end_index(oh, output_shape_2, input_shape_2); + + let iw_start = start_index(ow, output_shape_3, input_shape_3); + let iw_end = end_index(ow, output_shape_3, input_shape_3); + + let mut sum = E::from_int(0); + + let index_input_0 = b * input_stride_0; + let index_input_1 = c * input_stride_1; + + for ih in ih_start..ih_end { + let index_input_2 = ih * input_stride_2; + + for iw in iw_start..iw_end { + let index_input_3 = iw * input_stride_3; + + let index_input = index_input_0 + index_input_1 + index_input_2 + index_input_3; + sum += input[index_input]; + } + } + + let num_ih = ih_end - ih_start; + let num_iw = iw_end - iw_start; + + output[ABSOLUTE_POS] = sum / E::cast_from(num_ih * num_iw); +} + +#[cube] +fn start_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 { + (output_size_index * input_size) / output_size +} + +#[cube] +fn end_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 { + let index = (output_size_index + 1) * input_size; + let index = (index + output_size - 1) / output_size; + + if input_size < index { + input_size + } else { + index + } +} pub(crate) fn adaptive_avg_pool2d( input: JitTensor, @@ -11,14 +81,19 @@ pub(crate) fn adaptive_avg_pool2d( let [batch_size, channels, _, _] = input.shape.dims(); let output_shape = Shape::new([batch_size, channels, output_size[0], output_size[1]]); + let num_elems: usize = output_shape.num_elements(); let output = empty_device(input.client.clone(), input.device.clone(), output_shape); - let kernel = AdaptivePool2dEagerKernel::::new(); + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim); - Execution::start(kernel, input.client.clone()) - .inputs(&[input.as_handle_ref()]) - .outputs(&[output.as_handle_ref()]) - .execute(CubeCountSettings::Output { pos: 0 }); + adaptive_avg_pool2d_direct::launch::( + &input.client, + cube_count, + cube_dim, + input.as_tensor_arg(1), + output.as_tensor_arg(1), + ); output } diff --git a/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d_backward.rs index 580f71b8ff..f9eaa72bc4 100644 --- a/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d_backward.rs @@ -1,245 +1,80 @@ -use std::marker::PhantomData; - -use cubecl::{ - cpa, - ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, - CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, - OutputInfo, -}; - -use crate::{element::JitElement, kernel::Kernel, tensor::JitTensor, JitRuntime}; - -#[derive(new)] -struct AdaptiveAvgPool2dBackwardEagerKernel { - _runtime: PhantomData, - _elem: PhantomData, -} - -struct AdaptiveAvgPool2dBackwardComputeShader { - grad: Variable, - output: Variable, - _elem: PhantomData, -} - -impl AdaptiveAvgPool2dBackwardComputeShader { - fn expand(self, scope: &mut Scope) { - let grad = self.grad; - let output = self.output; - let id = Variable::AbsolutePos; - - let grad_stride_0 = scope.create_local(Elem::UInt); - let grad_stride_1 = scope.create_local(Elem::UInt); - let grad_stride_2 = scope.create_local(Elem::UInt); - let grad_stride_3 = scope.create_local(Elem::UInt); - - let grad_shape_2 = scope.create_local(Elem::UInt); - let grad_shape_3 = scope.create_local(Elem::UInt); - - let output_stride_0 = scope.create_local(Elem::UInt); - let output_stride_1 = scope.create_local(Elem::UInt); - let output_stride_2 = scope.create_local(Elem::UInt); - let output_stride_3 = scope.create_local(Elem::UInt); - - let output_shape_0 = scope.create_local(Elem::UInt); - let output_shape_1 = scope.create_local(Elem::UInt); - let output_shape_2 = scope.create_local(Elem::UInt); - let output_shape_3 = scope.create_local(Elem::UInt); - - cpa!(scope, grad_stride_0 = stride(grad, 0u32)); - cpa!(scope, grad_stride_1 = stride(grad, 1u32)); - cpa!(scope, grad_stride_2 = stride(grad, 2u32)); - cpa!(scope, grad_stride_3 = stride(grad, 3u32)); - - cpa!(scope, grad_shape_2 = shape(grad, 2u32)); - cpa!(scope, grad_shape_3 = shape(grad, 3u32)); - - cpa!(scope, output_stride_0 = stride(output, 0u32)); - cpa!(scope, output_stride_1 = stride(output, 1u32)); - cpa!(scope, output_stride_2 = stride(output, 2u32)); - cpa!(scope, output_stride_3 = stride(output, 3u32)); - - cpa!(scope, output_shape_0 = shape(output, 0u32)); - cpa!(scope, output_shape_1 = shape(output, 1u32)); - cpa!(scope, output_shape_2 = shape(output, 2u32)); - cpa!(scope, output_shape_3 = shape(output, 3u32)); - - let b = scope.create_local(Elem::UInt); - let c = scope.create_local(Elem::UInt); - let ih = scope.create_local(Elem::UInt); - let iw = scope.create_local(Elem::UInt); - - cpa!(scope, b = id / output_stride_0); - cpa!(scope, b = b % output_shape_0); - - cpa!(scope, c = id / output_stride_1); - cpa!(scope, c = c % output_shape_1); - - cpa!(scope, ih = id / output_stride_2); - cpa!(scope, ih = ih % output_shape_2); - - cpa!(scope, iw = id / output_stride_3); - cpa!(scope, iw = iw % output_shape_3); - - let oh_start = Self::start_index(scope, ih, output_shape_2, grad_shape_2); - let oh_end = Self::end_index(scope, ih, output_shape_2, grad_shape_2); - - let ow_start = Self::start_index(scope, iw, output_shape_3, grad_shape_3); - let ow_end = Self::end_index(scope, iw, output_shape_3, grad_shape_3); - - let grad_acc = scope.create_local(output.item()); - let contributed_h = scope.create_local(Elem::Bool); - let contributed_w = scope.create_local(Elem::Bool); - let contributed_tmp = scope.create_local(Elem::Bool); - - let count = scope.create_local(Elem::UInt); - let count_tmp = scope.create_local(Elem::UInt); - let count_float = scope.create_local(output.item()); - let the_grad = scope.create_local(output.item()); - let avg = scope.create_local(output.item()); - - let index_base = scope.create_local(Elem::UInt); - let index_tmp = scope.create_local(Elem::UInt); - let index = scope.create_local(Elem::UInt); - cpa!(scope, index_base = b * grad_stride_0); - cpa!(scope, index_tmp = c * grad_stride_1); - cpa!(scope, index_base += index_tmp); +use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; +use cubecl::{calculate_cube_count_elemwise, prelude::*}; + +#[cube(launch)] +fn adaptive_avg_pool2d_backward_direct(grad: &Tensor, output: &mut Tensor) { + let (output_stride_0, output_stride_1, output_stride_2, output_stride_3) = ( + output.stride(0), + output.stride(1), + output.stride(2), + output.stride(3), + ); + let (output_shape_0, output_shape_1, output_shape_2, output_shape_3) = ( + output.shape(0), + output.shape(1), + output.shape(2), + output.shape(3), + ); + let (grad_stride_0, grad_stride_1, grad_stride_2, grad_stride_3) = ( + grad.stride(0), + grad.stride(1), + grad.stride(2), + grad.stride(3), + ); + let (grad_shape_2, grad_shape_3) = (grad.shape(2), grad.shape(3)); - cpa!( - scope, - range(oh_start, oh_end).for_each(|oh, scope| { - let ih_start = Self::start_index(scope, oh, grad_shape_2, output_shape_2); - let ih_end = Self::end_index(scope, oh, grad_shape_2, output_shape_2); - cpa!(scope, contributed_h = ih >= ih_start); - cpa!(scope, contributed_tmp = ih < ih_end); - cpa!(scope, contributed_h = contributed_h && contributed_tmp); + let b = (ABSOLUTE_POS / output_stride_0) % output_shape_0; + let c = (ABSOLUTE_POS / output_stride_1) % output_shape_1; + let ih = (ABSOLUTE_POS / output_stride_2) % output_shape_2; + let iw = (ABSOLUTE_POS / output_stride_3) % output_shape_3; - cpa!(scope, if(contributed_h).then(|scope|{ - cpa!( - scope, - range(ow_start, ow_end).for_each(|ow, scope| { - let iw_start = Self::start_index(scope, ow, grad_shape_3, output_shape_3); - let iw_end = Self::end_index(scope, ow, grad_shape_3, output_shape_3); + let oh_start = start_index(ih, output_shape_2, grad_shape_2); + let oh_end = end_index(ih, output_shape_2, grad_shape_2); - cpa!(scope, contributed_w = iw >= iw_start); - cpa!(scope, contributed_tmp = iw < iw_end); - cpa!(scope, contributed_w = contributed_w && contributed_tmp); + let ow_start = start_index(iw, output_shape_3, grad_shape_3); + let ow_end = end_index(iw, output_shape_3, grad_shape_3); + let mut grad_acc = E::from_int(0); - cpa!(scope, if(contributed_w).then(|scope|{ - cpa!(scope, count = ih_end - ih_start); - cpa!(scope, count_tmp = iw_end - iw_start); - cpa!(scope, count *= count_tmp); - cpa!(scope, count_float = cast(count)); + let index_base = b * grad_stride_0 + (c * grad_stride_1); - cpa!(scope, index = index_base); - cpa!(scope, index_tmp = oh * grad_stride_2); - cpa!(scope, index += index_tmp); - cpa!(scope, index_tmp = ow * grad_stride_3); - cpa!(scope, index += index_tmp); + for oh in oh_start..oh_end { + let ih_start = start_index(oh, grad_shape_2, output_shape_2); + let ih_end = end_index(oh, grad_shape_2, output_shape_2); - cpa!(scope, the_grad = grad[index]); - cpa!(scope, avg = the_grad / count_float); - cpa!(scope, grad_acc += avg); - })); - }) - ); - })); - }) - ); + if ih >= ih_start && ih < ih_end { + for ow in ow_start..ow_end { + let iw_start = start_index(ow, grad_shape_3, output_shape_3); + let iw_end = end_index(ow, grad_shape_3, output_shape_3); - cpa!(scope, output[id] = grad_acc); - } + if iw >= iw_start && iw < iw_end { + let num_ih = ih_end - ih_start; + let num_iw = iw_end - iw_start; - fn start_index( - scope: &mut Scope, - output_size_index: Variable, - output_size: Variable, - input_size: Variable, - ) -> Variable { - let elem = E::cube_elem(); - let numerator_float = scope.create_local(elem); - let div = scope.create_local(elem); - let index = scope.create_local(Elem::UInt); - - cpa!(scope, index = output_size_index * input_size); - cpa!(scope, numerator_float = cast(index)); - cpa!(scope, div = cast(output_size)); - cpa!(scope, div = numerator_float / div); - cpa!(scope, div = floor(div)); - cpa!(scope, index = cast(div)); - index + let index = index_base + (oh * grad_stride_2) + (ow * grad_stride_3); + grad_acc += grad[index] / E::cast_from(num_iw * num_ih); + } + } + } } - fn end_index( - scope: &mut Scope, - output_size_index: Variable, - output_size: Variable, - input_size: Variable, - ) -> Variable { - let elem = E::cube_elem(); - let numerator_float = scope.create_local(elem); - let div = scope.create_local(elem); - let index = scope.create_local(Elem::UInt); - let min = scope.create_local(Elem::Bool); - let end_index = scope.create_local(Elem::UInt); - - cpa!(scope, index = output_size_index + 1u32); - cpa!(scope, index *= input_size); - cpa!(scope, numerator_float = cast(index)); - cpa!(scope, div = cast(output_size)); - cpa!(scope, div = numerator_float / div); - cpa!(scope, div = ceil(div)); - cpa!(scope, index = cast(div)); - - cpa!(scope, min = input_size < index); - cpa!(scope, if(min).then(|scope|{ - cpa!(scope, end_index = input_size); - }).else(|scope|{ - cpa!(scope, end_index = index); - })); - end_index - } + output[ABSOLUTE_POS] = grad_acc; } -impl Kernel for AdaptiveAvgPool2dBackwardEagerKernel { - fn define(&self) -> KernelDefinition { - let mut scope = Scope::root(); - let item = E::cube_elem().into(); - - let grad = Variable::GlobalInputArray { id: 0, item }; - let output = Variable::GlobalOutputArray { id: 0, item }; - - scope.write_global_custom(output); - - AdaptiveAvgPool2dBackwardComputeShader { - grad, - output, - _elem: PhantomData::, - } - .expand(&mut scope); - - let grad = InputInfo::Array { - item, - visibility: Visibility::Read, - }; - let scalars = InputInfo::Scalar { - elem: Elem::UInt, - size: 6, - }; - let output = OutputInfo::Array { item }; - - let info = KernelExpansion { - inputs: vec![grad, scalars], - outputs: vec![output], - scope, - }; +#[cube] +fn start_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 { + (output_size_index * input_size) / output_size +} - let settings = KernelSettings::default(); - KernelIntegrator::new(info).integrate(settings) - } +#[cube] +fn end_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 { + let index = (output_size_index + 1) * input_size; + let index = (index + output_size - 1) / output_size; - fn id(&self) -> cubecl::KernelId { - cubecl::KernelId::new::() + if input_size < index { + input_size + } else { + index } } @@ -257,12 +92,16 @@ pub(crate) fn adaptive_avg_pool2d_backward( output_buffer, ); - let kernel = AdaptiveAvgPool2dBackwardEagerKernel::::new(); + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim); - Execution::start(kernel, x.client) - .inputs(&[out_grad.as_handle_ref()]) - .outputs(&[output.as_handle_ref()]) - .execute(CubeCountSettings::Output { pos: 0 }); + adaptive_avg_pool2d_backward_direct::launch::( + &x.client, + cube_count, + cube_dim, + out_grad.as_tensor_arg(1), + output.as_tensor_arg(1), + ); output } diff --git a/crates/burn-jit/src/kernel/pool/adaptive_pool2d_shader.rs b/crates/burn-jit/src/kernel/pool/adaptive_pool2d_shader.rs deleted file mode 100644 index a9c6ae9cbf..0000000000 --- a/crates/burn-jit/src/kernel/pool/adaptive_pool2d_shader.rs +++ /dev/null @@ -1,229 +0,0 @@ -use cubecl::{ - cpa, - ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, - InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, -}; -use std::marker::PhantomData; - -use crate::{kernel::Kernel, JitElement, JitRuntime}; - -pub(crate) struct AdaptivePool2dComputeShader { - input: Variable, - output: Variable, - _elem: PhantomData, - _runtime: PhantomData, -} - -impl AdaptivePool2dComputeShader { - fn expand(self, scope: &mut Scope) { - let input = self.input; - let output = self.output; - let id = Variable::AbsolutePos; - - let input_stride_0 = scope.create_local(Elem::UInt); - let input_stride_1 = scope.create_local(Elem::UInt); - let input_stride_2 = scope.create_local(Elem::UInt); - let input_stride_3 = scope.create_local(Elem::UInt); - - let input_shape_0 = scope.create_local(Elem::UInt); - let input_shape_1 = scope.create_local(Elem::UInt); - let input_shape_2 = scope.create_local(Elem::UInt); - let input_shape_3 = scope.create_local(Elem::UInt); - - let output_stride_0 = scope.create_local(Elem::UInt); - let output_stride_1 = scope.create_local(Elem::UInt); - let output_stride_2 = scope.create_local(Elem::UInt); - let output_stride_3 = scope.create_local(Elem::UInt); - - let output_shape_0 = scope.create_local(Elem::UInt); - let output_shape_1 = scope.create_local(Elem::UInt); - let output_shape_2 = scope.create_local(Elem::UInt); - let output_shape_3 = scope.create_local(Elem::UInt); - - cpa!(scope, input_stride_0 = stride(input, 0u32)); - cpa!(scope, input_stride_1 = stride(input, 1u32)); - cpa!(scope, input_stride_2 = stride(input, 2u32)); - cpa!(scope, input_stride_3 = stride(input, 3u32)); - - cpa!(scope, input_shape_0 = shape(input, 2u32)); - cpa!(scope, input_shape_1 = shape(input, 3u32)); - cpa!(scope, input_shape_2 = shape(input, 2u32)); - cpa!(scope, input_shape_3 = shape(input, 3u32)); - - cpa!(scope, output_stride_0 = stride(output, 0u32)); - cpa!(scope, output_stride_1 = stride(output, 1u32)); - cpa!(scope, output_stride_2 = stride(output, 2u32)); - cpa!(scope, output_stride_3 = stride(output, 3u32)); - - cpa!(scope, output_shape_0 = shape(output, 0u32)); - cpa!(scope, output_shape_1 = shape(output, 1u32)); - cpa!(scope, output_shape_2 = shape(output, 2u32)); - cpa!(scope, output_shape_3 = shape(output, 3u32)); - - let b = scope.create_local(Elem::UInt); - let c = scope.create_local(Elem::UInt); - let oh = scope.create_local(Elem::UInt); - let ow = scope.create_local(Elem::UInt); - - cpa!(scope, b = id / output_stride_0); - cpa!(scope, b = b % output_shape_0); - - cpa!(scope, c = id / output_stride_1); - cpa!(scope, c = c % output_shape_1); - - cpa!(scope, oh = id / output_stride_2); - cpa!(scope, oh = oh % output_shape_2); - - cpa!(scope, ow = id / output_stride_3); - cpa!(scope, ow = ow % output_shape_3); - - let ih_start = Self::start_index(scope, oh, output_shape_2, input_shape_2); - let ih_end = Self::end_index(scope, oh, output_shape_2, input_shape_2); - let iw_start = Self::start_index(scope, ow, output_shape_3, input_shape_3); - let iw_end = Self::end_index(scope, ow, output_shape_3, input_shape_3); - - let result = scope.create_local(input.item()); - - let index_input = scope.create_local(Elem::UInt); - let index_input_0 = scope.create_local(Elem::UInt); - let index_input_1 = scope.create_local(Elem::UInt); - let index_input_2 = scope.create_local(Elem::UInt); - let index_input_3 = scope.create_local(Elem::UInt); - - cpa!(scope, index_input_0 = b * input_stride_0); - cpa!(scope, index_input_1 = c * input_stride_1); - - let sum = scope.zero(output.item()); - - cpa!( - scope, - range(ih_start, ih_end).for_each(|ih, scope| { - cpa!( - scope, - range(iw_start, iw_end).for_each(|iw, scope| { - cpa!(scope, index_input_2 = ih * input_stride_2); - cpa!(scope, index_input_3 = iw * input_stride_3); - - cpa!(scope, index_input = index_input_0); - cpa!(scope, index_input += index_input_1); - cpa!(scope, index_input += index_input_2); - cpa!(scope, index_input += index_input_3); - - cpa!(scope, result = input[index_input]); - - cpa!(scope, sum += result); - }) - ); - }) - ); - - let count = scope.create_local(Elem::UInt); - let count_tmp = scope.create_local(Elem::UInt); - let count_float = scope.create_local(output.item()); - let avg = scope.create_local(output.item()); - - cpa!(scope, count = ih_end - ih_start); - cpa!(scope, count_tmp = iw_end - iw_start); - cpa!(scope, count *= count_tmp); - - cpa!(scope, count_float = cast(count)); - cpa!(scope, avg = sum / count_float); - cpa!(scope, output[id] = avg); - } - - fn start_index( - scope: &mut Scope, - output_size_index: Variable, - output_size: Variable, - input_size: Variable, - ) -> Variable { - let elem = E::cube_elem(); - let numerator_float = scope.create_local(elem); - let div = scope.create_local(elem); - let index = scope.create_local(Elem::UInt); - - cpa!(scope, index = output_size_index * input_size); - cpa!(scope, numerator_float = cast(index)); - cpa!(scope, div = cast(output_size)); - cpa!(scope, div = numerator_float / div); - cpa!(scope, div = floor(div)); - cpa!(scope, index = cast(div)); - index - } - - fn end_index( - scope: &mut Scope, - output_size_index: Variable, - output_size: Variable, - input_size: Variable, - ) -> Variable { - let elem = E::cube_elem(); - let numerator_float = scope.create_local(elem); - let div = scope.create_local(elem); - let index = scope.create_local(Elem::UInt); - let min = scope.create_local(Elem::Bool); - let end_index = scope.create_local(Elem::UInt); - - cpa!(scope, index = output_size_index + 1u32); - cpa!(scope, index *= input_size); - cpa!(scope, numerator_float = cast(index)); - cpa!(scope, div = cast(output_size)); - cpa!(scope, div = numerator_float / div); - cpa!(scope, div = ceil(div)); - cpa!(scope, index = cast(div)); - - cpa!(scope, min = input_size < index); - cpa!(scope, if(min).then(|scope|{ - cpa!(scope, end_index = input_size); - }).else(|scope|{ - cpa!(scope, end_index = index); - })); - end_index - } -} - -#[derive(new)] -pub(crate) struct AdaptivePool2dEagerKernel { - _runtime: PhantomData, - _elem: PhantomData, -} - -impl Kernel for AdaptivePool2dEagerKernel { - fn define(&self) -> KernelDefinition { - let mut scope = Scope::root(); - let item = E::cube_elem().into(); - - let input = Variable::GlobalInputArray { id: 0, item }; - let output = Variable::GlobalOutputArray { id: 0, item }; - - scope.write_global_custom(output); - - AdaptivePool2dComputeShader { - input, - output, - _elem: PhantomData::, - _runtime: PhantomData::, - } - .expand(&mut scope); - - let input = InputInfo::Array { - item, - visibility: Visibility::Read, - }; - - let output = OutputInfo::Array { item }; - - let info = KernelExpansion { - inputs: vec![input], - outputs: vec![output], - scope, - }; - - let settings = KernelSettings::default(); - KernelIntegrator::new(info).integrate(settings) - } - - fn id(&self) -> cubecl::KernelId { - cubecl::KernelId::new::() - } -} diff --git a/crates/burn-jit/src/kernel/pool/mod.rs b/crates/burn-jit/src/kernel/pool/mod.rs index 928b677f2a..31f5d4b579 100644 --- a/crates/burn-jit/src/kernel/pool/mod.rs +++ b/crates/burn-jit/src/kernel/pool/mod.rs @@ -1,6 +1,5 @@ mod adaptive_avg_pool2d; mod adaptive_avg_pool2d_backward; -mod adaptive_pool2d_shader; mod avg_pool2d; mod avg_pool2d_backward; mod base; @@ -10,7 +9,6 @@ mod pool2d_shader; pub(crate) use adaptive_avg_pool2d::*; pub(crate) use adaptive_avg_pool2d_backward::*; -pub(crate) use adaptive_pool2d_shader::*; pub(crate) use avg_pool2d::*; pub(crate) use avg_pool2d_backward::*; pub(super) use base::*;