Skip to content

Commit

Permalink
[Feature] Add support for mps (#2092)
Browse files Browse the repository at this point in the history
* [Feature] Add support for MPS

* fix import error

* update ut

* fix error

* trigger CI

* use a unique basename for test file modules

* avoid bc-breaking
  • Loading branch information
zhouzaida authored Jul 7, 2022
1 parent 357b484 commit 6a03918
Show file tree
Hide file tree
Showing 15 changed files with 315 additions and 72 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ jobs:
--ignore=tests/test_utils/test_parrots_jit.py \
--ignore=tests/test_utils/test_trace.py \
--ignore=tests/test_utils/test_hub.py \
--ignore=tests/test_device/test_mlu/test_mlu_parallel.py \
--ignore=tests/test_device \
--ignore=tests/test_utils/test_torch_ops.py
build_without_ops:
Expand Down
6 changes: 4 additions & 2 deletions mmcv/device/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from . import ipu, mlu
from . import ipu, mlu, mps
from .scatter_gather import scatter, scatter_kwargs
from .utils import get_device

__all__ = ['mlu', 'ipu']
__all__ = ['mlu', 'ipu', 'mps', 'get_device', 'scatter', 'scatter_kwargs']
30 changes: 30 additions & 0 deletions mmcv/device/_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Union

import torch

from mmcv.utils import deprecated_api_warning
from .utils import get_device


def scatter(input: Union[List, torch.Tensor], devices: List) -> List:
"""scatter copies tensor to devices directly."""
current_device = get_device()
if isinstance(input, list):
outputs = [scatter(_input, devices) for _input in input]
return outputs
elif isinstance(input, torch.Tensor):
output = input.contiguous()
return output.to(current_device) if devices != [-1] else output
else:
raise Exception(f'Unknown type {type(input)}.')


class Scatter:

@staticmethod
@deprecated_api_warning({'target_mlus': 'target_devices'},
cls_name='Scatter')
def forward(target_devices, input):
outputs = scatter(input, target_devices)
return tuple(outputs) if isinstance(outputs, list) else (outputs, )
6 changes: 1 addition & 5 deletions mmcv/device/mlu/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .data_parallel import MLUDataParallel
from .distributed import MLUDistributedDataParallel
from .scatter_gather import scatter, scatter_kwargs

__all__ = [
'MLUDataParallel', 'MLUDistributedDataParallel', 'scatter',
'scatter_kwargs'
]
__all__ = ['MLUDataParallel', 'MLUDistributedDataParallel']
4 changes: 4 additions & 0 deletions mmcv/device/mps/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .data_parallel import MPSDataParallel

__all__ = ['MPSDataParallel']
34 changes: 34 additions & 0 deletions mmcv/device/mps/data_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) OpenMMLab. All rights reserved.

import torch

from mmcv.parallel import MMDataParallel
from ..scatter_gather import scatter_kwargs


class MPSDataParallel(MMDataParallel):
"""The MPSDataParallel module that supports DataContainer.
MPSDataParallel is a class inherited from MMDataParall, which supports
MPS training and inference only.
The main differences with MMDataParallel:
- It only supports single-card of MPS, and only use first card to
run training and inference.
- It uses direct host-to-device copy instead of stream-background
scatter.
Args:
module (:class:`nn.Module`): Module to be encapsulated.
dim (int): Dimension used to scatter the data. Defaults to 0.
"""

def __init__(self, *args, dim=0, **kwargs):
super().__init__(*args, dim=dim, **kwargs)
self.device_ids = [0]
self.src_device_obj = torch.device('mps:0')

def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
64 changes: 64 additions & 0 deletions mmcv/device/scatter_gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch

from mmcv.parallel.data_container import DataContainer
from mmcv.utils import deprecated_api_warning
from ._functions import Scatter
from .utils import get_device


@deprecated_api_warning({'target_mlus': 'target_devices'})
def scatter(inputs, target_devices, dim=0):
"""Scatter inputs to target devices.
The only difference from original :func:`scatter` is to add support for
:type:`~mmcv.parallel.DataContainer`.
"""
current_device = get_device()

def scatter_map(obj):
if isinstance(obj, torch.Tensor):
if target_devices != [-1]:
obj = obj.to(current_device)
return [obj]
else:
# for CPU inference we use self-implemented scatter
return Scatter.forward(target_devices, obj)
if isinstance(obj, DataContainer):
if obj.cpu_only:
return obj.data
else:
return Scatter.forward(target_devices, obj.data)
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
out = list(map(list, zip(*map(scatter_map, obj))))
return out
if isinstance(obj, dict) and len(obj) > 0:
out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
return out
return [obj for _ in target_devices]

# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
return scatter_map(inputs)
finally:
scatter_map = None


@deprecated_api_warning({'target_mlus': 'target_devices'})
def scatter_kwargs(inputs, kwargs, target_devices, dim=0):
"""Scatter with support for kwargs dictionary."""
inputs = scatter(inputs, target_devices, dim) if inputs else []
kwargs = scatter(kwargs, target_devices, dim) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
inputs = tuple(inputs)
kwargs = tuple(kwargs)
return inputs, kwargs
18 changes: 18 additions & 0 deletions mmcv/device/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE


