Skip to content

Commit

Permalink
fix dtype and some minor bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
yitongh committed Dec 11, 2024
1 parent feb8a27 commit 6304bb7
Showing 1 changed file with 72 additions and 18 deletions.
90 changes: 72 additions & 18 deletions torch_xla/_dynamo/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ def _maybe_move_tensors_to_device(tensors: tuple,
moved_tensors.append(tensor)
continue

if dynamo_debug:
print("Moving Tensor {} to device {}".format(tensor, target_device))
# if dynamo_debug:
# print("Moving Tensor {} to device {}".format(tensor, target_device))

zero_copy_enabled = xu.getenv_as(xenv.ZERO_COPY_ENABLED, bool, defval=False)
if zero_copy_enabled and tensor.device.type == 'cuda' and target_device.type == 'xla':
Expand Down Expand Up @@ -479,11 +479,14 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule,

with alias_with_buffer_donor_config() as saved_config:
# calculate graph hash
graph_hash = torch_xla._XLAC._get_graph_hash(args_and_out_tensor_only)
if len(args_and_out_tensor_only) == 0:
graph_hash = None
else:
graph_hash = torch_xla._XLAC._get_graph_hash(args_and_out_tensor_only)
# compiles and cache graph rooted at tensors in 'args_and_out_tensor_only'
torch_xla._XLAC._xla_warm_up_cache(args_and_out_tensor_only, [])
if dynamo_debug:
print("Graph Hash: ", graph_hash)
# compiles and cache graph rooted at tensors in 'args_and_out_tensor_only'
torch_xla._XLAC._xla_warm_up_cache(args_and_out_tensor_only, [])

# Restore the origional `xla_args`. Dynamo passed the real tensor as
# `xla_args`` and we performend the tracing on them. During the tracing,
Expand All @@ -503,10 +506,17 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule,
# mistakenlly update the input tensors.
torch_xla._XLAC._clear_pending_irs(str(xm.xla_device()))

xla_args_dtype = []
for arg in xla_args:
if isinstance(arg, torch.Tensor):
xla_args_dtype.append(arg.dtype)
else:
xla_args_dtype.append(None)

vars_to_return = (xla_args_sharding_spec, len(args_and_out), graph_hash,
arg_index_to_need_update_index, none_remover,
graph_input_matcher, special_return_handler,
xla_args_need_update)
xla_args_need_update, xla_args_dtype)
# populate the cache
sym_constants_to_graph_vars[sym_constants] = vars_to_return

