-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Add support for mps (#2092)
* [Feature] Add support for MPS * fix import error * update ut * fix error * trigger CI * use a unique basename for test file modules * avoid bc-breaking
- Loading branch information
Showing
15 changed files
with
315 additions
and
72 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from . import ipu, mlu | ||
from . import ipu, mlu, mps | ||
from .scatter_gather import scatter, scatter_kwargs | ||
from .utils import get_device | ||
|
||
__all__ = ['mlu', 'ipu'] | ||
__all__ = ['mlu', 'ipu', 'mps', 'get_device', 'scatter', 'scatter_kwargs'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import List, Union | ||
|
||
import torch | ||
|
||
from mmcv.utils import deprecated_api_warning | ||
from .utils import get_device | ||
|
||
|
||
def scatter(input: Union[List, torch.Tensor], devices: List) -> List: | ||
"""scatter copies tensor to devices directly.""" | ||
current_device = get_device() | ||
if isinstance(input, list): | ||
outputs = [scatter(_input, devices) for _input in input] | ||
return outputs | ||
elif isinstance(input, torch.Tensor): | ||
output = input.contiguous() | ||
return output.to(current_device) if devices != [-1] else output | ||
else: | ||
raise Exception(f'Unknown type {type(input)}.') | ||
|
||
|
||
class Scatter: | ||
|
||
@staticmethod | ||
@deprecated_api_warning({'target_mlus': 'target_devices'}, | ||
cls_name='Scatter') | ||
def forward(target_devices, input): | ||
outputs = scatter(input, target_devices) | ||
return tuple(outputs) if isinstance(outputs, list) else (outputs, ) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,5 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .data_parallel import MLUDataParallel | ||
from .distributed import MLUDistributedDataParallel | ||
from .scatter_gather import scatter, scatter_kwargs | ||
|
||
__all__ = [ | ||
'MLUDataParallel', 'MLUDistributedDataParallel', 'scatter', | ||
'scatter_kwargs' | ||
] | ||
__all__ = ['MLUDataParallel', 'MLUDistributedDataParallel'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .data_parallel import MPSDataParallel | ||
|
||
__all__ = ['MPSDataParallel'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
import torch | ||
|
||
from mmcv.parallel import MMDataParallel | ||
from ..scatter_gather import scatter_kwargs | ||
|
||
|
||
class MPSDataParallel(MMDataParallel): | ||
"""The MPSDataParallel module that supports DataContainer. | ||
MPSDataParallel is a class inherited from MMDataParall, which supports | ||
MPS training and inference only. | ||
The main differences with MMDataParallel: | ||
- It only supports single-card of MPS, and only use first card to | ||
run training and inference. | ||
- It uses direct host-to-device copy instead of stream-background | ||
scatter. | ||
Args: | ||
module (:class:`nn.Module`): Module to be encapsulated. | ||
dim (int): Dimension used to scatter the data. Defaults to 0. | ||
""" | ||
|
||
def __init__(self, *args, dim=0, **kwargs): | ||
super().__init__(*args, dim=dim, **kwargs) | ||
self.device_ids = [0] | ||
self.src_device_obj = torch.device('mps:0') | ||
|
||
def scatter(self, inputs, kwargs, device_ids): | ||
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import torch | ||
|
||
from mmcv.parallel.data_container import DataContainer | ||
from mmcv.utils import deprecated_api_warning | ||
from ._functions import Scatter | ||
from .utils import get_device | ||
|
||
|
||
@deprecated_api_warning({'target_mlus': 'target_devices'}) | ||
def scatter(inputs, target_devices, dim=0): | ||
"""Scatter inputs to target devices. | ||
The only difference from original :func:`scatter` is to add support for | ||
:type:`~mmcv.parallel.DataContainer`. | ||
""" | ||
current_device = get_device() | ||
|
||
def scatter_map(obj): | ||
if isinstance(obj, torch.Tensor): | ||
if target_devices != [-1]: | ||
obj = obj.to(current_device) | ||
return [obj] | ||
else: | ||
# for CPU inference we use self-implemented scatter | ||
return Scatter.forward(target_devices, obj) | ||
if isinstance(obj, DataContainer): | ||
if obj.cpu_only: | ||
return obj.data | ||
else: | ||
return Scatter.forward(target_devices, obj.data) | ||
if isinstance(obj, tuple) and len(obj) > 0: | ||
return list(zip(*map(scatter_map, obj))) | ||
if isinstance(obj, list) and len(obj) > 0: | ||
out = list(map(list, zip(*map(scatter_map, obj)))) | ||
return out | ||
if isinstance(obj, dict) and len(obj) > 0: | ||
out = list(map(type(obj), zip(*map(scatter_map, obj.items())))) | ||
return out | ||
return [obj for _ in target_devices] | ||
|
||
# After scatter_map is called, a scatter_map cell will exist. This cell | ||
# has a reference to the actual function scatter_map, which has references | ||
# to a closure that has a reference to the scatter_map cell (because the | ||
# fn is recursive). To avoid this reference cycle, we set the function to | ||
# None, clearing the cell | ||
try: | ||
return scatter_map(inputs) | ||
finally: | ||
scatter_map = None | ||
|
||
|
||
@deprecated_api_warning({'target_mlus': 'target_devices'}) | ||
def scatter_kwargs(inputs, kwargs, target_devices, dim=0): | ||
"""Scatter with support for kwargs dictionary.""" | ||
inputs = scatter(inputs, target_devices, dim) if inputs else [] | ||
kwargs = scatter(kwargs, target_devices, dim) if kwargs else [] | ||
if len(inputs) < len(kwargs): | ||
inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) | ||
elif len(kwargs) < len(inputs): | ||
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) | ||
inputs = tuple(inputs) | ||
kwargs = tuple(kwargs) | ||
return inputs, kwargs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE | ||
|
||
|
||
def get_device() -> str: | ||
"""Returns the currently existing device type. | ||
Returns: | ||
str: cuda | mlu | mps | cpu. | ||
""" | ||
if IS_CUDA_AVAILABLE: | ||
return 'cuda' | ||
elif IS_MLU_AVAILABLE: | ||
return 'mlu' | ||
elif IS_MPS_AVAILABLE: | ||
return 'mps' | ||
else: | ||
return 'cpu' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmcv.device import get_device | ||
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_MPS_AVAILABLE | ||
|
||
|
||
def test_get_device(): | ||
current_device = get_device() | ||
if IS_CUDA_AVAILABLE: | ||
assert current_device == 'cuda' | ||
elif IS_MLU_AVAILABLE: | ||
assert current_device == 'mlu' | ||
elif IS_MPS_AVAILABLE: | ||
assert current_device == 'mps' | ||
else: | ||
assert current_device == 'cpu' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import pytest | ||
import torch | ||
|
||
from mmcv.device._functions import Scatter, scatter | ||
from mmcv.utils import IS_MLU_AVAILABLE, IS_MPS_AVAILABLE | ||
|
||
|
||
def test_scatter(): | ||
# if the device is CPU, just return the input | ||
input = torch.zeros([1, 3, 3, 3]) | ||
output = scatter(input=input, devices=[-1]) | ||
assert torch.allclose(input, output) | ||
|
||
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] | ||
outputs = scatter(input=inputs, devices=[-1]) | ||
for input, output in zip(inputs, outputs): | ||
assert torch.allclose(input, output) | ||
|
||
# if the device is MLU, copy the input from CPU to MLU | ||
if IS_MLU_AVAILABLE: | ||
input = torch.zeros([1, 3, 3, 3]) | ||
output = scatter(input=input, devices=[0]) | ||
assert torch.allclose(input.to('mlu'), output) | ||
|
||
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] | ||
outputs = scatter(input=inputs, devices=[0]) | ||
for input, output in zip(inputs, outputs): | ||
assert torch.allclose(input.to('mlu'), output) | ||
|
||
# if the device is MPS, copy the input from CPU to MPS | ||
if IS_MPS_AVAILABLE: | ||
input = torch.zeros([1, 3, 3, 3]) | ||
output = scatter(input=input, devices=[0]) | ||
assert torch.allclose(input.to('mps'), output) | ||
|
||
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] | ||
outputs = scatter(input=inputs, devices=[0]) | ||
for input, output in zip(inputs, outputs): | ||
assert torch.allclose(input.to('mps'), output) | ||
|
||
# input should be a tensor or list of tensor | ||
with pytest.raises(Exception): | ||
scatter(5, [-1]) | ||
|
||
|
||
def test_Scatter(): | ||
# if the device is CPU, just return the input | ||
target_devices = [-1] | ||
input = torch.zeros([1, 3, 3, 3]) | ||
outputs = Scatter.forward(target_devices, input) | ||
assert isinstance(outputs, tuple) | ||
assert torch.allclose(input, outputs[0]) | ||
|
||
target_devices = [-1] | ||
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] | ||
outputs = Scatter.forward(target_devices, inputs) | ||
assert isinstance(outputs, tuple) | ||
for input, output in zip(inputs, outputs): | ||
assert torch.allclose(input, output) | ||
|
||
# if the device is MLU, copy the input from CPU to MLU | ||
if IS_MLU_AVAILABLE: | ||
target_devices = [0] | ||
input = torch.zeros([1, 3, 3, 3]) | ||
outputs = Scatter.forward(target_devices, input) | ||
assert isinstance(outputs, tuple) | ||
assert torch.allclose(input.to('mlu'), outputs[0]) | ||
|
||
target_devices = [0] | ||
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] | ||
outputs = Scatter.forward(target_devices, inputs) | ||
assert isinstance(outputs, tuple) | ||
for input, output in zip(inputs, outputs): | ||
assert torch.allclose(input.to('mlu'), output[0]) | ||
|
||
# if the device is MPS, copy the input from CPU to MPS | ||
if IS_MPS_AVAILABLE: | ||
target_devices = [0] | ||
input = torch.zeros([1, 3, 3, 3]) | ||
outputs = Scatter.forward(target_devices, input) | ||
assert isinstance(outputs, tuple) | ||
assert torch.allclose(input.to('mps'), outputs[0]) | ||
|
||
target_devices = [0] | ||
inputs = [torch.zeros([1, 3, 3, 3]), torch.zeros([1, 4, 4, 4])] | ||
outputs = Scatter.forward(target_devices, inputs) | ||
assert isinstance(outputs, tuple) | ||
for input, output in zip(inputs, outputs): | ||
assert torch.allclose(input.to('mps'), output[0]) |
Oops, something went wrong.