diff --git a/mmcv/parallel/distributed.py b/mmcv/parallel/distributed.py index e339ab8f57..bf34cb5906 100644 --- a/mmcv/parallel/distributed.py +++ b/mmcv/parallel/distributed.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List, Tuple +from typing import Any, List, Tuple import torch from torch.nn.parallel.distributed import (DistributedDataParallel, @@ -140,3 +140,28 @@ def val_step(self, *inputs, **kwargs): and digit_version(TORCH_VERSION) > digit_version('1.2')): self.require_forward_param_sync = False return output + + def _run_ddp_forward(self, *inputs, **kwargs) -> Any: + """Processes inputs and runs ``self.module.forward``. + + Pytorch 1.12.0 performs ``self.module.forward`` in ``_run_ddp_forward`` + and deprecates using ``DistributedDataParallel.to_kwargs`` to + process inputs, which leads to inputs cannot be processed by + :meth:`MMDistributedDataParallel.to_kwargs` anymore. Therefore, + ``MMDistributedDataParallel`` overrides this method to call + :meth:`to_kwargs` explicitly. + + See more information in ``_. # noqa: E501 + + Returns: + Any: Forward result of :attr:`module`. + """ + module_to_run = self._replicated_tensor_module if \ + self._use_replicated_tensor_module else self.module + + if self.device_ids: + inputs, kwargs = self.to_kwargs( # type: ignore + inputs, kwargs, self.device_ids[0]) + return module_to_run(*inputs[0], **kwargs[0]) # type: ignore + else: + return module_to_run(*inputs, **kwargs)