Skip to content

Commit 57d05f2

Browse files
guangyeypytorchmergebot
authored andcommitted
[RELAND] Add xpu to getAccelerator (#129205)
# Motivation Add `xpu` support to `getAccelerator`. Pull Request resolved: pytorch/pytorch#129205 Approved by: https://github.com/albanD, https://github.com/gujinghui ghstack dependencies: #129463
1 parent 551f3b9 commit 57d05f2

File tree

3 files changed

+34
-35
lines changed

3 files changed

+34
-35
lines changed

aten/src/ATen/DeviceAccelerator.cpp

+30-32
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,37 @@
1-
#include <ATen/DeviceAccelerator.h>
21
#include <ATen/Context.h>
3-
2+
#include <ATen/DeviceAccelerator.h>
43
namespace at {
54

65
C10_API std::optional<DeviceType> getAccelerator(bool checked) {
7-
#define CHECK_NO_CUDA \
8-
TORCH_CHECK(!at::hasCUDA(), "Cannot have both CUDA and PrivateUse1");
9-
10-
#define CHECK_NO_PU1 \
11-
TORCH_CHECK(!is_privateuse1_backend_registered(), "Cannot have both CUDA and PrivateUse1");
12-
13-
#define CHECK_NO_MTIA \
14-
TORCH_CHECK(!at::hasMTIA(), "Cannot have MTIA with other devices");
15-
16-
if (is_privateuse1_backend_registered()) {
17-
// We explicitly allow PrivateUse1 and another device at the same time
18-
// as we use this for testing.
19-
// Whenever a PrivateUse1 device is registered, use it first.
20-
return kPrivateUse1;
21-
} else if (at::hasCUDA()) {
22-
CHECK_NO_PU1
23-
CHECK_NO_MTIA
24-
return kCUDA;
25-
} else if (at::hasMTIA()) {
26-
CHECK_NO_CUDA
27-
CHECK_NO_PU1
28-
return kMTIA;
29-
} else {
30-
TORCH_CHECK(!checked, "Cannot access accelerator device when none is available.")
31-
return std::nullopt;
32-
}
33-
34-
#undef CHECK_NO_CUDA
35-
#undef CHECK_NO_PU1
6+
#define DETECT_AND_ASSIGN_ACCELERATOR(device_name) \
7+
if (at::has##device_name()) { \
8+
device_type = k##device_name; \
9+
TORCH_CHECK( \
10+
!is_accelerator_detected, \
11+
"Cannot have ", \
12+
device_type.value(), \
13+
" with other accelerators."); \
14+
is_accelerator_detected = true; \
15+
}
16+
17+
if (is_privateuse1_backend_registered()) {
18+
// We explicitly allow PrivateUse1 and another device at the same time as we
19+
// use this for testing. Whenever a PrivateUse1 device is registered, use it
20+
// first.
21+
return kPrivateUse1;
22+
}
23+
std::optional<DeviceType> device_type = std::nullopt;
24+
bool is_accelerator_detected = false;
25+
DETECT_AND_ASSIGN_ACCELERATOR(CUDA)
26+
DETECT_AND_ASSIGN_ACCELERATOR(MTIA)
27+
DETECT_AND_ASSIGN_ACCELERATOR(XPU)
28+
if (checked) {
29+
TORCH_CHECK(
30+
device_type, "Cannot access accelerator device when none is available.")
31+
}
32+
return device_type;
33+
34+
#undef DETECT_AND_ASSIGN_ACCELERATOR
3635
}
3736

38-
3937
} // namespace at

aten/src/ATen/DeviceAccelerator.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
// - It provides a set of common APIs as defined by AcceleratorHooksInterface
1414
//
1515
// As of today, accelerator devices are (in no particular order):
16-
// CUDA, MTIA, PrivateUse1
16+
// CUDA, MTIA, XPU, PrivateUse1
1717
// We want to add once all the proper APIs are supported and tested:
18-
// HIP, MPS, XPU
18+
// HIP, MPS
1919

2020
namespace at {
2121

test/test_cpp_extensions_stream_and_event.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
skipIfTorchDynamo,
1616
TEST_CUDA,
1717
TEST_PRIVATEUSE1,
18+
TEST_XPU,
1819
)
1920
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
2021

@@ -36,7 +37,7 @@ def remove_build_path():
3637
# Since we use a fake MTIA device backend to test generic Stream/Event, device backends are mutual exclusive to each other.
3738
# The test will be skipped if any of the following conditions are met:
3839
@unittest.skipIf(
39-
IS_ARM64 or not IS_LINUX or TEST_CUDA or TEST_PRIVATEUSE1 or TEST_ROCM,
40+
IS_ARM64 or not IS_LINUX or TEST_CUDA or TEST_XPU or TEST_PRIVATEUSE1 or TEST_ROCM,
4041
"Only on linux platform and mutual exclusive to other backends",
4142
)
4243
@torch.testing._internal.common_utils.markDynamoStrictTest

0 commit comments

Comments
 (0)