Skip to content

Commit

Permalink
fix mdconv addmm bug for parrots (#450)
Browse files Browse the repository at this point in the history
* fix mdconv addmm bug for parrots

* fix mdconv ctv save tensor

Co-authored-by: hanyachao <[email protected]>
  • Loading branch information
hanyc0914 and hanyachao authored Aug 19, 2020
1 parent 6159dac commit 5e3f56f
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 11 deletions.
28 changes: 20 additions & 8 deletions mmcv/ops/csrc/parrots/modulated_deform_conv_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,26 @@ void ModulatedDeformConvBackwardCUDAKernelLauncher(
weight.dim(2), weight.dim(3)});
for (size_t g = 0; g < group; g++) {
auto columns_g = columns[g];
gemm(ctx, 1, true,
weight[g].view(
{weight.dim(1), weight.dim(2) * weight.dim(3) * weight.dim(4)}),
false,
grad_output[b][g].view(
{grad_output.dim(2), grad_output.dim(3) * grad_output.dim(4)}),
0, columns_g);
auto columns_g = ctx.createDArrayLite(
weight.elemType(), DArrayShape(columns.dim(1), columns.dim(2)));
copy(ctx, columns_g, columns[g]);
auto weight_g = weight[g].view(
{weight.dim(1), weight.dim(2) * weight.dim(3) * weight.dim(4)});
weight_g = transpose(ctx, weight_g, 0, 1);
auto grad_output_bg = ctx.createDArrayLite(
grad_output.elemType(),
DArrayShape(grad_output.dim(2), grad_output.dim(3),
grad_output.dim(4)));
copy(ctx, grad_output_bg, grad_output[b][g]);
grad_output_bg =
grad_output_bg.view({grad_output_bg.dim(0),
grad_output_bg.dim(1) * grad_output_bg.dim(2)});
columns_g =
parrots::op::addmm(ctx, columns[g], weight_g, grad_output_bg, 0, 1);
auto columns_out = columns[g];
copy(ctx, columns_out, columns_g);
}
columns = columns.view({columns.dim(0) * columns.dim(1), columns.dim(2)});
Expand Down
1 change: 1 addition & 0 deletions mmcv/ops/csrc/parrots_cuda_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <float.h>

#include <parrots/darray/darraymath.hpp>
#include <parrots/darray/mathfunctions.hpp>
#include <parrots/extension.hpp>
#include <parrots/foundation/darrayutil.hpp>
#include <parrots/foundation/exceptions.hpp>
Expand Down
4 changes: 1 addition & 3 deletions mmcv/ops/modulated_deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@ def forward(ctx,
ctx.with_bias = bias is not None
if not ctx.with_bias:
bias = input.new_empty(0) # fake tensor
if weight.requires_grad or mask.requires_grad or offset.requires_grad \
or input.requires_grad:
ctx.save_for_backward(input, offset, mask, weight, bias)
ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new_empty(
ModulatedDeformConv2dFunction._output_size(ctx, input, weight))
ctx._bufs = [input.new_empty(0), input.new_empty(0)]
Expand Down

0 comments on commit 5e3f56f

Please sign in to comment.