Skip to content

Commit bdba823

Browse files
ys950902delock
andauthored
[XPU] Support XCCL on deepspeed side (#7299)
XCCL will be used for XPU device on Pytorch-2.8, with this support will remove torch-ccl on XPU device, and we will also reserve the old path for torch-CCL enable. --------- Signed-off-by: yisheng <[email protected]> Co-authored-by: Ma, Guokai <[email protected]>
1 parent 0e3209a commit bdba823

File tree

2 files changed

+46
-12
lines changed

2 files changed

+46
-12
lines changed

accelerator/real_accelerator.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,21 @@ def get_accelerator():
136136
accelerator_name = "xpu"
137137
except ImportError as e:
138138
pass
139+
if accelerator_name is None:
140+
try:
141+
import torch
142+
143+
# torch.xpu will be supported in upstream pytorch-2.8.
144+
# Currently we can run on xpu device only using pytorch,
145+
# also reserve the old path using ipex when the torch version is old.
146+
if hasattr(torch, 'xpu'):
147+
if torch.cuda.device_count() == 0: #ignore-cuda
148+
if torch.xpu.device_count() > 0 and torch.xpu.is_available():
149+
accelerator_name = "xpu"
150+
else:
151+
pass
152+
except ImportError as e:
153+
pass
139154
if accelerator_name is None:
140155
try:
141156
import torch_npu # noqa: F401,F811 # type: ignore

accelerator/xpu_accelerator.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,32 @@
55

66
import torch
77
from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator
8-
import intel_extension_for_pytorch as ipex # noqa: F401 # type: ignore
9-
import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
108
import functools
11-
129
import importlib
1310
import inspect
1411

12+
try:
13+
import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
14+
oneccl_imported_p = True
15+
except ImportError as e:
16+
oneccl_imported_p = False
17+
18+
try:
19+
import intel_extension_for_pytorch as ipex # noqa: F401 # type: ignore
20+
ipex_imported_p = True
21+
except ImportError as e:
22+
ipex_imported_p = False
23+
1524

1625
class XPU_Accelerator(DeepSpeedAccelerator):
1726

1827
def __init__(self):
1928
self._name = 'xpu'
20-
self._communication_backend_name = 'ccl'
29+
if oneccl_imported_p:
30+
self._communication_backend_name = 'ccl'
31+
else:
32+
# changed to xccl if not using torch-CCL on XPU device
33+
self._communication_backend_name = 'xccl'
2134
self._compile_backend = "inductor"
2235
self.aligned_tensors = []
2336
self.class_dict = None
@@ -26,11 +39,14 @@ def is_synchronized_device(self):
2639
return False
2740

2841
def use_host_timers(self):
29-
# WA XPU event will be consolidated in 2.6
30-
if ipex.__version__ < '2.6':
31-
return True
32-
else:
42+
if not ipex_imported_p:
3343
return self.is_synchronized_device()
44+
else:
45+
# WA XPU event will be consolidated in 2.6
46+
if ipex.__version__ < '2.6':
47+
return True
48+
else:
49+
return self.is_synchronized_device()
3450

3551
def resolves_data_dependency(self):
3652
return self.is_synchronized_device()
@@ -290,10 +306,13 @@ def get_op_builder(self, class_name):
290306
return self.class_dict['NotImplementedBuilder']
291307

292308
def build_extension(self):
293-
try:
294-
from intel_extension_for_pytorch.xpu.cpp_extension import DpcppBuildExtension
295-
except ImportError:
296-
from intel_extension_for_pytorch.xpu.utils import DpcppBuildExtension
309+
if ipex_imported_p:
310+
try:
311+
from intel_extension_for_pytorch.xpu.cpp_extension import DpcppBuildExtension
312+
except ImportError:
313+
from intel_extension_for_pytorch.xpu.utils import DpcppBuildExtension
314+
else:
315+
from torch.utils.cpp_extension import DpcppBuildExtension
297316
return DpcppBuildExtension
298317

299318
def export_envs(self):

0 commit comments

Comments
 (0)