diff --git a/torch_xla/_dynamo/dynamo_bridge.py b/torch_xla/_dynamo/dynamo_bridge.py index 3cf65a98fae..c0599f0ad96 100644 --- a/torch_xla/_dynamo/dynamo_bridge.py +++ b/torch_xla/_dynamo/dynamo_bridge.py @@ -161,8 +161,8 @@ def _maybe_move_tensors_to_device(tensors: tuple, moved_tensors.append(tensor) continue - if dynamo_debug: - print("Moving Tensor {} to device {}".format(tensor, target_device)) + # if dynamo_debug: + # print("Moving Tensor {} to device {}".format(tensor, target_device)) zero_copy_enabled = xu.getenv_as(xenv.ZERO_COPY_ENABLED, bool, defval=False) if zero_copy_enabled and tensor.device.type == 'cuda' and target_device.type == 'xla': @@ -479,11 +479,14 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule, with alias_with_buffer_donor_config() as saved_config: # calculate graph hash - graph_hash = torch_xla._XLAC._get_graph_hash(args_and_out_tensor_only) + if len(args_and_out_tensor_only) == 0: + graph_hash = None + else: + graph_hash = torch_xla._XLAC._get_graph_hash(args_and_out_tensor_only) + # compiles and cache graph rooted at tensors in 'args_and_out_tensor_only' + torch_xla._XLAC._xla_warm_up_cache(args_and_out_tensor_only, []) if dynamo_debug: print("Graph Hash: ", graph_hash) - # compiles and cache graph rooted at tensors in 'args_and_out_tensor_only' - torch_xla._XLAC._xla_warm_up_cache(args_and_out_tensor_only, []) # Restore the origional `xla_args`. Dynamo passed the real tensor as # `xla_args`` and we performend the tracing on them. During the tracing, @@ -503,10 +506,17 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule, # mistakenlly update the input tensors. torch_xla._XLAC._clear_pending_irs(str(xm.xla_device())) + xla_args_dtype = [] + for arg in xla_args: + if isinstance(arg, torch.Tensor): + xla_args_dtype.append(arg.dtype) + else: + xla_args_dtype.append(None) + vars_to_return = (xla_args_sharding_spec, len(args_and_out), graph_hash, arg_index_to_need_update_index, none_remover, graph_input_matcher, special_return_handler, - xla_args_need_update) + xla_args_need_update, xla_args_dtype) # populate the cache sym_constants_to_graph_vars[sym_constants] = vars_to_return @@ -539,7 +549,7 @@ def extract_internal(xla_model: torch.fx.GraphModule): (xla_args_sharding_spec, len_args_and_out, graph_hash, arg_index_to_need_update_index, none_remover, graph_input_matcher, special_return_handler, - xla_args_need_update) = extract_graph_helper(xla_model, + xla_args_need_update, xla_args_dtype) = extract_graph_helper(xla_model, sym_constants_to_graph_vars) skip_checking_input_sharding_threashold = xu.getenv_as( 'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5) @@ -548,14 +558,16 @@ def optimized_mod(*args: tuple): nonlocal xla_model nonlocal skip_checking_input_sharding_threashold nonlocal sym_constants_to_graph_vars + nonlocal graph_hash + + if graph_hash is None: + return xla_model(*args) original_device: torch.device = _get_input_arg_device(args) is_cuda_args: bool = False if original_device: is_cuda_args = original_device.type == "cuda" - if is_cuda_args: - args = _maybe_move_tensors_to_device(args, xm.xla_device()) # See [Note: Dynamo real-time input-shape cache look-up] above. xla_args_tensor_only, sym_constants = _split_xla_args_tensor_sym_constant( @@ -564,16 +576,24 @@ def optimized_mod(*args: tuple): (xla_args_sharding_spec, len_args_and_out, graph_hash, arg_index_to_need_update_index, none_remover, graph_input_matcher, special_return_handler, - xla_args_need_update) = sym_constants_to_graph_vars[sym_constants] + xla_args_need_update, xla_args_dtype) = sym_constants_to_graph_vars[sym_constants] else: xla_model.xla_args = args (xla_args_sharding_spec, len_args_and_out, graph_hash, arg_index_to_need_update_index, none_remover, graph_input_matcher, - special_return_handler, xla_args_need_update) = extract_graph_helper( + special_return_handler, xla_args_need_update, xla_args_dtype) = extract_graph_helper( xla_model, sym_constants_to_graph_vars) if hasattr(xla_model, 'xla_args'): delattr(xla_model, 'xla_args') + args = list(args) + for index, arg in enumerate(args): + if isinstance(arg, torch.Tensor) and arg.dtype != xla_args_dtype[index]: + args[index] = arg.to(xla_args_dtype[index]) + if is_cuda_args: + args = _maybe_move_tensors_to_device(args, xm.xla_device()) + xla_args_tensor_only, sym_constants = _split_xla_args_tensor_sym_constant( + args) if not config.skip_input_data_check: # mark_step needs to be blocking since we want to access args's XLADatas # and they can't be placeholder. @@ -627,8 +647,11 @@ def optimized_mod(*args: tuple): result = res[len(xla_args_need_update):] none_remover.add_nones(result) - if is_cuda_args: - result = _maybe_move_tensors_to_device(tuple(result), original_device) + + # TODO: better fix this, input is not cuda tensor, output is cuda tensor + # if is_cuda_args: + original_device = torch.device(torch.cuda.current_device()) + result = _maybe_move_tensors_to_device(tuple(result), original_device) if len(result) == 1: return result[0] @@ -673,6 +696,14 @@ def all_tensors_on_xla_device(value): # Not a tensor nor a container. return True + def have_any_tensor(value): + if isinstance(value, torch.Tensor): + return True + if isinstance(value, (list, tuple)): + return any(have_any_tensor(v) for v in value) + # Not a tensor nor a container. + return False + # Check whether the current node is supported or not. # # A supported node has the following characteristics: @@ -689,8 +720,14 @@ def all_tensors_on_xla_device(value): # If the current node is NOT supported, we add it to # the _unsupported_nodes list. + result_have_tensor = have_any_tensor(result) + args_have_tensor = any( + have_any_tensor(v) + for v in itertools.chain(args, kwargs.values())) if not (result_is_supported and args_are_supported): self._unsupported_nodes.append(n) + elif not (result_have_tensor or args_have_tensor): + self._unsupported_nodes.append(n) # Restore this metric counter torch_xla._XLAC._xla_increment_counter( @@ -739,7 +776,6 @@ def allow_cpu_device(self, node: torch.fx.Node): def move_cuda_to_xla(self, graph: torch.fx.Graph): constructors = [] for node in graph.nodes: - # print(f'node.kwargs: {node.kwargs.get("device")} {node.kwargs.get("device") == torch.device("cuda")}') device = node.kwargs.get("device") if device is None or device.type != "cuda": continue @@ -750,10 +786,27 @@ def move_cuda_to_xla(self, graph: torch.fx.Graph): kwargs = node.kwargs.copy() kwargs["device"] = self.target node.kwargs = kwargs - - def __call__(self, graph: torch.fx.Graph) -> None: - self.move_cuda_to_xla(graph) - super().__call__(graph) + + def move_xla_to_cuda(self, graph: torch.fx.Graph): + constructors = [] + for node in graph.nodes: + device = node.kwargs.get("device") + if device is None or device != self.target: + continue + constructors.append(node) + + for node in constructors: + kwargs = node.kwargs.copy() + kwargs["device"] = "cuda" + node.kwargs = kwargs + + + def __call__(self, graph: torch.fx.Graph, move_xla_to_cuda=False) -> None: + if move_xla_to_cuda: + self.move_xla_to_cuda(graph) + else: + self.move_cuda_to_xla(graph) + super().__call__(graph) def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): @@ -831,6 +884,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: node.replace_all_uses_with(new_node) partitioned_graph.graph.erase_node(node) + XLAConstructorMoverPass()(partitioned_graph.graph, move_xla_to_cuda=True) partitioned_graph.recompile() return partitioned_graph