Skip to content

Commit b968cca

Browse files
committed
Replace GetComputationClientOrDie() with macro for throwing. (Part 2)
1 parent 8274f94 commit b968cca

File tree

8 files changed

+172
-113
lines changed

8 files changed

+172
-113
lines changed

torch_xla/csrc/aten_xla_bridge.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,10 @@ class AtenXlaDeviceMapper {
5757
devices_.emplace_back(ParseDeviceString("SPMD:0"));
5858
devices_ordinals_[devices_.back()] = 0;
5959
} else {
60-
for (auto& device_str :
61-
torch_xla::runtime::GetComputationClientOrDie()->GetLocalDevices()) {
60+
XLA_ASSIGN_OR_THROW(
61+
runtime::ComputationClient * absl_nonnull const client,
62+
runtime::GetComputationClient());
63+
for (auto& device_str : client->GetLocalDevices()) {
6264
devices_.emplace_back(ParseDeviceString(device_str));
6365
devices_ordinals_[devices_.back()] = devices_.size() - 1;
6466
}
@@ -398,11 +400,15 @@ std::string ToXlaString(const c10::Device& device) {
398400
}
399401

400402
const torch::lazy::BackendDevice* GetDefaultDevice() {
401-
static std::string default_device_spec =
402-
UseVirtualDevice()
403-
? "SPMD:0"
404-
: runtime::GetComputationClientOrDie()->GetDefaultDevice();
405-
XLA_CHECK(!default_device_spec.empty());
403+
static std::string default_device_spec = []() -> std::string {
404+
if (UseVirtualDevice()) {
405+
return "SPMD:0";
406+
}
407+
XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client,
408+
runtime::GetComputationClient());
409+
return client->GetDefaultDevice();
410+
}();
411+
ABSL_CHECK(!default_device_spec.empty());
406412
static const torch::lazy::BackendDevice default_device =
407413
ParseDeviceString(default_device_spec);
408414
return &default_device;

torch_xla/csrc/cross_replica_reduces.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,8 +333,9 @@ at::Tensor all_to_all_single(const at::Tensor& input,
333333
bool pin_layout = false;
334334
const torch::lazy::Value& token =
335335
GetAllReduceToken(bridge::GetCurrentDevice());
336-
int64_t split_count =
337-
runtime::GetComputationClientOrDie()->GetAllDevices().size();
336+
XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client,
337+
runtime::GetComputationClient());
338+
int64_t split_count = client->GetAllDevices().size();
338339
std::vector<int64_t> all_groups(split_count);
339340
std::iota(all_groups.begin(), all_groups.end(), 0);
340341

torch_xla/csrc/dl_convertor.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,9 @@ DLManagedTensor* toDLPack(const at::Tensor& input) {
125125
ABSL_CHECK(handle != nullptr)
126126
<< "Could not extract a valid data handle from the input tensor";
127127

128-
std::shared_ptr<xla::PjRtBuffer> pjrt_buffer =
129-
runtime::GetComputationClientOrDie()->GetPjRtBuffer(handle);
128+
XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client,
129+
runtime::GetComputationClient());
130+
std::shared_ptr<xla::PjRtBuffer> pjrt_buffer = client->GetPjRtBuffer(handle);
130131
ABSL_CHECK(pjrt_buffer != nullptr) << "Could not get a valid pjrt_buffer";
131132

132133
ABSL_CHECK(!pjrt_buffer->IsTuple())
@@ -169,11 +170,13 @@ DLManagedTensor* toDLPack(const at::Tensor& input) {
169170
// Reference: https://github.com/openxla/xla/blob/main/xla/python/dlpack.cc
170171
absl::StatusOr<xla::PjRtDevice*> DeviceForDLDevice(const DLDevice& context) {
171172
switch (context.device_type) {
172-
case DLDeviceType::kDLCPU:
173-
XLA_CHECK_EQ(runtime::GetComputationClientOrDie()->GetPlatformID(),
174-
xla::CpuId());
175-
return runtime::GetComputationClientOrDie()->LookupAddressableDevice(
176-
context.device_id);
173+
case DLDeviceType::kDLCPU: {
174+
XLA_ASSIGN_OR_RETURN(
175+
runtime::ComputationClient * absl_nonnull const client,
176+
runtime::GetComputationClient());
177+
XLA_CHECK_EQ(client->GetPlatformID(), xla::CpuId());
178+
return client->LookupAddressableDevice(context.device_id);
179+
}
177180
default:
178181
return tsl::errors::InvalidArgument(
179182
"Unknown/unsupported DLPack device type %d", context.device_type);
@@ -330,10 +333,11 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) {
330333
shape, *device->default_memory_space(), on_delete_callback));
331334
ABSL_CHECK(pjrt_buffer.get() != nullptr) << "pjrt buffer is null.";
332335

336+
XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client,
337+
runtime::GetComputationClient());
333338
runtime::ComputationClient::DataPtr data =
334339
runtime::PjRtComputationClient::CreateData(
335-
runtime::GetComputationClientOrDie()->PjRtDeviceToString(device),
336-
shape, std::move(pjrt_buffer));
340+
client->PjRtDeviceToString(device), shape, std::move(pjrt_buffer));
337341

338342
at::ScalarType tensor_type = at::toScalarType(dlmt->dl_tensor.dtype);
339343
XLATensorPtr xla_tensor = XLATensor::Create(data, tensor_type);

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 125 additions & 70 deletions
Large diffs are not rendered by default.

torch_xla/csrc/ir_dump_util.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -274,15 +274,14 @@ std::string DumpUtil::ToHlo(c10::ArrayRef<torch::lazy::Value> values,
274274
xla::Shape shape = MakeShapeWithDeviceLayout(
275275
program_shape.result(), static_cast<XlaDeviceType>(device.type()));
276276
std::vector<runtime::ComputationClient::CompileInstance> instances;
277-
instances.push_back(
278-
{std::move(computation), device.toString(),
279-
runtime::GetComputationClientOrDie()->GetCompilationDevices(
280-
device.toString(), {}),
281-
&shape,
282-
/*parameter_is_tupled_arguments=*/false, is_sharded});
277+
XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client,
278+
runtime::GetComputationClient());
279+
instances.push_back({std::move(computation), device.toString(),
280+
client->GetCompilationDevices(device.toString(), {}),
281+
&shape,
282+
/*parameter_is_tupled_arguments=*/false, is_sharded});
283283
std::vector<std::shared_ptr<runtime::ComputationClient::Computation>>
284-
computations =
285-
runtime::GetComputationClientOrDie()->Compile(std::move(instances));
284+
computations = client->Compile(std::move(instances));
286285
computation = std::move(computations[0]->move_computation());
287286
}
288287

