From b968ccad105e7eb9080f3a0126c589ab5e876a65 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 3 Sep 2025 12:40:37 -0300 Subject: [PATCH] Replace `GetComputationClientOrDie()` with macro for throwing. (Part 2) --- torch_xla/csrc/aten_xla_bridge.cpp | 20 ++- torch_xla/csrc/cross_replica_reduces.cpp | 5 +- torch_xla/csrc/dl_convertor.cpp | 22 +-- torch_xla/csrc/init_python_bindings.cpp | 195 +++++++++++++++-------- torch_xla/csrc/ir_dump_util.cpp | 15 +- torch_xla/csrc/ops/device_data.cpp | 7 +- torch_xla/csrc/runtime/runtime.cpp | 14 +- torch_xla/csrc/runtime/runtime.h | 7 - 8 files changed, 172 insertions(+), 113 deletions(-) diff --git a/torch_xla/csrc/aten_xla_bridge.cpp b/torch_xla/csrc/aten_xla_bridge.cpp index 8af6f5816756..42d396e9ac2c 100644 --- a/torch_xla/csrc/aten_xla_bridge.cpp +++ b/torch_xla/csrc/aten_xla_bridge.cpp @@ -57,8 +57,10 @@ class AtenXlaDeviceMapper { devices_.emplace_back(ParseDeviceString("SPMD:0")); devices_ordinals_[devices_.back()] = 0; } else { - for (auto& device_str : - torch_xla::runtime::GetComputationClientOrDie()->GetLocalDevices()) { + XLA_ASSIGN_OR_THROW( + runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + for (auto& device_str : client->GetLocalDevices()) { devices_.emplace_back(ParseDeviceString(device_str)); devices_ordinals_[devices_.back()] = devices_.size() - 1; } @@ -398,11 +400,15 @@ std::string ToXlaString(const c10::Device& device) { } const torch::lazy::BackendDevice* GetDefaultDevice() { - static std::string default_device_spec = - UseVirtualDevice() - ? "SPMD:0" - : runtime::GetComputationClientOrDie()->GetDefaultDevice(); - XLA_CHECK(!default_device_spec.empty()); + static std::string default_device_spec = []() -> std::string { + if (UseVirtualDevice()) { + return "SPMD:0"; + } + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + return client->GetDefaultDevice(); + }(); + ABSL_CHECK(!default_device_spec.empty()); static const torch::lazy::BackendDevice default_device = ParseDeviceString(default_device_spec); return &default_device; diff --git a/torch_xla/csrc/cross_replica_reduces.cpp b/torch_xla/csrc/cross_replica_reduces.cpp index 77519c03cfc2..6d8abd33ad6f 100644 --- a/torch_xla/csrc/cross_replica_reduces.cpp +++ b/torch_xla/csrc/cross_replica_reduces.cpp @@ -333,8 +333,9 @@ at::Tensor all_to_all_single(const at::Tensor& input, bool pin_layout = false; const torch::lazy::Value& token = GetAllReduceToken(bridge::GetCurrentDevice()); - int64_t split_count = - runtime::GetComputationClientOrDie()->GetAllDevices().size(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + int64_t split_count = client->GetAllDevices().size(); std::vector all_groups(split_count); std::iota(all_groups.begin(), all_groups.end(), 0); diff --git a/torch_xla/csrc/dl_convertor.cpp b/torch_xla/csrc/dl_convertor.cpp index c6a68a65f609..274e30f017d5 100644 --- a/torch_xla/csrc/dl_convertor.cpp +++ b/torch_xla/csrc/dl_convertor.cpp @@ -125,8 +125,9 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { ABSL_CHECK(handle != nullptr) << "Could not extract a valid data handle from the input tensor"; - std::shared_ptr pjrt_buffer = - runtime::GetComputationClientOrDie()->GetPjRtBuffer(handle); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + std::shared_ptr pjrt_buffer = client->GetPjRtBuffer(handle); ABSL_CHECK(pjrt_buffer != nullptr) << "Could not get a valid pjrt_buffer"; ABSL_CHECK(!pjrt_buffer->IsTuple()) @@ -169,11 +170,13 @@ DLManagedTensor* toDLPack(const at::Tensor& input) { // Reference: https://github.com/openxla/xla/blob/main/xla/python/dlpack.cc absl::StatusOr DeviceForDLDevice(const DLDevice& context) { switch (context.device_type) { - case DLDeviceType::kDLCPU: - XLA_CHECK_EQ(runtime::GetComputationClientOrDie()->GetPlatformID(), - xla::CpuId()); - return runtime::GetComputationClientOrDie()->LookupAddressableDevice( - context.device_id); + case DLDeviceType::kDLCPU: { + XLA_ASSIGN_OR_RETURN( + runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + XLA_CHECK_EQ(client->GetPlatformID(), xla::CpuId()); + return client->LookupAddressableDevice(context.device_id); + } default: return tsl::errors::InvalidArgument( "Unknown/unsupported DLPack device type %d", context.device_type); @@ -330,10 +333,11 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) { shape, *device->default_memory_space(), on_delete_callback)); ABSL_CHECK(pjrt_buffer.get() != nullptr) << "pjrt buffer is null."; + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); runtime::ComputationClient::DataPtr data = runtime::PjRtComputationClient::CreateData( - runtime::GetComputationClientOrDie()->PjRtDeviceToString(device), - shape, std::move(pjrt_buffer)); + client->PjRtDeviceToString(device), shape, std::move(pjrt_buffer)); at::ScalarType tensor_type = at::toScalarType(dlmt->dl_tensor.dtype); XLATensorPtr xla_tensor = XLATensor::Create(data, tensor_type); diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 1d409850b808..1d205dd86a7e 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -261,7 +261,9 @@ torch::lazy::BackendDevice GetDeviceOrCurrent(const std::string& device_str) { void WaitDeviceOps(absl::Span devices = {}) { XLAGraphExecutor::Get()->WaitDeviceOps(devices); - runtime::GetComputationClientOrDie()->WaitDeviceOps(devices); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + client->WaitDeviceOps(devices); } void PrepareToExit() { @@ -721,8 +723,10 @@ void StepMarker(const std::string& device_str, XLAGraphExecutor::Get()->MarkStep(device, reset_scope); bool debug_mode = runtime::sys_util::GetEnvBool("PT_XLA_DEBUG", false); if (TF_PREDICT_FALSE(debug_mode)) { - std::string report = runtime::metrics::CreatePerformanceReport( - runtime::GetComputationClientOrDie()->GetMetrics()); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + std::string report = + runtime::metrics::CreatePerformanceReport(client->GetMetrics()); if (!report.empty()) { std::string fout = runtime::sys_util::GetEnvString("PT_XLA_DEBUG_FILE", ""); @@ -972,8 +976,9 @@ py::dict GetMemoryInfo(const std::string& device_str) { { NoGilSection nogil; torch::lazy::BackendDevice device = GetDeviceOrCurrent(device_str); - mem_info = - runtime::GetComputationClientOrDie()->GetMemoryInfo(device.toString()); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + mem_info = client->GetMemoryInfo(device.toString()); } auto py_dict = py::dict(); py_dict["bytes_used"] = mem_info.bytes_used; @@ -1283,10 +1288,10 @@ class PyLoweringContext { lowering_ctx.GetParametersData(); // Fetch this parameter data - XLA_ASSIGN_OR_THROW( - std::vector literals, - runtime::GetComputationClientOrDie()->TransferFromDevice( - UnwrapXlaData(device_data))); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + XLA_ASSIGN_OR_THROW(std::vector literals, + client->TransferFromDevice(UnwrapXlaData(device_data))); // Create a mapping from paramater id to the tensor data std::unordered_map results; @@ -1527,10 +1532,11 @@ void InitXlaModuleBindings(py::module m) { xla::Shape global_shape = CreateComputationShapeFromTensor(tensor, nullptr); if (minibatch) { - int num_local_devices = - runtime::GetComputationClientOrDie()->GetLocalDevices().size(); - int num_global_devices = - runtime::GetComputationClientOrDie()->GetAllDevices().size(); + XLA_ASSIGN_OR_THROW( + runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + int num_local_devices = client->GetLocalDevices().size(); + int num_global_devices = client->GetAllDevices().size(); XLA_CHECK(tile_assignment.size() == num_global_devices) << "Minibatch sharding only supports sharding along the batch " "dimension"; @@ -1751,37 +1757,45 @@ void InitXlaModuleBindings(py::module m) { }) .def("_xla_get_devices", []() { + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); if (UseVirtualDevice()) { // Under SPMD context, there is only one virtual devices from // user perspective. - std::vector all_devices = - runtime::GetComputationClientOrDie()->GetAllDevices(); + std::vector all_devices = client->GetAllDevices(); all_devices.resize(1); return all_devices; } else { - return runtime::GetComputationClientOrDie()->GetLocalDevices(); + return client->GetLocalDevices(); } }) .def("_xla_get_platform_version", []() { - return runtime::GetComputationClientOrDie()->GetPlatformVersion(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + return client->GetPlatformVersion(); }) .def("_xla_num_devices", []() -> int64_t { if (UseVirtualDevice()) { return 1; } else { - return runtime::GetComputationClientOrDie()->GetNumLocalDevices(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + return client->GetNumLocalDevices(); } }) .def("_xla_num_global_devices", []() -> int64_t { - return runtime::GetComputationClientOrDie()->GetNumDevices(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + return client->GetNumDevices(); }) .def("_xla_get_all_devices", []() { - std::vector all_devices = - runtime::GetComputationClientOrDie()->GetAllDevices(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + std::vector all_devices = client->GetAllDevices(); if (UseVirtualDevice()) { // Under SPMD context, there is only one virtual devices from // user perspective. @@ -1792,22 +1806,31 @@ void InitXlaModuleBindings(py::module m) { } }) .def("_xla_get_runtime_devices", - []() { return runtime::GetComputationClientOrDie()->GetLocalDevices(); }) + []() { + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + return client->GetLocalDevices(); + }) .def("_xla_num_runtime_devices", []() -> int64_t { - return runtime::GetComputationClientOrDie()->GetNumLocalDevices(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + return client->GetNumLocalDevices(); }) .def("_xla_get_all_runtime_devices", []() { - std::vector all_devices = - runtime::GetComputationClientOrDie()->GetAllDevices(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + std::vector all_devices = client->GetAllDevices(); return all_devices; }) .def( "_xla_real_devices", [](const std::optional> devices) { if (!devices) { - return runtime::GetComputationClientOrDie()->GetLocalDevices(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + return client->GetLocalDevices(); } std::vector xla_devices; @@ -1822,27 +1845,33 @@ void InitXlaModuleBindings(py::module m) { "_xla_device_kind", [](const std::string& device) { auto xla_device = bridge::AtenDeviceToXlaDevice(device).toString(); - return runtime::GetComputationClientOrDie()->GetDeviceKind(xla_device); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + return client->GetDeviceKind(xla_device); }, py::arg("device") = "") .def("_xla_set_replication_devices", [](const std::vector& devices) { auto replication_devices = std::make_shared>(devices); - runtime::GetComputationClientOrDie()->SetReplicationDevices( + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + client->SetReplicationDevices( std::move(replication_devices)); }) .def("_xla_get_replication_devices", []() { - auto replication_devices = - runtime::GetComputationClientOrDie()->GetReplicationDevices(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + auto replication_devices = client->GetReplicationDevices(); return replication_devices != nullptr ? *replication_devices : std::vector(); }) .def("_xla_get_replication_devices_count", []() { - auto replication_devices = - runtime::GetComputationClientOrDie()->GetReplicationDevices(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + auto replication_devices = client->GetReplicationDevices(); return replication_devices != nullptr ? replication_devices->size() : 0; }) @@ -2191,9 +2220,10 @@ void InitXlaModuleBindings(py::module m) { "_xla_create_placeholder_tensor", [](py::object py_shape) { xla::Shape shape = op_builder::PyShapeToShape(py_shape); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); auto xla_tensor = - XLATensor::Create(torch_xla::runtime::GetComputationClientOrDie() - ->CreateDataPlaceholder( + XLATensor::Create(client->CreateDataPlaceholder( bridge::GetCurrentDevice().toString(), std::move(shape))); return bridge::AtenFromXlaTensor(xla_tensor); @@ -2212,9 +2242,17 @@ void InitXlaModuleBindings(py::module m) { return device.ordinal(); }) .def("_xla_get_process_index", - []() { return runtime::GetComputationClientOrDie()->GetProcessIndex(); }) + []() { + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + return client->GetProcessIndex(); + }) .def("_xla_get_num_processes", - []() { return runtime::GetComputationClientOrDie()->GetNumProcesses(); }) + []() { + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + return client->GetNumProcesses(); + }) .def("_xla_get_num_cached_compilation_graph", []() -> int64_t { return XLAGraphExecutor::Get()->GetNumGraphHash(); @@ -2225,10 +2263,12 @@ void InitXlaModuleBindings(py::module m) { }) .def("_xla_get_device_attributes", [](const std::string& device_str) { + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); const absl::flat_hash_map< std::string, runtime::ComputationClient::DeviceAttribute> attributes = - runtime::GetComputationClientOrDie()->GetDeviceAttributes( + client->GetDeviceAttributes( bridge::AtenDeviceToXlaDevice(device_str).toString()); py::dict dict; @@ -2239,14 +2279,15 @@ void InitXlaModuleBindings(py::module m) { }) .def("_xla_get_all_device_attributes", []() { - std::vector global_devices = - runtime::GetComputationClientOrDie()->GetAllDevices(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + std::vector global_devices = client->GetAllDevices(); std::vector list; for (auto const& device : global_devices) { const absl::flat_hash_map< std::string, runtime::ComputationClient::DeviceAttribute>& attributes = - runtime::GetComputationClientOrDie()->GetDeviceAttributes(device); + client->GetDeviceAttributes(device); py::dict dict; for (auto const& [name, value] : attributes) { dict[py::str(name)] = py::cast(value); @@ -2419,9 +2460,11 @@ void InitXlaModuleBindings(py::module m) { // cannot depend on PyTorch (as part of TensorFlow). // TODO(jwtan): Unify them once ComputationClient becomes a // standalone library. + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); return torch::lazy::CreateMetricReport() + runtime::metrics_reader::CreateMetricReport( - runtime::GetComputationClientOrDie()->GetMetrics()); + client->GetMetrics()); }) .def("_short_xla_metrics_report", [](const py::list& counter_names, const py::list& metric_names) { @@ -2689,8 +2732,9 @@ void InitXlaModuleBindings(py::module m) { std::optional>& global_shape) -> at::Tensor { XLA_CHECK(UseVirtualDevice()) << "Please enable SPMD via `torch_xla.runtime.use_spmd()`"; - auto local_devices = - runtime::GetComputationClientOrDie()->GetLocalDevices(); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + auto local_devices = client->GetLocalDevices(); XLA_CHECK(local_devices.size() == shards.size()) << "Must specify a shard for each local device"; XLA_CHECK(!global_shape.has_value() || @@ -2764,6 +2808,8 @@ void InitXlaModuleBindings(py::module m) { std::vector handles; std::vector element_types; // Find all shard handles for transfer + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); for (auto& tensor : input) { XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(tensor)); @@ -2775,7 +2821,7 @@ void InitXlaModuleBindings(py::module m) { std::dynamic_pointer_cast( xtensor->GetXlaData()); std::vector shard_handles = - runtime::GetComputationClientOrDie()->GetDataShards(handle); + client->GetDataShards(handle); handles.insert(handles.end(), shard_handles.begin(), shard_handles.end()); element_types.insert( @@ -2788,8 +2834,7 @@ void InitXlaModuleBindings(py::module m) { XlaDataToTensors(WrapXlaData(handles), element_types)); // Populate the resulting vector of shards and device strings std::vector>> result; - int shards_per_tensor = - runtime::GetComputationClientOrDie()->GetLocalDevices().size(); + int shards_per_tensor = client->GetLocalDevices().size(); result.reserve(cpu_shards.size() / shards_per_tensor); for (int i = 0; i < cpu_shards.size(); i += shards_per_tensor) { std::vector> shard_devices; @@ -2818,6 +2863,8 @@ void InitXlaModuleBindings(py::module m) { [](const std::vector& input_tensors) -> std::vector>> { std::vector>> result; + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); for (auto& tensor : input_tensors) { XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(tensor)); @@ -2827,7 +2874,7 @@ void InitXlaModuleBindings(py::module m) { std::dynamic_pointer_cast( xtensor->GetXlaData()); auto shards = - runtime::GetComputationClientOrDie()->GetDataShards(handle); + client->GetDataShards(handle); std::vector shard_devices; for (auto& shard : shards) { shard_devices.push_back(shard->device()); @@ -2881,8 +2928,9 @@ void InitXlaModuleBindings(py::module m) { bridge::GetXlaTensor(tensor)); XLA_CHECK(xtensor->sharding_spec() != nullptr) << "Cannot load local shards into a non sharded tensor"; - XLA_CHECK(devices.size() == - runtime::GetComputationClientOrDie()->GetLocalDevices().size()) + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + XLA_CHECK(devices.size() == client->GetLocalDevices().size()) << "Shards must be provided for all local devices"; auto sharding = xtensor->sharding_spec()->sharding; auto sharding_spec = xtensor->sharding_spec(); @@ -2907,10 +2955,10 @@ void InitXlaModuleBindings(py::module m) { "_ensure_xla_coordinator_initialized", [](int global_rank, int world_size, std::string master_addr, std::string master_port) { - auto comp_client = runtime::GetComputationClientOrDie(); - if (!comp_client->CoordinatorInitialized()) { - runtime::GetComputationClientOrDie()->InitializeCoordinator( - global_rank, world_size, master_addr, master_port); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + if (!client->CoordinatorInitialized()) { + client->InitializeCoordinator(global_rank, world_size, master_addr, master_port); } }, py::arg("global_rank"), // @@ -2924,10 +2972,11 @@ void InitXlaModuleBindings(py::module m) { // effect. "_activate_preemption_sync_manager", []() { - auto comp_client = runtime::GetComputationClientOrDie(); - XLA_CHECK(comp_client->CoordinatorInitialized()) + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + XLA_CHECK(client->CoordinatorInitialized()) << "Coordinator must be initialized"; - auto& coordinator = comp_client->GetCoordinator(); + auto& coordinator = client->GetCoordinator(); coordinator.ActivatePreemptionSyncManager(); }) .def( @@ -2935,10 +2984,11 @@ void InitXlaModuleBindings(py::module m) { // is active "_deactivate_preemption_sync_manager", []() { - auto comp_client = runtime::GetComputationClientOrDie(); - XLA_CHECK(comp_client->CoordinatorInitialized()) + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + XLA_CHECK(client->CoordinatorInitialized()) << "Coordinator must be initialized"; - auto& coordinator = comp_client->GetCoordinator(); + auto& coordinator = client->GetCoordinator(); coordinator.DeactivatePreemptionSyncManager(); }) .def( @@ -2947,10 +2997,11 @@ void InitXlaModuleBindings(py::module m) { // PreemptionSyncManager activated. "_sync_point_reached", [](int step) { - auto comp_client = runtime::GetComputationClientOrDie(); - XLA_CHECK(comp_client->CoordinatorInitialized()) + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + XLA_CHECK(client->CoordinatorInitialized()) << "Coordinator must be initialized"; - auto& coordinator = comp_client->GetCoordinator(); + auto& coordinator = client->GetCoordinator(); return coordinator.ReachedSyncPoint(step); }) .def("_is_placecholder", @@ -3058,8 +3109,9 @@ void InitXlaModuleBindings(py::module m) { .def("_xla_register_custom_call_target", [](const std::string& fn_name, const py::capsule& function_ptr, const std::string& platform) { - runtime::GetComputationClientOrDie()->RegisterCustomCall( - fn_name, function_ptr.get_pointer(), platform); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + client->RegisterCustomCall(fn_name, function_ptr.get_pointer(), platform); }) .def("_set_xla_custom_op_name_prefix", [](const at::Tensor& input, const std::string& op_name_prefix, @@ -3218,26 +3270,29 @@ void InitXlaModuleBindings(py::module m) { } XLA_ERROR() << "Could not get buffer for tensor"; } - runtime::GetComputationClientOrDie()->OnReadyCallback(data, - callback); + XLA_ASSIGN_OR_THROW( + runtime::ComputationClient * absl_nonnull client, + runtime::GetComputationClient()); + client->OnReadyCallback(data, callback); }) .def("_unsafe_buffer_pointer", [](const at::Tensor& input) -> std::uintptr_t { XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(input)); + XLA_ASSIGN_OR_THROW( + runtime::ComputationClient * absl_nonnull client, + runtime::GetComputationClient()); if (xtensor->CurrentDataHandle() != nullptr) { std::shared_ptr data = std::dynamic_pointer_cast( xtensor->CurrentDataHandle()); - return runtime::GetComputationClientOrDie()->UnsafeBufferPointer( - data); + return client->UnsafeBufferPointer(data); } else if (xtensor->CurrentIrValue().node != nullptr) { DeviceData* device_data = DeviceData::Cast(xtensor->CurrentIrValue().node.get()); if (device_data != nullptr) { torch::lazy::BackendDataPtr data = device_data->data(); - return runtime::GetComputationClientOrDie() - ->UnsafeBufferPointer(UnwrapXlaData(data)); + return client->UnsafeBufferPointer(UnwrapXlaData(data)); } else { XLA_ERROR() << "Could not get the buffer pointer for XLATensor " diff --git a/torch_xla/csrc/ir_dump_util.cpp b/torch_xla/csrc/ir_dump_util.cpp index 3453bd642c38..8a8fcd73dd44 100644 --- a/torch_xla/csrc/ir_dump_util.cpp +++ b/torch_xla/csrc/ir_dump_util.cpp @@ -274,15 +274,14 @@ std::string DumpUtil::ToHlo(c10::ArrayRef values, xla::Shape shape = MakeShapeWithDeviceLayout( program_shape.result(), static_cast(device.type())); std::vector instances; - instances.push_back( - {std::move(computation), device.toString(), - runtime::GetComputationClientOrDie()->GetCompilationDevices( - device.toString(), {}), - &shape, - /*parameter_is_tupled_arguments=*/false, is_sharded}); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + instances.push_back({std::move(computation), device.toString(), + client->GetCompilationDevices(device.toString(), {}), + &shape, + /*parameter_is_tupled_arguments=*/false, is_sharded}); std::vector> - computations = - runtime::GetComputationClientOrDie()->Compile(std::move(instances)); + computations = client->Compile(std::move(instances)); computation = std::move(computations[0]->move_computation()); } diff --git a/torch_xla/csrc/ops/device_data.cpp b/torch_xla/csrc/ops/device_data.cpp index a5f5536b5b67..bc04820a8033 100644 --- a/torch_xla/csrc/ops/device_data.cpp +++ b/torch_xla/csrc/ops/device_data.cpp @@ -16,9 +16,10 @@ DeviceData::DeviceData(std::shared_ptr data) /*num_outputs=*/1, /*hash_seed=*/(uint32_t)101), data_(std::move(data)) { - std::optional op_sharding = - torch_xla::runtime::GetComputationClientOrDie()->GetDataSharding( - std::dynamic_pointer_cast(data_)); + XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client, + runtime::GetComputationClient()); + std::optional op_sharding = client->GetDataSharding( + std::dynamic_pointer_cast(data_)); if (op_sharding.has_value()) { // DeviceData Node only has 1 output. SetSharding(op_sharding.value(), 0); diff --git a/torch_xla/csrc/runtime/runtime.cpp b/torch_xla/csrc/runtime/runtime.cpp index 3836f6975719..4dced2531bd5 100644 --- a/torch_xla/csrc/runtime/runtime.cpp +++ b/torch_xla/csrc/runtime/runtime.cpp @@ -60,14 +60,14 @@ const absl::StatusOr& GetComputationClient() { return maybe_client; } -ComputationClient* absl_nonnull GetComputationClientOrDie() { - XLA_ASSIGN_OR_THROW(ComputationClient * client, GetComputationClient()); - return client; -} - ComputationClient* GetComputationClientIfInitialized() { - return g_computation_client_initialized ? GetComputationClientOrDie() - : nullptr; + if (!g_computation_client_initialized) { + return nullptr; + } + const absl::StatusOr& client = + GetComputationClient(); + XLA_CHECK_OK(client); + return client.value(); } } // namespace torch_xla::runtime diff --git a/torch_xla/csrc/runtime/runtime.h b/torch_xla/csrc/runtime/runtime.h index 6a1588935e6f..9d9c3f59f29e 100644 --- a/torch_xla/csrc/runtime/runtime.h +++ b/torch_xla/csrc/runtime/runtime.h @@ -10,13 +10,6 @@ namespace torch_xla::runtime { // Returns the ComputationClient singleton. const absl::StatusOr& GetComputationClient(); -ABSL_DEPRECATED( - "Use GetComputationClient(), instead. " - "This function throws an exception on error, instead of " - "actually handling the StatusOr return value, which is " - "safer.") -ComputationClient* absl_nonnull GetComputationClientOrDie(); - // Returns the ComputationClient singleton if it was successfully initialized. // Returns a nullptr if the ComputationClient wasn't initialized yet. // Throws an exception if the ComputationClient was initialized but the