Skip to content

Commit

Permalink
fix deform conv (#3212)
Browse files Browse the repository at this point in the history
  • Loading branch information
hust17yixuan authored Dec 9, 2024
1 parent 174b647 commit 5e2b9a7
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
3 changes: 2 additions & 1 deletion mmcv/ops/deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ def symbolic(g,
def _npu_backward(ctx, grad_output):
input_tensor, weight, offset_out, offset_all, sort_index_for_npu_bp = \
ctx.saved_tensors
import torch_npu
grad_input, grad_weight, grad_offset_all, grad_bias = \
torch.npu_deformable_conv2dbk(
torch_npu.npu_deformable_conv2dbk(
input_tensor, grad_output, offset_out, weight, offset_all,
kernel_size=[weight.shape[3], weight.shape[2]],
stride=[1, 1, ctx.stride[0], ctx.stride[1]],
Expand Down
6 changes: 4 additions & 2 deletions mmcv/ops/modulated_deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def _npu_forward(ctx, input_tensor, offset, mask, weight, bias):
kernel_w, kernel_h, ctx.deform_groups)
select_offset = offset.index_select(1, sort_index_fp)
offset_all = torch.cat([select_offset, mask], dim=1)
output, offset_out = torch.npu_deformable_conv2d(
import torch_npu
output, offset_out = torch_npu.npu_deformable_conv2d(
input_tensor,
weight,
offset_all,
Expand All @@ -80,8 +81,9 @@ def _npu_forward(ctx, input_tensor, offset, mask, weight, bias):
def _npu_backward(ctx, grad_output):
input_tensor, weight, offset_out, offset_all, sort_index_bp = \
ctx.saved_tensors
import torch_npu
grad_input, grad_weight, grad_offset_all, grad_bias = \
torch.npu_deformable_conv2dbk(
torch_npu.npu_deformable_conv2dbk(
input_tensor, grad_output, offset_out, weight, offset_all,
kernel_size=[weight.shape[3], weight.shape[2]],
stride=[1, 1, ctx.stride[0], ctx.stride[1]],
Expand Down

0 comments on commit 5e2b9a7

Please sign in to comment.