|
1 |
| -#include <ATen/DeviceAccelerator.h> |
2 | 1 | #include <ATen/Context.h>
|
3 |
| - |
| 2 | +#include <ATen/DeviceAccelerator.h> |
4 | 3 | namespace at {
|
5 | 4 |
|
6 | 5 | C10_API std::optional<DeviceType> getAccelerator(bool checked) {
|
7 |
| -#define CHECK_NO_CUDA \ |
8 |
| - TORCH_CHECK(!at::hasCUDA(), "Cannot have both CUDA and PrivateUse1"); |
9 |
| - |
10 |
| -#define CHECK_NO_PU1 \ |
11 |
| - TORCH_CHECK(!is_privateuse1_backend_registered(), "Cannot have both CUDA and PrivateUse1"); |
12 |
| - |
13 |
| -#define CHECK_NO_MTIA \ |
14 |
| - TORCH_CHECK(!at::hasMTIA(), "Cannot have MTIA with other devices"); |
15 |
| - |
16 |
| - if (is_privateuse1_backend_registered()) { |
17 |
| - // We explicitly allow PrivateUse1 and another device at the same time |
18 |
| - // as we use this for testing. |
19 |
| - // Whenever a PrivateUse1 device is registered, use it first. |
20 |
| - return kPrivateUse1; |
21 |
| - } else if (at::hasCUDA()) { |
22 |
| - CHECK_NO_PU1 |
23 |
| - CHECK_NO_MTIA |
24 |
| - return kCUDA; |
25 |
| - } else if (at::hasMTIA()) { |
26 |
| - CHECK_NO_CUDA |
27 |
| - CHECK_NO_PU1 |
28 |
| - return kMTIA; |
29 |
| - } else { |
30 |
| - TORCH_CHECK(!checked, "Cannot access accelerator device when none is available.") |
31 |
| - return std::nullopt; |
32 |
| - } |
33 |
| - |
34 |
| -#undef CHECK_NO_CUDA |
35 |
| -#undef CHECK_NO_PU1 |
| 6 | +#define DETECT_AND_ASSIGN_ACCELERATOR(device_name) \ |
| 7 | + if (at::has##device_name()) { \ |
| 8 | + device_type = k##device_name; \ |
| 9 | + TORCH_CHECK( \ |
| 10 | + !is_accelerator_detected, \ |
| 11 | + "Cannot have ", \ |
| 12 | + device_type.value(), \ |
| 13 | + " with other accelerators."); \ |
| 14 | + is_accelerator_detected = true; \ |
| 15 | + } |
| 16 | + |
| 17 | + if (is_privateuse1_backend_registered()) { |
| 18 | + // We explicitly allow PrivateUse1 and another device at the same time as we |
| 19 | + // use this for testing. Whenever a PrivateUse1 device is registered, use it |
| 20 | + // first. |
| 21 | + return kPrivateUse1; |
| 22 | + } |
| 23 | + std::optional<DeviceType> device_type = std::nullopt; |
| 24 | + bool is_accelerator_detected = false; |
| 25 | + DETECT_AND_ASSIGN_ACCELERATOR(CUDA) |
| 26 | + DETECT_AND_ASSIGN_ACCELERATOR(MTIA) |
| 27 | + DETECT_AND_ASSIGN_ACCELERATOR(XPU) |
| 28 | + if (checked) { |
| 29 | + TORCH_CHECK( |
| 30 | + device_type, "Cannot access accelerator device when none is available.") |
| 31 | + } |
| 32 | + return device_type; |
| 33 | + |
| 34 | +#undef DETECT_AND_ASSIGN_ACCELERATOR |
36 | 35 | }
|
37 | 36 |
|
38 |
| - |
39 | 37 | } // namespace at
|
0 commit comments