@@ -125,8 +125,9 @@ DLManagedTensor* toDLPack(const at::Tensor& input) {
125
125
ABSL_CHECK (handle != nullptr )
126
126
<< " Could not extract a valid data handle from the input tensor" ;
127
127
128
- std::shared_ptr<xla::PjRtBuffer> pjrt_buffer =
129
- runtime::GetComputationClientOrDie ()->GetPjRtBuffer (handle);
128
+ XLA_ASSIGN_OR_THROW (runtime::ComputationClient * absl_nonnull const client,
129
+ runtime::GetComputationClient ());
130
+ std::shared_ptr<xla::PjRtBuffer> pjrt_buffer = client->GetPjRtBuffer (handle);
130
131
ABSL_CHECK (pjrt_buffer != nullptr ) << " Could not get a valid pjrt_buffer" ;
131
132
132
133
ABSL_CHECK (!pjrt_buffer->IsTuple ())
@@ -169,11 +170,13 @@ DLManagedTensor* toDLPack(const at::Tensor& input) {
169
170
// Reference: https://github.com/openxla/xla/blob/main/xla/python/dlpack.cc
170
171
absl::StatusOr<xla::PjRtDevice*> DeviceForDLDevice (const DLDevice& context) {
171
172
switch (context.device_type ) {
172
- case DLDeviceType::kDLCPU :
173
- XLA_CHECK_EQ (runtime::GetComputationClientOrDie ()->GetPlatformID (),
174
- xla::CpuId ());
175
- return runtime::GetComputationClientOrDie ()->LookupAddressableDevice (
176
- context.device_id );
173
+ case DLDeviceType::kDLCPU : {
174
+ XLA_ASSIGN_OR_RETURN (
175
+ runtime::ComputationClient * absl_nonnull const client,
176
+ runtime::GetComputationClient ());
177
+ XLA_CHECK_EQ (client->GetPlatformID (), xla::CpuId ());
178
+ return client->LookupAddressableDevice (context.device_id );
179
+ }
177
180
default :
178
181
return tsl::errors::InvalidArgument (
179
182
" Unknown/unsupported DLPack device type %d" , context.device_type );
@@ -330,10 +333,11 @@ at::Tensor fromDLPack(DLManagedTensor* dlmt) {
330
333
shape, *device->default_memory_space (), on_delete_callback));
331
334
ABSL_CHECK (pjrt_buffer.get () != nullptr ) << " pjrt buffer is null." ;
332
335
336
+ XLA_ASSIGN_OR_THROW (runtime::ComputationClient * absl_nonnull const client,
337
+ runtime::GetComputationClient ());
333
338
runtime::ComputationClient::DataPtr data =
334
339
runtime::PjRtComputationClient::CreateData (
335
- runtime::GetComputationClientOrDie ()->PjRtDeviceToString (device),
336
- shape, std::move (pjrt_buffer));
340
+ client->PjRtDeviceToString (device), shape, std::move (pjrt_buffer));
337
341
338
342
at::ScalarType tensor_type = at::toScalarType (dlmt->dl_tensor .dtype );
339
343
XLATensorPtr xla_tensor = XLATensor::Create (data, tensor_type);
0 commit comments