torch_xla/csrc/ops/device_data.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ DeviceData::DeviceData(std::shared_ptr<torch::lazy::BackendData> data)
1616
/*num_outputs=*/1,
1717
/*hash_seed=*/(uint32_t)101),
1818
data_(std::move(data)) {
19-
std::optional<xla::OpSharding> op_sharding =
20-
torch_xla::runtime::GetComputationClientOrDie()->GetDataSharding(
21-
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(data_));
19+
XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client,
20+
runtime::GetComputationClient());
21+
std::optional<xla::OpSharding> op_sharding = client->GetDataSharding(
22+
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(data_));
2223
if (op_sharding.has_value()) {
2324
// DeviceData Node only has 1 output.
2425
SetSharding(op_sharding.value(), 0);

torch_xla/csrc/runtime/runtime.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,14 @@ const absl::StatusOr<ComputationClient * absl_nonnull>& GetComputationClient() {
6060
return maybe_client;
6161
}
6262

63-
ComputationClient* absl_nonnull GetComputationClientOrDie() {
64-
XLA_ASSIGN_OR_THROW(ComputationClient * client, GetComputationClient());
65-
return client;
66-
}
67-
6863
ComputationClient* GetComputationClientIfInitialized() {
69-
return g_computation_client_initialized ? GetComputationClientOrDie()
70-
: nullptr;
64+
if (!g_computation_client_initialized) {
65+
return nullptr;
66+
}
67+
const absl::StatusOr<ComputationClient* absl_nonnull>& client =
68+
GetComputationClient();
69+
XLA_CHECK_OK(client);
70+
return client.value();
7171
}
7272

7373
} // namespace torch_xla::runtime

torch_xla/csrc/runtime/runtime.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,6 @@ namespace torch_xla::runtime {
1010
// Returns the ComputationClient singleton.
1111
const absl::StatusOr<ComputationClient * absl_nonnull>& GetComputationClient();
1212

13-
ABSL_DEPRECATED(
14-
"Use GetComputationClient(), instead. "
15-
"This function throws an exception on error, instead of "
16-
"actually handling the StatusOr return value, which is "
17-
"safer.")
18-
ComputationClient* absl_nonnull GetComputationClientOrDie();
19-
2013
// Returns the ComputationClient singleton if it was successfully initialized.
2114
// Returns a nullptr if the ComputationClient wasn't initialized yet.
2215
// Throws an exception if the ComputationClient was initialized but the

0 commit comments

Comments
 (0)