Expand Down Expand Up @@ -539,7 +549,7 @@ def extract_internal(xla_model: torch.fx.GraphModule):
(xla_args_sharding_spec, len_args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
special_return_handler,
xla_args_need_update) = extract_graph_helper(xla_model,
xla_args_need_update, xla_args_dtype) = extract_graph_helper(xla_model,
sym_constants_to_graph_vars)
skip_checking_input_sharding_threashold = xu.getenv_as(
'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5)
Expand All @@ -548,14 +558,16 @@ def optimized_mod(*args: tuple):
nonlocal xla_model
nonlocal skip_checking_input_sharding_threashold
nonlocal sym_constants_to_graph_vars
nonlocal graph_hash

if graph_hash is None:
return xla_model(*args)

original_device: torch.device = _get_input_arg_device(args)
is_cuda_args: bool = False
if original_device:
is_cuda_args = original_device.type == "cuda"

if is_cuda_args:
args = _maybe_move_tensors_to_device(args, xm.xla_device())

# See [Note: Dynamo real-time input-shape cache look-up] above.
xla_args_tensor_only, sym_constants = _split_xla_args_tensor_sym_constant(
Expand All @@ -564,16 +576,24 @@ def optimized_mod(*args: tuple):
(xla_args_sharding_spec, len_args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
special_return_handler,
xla_args_need_update) = sym_constants_to_graph_vars[sym_constants]
xla_args_need_update, xla_args_dtype) = sym_constants_to_graph_vars[sym_constants]
else:
xla_model.xla_args = args
(xla_args_sharding_spec, len_args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
special_return_handler, xla_args_need_update) = extract_graph_helper(
special_return_handler, xla_args_need_update, xla_args_dtype) = extract_graph_helper(
xla_model, sym_constants_to_graph_vars)
if hasattr(xla_model, 'xla_args'):
delattr(xla_model, 'xla_args')

args = list(args)
for index, arg in enumerate(args):
if isinstance(arg, torch.Tensor) and arg.dtype != xla_args_dtype[index]:
args[index] = arg.to(xla_args_dtype[index])
if is_cuda_args:
args = _maybe_move_tensors_to_device(args, xm.xla_device())
xla_args_tensor_only, sym_constants = _split_xla_args_tensor_sym_constant(
args)
if not config.skip_input_data_check:
# mark_step needs to be blocking since we want to access args's XLADatas
# and they can't be placeholder.
Expand Down Expand Up @@ -627,8 +647,11 @@ def optimized_mod(*args: tuple):
result = res[len(xla_args_need_update):]

none_remover.add_nones(result)
if is_cuda_args:
result = _maybe_move_tensors_to_device(tuple(result), original_device)

# TODO: better fix this, input is not cuda tensor, output is cuda tensor
# if is_cuda_args:
original_device = torch.device(torch.cuda.current_device())
result = _maybe_move_tensors_to_device(tuple(result), original_device)

if len(result) == 1:
return result[0]
Expand Down Expand Up @@ -673,6 +696,14 @@ def all_tensors_on_xla_device(value):
# Not a tensor nor a container.
return True

def have_any_tensor(value):
if isinstance(value, torch.Tensor):
return True
if isinstance(value, (list, tuple)):
return any(have_any_tensor(v) for v in value)
# Not a tensor nor a container.
return False

# Check whether the current node is supported or not.
#
# A supported node has the following characteristics:
Expand All @@ -689,8 +720,14 @@ def all_tensors_on_xla_device(value):

# If the current node is NOT supported, we add it to
# the _unsupported_nodes list.
result_have_tensor = have_any_tensor(result)
args_have_tensor = any(
have_any_tensor(v)
for v in itertools.chain(args, kwargs.values()))
if not (result_is_supported and args_are_supported):
self._unsupported_nodes.append(n)
elif not (result_have_tensor or args_have_tensor):
self._unsupported_nodes.append(n)

# Restore this metric counter
torch_xla._XLAC._xla_increment_counter(
Expand Down Expand Up @@ -739,7 +776,6 @@ def allow_cpu_device(self, node: torch.fx.Node):
def move_cuda_to_xla(self, graph: torch.fx.Graph):
constructors = []
for node in graph.nodes:
# print(f'node.kwargs: {node.kwargs.get("device")} {node.kwargs.get("device") == torch.device("cuda")}')
device = node.kwargs.get("device")
if device is None or device.type != "cuda":
continue
Expand All @@ -750,10 +786,27 @@ def move_cuda_to_xla(self, graph: torch.fx.Graph):
kwargs = node.kwargs.copy()
kwargs["device"] = self.target
node.kwargs = kwargs

def __call__(self, graph: torch.fx.Graph) -> None:
self.move_cuda_to_xla(graph)
super().__call__(graph)

def move_xla_to_cuda(self, graph: torch.fx.Graph):
constructors = []
for node in graph.nodes:
device = node.kwargs.get("device")
if device is None or device != self.target:
continue
constructors.append(node)

for node in constructors:
kwargs = node.kwargs.copy()
kwargs["device"] = "cuda"
node.kwargs = kwargs


def __call__(self, graph: torch.fx.Graph, move_xla_to_cuda=False) -> None:
if move_xla_to_cuda:
self.move_xla_to_cuda(graph)
else:
self.move_cuda_to_xla(graph)
super().__call__(graph)


def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
Expand Down Expand Up @@ -831,6 +884,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
node.replace_all_uses_with(new_node)
partitioned_graph.graph.erase_node(node)

XLAConstructorMoverPass()(partitioned_graph.graph, move_xla_to_cuda=True)
partitioned_graph.recompile()

return partitioned_graph
Expand Down

0 comments on commit 6304bb7

Please sign in to comment.