@@ -149,7 +149,6 @@ def _maybe_move_tensors_to_device(tensors: tuple,
149
149
target_device : torch .device ) -> tuple :
150
150
assert target_device , "Moving tensors to None device not supported"
151
151
152
- already_mark_step = False
153
152
device_id = None
154
153
155
154
moved_tensors = []
@@ -171,12 +170,10 @@ def _maybe_move_tensors_to_device(tensors: tuple,
171
170
device_type , device_id = tensor .__dlpack_device__ ()
172
171
moved_tensor = torch_xla_dlpack .from_dlpack (tensor .detach ())
173
172
elif zero_copy_enabled and tensor .device .type == 'xla' and target_device .type == 'cuda' :
174
- # mark_step is need to make sure the pjrt buffer is valid.
175
- if not already_mark_step :
176
- xm .mark_step ()
177
- already_mark_step = True
178
- device_id = tensor .device .index
179
173
moved_tensor = torch_xla_dlpack .from_xla_cuda_to_cuda (tensor )
174
+ # HACK: The `torch_xla._XLAC._get_stream_for_cuda_device` requires a local device index, while the device index for xla tensors is always 0.
175
+ # Meanwhile, dlpack uses the actual device index, so we use the device index of the converted CUDA tensor.
176
+ device_id = moved_tensor .device .index
180
177
else :
181
178
# Have to move to CPU before moving it to target device.
182
179
cpu_device : torch .device = torch .device ("cpu" )
@@ -189,9 +186,6 @@ def _maybe_move_tensors_to_device(tensors: tuple,
189
186
moved_tensors .append (moved_tensor )
190
187
191
188
if zero_copy_enabled and device_id is not None :
192
- # device_id = tensor.device.index
193
- # print(f"device_id: {device_id}")
194
- device_id = 0
195
189
stream = torch_xla ._XLAC ._get_stream_for_cuda_device (device_id )
196
190
stream = 1 if stream == 0 else stream
197
191
assert stream is None or type (stream ) is int
@@ -274,17 +268,15 @@ class SpecialReturnHandler:
274
268
275
269
def __init__ (self , trace_inputs , trace_outputs ,
276
270
trace_inputs_inplace_update_bool , constant_outputs_and_indexes ):
277
- self .trace_inputs = trace_inputs
278
- self .trace_outputs = trace_outputs
279
271
self .constant_outputs_and_indexes = constant_outputs_and_indexes
280
272
281
273
# dedup the traced outputs first
282
274
self .deduper = Deduper ()
283
- self .deduped_trace_outputs = self .deduper .dedup (self . trace_outputs )
275
+ self .deduped_trace_outputs = self .deduper .dedup (trace_outputs )
284
276
285
277
# record the output that is also a input
286
278
trace_inputs_id2pos = {
287
- id (x ): pos for pos , x in enumerate (self . trace_inputs )
279
+ id (x ): pos for pos , x in enumerate (trace_inputs )
288
280
}
289
281
self .trace_outputs_pos_to_inputs_pos = []
290
282
for out_pos , out in enumerate (self .deduped_trace_outputs ):
@@ -511,7 +503,7 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule,
511
503
# mistakenlly update the input tensors.
512
504
torch_xla ._XLAC ._clear_pending_irs (str (xm .xla_device ()))
513
505
514
- vars_to_return = (xla_args_sharding_spec , args_and_out , graph_hash ,
506
+ vars_to_return = (xla_args_sharding_spec , len ( args_and_out ) , graph_hash ,
515
507
arg_index_to_need_update_index , none_remover ,
516
508
graph_input_matcher , special_return_handler ,
517
509
xla_args_need_update )
@@ -544,7 +536,7 @@ def extract_internal(xla_model: torch.fx.GraphModule):
544
536
sym_constants_to_graph_vars : Dict [Tuple [Union [int , float ], ...],
545
537
Tuple [Any , ...]] = {}
546
538
547
- (xla_args_sharding_spec , args_and_out , graph_hash ,
539
+ (xla_args_sharding_spec , len_args_and_out , graph_hash ,
548
540
arg_index_to_need_update_index , none_remover , graph_input_matcher ,
549
541
special_return_handler ,
550
542
xla_args_need_update ) = extract_graph_helper (xla_model ,
@@ -569,16 +561,18 @@ def optimized_mod(*args: tuple):
569
561
xla_args_tensor_only , sym_constants = _split_xla_args_tensor_sym_constant (
570
562
args )
571
563
if sym_constants in sym_constants_to_graph_vars :
572
- (xla_args_sharding_spec , args_and_out , graph_hash ,
564
+ (xla_args_sharding_spec , len_args_and_out , graph_hash ,
573
565
arg_index_to_need_update_index , none_remover , graph_input_matcher ,
574
566
special_return_handler ,
575
567
xla_args_need_update ) = sym_constants_to_graph_vars [sym_constants ]
576
568
else :
577
569
xla_model .xla_args = args
578
- (xla_args_sharding_spec , args_and_out , graph_hash ,
570
+ (xla_args_sharding_spec , len_args_and_out , graph_hash ,
579
571
arg_index_to_need_update_index , none_remover , graph_input_matcher ,
580
572
special_return_handler , xla_args_need_update ) = extract_graph_helper (
581
573
xla_model , sym_constants_to_graph_vars )
574
+ if hasattr (xla_model , 'xla_args' ):
575
+ delattr (xla_model , 'xla_args' )
582
576
583
577
if not config .skip_input_data_check :
584
578
# mark_step needs to be blocking since we want to access args's XLADatas
@@ -614,15 +608,16 @@ def optimized_mod(*args: tuple):
614
608
else :
615
609
skip_checking_input_sharding_threashold -= 1
616
610
617
- if len ( args_and_out ) == 0 :
611
+ if len_args_and_out == 0 :
618
612
return ()
619
613
620
614
# graph input should be tensor only
621
615
graph_input = graph_input_matcher (xla_args_tensor_only )
622
616
res = torch_xla ._XLAC ._run_cached_graph (graph_hash , graph_input )
617
+ xm .wait_device_ops ()
623
618
res = special_return_handler .addDumbReturn (xla_args_tensor_only , res )
624
619
625
- assert len (res ) == len ( args_and_out ) , f"{ len (res )} v.s. { len ( args_and_out ) } "
620
+ assert len (res ) == len_args_and_out , f"{ len (res )} v.s. { len_args_and_out } "
626
621
ncopy = 0
627
622
628
623
for arg_index , res_index in arg_index_to_need_update_index .items ():
@@ -640,6 +635,11 @@ def optimized_mod(*args: tuple):
640
635
else :
641
636
return result
642
637
638
+ if hasattr (xla_model , 'xla_args' ):
639
+ delattr (xla_model , 'xla_args' )
640
+
641
+ torch_xla ._XLAC ._clear_pending_irs (str (xm .xla_device ()))
642
+
643
643
if dynamo_debug :
644
644
print (
645
645
'=================== OpenXLA Dynamo Compile Debug End =====================\n '
0 commit comments