From 7d0f5407c97391a5db1d2f8ec27d11e3d1785e5e Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Wed, 21 Aug 2024 09:13:04 +0900 Subject: [PATCH] feat: lowering replace aten.full_like with aten.full (#3077) Co-authored-by: Dheeraj Peri --- core/runtime/execute_engine.cpp | 7 +- .../dynamo/conversion/aten_ops_converters.py | 1 + .../dynamo/conversion/impl/full.py | 31 +++++++-- .../dynamo/lowering/_decompositions.py | 14 ++-- .../lowering/passes/_aten_lowering_pass.py | 2 + .../passes/replace_full_like_with_full.py | 62 +++++++++++++++++ .../py/dynamo/lowering/test_decompositions.py | 68 +++++++++++++++++-- 7 files changed, 166 insertions(+), 19 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/replace_full_like_with_full.py diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index ef5585e723..1ec234e0bc 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -294,7 +294,12 @@ std::vector execute_engine(std::vector inputs, c10::intr } } - auto current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart + auto current_device_id = -1; + if (inputs.size() > 0) { + current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart + } else if (outputs.size() > 0) { + current_device_id = outputs[0].device().index(); // Done this way to avoid a call to cudart + } compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id); if (compiled_engine->engine_stream == c10::cuda::getDefaultCUDAStream(current_device_id)) { diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 43dd0dbde7..61b1f38242 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -3861,4 +3861,5 @@ def aten_ops_full( name, shape=args[0], fill_value=args[1], + dtype=kwargs.get("dtype", None), ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/full.py b/py/torch_tensorrt/dynamo/conversion/impl/full.py index d211cef532..34a2af564f 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/full.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/full.py @@ -2,7 +2,9 @@ import numpy as np import tensorrt as trt +import torch from torch.fx.node import Target +from torch_tensorrt import _enums from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( @@ -20,23 +22,39 @@ def full( name: str, shape: Union[List[int], TRTTensor], fill_value: Union[int, float, bool], + dtype: Union[torch.dtype, trt.DataType] = None, ) -> TRTTensor: + fill_value_tensor = torch.tensor(fill_value) + if dtype is None: + output_dtype = _enums.dtype._from(fill_value_tensor.dtype) + else: + output_dtype = _enums.dtype._from(dtype) # in static shape scenario, shape is a list of int if isinstance(shape, List): # in static shape scenario, shape is a list of int if all(isinstance(dim, int) for dim in shape): - return np.full(shape, fill_value) + output_np_dtype = output_dtype.try_to(np.dtype, use_default=True) + return np.full(shape, fill_value, dtype=output_np_dtype) else: shape = impl.cat.cat( ctx, target, source_ir, name + "_concat_shape", shape, 0 ) - # in dynamic shape scenario, shape is a shap tensor + # in dynamic shape scenario, shape is a shape tensor # use IFillLayer to fill the shape tensor with LINSPACE value - layer = ctx.net.add_fill(shape.shape, trt.FillOperation.LINSPACE, shape.dtype) + layer = ctx.net.add_fill( + shape.shape, trt.FillOperation.LINSPACE, trt.DataType.INT32 + ) layer.set_input(0, shape) - layer.set_input(1, get_trt_tensor(ctx, 0, name + "_start", min_rank=0)) - delta = get_trt_tensor(ctx, 1, name + "_delta") + layer.set_input( + 1, get_trt_tensor(ctx, 0, name + "_start", dtype=trt.DataType.INT32, min_rank=0) + ) + delta = get_trt_tensor( + ctx, + 1, + name + "_delta", + dtype=trt.DataType.INT32, + ) input = [] for _ in range(shape.shape[0]): input.append(delta) @@ -46,7 +64,8 @@ def full( # fill the output tensor with the actual fill_value output = impl.elementwise.mul(ctx, target, source_ir, name + "_mul", output, 0) - if isinstance(fill_value, (int, float)): + # https://stackoverflow.com/questions/37888620/comparing-boolean-and-int-using-isinstance + if type(fill_value) in (int, float): if isinstance(fill_value, float): output = cast_trt_tensor( ctx, output, trt.float32, name + "_casted", target, source_ir diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 2729e38ff5..378d407416 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -168,7 +168,7 @@ def var_decomposition( @register_torch_trt_decomposition( torch.ops.aten.empty_permuted.default, registry=TORCH_TRT_DECOMPOSITIONS ) -def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor: +def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor: # type: ignore empty_size = args[0] empty_permute = args[1] perm = [0] * len(empty_size) @@ -188,7 +188,7 @@ def slice_scatter_decomposition( start: Optional[int] = None, end: Optional[int] = None, step: Optional[int] = None, -): +) -> torch.Tensor: dim_size = input_tensor.shape[dim] start = get_positive_dim(start, input_tensor.shape[dim]) if end is None: @@ -197,6 +197,11 @@ def slice_scatter_decomposition( if step is None: step = 1 + # Ensure start, end, and step are all integers + assert isinstance(start, int), "start must be an integer" + assert isinstance(end, int), "end must be an integer" + assert isinstance(step, int), "step must be an integer" + src_dim = src_tensor.shape # step == 0 is not a valid torch case # also src_dim should be equal to slice dimension @@ -233,7 +238,7 @@ def select_scatter_decomposition( @register_torch_trt_decomposition( torch.ops.aten.empty_strided.default, registry=TORCH_TRT_DECOMPOSITIONS ) -def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor: +def empty_strided_decomposition(*args, **kwargs) -> torch.Tensor: # type: ignore empty_size = args[0] empty_stride = args[1] return torch.as_strided( @@ -256,8 +261,7 @@ def scatter_add_decomposition( src_shape = list(src_tensor.shape) src_dim = src_shape[dim] for i in range(0, src_dim): - to_scatter_tensor = torch.zeros_like(input_tensor) - + to_scatter_tensor = torch.zeros(input_tensor.shape, dtype=input_tensor.dtype) # index and src slice src_slice = torch.select(src_tensor, dim, i) index_slice = torch.select(index, dim, i) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index a5a1bc5818..dc76ca8036 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -11,6 +11,7 @@ from .remove_detach import remove_detach from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones from .repair_input_as_output import repair_input_as_output +from .replace_full_like_with_full import replace_full_like_with_full from .replace_max_pool_with_indices import replace_max_pool_with_indices from .view_to_reshape import view_to_reshape @@ -23,6 +24,7 @@ lower_linear, fuse_prims_broadcast, replace_max_pool_with_indices, + replace_full_like_with_full, view_to_reshape, ] ) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/replace_full_like_with_full.py b/py/torch_tensorrt/dynamo/lowering/passes/replace_full_like_with_full.py new file mode 100644 index 0000000000..5c6f0028c0 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/replace_full_like_with_full.py @@ -0,0 +1,62 @@ +import logging + +import torch +import torch.fx +from torch_tensorrt.dynamo._defaults import default_device +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) +from torch_tensorrt.dynamo.utils import to_torch_device + +logger = logging.getLogger(__name__) + + +def replace_full_like_with_full( + gm: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + """Replace full_like nodes with equivalent full nodes""" + modified_graph = False + + for node in gm.graph.nodes: + if node.target == torch.ops.aten.full_like.default: + modified_graph = True + + # Extract arguments from full_like + input_tensor = node.args[0] + fill_value = node.args[1] + input_dtype = None + input_shape = None + input_device = to_torch_device(default_device()) + if "val" in input_tensor.meta: + input_dtype = input_tensor.meta["val"].dtype + input_device = input_tensor.meta["val"].device + input_shape = list(input_tensor.meta["val"].shape) + elif "tensor_meta" in input_tensor.meta: + input_dtype = input_tensor.meta["tensor_meta"].dtype + input_shape = list(input_tensor.meta["tensor_meta"].shape) + + # There's no memory format argument for torch.full. + # Set the input_device and dtype correspondingly. + new_kwargs = {} + for key, val in node.kwargs.items(): + if key != "memory_format": + new_kwargs[key] = val + new_kwargs["device"] = input_device + new_kwargs["dtype"] = input_dtype + # Replace full_like with full, using the shape as a list + input_nodes = (input_shape, fill_value) + with gm.graph.inserting_after(node): + full_node = gm.graph.call_function( + torch.ops.aten.full.default, + args=input_nodes, + kwargs=new_kwargs, + ) + full_node.meta = node.meta + + node.replace_all_uses_with(full_node) + gm.graph.erase_node(node) + + if modified_graph: + gm = clean_up_graph_after_modifications(gm) + + return gm diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index a1416c00db..74ac6cde62 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -421,6 +421,66 @@ def forward(self, x): f"MaxPool3d TRT outputs don't match with the original model.", ) + def test_lowering_full_like_module(self): + class FullLike(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x): + y = torch.full_like(x, 2.0) + return y + + # Operations expected to be removed in the traced graph after decompositions + expected_ops = {torch.ops.aten.full.default} + unexpected_ops = {torch.ops.aten.full_like.default} + + inputs = [torch.randn(3, 3, dtype=torch.float32).cuda()] + + fx_graph = torch.fx.symbolic_trace(FullLike()) + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEqual( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEqual( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + truncate_double=True, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"FullLike TRT outputs don't match with the original model.", + ) + def test_lowering_empty_like_module(self): class emptyLike(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: @@ -976,7 +1036,7 @@ def forward(self, input): 0, torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]).cuda(), torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32).cuda(), - {torch.ops.aten.add.Tensor, torch.ops.aten.scatter.src}, + {torch.ops.aten.add.Tensor}, ), ( "scatter_add_one_dim_indexOne_constant", @@ -985,8 +1045,6 @@ def forward(self, input): torch.tensor([[1, 2, 3, 1]], dtype=torch.int32).cuda(), { torch.ops.aten.add.Tensor, - torch.ops.aten.scatter.src, - torch.ops.aten.full_like.default, }, ), ( @@ -996,8 +1054,6 @@ def forward(self, input): torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32).cuda(), { torch.ops.aten.add.Tensor, - torch.ops.aten.scatter.src, - torch.ops.aten.full_like.default, }, ), ( @@ -1009,8 +1065,6 @@ def forward(self, input): ).cuda(), { torch.ops.aten.add.Tensor, - torch.ops.aten.scatter.src, - torch.ops.aten.full_like.default, }, ), ]