From 6a03918f55c1cc860bcb4d64a0e868ca13cf08ba Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Thu, 7 Jul 2022 16:05:49 +0800 Subject: [PATCH] [Feature] Add support for mps (#2092) * [Feature] Add support for MPS * fix import error * update ut * fix error * trigger CI * use a unique basename for test file modules * avoid bc-breaking --- .github/workflows/build.yml | 2 +- mmcv/device/__init__.py | 6 +- mmcv/device/_functions.py | 30 +++++++ mmcv/device/mlu/__init__.py | 6 +- mmcv/device/mps/__init__.py | 4 + mmcv/device/mps/data_parallel.py | 34 +++++++ mmcv/device/scatter_gather.py | 64 +++++++++++++ mmcv/device/utils.py | 18 ++++ mmcv/parallel/data_parallel.py | 2 +- mmcv/utils/__init__.py | 5 +- mmcv/utils/device_type.py | 16 ++++ tests/test_device/test_device_utils.py | 15 ++++ tests/test_device/test_functions.py | 90 +++++++++++++++++++ .../test_device/test_mlu/test_mlu_parallel.py | 61 ------------- .../test_device/test_mps/test_mps_parallel.py | 34 +++++++ 15 files changed, 315 insertions(+), 72 deletions(-) create mode 100644 mmcv/device/_functions.py create mode 100644 mmcv/device/mps/__init__.py create mode 100644 mmcv/device/mps/data_parallel.py create mode 100644 mmcv/device/scatter_gather.py create mode 100644 mmcv/device/utils.py create mode 100644 tests/test_device/test_device_utils.py create mode 100644 tests/test_device/test_functions.py create mode 100644 tests/test_device/test_mps/test_mps_parallel.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 4a393a2049..e2ec9d8796 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -61,7 +61,7 @@ jobs: --ignore=tests/test_utils/test_parrots_jit.py \ --ignore=tests/test_utils/test_trace.py \ --ignore=tests/test_utils/test_hub.py \ - --ignore=tests/test_device/test_mlu/test_mlu_parallel.py \ + --ignore=tests/test_device \ --ignore=tests/test_utils/test_torch_ops.py build_without_ops: diff --git a/mmcv/device/__init__.py b/mmcv/device/__init__.py index 6ac55e63b9..ba217b0771 100644 --- a/mmcv/device/__init__.py +++ b/mmcv/device/__init__.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from . import ipu, mlu +from . import ipu, mlu, mps +from .scatter_gather import scatter, scatter_kwargs +from .utils import get_device -__all__ = ['mlu', 'ipu'] +__all__ = ['mlu', 'ipu', 'mps', 'get_device', 'scatter', 'scatter_kwargs'] diff --git a/mmcv/device/_functions.py b/mmcv/device/_functions.py new file mode 100644 index 0000000000..462a7e4ddc --- /dev/null +++ b/mmcv/device/_functions.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Union + +import torch + +from mmcv.utils import deprecated_api_warning +from .utils import get_device + + +def scatter(input: Union[List, torch.Tensor], devices: List) -> List: + """scatter copies tensor to devices directly.""" + current_device = get_device() + if isinstance(input, list): + outputs = [scatter(_input, devices) for _input in input] + return outputs + elif isinstance(input, torch.Tensor): + output = input.contiguous() + return output.to(current_device) if devices != [-1] else output + else: + raise Exception(f'Unknown type {type(input)}.') + + +class Scatter: + + @staticmethod + @deprecated_api_warning({'target_mlus': 'target_devices'}, + cls_name='Scatter') + def forward(target_devices, input): + outputs = scatter(input, target_devices) + return tuple(outputs) if isinstance(outputs, list) else (outputs, ) diff --git a/mmcv/device/mlu/__init__.py b/mmcv/device/mlu/__init__.py index 572c4da7ee..77c71ccf3c 100644 --- a/mmcv/device/mlu/__init__.py +++ b/mmcv/device/mlu/__init__.py @@ -1,9 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .data_parallel import MLUDataParallel from .distributed import MLUDistributedDataParallel -from .scatter_gather import scatter, scatter_kwargs -__all__ = [ - 'MLUDataParallel', 'MLUDistributedDataParallel', 'scatter', - 'scatter_kwargs' -] +__all__ = ['MLUDataParallel', 'MLUDistributedDataParallel'] diff --git a/mmcv/device/mps/__init__.py b/mmcv/device/mps/__init__.py new file mode 100644 index 0000000000..e28144ef0a --- /dev/null +++ b/mmcv/device/mps/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .data_parallel import MPSDataParallel + +__all__ = ['MPSDataParallel'] diff --git a/mmcv/device/mps/data_parallel.py b/mmcv/device/mps/data_parallel.py new file mode 100644 index 0000000000..7ae5396d24 --- /dev/null +++ b/mmcv/device/mps/data_parallel.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch + +from mmcv.parallel import MMDataParallel +from ..scatter_gather import scatter_kwargs + + +class MPSDataParallel(MMDataParallel): + """The MPSDataParallel module that supports DataContainer. + + MPSDataParallel is a class inherited from MMDataParall, which supports + MPS training and inference only. + + The main differences with MMDataParallel: + + - It only supports single-card of MPS, and only use first card to + run training and inference. + + - It uses direct host-to-device copy instead of stream-background + scatter. + + Args: + module (:class:`nn.Module`): Module to be encapsulated. + dim (int): Dimension used to scatter the data. Defaults to 0. + """ + + def __init__(self, *args, dim=0, **kwargs): + super().__init__(*args, dim=dim, **kwargs) + self.device_ids = [0] + self.src_device_obj = torch.device('mps:0') + + def scatter(self, inputs, kwargs, device_ids): + return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) diff --git a/mmcv/device/scatter_gather.py b/mmcv/device/scatter_gather.py new file mode 100644 index 0000000000..744b0ca51e --- /dev/null +++ b/mmcv/device/scatter_gather.py @@ -0,0 +1,64 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmcv.parallel.data_container import DataContainer +from mmcv.utils import deprecated_api_warning +from ._functions import Scatter +from .utils import get_device + + +@deprecated_api_warning({'target_mlus': 'target_devices'}) +def scatter(inputs, target_devices, dim=0): + """Scatter inputs to target devices. + + The only difference from original :func:`scatter` is to add support for + :type:`~mmcv.parallel.DataContainer`. + """ + current_device = get_device() + + def scatter_map(obj): + if isinstance(obj, torch.Tensor): + if target_devices != [-1]: + obj = obj.to(current_device) + return [obj] + else: + # for CPU inference we use self-implemented scatter + return Scatter.forward(target_devices, obj) + if isinstance(obj, DataContainer): + if obj.cpu_only: + return obj.data + else: + return Scatter.forward(target_devices, obj.data) + if isinstance(obj, tuple) and len(obj) > 0: + return list(zip(*map(scatter_map, obj))) + if isinstance(obj, list) and len(obj) > 0: + out = list(map(list, zip(*map(scatter_map, obj)))) + return out + if isinstance(obj, dict) and len(obj) > 0: + out = list(map(type(obj), zip(*map(scatter_map, obj.items())))) + return out + return [obj for _ in target_devices] + + # After scatter_map is called, a scatter_map cell will exist. This cell + # has a reference to the actual function scatter_map, which has references + # to a closure that has a reference to the scatter_map cell (because the + # fn is recursive). To avoid this reference cycle, we set the function to + # None, clearing the cell + try: + return scatter_map(inputs) + finally: + scatter_map = None + + +@deprecated_api_warning({'target_mlus': 'target_devices'}) +def scatter_kwargs(inputs, kwargs, target_devices, dim=0): + """Scatter with support for kwargs dictionary.""" + inputs = scatter(inputs, target_devices, dim) if inputs else [] + kwargs = scatter(kwargs, target_devices, dim) if kwargs else [] + if len(inputs) < len(kwargs): + inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) + elif len(kwargs) < len(inputs): + kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) + inputs = tuple(inputs) + kwargs = tuple(kwargs) + return inputs, kwargs diff --git a/mmcv/device/utils.py b/mmcv/device/utils.py new file mode 100644 index 0000000000..e2adec08dd --- /dev/null +++ b/mmcv/device/utils.py @@ -0,0 +1,18 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE + + +def get_device() -> str: + """Returns the currently existing device type. + + Returns: + str: cuda | mlu | mps | cpu. + """ + if IS_CUDA_AVAILABLE: + return 'cuda' + elif IS_MLU_AVAILABLE: + return 'mlu' + elif IS_MPS_AVAILABLE: + return 'mps' + else: + return 'cpu' diff --git a/mmcv/parallel/data_parallel.py b/mmcv/parallel/data_parallel.py index 22c6b2feff..eea088fa0c 100644 --- a/mmcv/parallel/data_parallel.py +++ b/mmcv/parallel/data_parallel.py @@ -14,7 +14,7 @@ class MMDataParallel(DataParallel): - It supports a custom type :class:`DataContainer` which allows more flexible control of input data during both GPU and CPU inference. - - It implement two more APIs ``train_step()`` and ``val_step()``. + - It implements two more APIs ``train_step()`` and ``val_step()``. .. warning:: MMDataParallel only supports single GPU training, if you need to diff --git a/mmcv/utils/__init__.py b/mmcv/utils/__init__.py index 059ae746da..8bb5a8173d 100644 --- a/mmcv/utils/__init__.py +++ b/mmcv/utils/__init__.py @@ -36,7 +36,8 @@ 'is_method_overridden', 'has_method' ] else: - from .device_type import IS_IPU_AVAILABLE, IS_MLU_AVAILABLE + from .device_type import (IS_IPU_AVAILABLE, IS_MLU_AVAILABLE, + IS_MPS_AVAILABLE) from .env import collect_env from .hub import load_url from .logging import get_logger, print_log @@ -76,5 +77,5 @@ 'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch', '_get_cuda_home', 'load_url', 'has_method', 'IS_CUDA_AVAILABLE', 'worker_init_fn', 'IS_MLU_AVAILABLE', 'IS_IPU_AVAILABLE', - 'torch_meshgrid' + 'IS_MPS_AVAILABLE', 'torch_meshgrid' ] diff --git a/mmcv/utils/device_type.py b/mmcv/utils/device_type.py index c66052c2e1..d42ff72e9f 100644 --- a/mmcv/utils/device_type.py +++ b/mmcv/utils/device_type.py @@ -22,3 +22,19 @@ def is_mlu_available() -> bool: IS_MLU_AVAILABLE = is_mlu_available() + + +def is_mps_available() -> bool: + """Return True if mps devices exist. + + It's specialized for mac m1 chips and require torch version 1.12 or higher. + """ + try: + import torch + return hasattr(torch.backends, + 'mps') and torch.backends.mps.is_available() + except Exception: + return False + + +IS_MPS_AVAILABLE = is_mps_available() diff --git a/tests/test_device/test_device_utils.py b/tests/test_device/test_device_utils.py new file mode 100644 index 0000000000..6597efa5a3 --- /dev/null +++ b/tests/test_device/test_device_utils.py @@ -0,0 +1,15 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.device import get_device +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE + + +def test_get_device(): + current_device = get_device() + if IS_CUDA_AVAILABLE: + assert current_device == 'cuda' + elif IS_MLU_AVAILABLE: + assert current_device == 'mlu' + elif IS_MPS_AVAILABLE: + assert current_device == 'mps' + else: + assert current_device == 'cpu' diff --git a/tests/test_device/test_functions.py b/tests/test_device/test_functions.py new file mode 100644 index 0000000000..dbbb8978b5 --- /dev/null +++ b/tests/test_device/test_functions.py @@ -0,0 +1,90 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmcv.device._functions import Scatter, scatter +from mmcv.utils import IS_MLU_AVAILABLE, IS_MPS_AVAILABLE + + +def test_scatter(): + # if the device is CPU, just return the input + input = torch.zeros([1, 3, 3, 3]) + output = scatter(input=input, devices=[-1]) + assert torch.allclose(input, output) + + inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] + outputs = scatter(input=inputs, devices=[-1]) + for input, output in zip(inputs, outputs): + assert torch.allclose(input, output) + + # if the device is MLU, copy the input from CPU to MLU + if IS_MLU_AVAILABLE: + input = torch.zeros([1, 3, 3, 3]) + output = scatter(input=input, devices=[0]) + assert torch.allclose(input.to('mlu'), output) + + inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] + outputs = scatter(input=inputs, devices=[0]) + for input, output in zip(inputs, outputs): + assert torch.allclose(input.to('mlu'), output) + + # if the device is MPS, copy the input from CPU to MPS + if IS_MPS_AVAILABLE: + input = torch.zeros([1, 3, 3, 3]) + output = scatter(input=input, devices=[0]) + assert torch.allclose(input.to('mps'), output) + + inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] + outputs = scatter(input=inputs, devices=[0]) + for input, output in zip(inputs, outputs): + assert torch.allclose(input.to('mps'), output) + + # input should be a tensor or list of tensor + with pytest.raises(Exception): + scatter(5, [-1]) + + +def test_Scatter(): + # if the device is CPU, just return the input + target_devices = [-1] + input = torch.zeros([1, 3, 3, 3]) + outputs = Scatter.forward(target_devices, input) + assert isinstance(outputs, tuple) + assert torch.allclose(input, outputs[0]) + + target_devices = [-1] + inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] + outputs = Scatter.forward(target_devices, inputs) + assert isinstance(outputs, tuple) + for input, output in zip(inputs, outputs): + assert torch.allclose(input, output) + + # if the device is MLU, copy the input from CPU to MLU + if IS_MLU_AVAILABLE: + target_devices = [0] + input = torch.zeros([1, 3, 3, 3]) + outputs = Scatter.forward(target_devices, input) + assert isinstance(outputs, tuple) + assert torch.allclose(input.to('mlu'), outputs[0]) + + target_devices = [0] + inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] + outputs = Scatter.forward(target_devices, inputs) + assert isinstance(outputs, tuple) + for input, output in zip(inputs, outputs): + assert torch.allclose(input.to('mlu'), output[0]) + + # if the device is MPS, copy the input from CPU to MPS + if IS_MPS_AVAILABLE: + target_devices = [0] + input = torch.zeros([1, 3, 3, 3]) + outputs = Scatter.forward(target_devices, input) + assert isinstance(outputs, tuple) + assert torch.allclose(input.to('mps'), outputs[0]) + + target_devices = [0] + inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] + outputs = Scatter.forward(target_devices, inputs) + assert isinstance(outputs, tuple) + for input, output in zip(inputs, outputs): + assert torch.allclose(input.to('mps'), output[0]) diff --git a/tests/test_device/test_mlu/test_mlu_parallel.py b/tests/test_device/test_mlu/test_mlu_parallel.py index cecf148e08..4d04fb6551 100644 --- a/tests/test_device/test_mlu/test_mlu_parallel.py +++ b/tests/test_device/test_mlu/test_mlu_parallel.py @@ -1,12 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from unittest.mock import MagicMock, patch -import pytest -import torch import torch.nn as nn from mmcv.device.mlu import MLUDataParallel, MLUDistributedDataParallel -from mmcv.device.mlu._functions import Scatter, scatter from mmcv.parallel import is_module_wrapper from mmcv.utils import IS_MLU_AVAILABLE @@ -38,61 +35,3 @@ def forward(self, x): mluddp = MLUDistributedDataParallel(model, process_group=MagicMock()) assert is_module_wrapper(mluddp) - - -def test_scatter(): - # if the device is CPU, just return the input - input = torch.zeros([1, 3, 3, 3]) - output = scatter(input=input, devices=[-1]) - assert torch.allclose(input, output) - - inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] - outputs = scatter(input=inputs, devices=[-1]) - for input, output in zip(inputs, outputs): - assert torch.allclose(input, output) - - # if the device is MLU, copy the input from CPU to MLU - if IS_MLU_AVAILABLE: - input = torch.zeros([1, 3, 3, 3]) - output = scatter(input=input, devices=[0]) - assert torch.allclose(input.to('mlu'), output) - - inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] - outputs = scatter(input=inputs, devices=[0]) - for input, output in zip(inputs, outputs): - assert torch.allclose(input.to('mlu'), output) - - # input should be a tensor or list of tensor - with pytest.raises(Exception): - scatter(5, [-1]) - - -def test_Scatter(): - # if the device is CPU, just return the input - target_mlus = [-1] - input = torch.zeros([1, 3, 3, 3]) - outputs = Scatter.forward(target_mlus, input) - assert isinstance(outputs, tuple) - assert torch.allclose(input, outputs[0]) - - target_mlus = [-1] - inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] - outputs = Scatter.forward(target_mlus, inputs) - assert isinstance(outputs, tuple) - for input, output in zip(inputs, outputs): - assert torch.allclose(input, output) - - # if the device is MLU, copy the input from CPU to MLU - if IS_MLU_AVAILABLE: - target_mlus = [0] - input = torch.zeros([1, 3, 3, 3]) - outputs = Scatter.forward(target_mlus, input) - assert isinstance(outputs, tuple) - assert torch.allclose(input.to('mlu'), outputs[0]) - - target_mlus = [0] - inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] - outputs = Scatter.forward(target_mlus, inputs) - assert isinstance(outputs, tuple) - for input, output in zip(inputs, outputs): - assert torch.allclose(input.to('mlu'), output[0]) diff --git a/tests/test_device/test_mps/test_mps_parallel.py b/tests/test_device/test_mps/test_mps_parallel.py new file mode 100644 index 0000000000..4b4e0b86e1 --- /dev/null +++ b/tests/test_device/test_mps/test_mps_parallel.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest.mock import patch + +import torch.nn as nn + +from mmcv.device.mps import MPSDataParallel +from mmcv.parallel import is_module_wrapper +from mmcv.utils import IS_MPS_AVAILABLE + + +def mock(*args, **kwargs): + pass + + +@patch('torch.distributed._broadcast_coalesced', mock) +@patch('torch.distributed.broadcast', mock) +@patch('torch.nn.parallel.DistributedDataParallel._ddp_init_helper', mock) +def test_is_module_wrapper(): + + class Model(nn.Module): + + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(2, 2, 1) + + def forward(self, x): + return self.conv(x) + + model = Model() + assert not is_module_wrapper(model) + + if IS_MPS_AVAILABLE: + mpsdp = MPSDataParallel(model) + assert is_module_wrapper(mpsdp)