Skip to content

Commit

Permalink
fix dynamo inplace copy (pytorch#7933)
Browse files Browse the repository at this point in the history
  • Loading branch information
zpcore authored and yitongh committed Dec 11, 2024
1 parent 088b845 commit 4c87fa4
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 12 deletions.
108 changes: 108 additions & 0 deletions test/pjrt/test_collective_ops_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,5 +130,113 @@ def test_all_to_all(self, pin_layout):
list(range(world_size))]])


# Test for collective ops from torch.distributed
class TestDistCollectiveOpsTpu(parameterized.TestCase):

@staticmethod
def _all_reduce(use_dynamo: bool):
met.clear_all()

def callable(input):
dist.all_reduce(input, dist.ReduceOp.SUM)
return input

dist.init_process_group("xla", init_method='xla://')
device = xm.xla_device()
input = torch.tensor([xr.global_ordinal()],
dtype=torch.float,
device=device)

f = torch.compile(callable, backend='openxla') if use_dynamo else callable
f(input)
torch_xla.sync()
if not use_dynamo:
assert 'xla::AllReduceInPlace' in met.counter_names(
) or 'xla::AllReduce' in met.counter_names()
else:
assert 'xla::all_reduce' in met.counter_names()
return input.cpu()

@staticmethod
def _all_gather_into_tensor(use_dynamo: bool):
met.clear_all()

def callable(output, input):
dist.all_gather_into_tensor(output_tensor, input, None)
return output_tensor

dist.init_process_group("xla", init_method='xla://')
device = xm.xla_device()
input = torch.tensor([xr.global_ordinal()],
dtype=torch.float,
device=device)
output_tensor = torch.empty((1, xr.world_size()), device=device)
f = torch.compile(callable, backend='openxla') if use_dynamo else callable
f(output_tensor, input)
torch_xla.sync()
if not use_dynamo:
assert 'xla::AllGather' in met.counter_names(
) or 'xla::AllGatherOut' in met.counter_names()
else:
assert 'xla::all_gather_into_tensor' in met.counter_names()
return output_tensor.cpu()

@staticmethod
def _all_gather(use_dynamo: bool):
met.clear_all()
dist.init_process_group("xla", init_method='xla://')
device = xm.xla_device()

def callable(input):
output_tensor = [
torch.tensor([0], dtype=torch.float).to(device)
for _ in range(xr.world_size())
]
dist.all_gather(output_tensor, input, None)
return output_tensor

input = torch.tensor([xr.global_ordinal()],
dtype=torch.float,
device=device)

f = torch.compile(callable, backend='openxla') if use_dynamo else callable
output = f(input)
torch_xla.sync()
if not use_dynamo:
assert 'xla::AllGather' in met.counter_names(
) or 'xla::AllGatherOut' in met.counter_names()
else:
assert 'xla::all_gather_into_tensor' in met.counter_names()
# output is list of tensors
return pytree.tree_map(lambda x: x.cpu(), output)

@parameterized.named_parameters(('dynamo', True), ('nondynamo', False))
def test_all_reduce(self, use_dynamo):
results = pjrt.run_multiprocess(self._all_reduce, use_dynamo=use_dynamo)
expected = torch.tensor([sum(range(tpu.num_expected_global_devices()))],
dtype=torch.float)
for index, val in results.items():
torch.testing.assert_close(val, expected)

@parameterized.named_parameters(('dynamo', True), ('nondynamo', False))
def test_all_gather_into_tensor(self, use_dynamo):
results = pjrt.run_multiprocess(
self._all_gather_into_tensor, use_dynamo=use_dynamo)
expected = torch.arange(
tpu.num_expected_global_devices(), dtype=torch.float).unsqueeze(0)
for index, val in results.items():
torch.testing.assert_close(val, expected)

@parameterized.named_parameters(('dynamo', True), ('nondynamo', False))
def test_all_gather(self, use_dynamo):
results = pjrt.run_multiprocess(self._all_gather, use_dynamo=use_dynamo)
expected = [
torch.tensor([i], dtype=torch.float)
for i in range(tpu.num_expected_global_devices())
]
for index, val in results.items():
torch.testing.assert_close(val, expected)


if __name__ == '__main__':
absltest.main()
29 changes: 17 additions & 12 deletions torch_xla/_dynamo/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,18 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
return extract_compiled_graph_helper(xla_model, xla_args)


def _clear_pending_irs_on_args(args_tensor_only, cloned_args):
# if args_tensor_only has pending IR which means there is a in place operations
# happened. We don't want to execute that operation yet, so we will replace the
# pending IR with the cloned arg.
args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization(
args_tensor_only)

for i, need_update in enumerate(args_need_update_bool):
if need_update and isinstance(args_tensor_only[i], torch.Tensor):
args_tensor_only[i].copy_(cloned_args[i])


def partition_fx_graph_for_cpu_fallback(xla_model, xla_args, all_xla_args,
all_xla_args_tensor_only):
# below logic will try to partition the fx graph based on the fallback ops.
Expand All @@ -739,18 +751,8 @@ def partition_fx_graph_for_cpu_fallback(xla_model, xla_args, all_xla_args,
print('Dynamo fallback ops are' + str(unsupported_nodes) +
'. Please open a GitHub issue with the above op lowering requests.')

# This logic, needed for supporting in-place operations, is a duplicate of
# the one in the main `extract_internal` function above. We need to do this
# check for fetching fallback ops as well.
# TODO (@wonjoo): Make this duplicate code a bit cleaner.
args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization(
all_xla_args_tensor_only)

# Again, same logic in the `extract_internal` above to support in-place operations.
# TODO (@wonjoo): Make this duplicate code a bit cleaner.
for i, need_update in enumerate(args_need_update_bool):
if need_update and isinstance(all_xla_args_tensor_only[i], torch.Tensor):
all_xla_args_tensor_only[i].copy_(cloned_args[i])
# UnsupportedNodesCollector might trigger in place ops, need to clear them here.
_clear_pending_irs_on_args(all_xla_args_tensor_only, cloned_args)

torch_xla._XLAC._clear_pending_irs(str(xm.xla_device()))

Expand All @@ -775,6 +777,9 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
partitioned_graph = partitioner.fuse_partitions(partitions)
InputCollector(partitioned_graph).run(*xla_args)

# InputCollector might trigger in place ops, need to clear them here.
_clear_pending_irs_on_args(all_xla_args_tensor_only, cloned_args)

# compile each submodule and replace it with a call
for node in partitioned_graph.graph.nodes:
if node.op == "call_module" and "fused_" in node.name:
Expand Down

0 comments on commit 4c87fa4

Please sign in to comment.