Skip to content

Commit

Permalink
support ONNX adaptive average pooling (#504)
Browse files Browse the repository at this point in the history
* support ONNX adaptive average pooling

* fix double quotes

Co-authored-by: Kai Chen <[email protected]>
  • Loading branch information
drcut and hellock authored Aug 19, 2020
1 parent 5e3f56f commit 11d8dd5
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
47 changes: 47 additions & 0 deletions mmcv/onnx/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,50 @@ def softmax(g, input, dim, dtype=None):
return softmax


def _adaptive_pool(name, type, tuple_fn, fn=None):

@parse_args('v', 'is')
def symbolic_fn(g, input, output_size):
if output_size == [1] * len(output_size) and type == 'AveragePool':
return g.op('GlobalAveragePool', input)
if not input.isCompleteTensor():
if output_size == [1] * len(output_size):
return g.op('GlobalMaxPool', input), None
raise NotImplementedError(
'[Adaptive pool]:input size not accessible')
dim = input.type().sizes()[2:]
if output_size == [1] * len(output_size) and type == 'MaxPool':
return g.op('GlobalMaxPool', input), None

# compute stride = floor(input_size / output_size)
s = [int(dim[i] / output_size[i]) for i in range(0, len(dim))]

# compute kernel_size = input_size - (output_size - 1) * stride
k = [dim[i] - (output_size[i] - 1) * s[i] for i in range(0, len(dim))]

# call max_poolxd_with_indices to get indices in the output
if type == 'MaxPool':
return fn(g, input, k, k, (0, ) * len(dim), (1, ) * len(dim),
False)
output = g.op(
type,
input,
kernel_shape_i=tuple_fn(k),
strides_i=tuple_fn(s),
ceil_mode_i=False)
return output

return symbolic_fn


adaptive_avg_pool1d = _adaptive_pool('adaptive_avg_pool1d', 'AveragePool',
_single)
adaptive_avg_pool2d = _adaptive_pool('adaptive_avg_pool2d', 'AveragePool',
_pair)
adaptive_avg_pool3d = _adaptive_pool('adaptive_avg_pool3d', 'AveragePool',
_triple)


def register_extra_symbolics(opset=11):
register_op('one_hot', one_hot, '', opset)
register_op('im2col', im2col, '', opset)
Expand All @@ -317,6 +361,9 @@ def register_extra_symbolics(opset=11):
register_op('avg_pool1d', avg_pool1d, '', opset)
register_op('avg_pool2d', avg_pool2d, '', opset)
register_op('avg_pool3d', avg_pool3d, '', opset)
register_op('adaptive_avg_pool1d', adaptive_avg_pool1d, '', opset)
register_op('adaptive_avg_pool2d', adaptive_avg_pool2d, '', opset)
register_op('adaptive_avg_pool3d', adaptive_avg_pool3d, '', opset)
register_op('masked_select', masked_select, '', opset)
register_op('upsample_nearest1d', upsample_nearest1d, '', opset)
register_op('upsample_nearest2d', upsample_nearest2d, '', opset)
Expand Down
2 changes: 1 addition & 1 deletion mmcv/ops/csrc/parrots/tin_shift.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ PARROTS_EXTENSION_REGISTER(tin_shift_backward)
.input(2)
.output(1)
.apply(tin_shift_backward_cuda)
.done();
.done();

0 comments on commit 11d8dd5

Please sign in to comment.