Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions torch_xla/csrc/aten_xla_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ class AtenXlaDeviceMapper {
devices_.emplace_back(ParseDeviceString("SPMD:0"));
devices_ordinals_[devices_.back()] = 0;
} else {
for (auto& device_str :
torch_xla::runtime::GetComputationClientOrDie()->GetLocalDevices()) {
XLA_ASSIGN_OR_THROW(
runtime::ComputationClient * absl_nonnull const client,
runtime::GetComputationClient());
for (auto& device_str : client->GetLocalDevices()) {
devices_.emplace_back(ParseDeviceString(device_str));
devices_ordinals_[devices_.back()] = devices_.size() - 1;
}
Expand Down Expand Up @@ -398,11 +400,15 @@ std::string ToXlaString(const c10::Device& device) {
}

const torch::lazy::BackendDevice* GetDefaultDevice() {
static std::string default_device_spec =
UseVirtualDevice()
? "SPMD:0"
: runtime::GetComputationClientOrDie()->GetDefaultDevice();
XLA_CHECK(!default_device_spec.empty());
static std::string default_device_spec = []() -> std::string {
if (UseVirtualDevice()) {
return "SPMD:0";
}
XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client,
runtime::GetComputationClient());
return client->GetDefaultDevice();
}();
ABSL_CHECK(!default_device_spec.empty());
static const torch::lazy::BackendDevice default_device =
ParseDeviceString(default_device_spec);
return &default_device;
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,9 @@ at::Tensor all_to_all_single(const at::Tensor& input,
bool pin_layout = false;
const torch::lazy::Value& token =
GetAllReduceToken(bridge::GetCurrentDevice());
int64_t split_count =
runtime::GetComputationClientOrDie()->GetAllDevices().size();
XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client,
runtime::GetComputationClient());
int64_t split_count = client->GetAllDevices().size();
std::vector<int64_t> all_groups(split_count);
std::iota(all_groups.begin(), all_groups.end(), 0);

Expand Down
22 changes: 13 additions & 9 deletions torch_xla/csrc/dl_convertor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,9 @@ DLManagedTensor* toDLPack(const at::Tensor& input) {
ABSL_CHECK(handle != nullptr)
<< "Could not extract a valid data handle from the input tensor";

std::shared_ptr<xla::PjRtBuffer> pjrt_buffer =
runtime::GetComputationClientOrDie()->GetPjRtBuffer(handle);
XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client,
runtime::GetComputationClient());
std::shared_ptr<xla::PjRtBuffer> pjrt_buffer = client->GetPjRtBuffer(handle);
ABSL_CHECK(pjrt_buffer != nullptr) << "Could not get a valid pjrt_buffer";

ABSL_CHECK(!pjrt_buffer->IsTuple())
Expand Down Expand Up @@ -169,11 +170,13 @@ DLManagedTensor* toDLPack(const at::Tensor& input) {
// Reference: https://github.com/openxla/xla/blob/main/xla/python/dlpack.cc
absl::StatusOr<xla::PjRtDevice*> DeviceForDLDevice(const DLDevice& context) {
switch (context.device_type) {
case DLDeviceType::kDLCPU:
XLA_CHECK_EQ(runtime::GetComputationClientOrDie()->GetPlatformID(),
xla::CpuId());
return runtime::GetComputationClientOrDie()->LookupAddressableDevice(
context.device_id);
case DLDeviceType::kDLCPU: {
XLA_ASSIGN_OR_RETURN(
runtime::ComputationClient * absl_nonnull const client,
runtime::GetComputationClient());
XLA_CHECK_EQ(client->GetPlatformID(), xla::CpuId());
return client->LookupAddressableDevice(context.device_id);
}
default:
return tsl::errors::InvalidArgument(
"Unknown/unsupported DLPack device type %d", context.device_type);
Expand Down Expand Up @@ -330,10 +333,11 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) {
shape, *device->default_memory_space(), on_delete_callback));
ABSL_CHECK(pjrt_buffer.get() != nullptr) << "pjrt buffer is null.";

XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client,
runtime::GetComputationClient());
runtime::ComputationClient::DataPtr data =
runtime::PjRtComputationClient::CreateData(
runtime::GetComputationClientOrDie()->PjRtDeviceToString(device),
shape, std::move(pjrt_buffer));
client->PjRtDeviceToString(device), shape, std::move(pjrt_buffer));

at::ScalarType tensor_type = at::toScalarType(dlmt->dl_tensor.dtype);
XLATensorPtr xla_tensor = XLATensor::Create(data, tensor_type);
Expand Down
195 changes: 125 additions & 70 deletions torch_xla/csrc/init_python_bindings.cpp

Large diffs are not rendered by default.

15 changes: 7 additions & 8 deletions torch_xla/csrc/ir_dump_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,15 +274,14 @@ std::string DumpUtil::ToHlo(c10::ArrayRef<torch::lazy::Value> values,
xla::Shape shape = MakeShapeWithDeviceLayout(
program_shape.result(), static_cast<XlaDeviceType>(device.type()));
std::vector<runtime::ComputationClient::CompileInstance> instances;
instances.push_back(
{std::move(computation), device.toString(),
runtime::GetComputationClientOrDie()->GetCompilationDevices(
device.toString(), {}),
&shape,
/*parameter_is_tupled_arguments=*/false, is_sharded});
XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client,
runtime::GetComputationClient());
instances.push_back({std::move(computation), device.toString(),
client->GetCompilationDevices(device.toString(), {}),
&shape,
/*parameter_is_tupled_arguments=*/false, is_sharded});
std::vector<std::shared_ptr<runtime::ComputationClient::Computation>>
computations =
runtime::GetComputationClientOrDie()->Compile(std::move(instances));
computations = client->Compile(std::move(instances));
computation = std::move(computations[0]->move_computation());
}

