Skip to content

Commit feb8a27

Browse files
committed
fix memory issue and opt sync time
1 parent bdd8cb9 commit feb8a27

File tree

3 files changed

+21
-21
lines changed

3 files changed

+21
-21
lines changed

torch_xla/_dynamo/dynamo_bridge.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ def _maybe_move_tensors_to_device(tensors: tuple,
149149
target_device: torch.device) -> tuple:
150150
assert target_device, "Moving tensors to None device not supported"
151151

152-
already_mark_step = False
153152
device_id = None
154153

155154
moved_tensors = []
@@ -171,12 +170,10 @@ def _maybe_move_tensors_to_device(tensors: tuple,
171170
device_type, device_id = tensor.__dlpack_device__()
172171
moved_tensor = torch_xla_dlpack.from_dlpack(tensor.detach())
173172
elif zero_copy_enabled and tensor.device.type == 'xla' and target_device.type == 'cuda':
174-
# mark_step is need to make sure the pjrt buffer is valid.
175-
if not already_mark_step:
176-
xm.mark_step()
177-
already_mark_step = True
178-
device_id = tensor.device.index
179173
moved_tensor = torch_xla_dlpack.from_xla_cuda_to_cuda(tensor)
174+
# HACK: The `torch_xla._XLAC._get_stream_for_cuda_device` requires a local device index, while the device index for xla tensors is always 0.
175+
# Meanwhile, dlpack uses the actual device index, so we use the device index of the converted CUDA tensor.
176+
device_id = moved_tensor.device.index
180177
else:
181178
# Have to move to CPU before moving it to target device.
182179
cpu_device: torch.device = torch.device("cpu")
@@ -189,9 +186,6 @@ def _maybe_move_tensors_to_device(tensors: tuple,
189186
moved_tensors.append(moved_tensor)
190187

191188
if zero_copy_enabled and device_id is not None:
192-
# device_id = tensor.device.index
193-
# print(f"device_id: {device_id}")
194-
device_id = 0
195189
stream = torch_xla._XLAC._get_stream_for_cuda_device(device_id)
196190
stream = 1 if stream == 0 else stream
197191
assert stream is None or type(stream) is int
@@ -274,17 +268,15 @@ class SpecialReturnHandler:
274268

275269
def __init__(self, trace_inputs, trace_outputs,
276270
trace_inputs_inplace_update_bool, constant_outputs_and_indexes):
277-
self.trace_inputs = trace_inputs
278-
self.trace_outputs = trace_outputs
279271
self.constant_outputs_and_indexes = constant_outputs_and_indexes
280272

281273
# dedup the traced outputs first
282274
self.deduper = Deduper()
283-
self.deduped_trace_outputs = self.deduper.dedup(self.trace_outputs)
275+
self.deduped_trace_outputs = self.deduper.dedup(trace_outputs)
284276

285277
# record the output that is also a input
286278
trace_inputs_id2pos = {
287-
id(x): pos for pos, x in enumerate(self.trace_inputs)
279+
id(x): pos for pos, x in enumerate(trace_inputs)
288280
}
289281
self.trace_outputs_pos_to_inputs_pos = []
290282
for out_pos, out in enumerate(self.deduped_trace_outputs):
@@ -511,7 +503,7 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule,
511503
# mistakenlly update the input tensors.
512504
torch_xla._XLAC._clear_pending_irs(str(xm.xla_device()))
513505

514-
vars_to_return = (xla_args_sharding_spec, args_and_out, graph_hash,
506+
vars_to_return = (xla_args_sharding_spec, len(args_and_out), graph_hash,
515507
arg_index_to_need_update_index, none_remover,
516508
graph_input_matcher, special_return_handler,
517509
xla_args_need_update)
@@ -544,7 +536,7 @@ def extract_internal(xla_model: torch.fx.GraphModule):
544536
sym_constants_to_graph_vars: Dict[Tuple[Union[int, float], ...],
545537
Tuple[Any, ...]] = {}
546538

547-
(xla_args_sharding_spec, args_and_out, graph_hash,
539+
(xla_args_sharding_spec, len_args_and_out, graph_hash,
548540
arg_index_to_need_update_index, none_remover, graph_input_matcher,
549541
special_return_handler,
550542
xla_args_need_update) = extract_graph_helper(xla_model,
@@ -569,16 +561,18 @@ def optimized_mod(*args: tuple):
569561
xla_args_tensor_only, sym_constants = _split_xla_args_tensor_sym_constant(
570562
args)
571563
if sym_constants in sym_constants_to_graph_vars:
572-
(xla_args_sharding_spec, args_and_out, graph_hash,
564+
(xla_args_sharding_spec, len_args_and_out, graph_hash,
573565
arg_index_to_need_update_index, none_remover, graph_input_matcher,
574566
special_return_handler,
575567
xla_args_need_update) = sym_constants_to_graph_vars[sym_constants]
576568
else:
577569
xla_model.xla_args = args
578-
(xla_args_sharding_spec, args_and_out, graph_hash,
570+
(xla_args_sharding_spec, len_args_and_out, graph_hash,
579571
arg_index_to_need_update_index, none_remover, graph_input_matcher,
580572
special_return_handler, xla_args_need_update) = extract_graph_helper(
581573
xla_model, sym_constants_to_graph_vars)
574+
if hasattr(xla_model, 'xla_args'):
575+
delattr(xla_model, 'xla_args')
582576

583577
if not config.skip_input_data_check:
584578
# mark_step needs to be blocking since we want to access args's XLADatas
@@ -614,15 +608,16 @@ def optimized_mod(*args: tuple):
614608
else:
615609
skip_checking_input_sharding_threashold -= 1
616610

617-
if len(args_and_out) == 0:
611+
if len_args_and_out == 0:
618612
return ()
619613

620614
# graph input should be tensor only
621615
graph_input = graph_input_matcher(xla_args_tensor_only)
622616
res = torch_xla._XLAC._run_cached_graph(graph_hash, graph_input)
617+
xm.wait_device_ops()
623618
res = special_return_handler.addDumbReturn(xla_args_tensor_only, res)
624619

625-
assert len(res) == len(args_and_out), f"{len(res)} v.s. {len(args_and_out)}"
620+
assert len(res) == len_args_and_out, f"{len(res)} v.s. {len_args_and_out}"
626621
ncopy = 0
627622

628623
for arg_index, res_index in arg_index_to_need_update_index.items():
@@ -640,6 +635,11 @@ def optimized_mod(*args: tuple):
640635
else:
641636
return result
642637

638+
if hasattr(xla_model, 'xla_args'):
639+
delattr(xla_model, 'xla_args')
640+
641+
torch_xla._XLAC._clear_pending_irs(str(xm.xla_device()))
642+
643643
if dynamo_debug:
644644
print(
645645
'=================== OpenXLA Dynamo Compile Debug End =====================\n'

torch_xla/csrc/dl_convertor.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ DLManagedTensor* toDLPack(const at::Tensor& input) {
138138
auto external_ref = pjrt_buffer->AcquireExternalReference();
139139
XLA_CHECK_OK(external_ref.status());
140140
pack->external_reference = std::move(external_ref.value());
141-
XLA_CHECK_OK(pjrt_buffer->GetReadyFuture().Await());
141+
// XLA_CHECK_OK(pjrt_buffer->GetReadyFuture().Await());
142142
}
143143
pack->buffer_reference = pjrt_buffer;
144144

torch_xla/csrc/runtime/pjrt_computation_client.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class PjRtComputationClient : public ComputationClient {
110110
xla::PjRtLocalDeviceId(local_device_id));
111111
XLA_CHECK(pjrt_device.ok()) << "Failed to get a PjRt device.";
112112
absl::StatusOr<std::intptr_t> stream =
113-
pjrt_device.value()->GetStreamForExternalReadyEvents();
113+
pjrt_device.value()->GetLocalComputeStream();
114114
XLA_CHECK(stream.ok()) << "Failed to get a stream.";
115115
return stream.value();
116116
}

0 commit comments

Comments
 (0)