def get_device() -> str:
"""Returns the currently existing device type.
Returns:
str: cuda | mlu | mps | cpu.
"""
if IS_CUDA_AVAILABLE:
return 'cuda'
elif IS_MLU_AVAILABLE:
return 'mlu'
elif IS_MPS_AVAILABLE:
return 'mps'
else:
return 'cpu'
2 changes: 1 addition & 1 deletion mmcv/parallel/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class MMDataParallel(DataParallel):
- It supports a custom type :class:`DataContainer` which allows more
flexible control of input data during both GPU and CPU inference.
- It implement two more APIs ``train_step()`` and ``val_step()``.
- It implements two more APIs ``train_step()`` and ``val_step()``.
.. warning::
MMDataParallel only supports single GPU training, if you need to
Expand Down
5 changes: 3 additions & 2 deletions mmcv/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
'is_method_overridden', 'has_method'
]
else:
from .device_type import IS_IPU_AVAILABLE, IS_MLU_AVAILABLE
from .device_type import (IS_IPU_AVAILABLE, IS_MLU_AVAILABLE,
IS_MPS_AVAILABLE)
from .env import collect_env
from .hub import load_url
from .logging import get_logger, print_log
Expand Down Expand Up @@ -76,5 +77,5 @@
'is_method_overridden', 'is_jit_tracing', 'is_rocm_pytorch',
'_get_cuda_home', 'load_url', 'has_method', 'IS_CUDA_AVAILABLE',
'worker_init_fn', 'IS_MLU_AVAILABLE', 'IS_IPU_AVAILABLE',
'torch_meshgrid'
'IS_MPS_AVAILABLE', 'torch_meshgrid'
]
16 changes: 16 additions & 0 deletions mmcv/utils/device_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,19 @@ def is_mlu_available() -> bool:


IS_MLU_AVAILABLE = is_mlu_available()


def is_mps_available() -> bool:
"""Return True if mps devices exist.
It's specialized for mac m1 chips and require torch version 1.12 or higher.
"""
try:
import torch
return hasattr(torch.backends,
'mps') and torch.backends.mps.is_available()
except Exception:
return False


IS_MPS_AVAILABLE = is_mps_available()
15 changes: 15 additions & 0 deletions tests/test_device/test_device_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.device import get_device
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE


def test_get_device():
current_device = get_device()
if IS_CUDA_AVAILABLE:
assert current_device == 'cuda'
elif IS_MLU_AVAILABLE:
assert current_device == 'mlu'
elif IS_MPS_AVAILABLE:
assert current_device == 'mps'
else:
assert current_device == 'cpu'
90 changes: 90 additions & 0 deletions tests/test_device/test_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmcv.device._functions import Scatter, scatter
from mmcv.utils import IS_MLU_AVAILABLE, IS_MPS_AVAILABLE


def test_scatter():
# if the device is CPU, just return the input
input = torch.zeros([1, 3, 3, 3])
output = scatter(input=input, devices=[-1])
assert torch.allclose(input, output)

inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = scatter(input=inputs, devices=[-1])
for input, output in zip(inputs, outputs):
assert torch.allclose(input, output)

# if the device is MLU, copy the input from CPU to MLU
if IS_MLU_AVAILABLE:
input = torch.zeros([1, 3, 3, 3])
output = scatter(input=input, devices=[0])
assert torch.allclose(input.to('mlu'), output)

inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = scatter(input=inputs, devices=[0])
for input, output in zip(inputs, outputs):
assert torch.allclose(input.to('mlu'), output)

# if the device is MPS, copy the input from CPU to MPS
if IS_MPS_AVAILABLE:
input = torch.zeros([1, 3, 3, 3])
output = scatter(input=input, devices=[0])
assert torch.allclose(input.to('mps'), output)

inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = scatter(input=inputs, devices=[0])
for input, output in zip(inputs, outputs):
assert torch.allclose(input.to('mps'), output)

# input should be a tensor or list of tensor
with pytest.raises(Exception):
scatter(5, [-1])


def test_Scatter():
# if the device is CPU, just return the input
target_devices = [-1]
input = torch.zeros([1, 3, 3, 3])
outputs = Scatter.forward(target_devices, input)
assert isinstance(outputs, tuple)
assert torch.allclose(input, outputs[0])

target_devices = [-1]
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = Scatter.forward(target_devices, inputs)
assert isinstance(outputs, tuple)
for input, output in zip(inputs, outputs):
assert torch.allclose(input, output)

# if the device is MLU, copy the input from CPU to MLU
if IS_MLU_AVAILABLE:
target_devices = [0]
input = torch.zeros([1, 3, 3, 3])
outputs = Scatter.forward(target_devices, input)
assert isinstance(outputs, tuple)
assert torch.allclose(input.to('mlu'), outputs[0])

target_devices = [0]
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = Scatter.forward(target_devices, inputs)
assert isinstance(outputs, tuple)
for input, output in zip(inputs, outputs):
assert torch.allclose(input.to('mlu'), output[0])

# if the device is MPS, copy the input from CPU to MPS
if IS_MPS_AVAILABLE:
target_devices = [0]
input = torch.zeros([1, 3, 3, 3])
outputs = Scatter.forward(target_devices, input)
assert isinstance(outputs, tuple)
assert torch.allclose(input.to('mps'), outputs[0])

target_devices = [0]
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])]
outputs = Scatter.forward(target_devices, inputs)
assert isinstance(outputs, tuple)
for input, output in zip(inputs, outputs):
assert torch.allclose(input.to('mps'), output[0])
Loading

0 comments on commit 6a03918

Please sign in to comment.