Skip to content

Commit

Permalink
fix: Error with aten.view across Tensor memory (#2464)
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive authored Nov 27, 2023
1 parent 867dc7b commit 60bfa04
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
from .repair_input_as_output import repair_input_as_output
from .replace_max_pool_with_indices import replace_max_pool_with_indices
from .view_to_reshape import view_to_reshape

ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
[
Expand All @@ -21,6 +22,7 @@
lower_linear,
fuse_prims_broadcast,
replace_max_pool_with_indices,
view_to_reshape,
]
)

Expand Down
41 changes: 41 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import logging
from typing import Callable, List, Sequence, Tuple

import torch
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)

logger = logging.getLogger(__name__)


def view_to_reshape(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
"""Replace aten.view with an equivalent implementation which avoids Tensor memory issues"""
orig, replacement = view_replacement()

if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")

return gm


def view_replacement() -> (
Tuple[
torch.fx.GraphModule,
Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
]
):
"""Constructs the original and replacement functions for view"""

# Original graph
def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
return torch.ops.aten.view.default(input, shape)

# Replacement graph
def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
return torch.ops.aten.reshape.default(input, shape)

return orig, replacement
68 changes: 67 additions & 1 deletion tests/py/dynamo/lowering/test_aten_lowering_passes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
import torch_tensorrt
from torch.testing._internal.common_utils import TestCase, run_tests

import torch_tensorrt

from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing


Expand Down Expand Up @@ -375,5 +376,70 @@ def forward(self, input, weight, bias):
torch._dynamo.reset()


class TestLowerViewToReshape(TestCase):
def test_view_to_reshape(self):
class ViewToReshape(torch.nn.Module):
def forward(self, input):
out = torch.ops.aten.view.default(input, (1, 1, -1))
return out

inputs = [
torch.rand((3, 4, 5, 32)).cuda(),
]

fx_graph = torch.fx.symbolic_trace(ViewToReshape())
expected_ops = {torch.ops.aten.reshape.default}
unexpected_ops = {
torch.ops.aten.view.default,
}

unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEquals(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEquals(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)
torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
pass_through_build_failures=True,
)
optimized_model_results = torch.cat(
[tensor.detach().cpu() for tensor in optimized_model(*inputs)]
)
torch_model_results = torch.cat(
[tensor.detach().cpu() for tensor in fx_graph(*inputs)]
)

max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
msg=f"ViewToReshape TRT outputs don't match with the original model.",
)
torch._dynamo.reset()


if __name__ == "__main__":
run_tests()

0 comments on commit 60bfa04

Please sign in to comment.