From 088fde35419522730248c540bd25942118acde05 Mon Sep 17 00:00:00 2001 From: Shilong Zhang <61961338+jshilong@users.noreply.github.com> Date: Wed, 16 Jun 2021 13:48:12 +0800 Subject: [PATCH] Avoid bc-breaking of importing `MultiScaleDeformableAttention` (#1100) * avoid bc-breaking * fix function name * fix typo * fix import * add import warning * remove wapper * remove unitest * add dep warning * add dep warning --- mmcv/cnn/bricks/transformer.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/mmcv/cnn/bricks/transformer.py b/mmcv/cnn/bricks/transformer.py index 34ec444c02..06715cde60 100644 --- a/mmcv/cnn/bricks/transformer.py +++ b/mmcv/cnn/bricks/transformer.py @@ -12,6 +12,22 @@ from .registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING, TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE) +# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file +try: + from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention # noqa F401 + warnings.warn( + ImportWarning( + '``MultiScaleDeformableAttention`` has been moved to ' + '``mmcv.ops.multi_scale_deform_attn``, please change original path ' # noqa E501 + '``from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention`` ' # noqa E501 + 'to ``from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention`` ' # noqa E501 + )) + +except ImportError: + warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from ' + '``mmcv.ops.multi_scale_deform_attn``, ' + 'You should install ``mmcv-full`` if you need this module. ') + def build_positional_encoding(cfg, default_args=None): """Builder for Position Encoding.""" @@ -56,9 +72,9 @@ class MultiheadAttention(BaseModule): when adding the shortcut. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. - batch_first (bool): Key, Query and Value are shape of - (batch, n, embed_dim) - or (n, batch, embed_dim). Default to False. + batch_first (bool): When it is True, Key, Query and Value are shape of + (batch, n, embed_dim), otherwise (n, batch, embed_dim). + Default to False. """ def __init__(self, @@ -88,10 +104,12 @@ def __init__(self, if self.batch_first: def _bnc_to_nbc(forward): - """This function can adjust the shape of dataflow('key', - 'query', 'value') from batch_first (batch, num_query, - embed_dims) to num_query_first (num_query ,batch, - embed_dims).""" + """Because the dataflow('key', 'query', 'value') of + ``torch.nn.MultiheadAttention`` is (num_query, batch, + embed_dims), We should adjust the shape of dataflow from + batch_first (batch, num_query, embed_dims) to num_query_first + (num_query ,batch, embed_dims), and recover ``attn_output`` + from num_query_first to batch_first.""" def forward_wrapper(**kwargs): convert_keys = ('key', 'query', 'value')