Expand Down
7 changes: 4 additions & 3 deletions torch_xla/csrc/ops/device_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ DeviceData::DeviceData(std::shared_ptr<torch::lazy::BackendData> data)
/*num_outputs=*/1,
/*hash_seed=*/(uint32_t)101),
data_(std::move(data)) {
std::optional<xla::OpSharding> op_sharding =
torch_xla::runtime::GetComputationClientOrDie()->GetDataSharding(
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(data_));
XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client,
runtime::GetComputationClient());
std::optional<xla::OpSharding> op_sharding = client->GetDataSharding(
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(data_));
if (op_sharding.has_value()) {
// DeviceData Node only has 1 output.
SetSharding(op_sharding.value(), 0);
Expand Down
14 changes: 7 additions & 7 deletions torch_xla/csrc/runtime/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ const absl::StatusOr<ComputationClient * absl_nonnull>& GetComputationClient() {
return maybe_client;
}

ComputationClient* absl_nonnull GetComputationClientOrDie() {
XLA_ASSIGN_OR_THROW(ComputationClient * client, GetComputationClient());
return client;
}

ComputationClient* GetComputationClientIfInitialized() {
return g_computation_client_initialized ? GetComputationClientOrDie()
: nullptr;
if (!g_computation_client_initialized) {
return nullptr;
}
const absl::StatusOr<ComputationClient* absl_nonnull>& client =
GetComputationClient();
XLA_CHECK_OK(client);
return client.value();
}

} // namespace torch_xla::runtime
7 changes: 0 additions & 7 deletions torch_xla/csrc/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@ namespace torch_xla::runtime {
// Returns the ComputationClient singleton.
const absl::StatusOr<ComputationClient * absl_nonnull>& GetComputationClient();

ABSL_DEPRECATED(
"Use GetComputationClient(), instead. "
"This function throws an exception on error, instead of "
"actually handling the StatusOr return value, which is "
"safer.")
ComputationClient* absl_nonnull GetComputationClientOrDie();

// Returns the ComputationClient singleton if it was successfully initialized.
// Returns a nullptr if the ComputationClient wasn't initialized yet.
// Throws an exception if the ComputationClient was initialized but the
Expand Down