Skip to content

Commit

Permalink
Avoid bc-breaking of importing MultiScaleDeformableAttention (#1100)
Browse files Browse the repository at this point in the history
* avoid bc-breaking

* fix function name

* fix typo

* fix import

* add import warning

* remove wapper

* remove unitest

* add dep warning

* add dep warning
  • Loading branch information
jshilong authored Jun 16, 2021
1 parent a5d4c65 commit 088fde3
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions mmcv/cnn/bricks/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,22 @@
from .registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING,
TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE)

# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
try:
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention # noqa F401
warnings.warn(
ImportWarning(
'``MultiScaleDeformableAttention`` has been moved to '
'``mmcv.ops.multi_scale_deform_attn``, please change original path ' # noqa E501
'``from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention`` ' # noqa E501
'to ``from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention`` ' # noqa E501
))

except ImportError:
warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '
'``mmcv.ops.multi_scale_deform_attn``, '
'You should install ``mmcv-full`` if you need this module. ')


def build_positional_encoding(cfg, default_args=None):
"""Builder for Position Encoding."""
Expand Down Expand Up @@ -56,9 +72,9 @@ class MultiheadAttention(BaseModule):
when adding the shortcut.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
batch_first (bool): Key, Query and Value are shape of
(batch, n, embed_dim)
or (n, batch, embed_dim). Default to False.
batch_first (bool): When it is True, Key, Query and Value are shape of
(batch, n, embed_dim), otherwise (n, batch, embed_dim).
Default to False.
"""

def __init__(self,
Expand Down Expand Up @@ -88,10 +104,12 @@ def __init__(self,
if self.batch_first:

def _bnc_to_nbc(forward):
"""This function can adjust the shape of dataflow('key',
'query', 'value') from batch_first (batch, num_query,
embed_dims) to num_query_first (num_query ,batch,
embed_dims)."""
"""Because the dataflow('key', 'query', 'value') of
``torch.nn.MultiheadAttention`` is (num_query, batch,
embed_dims), We should adjust the shape of dataflow from
batch_first (batch, num_query, embed_dims) to num_query_first
(num_query ,batch, embed_dims), and recover ``attn_output``
from num_query_first to batch_first."""

def forward_wrapper(**kwargs):
convert_keys = ('key', 'query', 'value')
Expand Down

0 comments on commit 088fde3

Please sign in to comment.