diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index a883018c5e..93ad4655b5 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -10,6 +10,7 @@ from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones from .repair_input_as_output import repair_input_as_output from .replace_max_pool_with_indices import replace_max_pool_with_indices +from .view_to_reshape import view_to_reshape ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist( [ @@ -19,6 +20,7 @@ lower_efficient_attention, fuse_prims_broadcast, replace_max_pool_with_indices, + view_to_reshape, ] ) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py new file mode 100644 index 0000000000..efc836814f --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py @@ -0,0 +1,41 @@ +import logging +from typing import Callable, List, Sequence, Tuple + +import torch +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) + + +def view_to_reshape( + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] +) -> torch.fx.GraphModule: + """Replace aten.view with an equivalent implementation which avoids Tensor memory issues""" + orig, replacement = view_replacement() + + if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement): + gm = clean_up_graph_after_modifications(gm) + logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}") + + return gm + + +def view_replacement() -> ( + Tuple[ + torch.fx.GraphModule, + Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor], + ] +): + """Constructs the original and replacement functions for view""" + + # Original graph + def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor: + return torch.ops.aten.view.default(input, shape) + + # Replacement graph + def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor: + return torch.ops.aten.reshape.default(input, shape) + + return orig, replacement diff --git a/tests/py/dynamo/lowering/test_aten_lowering_passes.py b/tests/py/dynamo/lowering/test_aten_lowering_passes.py index 1bbb54192c..edbe93eddd 100644 --- a/tests/py/dynamo/lowering/test_aten_lowering_passes.py +++ b/tests/py/dynamo/lowering/test_aten_lowering_passes.py @@ -267,5 +267,70 @@ def forward(self, q, k, v): torch._dynamo.reset() +class TestLowerViewToReshape(TestCase): + def test_view_to_reshape(self): + class ViewToReshape(torch.nn.Module): + def forward(self, input): + out = torch.ops.aten.view.default(input, (1, 1, -1)) + return out + + inputs = [ + torch.rand((3, 4, 5, 32)).cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(ViewToReshape()) + expected_ops = {torch.ops.aten.reshape.default} + unexpected_ops = { + torch.ops.aten.view.default, + } + + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = torch.cat( + [tensor.detach().cpu() for tensor in optimized_model(*inputs)] + ) + torch_model_results = torch.cat( + [tensor.detach().cpu() for tensor in fx_graph(*inputs)] + ) + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"ViewToReshape TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + if __name__ == "__main__": run_tests()