Skip to content

Commit

Permalink
[Fix] Make MMDistributedDataParallel compatible with Pytorch1.12 (#…
Browse files Browse the repository at this point in the history
…2107)

* make  compatible with pytorch 1.12

* override _run_ddp_forward

* over write _run_ddp_forward

* refactor docstring
  • Loading branch information
HAOCHENYE authored Jul 8, 2022
1 parent 6a03918 commit da2df84
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion mmcv/parallel/distributed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
from typing import Any, List, Tuple

import torch
from torch.nn.parallel.distributed import (DistributedDataParallel,
Expand Down Expand Up @@ -140,3 +140,28 @@ def val_step(self, *inputs, **kwargs):
and digit_version(TORCH_VERSION) > digit_version('1.2')):
self.require_forward_param_sync = False
return output

def _run_ddp_forward(self, *inputs, **kwargs) -> Any:
"""Processes inputs and runs ``self.module.forward``.
Pytorch 1.12.0 performs ``self.module.forward`` in ``_run_ddp_forward``
and deprecates using ``DistributedDataParallel.to_kwargs`` to
process inputs, which leads to inputs cannot be processed by
:meth:`MMDistributedDataParallel.to_kwargs` anymore. Therefore,
``MMDistributedDataParallel`` overrides this method to call
:meth:`to_kwargs` explicitly.
See more information in `<https://github.com/open-mmlab/mmsegmentation/issues/1742>`_. # noqa: E501
Returns:
Any: Forward result of :attr:`module`.
"""
module_to_run = self._replicated_tensor_module if \
self._use_replicated_tensor_module else self.module

if self.device_ids:
inputs, kwargs = self.to_kwargs( # type: ignore
inputs, kwargs, self.device_ids[0])
return module_to_run(*inputs[0], **kwargs[0]) # type: ignore
else:
return module_to_run(*inputs, **kwargs)

0 comments on commit da2df84

Please sign in to comment.