Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
294 changes: 294 additions & 0 deletions crates/burn-autodiff/src/ops/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1684,6 +1684,258 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
panic!("Can't differentiate adaptive avg pool2d backward.");
}

fn avg_pool3d(
x: AutodiffTensor<B>,
kernel_size: [usize; 3],
stride: [usize; 3],
padding: [usize; 3],
count_include_pad: bool,
ceil_mode: bool,
) -> AutodiffTensor<B> {
#[derive(Debug)]
struct AvgPool3D;

impl<B: Backend> Backward<B, 1> for AvgPool3D {
type State = (NodeId, [usize; 3], [usize; 3], [usize; 3], bool, bool);

fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
checkpointer: &mut Checkpointer,
) {
let [node_parent] = ops.parents;
let grad = grads.consume::<B>(&ops.node);
let (x_state, kernel_size, stride, padding, count_include_pad, ceil_mode) =
ops.state;
let x = checkpointer.retrieve_node_output(x_state);

if let Some(node) = node_parent {
let grad = B::avg_pool3d_backward(
x,
grad,
kernel_size,
stride,
padding,
count_include_pad,
ceil_mode,
);
grads.register::<B>(node.id, grad);
}
}
}

match AvgPool3D
.prepare::<C>([x.node.clone()])
.compute_bound()
.stateful()
{
OpsKind::Tracked(mut prep) => {
let x_state = prep.checkpoint(&x);
prep.finish(
(
x_state,
kernel_size,
stride,
padding,
count_include_pad,
ceil_mode,
),
B::avg_pool3d(
x.primitive.clone(),
kernel_size,
stride,
padding,
count_include_pad,
ceil_mode,
),
)
}
OpsKind::UnTracked(prep) => prep.finish(B::avg_pool3d(
x.primitive,
kernel_size,
stride,
padding,
count_include_pad,
ceil_mode,
)),
}
}

fn avg_pool3d_backward(
_x: AutodiffTensor<B>,
_grad: AutodiffTensor<B>,
_kernel_size: [usize; 3],
_stride: [usize; 3],
_padding: [usize; 3],
_count_include_pad: bool,
_ceil_mode: bool,
) -> AutodiffTensor<B> {
panic!("Can't differentiate avg pool 3d backward.");
}

fn max_pool3d(
x: AutodiffTensor<B>,
kernel_size: [usize; 3],
stride: [usize; 3],
padding: [usize; 3],
dilation: [usize; 3],
ceil_mode: bool,
) -> AutodiffTensor<B> {
match MaxPool3D
.prepare::<C>([x.node.clone()])
.compute_bound()
.stateful()
{
OpsKind::Tracked(mut prep) => {
let x_state = prep.checkpoint(&x);
let output = B::max_pool3d_with_indices(
x.primitive,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
);
prep.finish(
(
x_state,
output.indices,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
),
output.output,
)
}
OpsKind::UnTracked(prep) => prep.finish(B::max_pool3d(
x.primitive,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
)),
}
}

fn max_pool3d_with_indices(
x: AutodiffTensor<B>,
kernel_size: [usize; 3],
stride: [usize; 3],
padding: [usize; 3],
dilation: [usize; 3],
ceil_mode: bool,
) -> MaxPool3dWithIndices<Self> {
match MaxPool3D
.prepare::<C>([x.node.clone()])
.compute_bound()
.stateful()
{
OpsKind::Tracked(mut prep) => {
let x_state = prep.checkpoint(&x);

let output = B::max_pool3d_with_indices(
x.primitive,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
);

let output_tensor = prep.finish(
(
x_state,
output.indices.clone(),
kernel_size,
stride,
padding,
dilation,
ceil_mode,
),
output.output,
);

MaxPool3dWithIndices::new(output_tensor, output.indices)
}
OpsKind::UnTracked(prep) => {
let output = B::max_pool3d_with_indices(
x.primitive,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
);
let output_tensor = prep.finish(output.output);

MaxPool3dWithIndices::new(output_tensor, output.indices)
}
}
}

fn max_pool3d_with_indices_backward(
_x: AutodiffTensor<B>,
_kernel_size: [usize; 3],
_stride: [usize; 3],
_padding: [usize; 3],
_dilation: [usize; 3],
_ceil_mode: bool,
_output_grad: AutodiffTensor<B>,
_indices: IntTensor<B>,
) -> MaxPool3dBackward<Self> {
panic!("Can't differentiate max pool3d with indices backward.");
}

fn adaptive_avg_pool3d(x: AutodiffTensor<B>, output_size: [usize; 3]) -> AutodiffTensor<B> {
#[derive(Debug)]
struct AdaptiveAvgPool3D;

impl<B: Backend> Backward<B, 1> for AdaptiveAvgPool3D {
type State = NodeId;

fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
checkpointer: &mut Checkpointer,
) {
let [node_parent] = ops.parents;
let grad = grads.consume::<B>(&ops.node);
let state = checkpointer.retrieve_node_output(ops.state);

if let Some(node) = node_parent {
let grad = B::adaptive_avg_pool3d_backward(state, grad);
grads.register::<B>(node.id, grad);
}
}
}

match AdaptiveAvgPool3D
.prepare::<C>([x.node.clone()])
.compute_bound()
.stateful()
{
OpsKind::Tracked(mut prep) => {
let x_state = prep.checkpoint(&x);
prep.finish(x_state, B::adaptive_avg_pool3d(x.primitive, output_size))
}
OpsKind::UnTracked(prep) => {
prep.finish(B::adaptive_avg_pool3d(x.primitive, output_size))
}
}
}

fn adaptive_avg_pool3d_backward(
_x: AutodiffTensor<B>,
_grad: AutodiffTensor<B>,
) -> <Autodiff<B> as Backend>::FloatTensorPrimitive {
panic!("Can't differentiate adaptive avg pool3d backward.");
}

fn interpolate(
x: AutodiffTensor<B>,
output_size: [usize; 2],
Expand Down Expand Up @@ -1814,3 +2066,45 @@ impl<B: Backend> Backward<B, 1> for MaxPool2D {
}
}
}

#[derive(Debug)]
struct MaxPool3D;

impl<B: Backend> Backward<B, 1> for MaxPool3D {
type State = (
NodeId,
IntTensor<B>,
[usize; 3],
[usize; 3],
[usize; 3],
[usize; 3],
bool,
);

fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
checkpointer: &mut Checkpointer,
) {
let [node_parent] = ops.parents;
let grad = grads.consume::<B>(&ops.node);
let (x_state, indices, kernel_size, stride, padding, dilation, ceil_mode) = ops.state;
let x = checkpointer.retrieve_node_output(x_state);

if let Some(node) = node_parent {
let grad = B::max_pool3d_with_indices_backward(
x,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
grad,
indices,
);

grads.register::<B>(node.id, grad.x_grad);
}
}
}
Loading
Loading