diff --git a/mmcv/onnx/symbolic.py b/mmcv/onnx/symbolic.py index 5dee10c6e0..4a301b7274 100644 --- a/mmcv/onnx/symbolic.py +++ b/mmcv/onnx/symbolic.py @@ -305,6 +305,50 @@ def softmax(g, input, dim, dtype=None): return softmax +def _adaptive_pool(name, type, tuple_fn, fn=None): + + @parse_args('v', 'is') + def symbolic_fn(g, input, output_size): + if output_size == [1] * len(output_size) and type == 'AveragePool': + return g.op('GlobalAveragePool', input) + if not input.isCompleteTensor(): + if output_size == [1] * len(output_size): + return g.op('GlobalMaxPool', input), None + raise NotImplementedError( + '[Adaptive pool]:input size not accessible') + dim = input.type().sizes()[2:] + if output_size == [1] * len(output_size) and type == 'MaxPool': + return g.op('GlobalMaxPool', input), None + + # compute stride = floor(input_size / output_size) + s = [int(dim[i] / output_size[i]) for i in range(0, len(dim))] + + # compute kernel_size = input_size - (output_size - 1) * stride + k = [dim[i] - (output_size[i] - 1) * s[i] for i in range(0, len(dim))] + + # call max_poolxd_with_indices to get indices in the output + if type == 'MaxPool': + return fn(g, input, k, k, (0, ) * len(dim), (1, ) * len(dim), + False) + output = g.op( + type, + input, + kernel_shape_i=tuple_fn(k), + strides_i=tuple_fn(s), + ceil_mode_i=False) + return output + + return symbolic_fn + + +adaptive_avg_pool1d = _adaptive_pool('adaptive_avg_pool1d', 'AveragePool', + _single) +adaptive_avg_pool2d = _adaptive_pool('adaptive_avg_pool2d', 'AveragePool', + _pair) +adaptive_avg_pool3d = _adaptive_pool('adaptive_avg_pool3d', 'AveragePool', + _triple) + + def register_extra_symbolics(opset=11): register_op('one_hot', one_hot, '', opset) register_op('im2col', im2col, '', opset) @@ -317,6 +361,9 @@ def register_extra_symbolics(opset=11): register_op('avg_pool1d', avg_pool1d, '', opset) register_op('avg_pool2d', avg_pool2d, '', opset) register_op('avg_pool3d', avg_pool3d, '', opset) + register_op('adaptive_avg_pool1d', adaptive_avg_pool1d, '', opset) + register_op('adaptive_avg_pool2d', adaptive_avg_pool2d, '', opset) + register_op('adaptive_avg_pool3d', adaptive_avg_pool3d, '', opset) register_op('masked_select', masked_select, '', opset) register_op('upsample_nearest1d', upsample_nearest1d, '', opset) register_op('upsample_nearest2d', upsample_nearest2d, '', opset) diff --git a/mmcv/ops/csrc/parrots/tin_shift.cpp b/mmcv/ops/csrc/parrots/tin_shift.cpp index a31444bfdd..17b48af41c 100644 --- a/mmcv/ops/csrc/parrots/tin_shift.cpp +++ b/mmcv/ops/csrc/parrots/tin_shift.cpp @@ -39,4 +39,4 @@ PARROTS_EXTENSION_REGISTER(tin_shift_backward) .input(2) .output(1) .apply(tin_shift_backward_cuda) - .done(); \ No newline at end of file + .done();