diff --git a/torch_xla/csrc/runtime/disc_computation_client.cc b/torch_xla/csrc/runtime/disc_computation_client.cc index dbf5ca065c4..6465551dbde 100644 --- a/torch_xla/csrc/runtime/disc_computation_client.cc +++ b/torch_xla/csrc/runtime/disc_computation_client.cc @@ -17,6 +17,7 @@ #include "mlir/Pass/Pass.h" // from @llvm-project #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/disc/disc_compile.h" +#include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/sys_util.h" #include "xla/hlo/ir/hlo_module.h" @@ -97,6 +98,10 @@ DISCComputationClient::DISCComputationClient() { world_size_ = sys_util::GetEnvInt("WORLD_SIZE", 1); local_rank_ = sys_util::GetEnvInt("LOCAL_RANK", 0); global_rank_ = sys_util::GetEnvInt("RANK", local_rank_); + device_type_ = sys_util::GetEnvString(env::kEnvDISCDevice, "CUDA"); + if (device_type_ != "CUDA") { + XLA_ERROR() << "Only CUDA device is supported by DISC backend"; + } } DISCComputationClient::~DISCComputationClient() {} @@ -362,7 +367,7 @@ std::map DISCComputationClient::GetMetrics() const { } std::string DISCComputationClient::GetDefaultDevice() const { - return absl::StrCat(DefaultDevicePrefix, std::to_string(local_rank_)); + return absl::StrCat(device_type_, ":", std::to_string(local_rank_)); } std::vector DISCComputationClient::GetLocalDevices() const { @@ -390,8 +395,7 @@ std::vector DISCComputationClient::GetAllDevices() const { std::vector all_devices; int device_count = world_size_; for (int idx = 0; idx < device_count; idx++) { - all_devices.push_back( - absl::StrCat(DefaultDevicePrefix, std::to_string(idx))); + all_devices.push_back(absl::StrCat(device_type_, ":", std::to_string(idx))); } return all_devices; } diff --git a/torch_xla/csrc/runtime/disc_computation_client.h b/torch_xla/csrc/runtime/disc_computation_client.h index 8df3a6c0e73..0701d7b3591 100644 --- a/torch_xla/csrc/runtime/disc_computation_client.h +++ b/torch_xla/csrc/runtime/disc_computation_client.h @@ -11,8 +11,6 @@ namespace runtime { class DISCComputationClient : public ComputationClient { public: - const std::string DefaultDevicePrefix = "GPU:"; - DISCComputationClient(); ~DISCComputationClient(); @@ -142,6 +140,7 @@ class DISCComputationClient : public ComputationClient { int world_size_; int local_rank_; int global_rank_; + std::string device_type_; struct DISCData : public Data { DISCData(std::string device, xla::Shape device_shape) : Data(std::move(device), std::move(device_shape)) {}