diff --git a/python/mxnet/gluon/nn/conv_layers.py b/python/mxnet/gluon/nn/conv_layers.py index 65e22d82ede..2cd56c9c3ed 100644 --- a/python/mxnet/gluon/nn/conv_layers.py +++ b/python/mxnet/gluon/nn/conv_layers.py @@ -1647,8 +1647,8 @@ def forward(self, x): offset = npx.convolution(x, self.offset_weight.data(ctx), self.offset_bias.data(ctx), cudnn_off=True, **self._kwargs_offset) - offset_t = npx.slice_axis(offset, axis=1, begin=0, end=self.offset_split_index) - mask = npx.slice_axis(offset, axis=1, begin=self.offset_split_index, end=None) + offset_t = offset[:,0:self.offset_split_index,:, :] + mask = offset[:,self.offset_split_index:,:, :] mask = npx.sigmoid(mask) * 2 if self.deformable_conv_bias is None: