From c4c7f451c28f1830843db560b9e0be72124d78ff Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 17 Sep 2024 16:00:38 -0700 Subject: [PATCH] feat: cherry pick of Refit fixes (#3166) --- py/torch_tensorrt/dynamo/_compiler.py | 10 ++- py/torch_tensorrt/dynamo/_refit.py | 3 +- .../dynamo/conversion/_ConverterRegistry.py | 40 +++++++---- .../dynamo/conversion/_TRTInterpreter.py | 8 ++- .../dynamo/conversion/aten_ops_converters.py | 66 +++++++++++++------ .../dynamo/conversion/converter_utils.py | 5 +- .../dynamo/conversion/ops_evaluators.py | 11 ++-- .../dynamo/conversion/prims_ops_converters.py | 5 +- py/torch_tensorrt/dynamo/utils.py | 4 +- tests/py/dynamo/conversion/harness.py | 14 +++- .../py/dynamo/conversion/test_cumsum_aten.py | 4 ++ .../conversion/test_embedding_bag_aten.py | 4 ++ tests/py/dynamo/models/test_model_refit.py | 56 +++++++++++++++- 13 files changed, 174 insertions(+), 56 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 96e5f313ae..97aa2ec443 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -314,11 +314,9 @@ def compile_module( dryrun_tracker = DryRunTracker() if sample_kwarg_inputs is None: sample_kwarg_inputs = {} - # Assume converters support dynamic shapes and disable validation - CONVERTERS.set_dynamic_shape_support(settings.assume_dynamic_shape_support) - # Set torch-executed ops - CONVERTERS.set_disallowed_targets(settings.torch_executed_ops) + # Configure user compilation settings to converters. + CONVERTERS.set_compilation_settings(settings) # Check the number of supported operations in the graph num_supported_ops, total_ops = partitioning.get_graph_converter_support( @@ -670,8 +668,8 @@ def convert_exported_program_to_serialized_trt_engine( settings = CompilationSettings(**compilation_options) logger.info("Compilation Settings: %s\n", settings) - # Assume converters support dynamic shapes and disable validation - CONVERTERS.set_dynamic_shape_support(settings.assume_dynamic_shape_support) + # Configure user compilation settings to converters. + CONVERTERS.set_compilation_settings(settings) try: interpreter_result = interpret_module_to_result( diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 47faf90031..359dc0b3ff 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -6,6 +6,7 @@ from typing import Any, List, Optional, Sequence, Tuple import numpy as np +import tensorrt as trt import torch from torch.export import ExportedProgram from torch_tensorrt._enums import dtype @@ -42,8 +43,6 @@ ) from torch_tensorrt.logging import TRT_LOGGER -import tensorrt as trt - logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py index e4ea91c196..4801834e56 100644 --- a/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py +++ b/py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py @@ -23,6 +23,7 @@ from torch import SymBool, SymFloat, SymInt from torch._ops import OpOverloadPacket from torch.fx.node import Argument, Node, Target, _get_qualified_name +from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.fx.converter_registry import CONVERTERS as FX_CONVERTERS @@ -82,7 +83,9 @@ class ConverterSupport: """ converter_implementation: ConverterImplSignature - capability_validator: Callable[[Node], bool] = field(default=lambda node: True) + capability_validator: Callable[[Node, CompilationSettings], bool] = field( + default=lambda node, compilation_settings: True + ) supports_dynamic_shapes: bool = False @@ -112,10 +115,10 @@ def has_dynamic_shapes_in_args( def has_static_shapes_in_args( arg_positions_to_check: Optional[List[int]] = None, -) -> Callable[[torch.fx.Node], bool]: +) -> Callable[[torch.fx.Node, CompilationSettings], bool]: """Returns True if a node has static inputs in node.args at specified positions""" - _has_static_shapes = lambda node, arg_positions_to_check: not _has_dynamic_shapes( - node, arg_positions_to_check + _has_static_shapes = lambda node, compilation_settings, arg_positions_to_check: not _has_dynamic_shapes( + node, compilation_settings, arg_positions_to_check ) return functools.partial( _has_static_shapes, arg_positions_to_check=arg_positions_to_check @@ -123,7 +126,9 @@ def has_static_shapes_in_args( def _has_dynamic_shapes( - node: torch.fx.Node, arg_positions_to_check: Optional[List[int]] = None + node: torch.fx.Node, + compilation_settings: CompilationSettings = None, + arg_positions_to_check: Optional[List[int]] = None, ) -> bool: # Validate that none of the inputs to the node have Dynamic shapes assert isinstance( @@ -188,7 +193,7 @@ def dynamo_tensorrt_converter( key: Target, *, enabled: bool = True, - capability_validator: Optional[Callable[[Node], bool]] = None, + capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None, priority: ConverterPriority = ConverterPriority.STANDARD, supports_dynamic_shapes: bool = False, ) -> Callable[[ConverterImplSignature], ConverterImplSignature]: @@ -297,7 +302,6 @@ def __init__( ], registry_names: Optional[Sequence[str]] = None, registry_calling_conventions: Optional[Sequence[CallingConvention]] = None, - assume_dynamic_shape_support: bool = False, ): # Copy reference to each dictionary object into attribute list self.registries = list(registries) @@ -318,12 +322,16 @@ def __init__( CallingConvention.CTX for _ in range(len(self.registries)) ] + self.compilation_settings: CompilationSettings = None self.disallowed_targets: Collection[Target] = set() - self.assume_dynamic_shape_support = assume_dynamic_shape_support self.validate_invariants() - def set_dynamic_shape_support(self, assume_dynamic_shape_support: bool) -> None: - self.assume_dynamic_shape_support = assume_dynamic_shape_support + def set_compilation_settings( + self, compilation_settings: CompilationSettings + ) -> None: + self.compilation_settings = compilation_settings + # set torch executed ops as disallowed targets + self.set_disallowed_targets(compilation_settings.torch_executed_ops) def set_disallowed_targets(self, torch_executed_ops: Collection[Target]) -> None: self.disallowed_targets = torch_executed_ops @@ -412,7 +420,11 @@ def __getitem__( self.validate_invariants() key = node.target - + assume_dynamic_shape_support = False + if self.compilation_settings: + assume_dynamic_shape_support = ( + self.compilation_settings.assume_dynamic_shape_support + ) if ( key in self.disallowed_targets or self.qualified_name_or_str(key) in self.disallowed_targets @@ -436,8 +448,10 @@ def __getitem__( # 2) Assume dynamic_shape support is True # 3) Node only has static shaped inputs # 4) Node has dynamic inputs and the converter has supports_dynamic_shapes=True - if candidate.capability_validator(node) and ( - self.assume_dynamic_shape_support + if candidate.capability_validator( + node, self.compilation_settings + ) and ( + assume_dynamic_shape_support or not node_has_dynamic_shapes(node) or candidate.supports_dynamic_shapes ): diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index ff35bf39d7..aab4d521f8 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -18,6 +18,7 @@ ) import numpy as np +import tensorrt as trt import torch import torch.fx from torch.fx.node import _get_qualified_name @@ -43,7 +44,6 @@ from torch_tensorrt.fx.observer import Observer from torch_tensorrt.logging import TRT_LOGGER -import tensorrt as trt from packaging import version _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -89,6 +89,11 @@ def __init__( self.builder.create_network(flag), compilation_settings ) + self.compilation_settings = compilation_settings + if not CONVERTERS.compilation_settings: + # Configure user compilation settings to converters. + CONVERTERS.set_compilation_settings(compilation_settings) + assert TRTInterpreter._all_precisions_supported( compilation_settings.enabled_precisions ), f"Attempted to enable kernel precisions that are not supported (got: {compilation_settings.enabled_precisions}, support: {_defaults.SUPPORTED_KERNEL_PRECISIONS})" @@ -117,7 +122,6 @@ def __init__( self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = ( dict() ) - self.compilation_settings = compilation_settings # Data types for TRT Module output Tensors self.output_dtypes = ( diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index a757cf023e..60a48d98e3 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -7,6 +7,7 @@ import numpy as np import torch from torch.fx.node import Argument, Node, Target +from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext @@ -48,7 +49,7 @@ def get_ir(target: Target) -> SourceIR: return SourceIR.UNKNOWN -def one_user_validator(node: Node) -> bool: +def one_user_validator(node: Node, settings: CompilationSettings = None) -> bool: # Validate only one user, which is a getitem node that accesses the first element in the list return ( len(node.users) == 1 @@ -270,7 +271,11 @@ def aten_ops_embedding( ) -def embedding_bag_validator(node: Node) -> bool: +def embedding_bag_validator(node: Node, settings: CompilationSettings = None) -> bool: + # Embedding bag op is not refitable + if settings.make_refittable: + return False + if not one_user_validator(node): return False meta = node.args[1].meta @@ -416,7 +421,7 @@ def aten_ops_symsize_int( return impl.shape.shape(ctx, target, SourceIR.ATEN, name, args[0], args[1]) -def index_dtype_validator(node: Node) -> bool: +def index_dtype_validator(node: Node, settings: CompilationSettings = None) -> bool: index = node.args[1] for ind in index: if ind is not None: @@ -837,7 +842,7 @@ def aten_ops_select( ) -def index_put_validator(node: Node) -> bool: +def index_put_validator(node: Node, settings: CompilationSettings = None) -> bool: if args_bounds_check(node.args, 3, False): # Check if accumulate is valid _LOGGER.debug("We do not support accumulate=True for aten.index_put operation") accumulate_valid = False @@ -924,7 +929,18 @@ def aten_ops_slice( ) -@dynamo_tensorrt_converter(torch.ops.aten.cumsum.default, supports_dynamic_shapes=True) +def refit_validator(node: Node, settings: CompilationSettings = None) -> bool: + # cumsum op is not refitable + if settings and settings.make_refittable: + return False + return True + + +@dynamo_tensorrt_converter( + torch.ops.aten.cumsum.default, + capability_validator=refit_validator, + supports_dynamic_shapes=True, +) @enforce_tensor_types( { 0: (TRTTensor,), @@ -970,7 +986,7 @@ def aten_ops_tile( ) -def zero_output_validator(node: Node) -> bool: +def zero_output_validator(node: Node, settings: CompilationSettings = None) -> bool: if 0 in node.args[1]: _LOGGER.debug( f"We do not support output tensor {node.args[1]} tensors with zero-sized dimensions for this operation." @@ -1027,7 +1043,9 @@ def aten_ops_permute( ) -def to_copy_dtype_validator(placeholder_only: bool) -> Callable[[Node], bool]: +def to_copy_dtype_validator( + placeholder_only: bool, settings: CompilationSettings = None +) -> Callable[[Node, CompilationSettings], bool]: """Return validator for to_copy node with placeholder restrictions""" def validate_dtype(to_copy_node: Node) -> bool: @@ -1059,7 +1077,7 @@ def validate_dtype(to_copy_node: Node) -> bool: ) return False - def validator(to_copy_node: Node) -> bool: + def validator(to_copy_node: Node, settings: CompilationSettings = None) -> bool: """Returns true if the to_copy node can be converted to TRT and the placeholder restriction is satisfied """ @@ -1074,7 +1092,9 @@ def validator(to_copy_node: Node) -> bool: @dynamo_tensorrt_converter( torch.ops.aten.clone.default, - capability_validator=lambda node: not is_only_operator_on_placeholder(node), + capability_validator=lambda node, settings: not is_only_operator_on_placeholder( + node, settings + ), supports_dynamic_shapes=True, ) @dynamo_tensorrt_converter( @@ -2128,7 +2148,7 @@ def aten_ops_logical_xor( ) -def bitwise_type_validator(node: Node) -> bool: +def bitwise_type_validator(node: Node, settings: CompilationSettings = None) -> bool: supported_type = [torch.bool, bool] tensor_targets = [ @@ -2271,7 +2291,9 @@ def aten_ops_bitwise_xor( ) -def bitwise_not_type_validator(node: Node) -> bool: +def bitwise_not_type_validator( + node: Node, settings: CompilationSettings = None +) -> bool: val = node.args[0] val_meta = val.meta.get("tensor_meta") @@ -2453,7 +2475,7 @@ def aten_ops_le( ) -def conv_param_validator(conv_node: Node) -> bool: +def conv_param_validator(conv_node: Node, settings: CompilationSettings = None) -> bool: return conv_node.args[7] in ([0], [0, 0], [0, 0, 0]) @@ -2549,7 +2571,9 @@ def aten_ops_cdist_forward( ) -def avg_pool_param_validator(pool_node: Node) -> bool: +def avg_pool_param_validator( + pool_node: Node, settings: CompilationSettings = None +) -> bool: ceil_mode = args_bounds_check(pool_node.args, 4, False) divisor_override = args_bounds_check(pool_node.args, 6) @@ -2665,12 +2689,12 @@ def aten_ops_adaptive_avg_poolNd( ) -def topk_validator(node: Node) -> bool: +def topk_validator(node: Node, settings: CompilationSettings = None) -> bool: k = node.args[1] return topk_sort_validator(k) -def sort_validator(node: Node) -> bool: +def sort_validator(node: Node, settings: CompilationSettings = None) -> bool: meta_data = node.args[0].meta.get("tensor_meta") if meta_data is None: return False @@ -2692,7 +2716,9 @@ def topk_sort_validator(k: int) -> bool: return True -def max_pool_param_validator(pool_node: Node) -> bool: +def max_pool_param_validator( + pool_node: Node, settings: CompilationSettings = None +) -> bool: dilation = args_bounds_check(pool_node.args, 4, 1) ceil_mode = args_bounds_check(pool_node.args, 5, False) @@ -2746,7 +2772,7 @@ def aten_ops_max_pool( ) -def attention_validator(node: Node) -> bool: +def attention_validator(node: Node, settings: CompilationSettings = None) -> bool: # Currently, `attn_mask` is not supported return args_bounds_check(node.args, 3) is None @@ -3637,7 +3663,7 @@ def aten_ops_flip( ) -def zero_diag_size_validator(node: Node) -> bool: +def zero_diag_size_validator(node: Node, settings: CompilationSettings = None) -> bool: meta = node.args[0].meta.get("tensor_meta") if meta: input_shape = meta.shape @@ -3765,7 +3791,9 @@ def aten_ops_index_select( ) -def dropout_inference_validator(node: Node) -> bool: +def dropout_inference_validator( + node: Node, settings: CompilationSettings = None +) -> bool: train_mode = args_bounds_check(node.args, 2, None) if train_mode is False: return True diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 70135f86d3..39ace2a873 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -10,6 +10,7 @@ from torch.fx.node import Argument, Target from torch.fx.passes.shape_prop import TensorMetadata from torch_tensorrt import _enums +from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( @@ -111,7 +112,9 @@ def format_tensor_metadata(metadata: Union[Any, Sequence[Any]]) -> str: return metadata_string -def is_only_operator_on_placeholder(node: torch.fx.Node) -> bool: +def is_only_operator_on_placeholder( + node: torch.fx.Node, settings: CompilationSettings = None +) -> bool: """Detects whether a call_function node is the only operator on a placeholder""" # Returns true if the node operates on a placeholder and is a direct output return ( diff --git a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py index 5eeb2db661..f320505c94 100644 --- a/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py +++ b/py/torch_tensorrt/dynamo/conversion/ops_evaluators.py @@ -7,6 +7,7 @@ import numpy as np import torch from torch.fx.node import Argument, Node, Target +from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( ConverterRegistry, @@ -18,7 +19,7 @@ _LOGGER: logging.Logger = logging.getLogger(__name__) -def getitem_validator(getitem_node: Node) -> bool: +def getitem_validator(getitem_node: Node, settings: CompilationSettings = None) -> bool: from torch_tensorrt.dynamo.conversion._ConverterRegistry import DYNAMO_CONVERTERS # Getitem nodes can only be converted if their parent node also can @@ -45,7 +46,7 @@ def generic_evaluator( return target(*args) -def rand_validator(rand_node: Node) -> bool: +def rand_validator(rand_node: Node, settings: CompilationSettings = None) -> bool: dtype = rand_node.kwargs.get("dtype", None) layout = rand_node.kwargs.get("layout", None) if dtype is not None: @@ -85,7 +86,9 @@ def aten_ops_randn( return np.random.randn(*args[0]) -def randperm_validator(randperm_node: Node) -> bool: +def randperm_validator( + randperm_node: Node, settings: CompilationSettings = None +) -> bool: dtype = randperm_node.kwargs.get("dtype", None) layout = randperm_node.kwargs.get("layout", None) input = randperm_node.args[0] @@ -116,7 +119,7 @@ def aten_ops_randperm( return np.random.permutation(args[0]) -def empty_validator(empty_node: Node) -> bool: +def empty_validator(empty_node: Node, settings: CompilationSettings = None) -> bool: device = empty_node.kwargs.get("device", None) if device is not None: _LOGGER.debug(f"Currently we don't support specifying device, got {device}.") diff --git a/py/torch_tensorrt/dynamo/conversion/prims_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/prims_ops_converters.py index 9548dc287a..923ca9be6c 100644 --- a/py/torch_tensorrt/dynamo/conversion/prims_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/prims_ops_converters.py @@ -3,6 +3,7 @@ import torch from torch.fx.node import Argument, Target +from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo._SourceIR import SourceIR from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext @@ -15,7 +16,9 @@ # TODO: expand the scope of this converter with aten.expand implementation -def broadcast_checker(broadcast_node: torch.fx.Node) -> bool: +def broadcast_checker( + broadcast_node: torch.fx.Node, settings: CompilationSettings = None +) -> bool: # The current implementation of broadcast_in_dim can only handle unsqueeze return all( broadcast_node.args[1][i] == 1 diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 16e22aface..ee11e597a1 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -195,9 +195,7 @@ def get_model_device(module: torch.fx.GraphModule) -> torch.device: if device is None: device = to_torch_device(default_device()) - logger.warning( - "Could not detect the device on which the model exists. Assuming the model is on CPU" - ) + return device diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index f53bdf5d59..632b73e2f3 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -23,7 +23,7 @@ pre_export_lowering, ) from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule -from torch_tensorrt.dynamo.utils import ATOL, RTOL, get_torch_inputs +from torch_tensorrt.dynamo.utils import ATOL, RTOL, get_model_device, get_torch_inputs _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -225,7 +225,6 @@ def generate_graph( propagate_shapes: bool = False, ): mod = mod.eval() - torch_inputs = get_torch_inputs(original_inputs, _defaults.DEVICE) if use_dynamo_tracer: exported_program = torch_tensorrt.dynamo.trace(mod, tuple(original_inputs)) exported_program = pre_export_lowering(exported_program) @@ -242,6 +241,8 @@ def generate_graph( if propagate_shapes: # TODO: This is currently being used to test embedding_bag_aten due to https://github.com/pytorch/TensorRT/issues/2843 try: + device = get_model_device(fx_module) + torch_inputs = get_torch_inputs(original_inputs, device) ShapeProp(fx_module).propagate(*torch_inputs) except (RuntimeError, AssertionError): _LOGGER.warning( @@ -262,6 +263,7 @@ def run_test( enable_passes=False, propagate_shapes=False, int32_reqd=False, + make_refittable=False, ): mod = self.generate_graph( mod, @@ -277,6 +279,7 @@ def run_test( enabled_precisions={dtype._from(precision)}, truncate_double=True, debug=True, + make_refittable=make_refittable, ) num_inputs = len(inputs) @@ -345,6 +348,7 @@ def run_test_compare_tensor_attributes_only( output_dtypes=None, use_dynamo_tracer=False, enable_passes=False, + make_refittable=False, ): mod = self.generate_graph( mod, @@ -358,6 +362,7 @@ def run_test_compare_tensor_attributes_only( enabled_precisions={dtype._from(precision)}, truncate_double=True, debug=True, + make_refittable=make_refittable, ) interp = TRTInterpreter( @@ -383,6 +388,7 @@ def run_test_with_dynamic_shape( pyt_inputs=None, propagate_shapes=False, check_dtype=True, + make_refittable=False, ): mod = self.generate_graph( mod, @@ -394,7 +400,9 @@ def run_test_with_dynamic_shape( # Previous instance of the interpreter auto-casted 64-bit inputs # We replicate this behavior here - compilation_settings = CompilationSettings(truncate_double=True) + compilation_settings = CompilationSettings( + truncate_double=True, make_refittable=make_refittable + ) if check_dtype: output_dtypes = infer_module_output_dtypes( diff --git a/tests/py/dynamo/conversion/test_cumsum_aten.py b/tests/py/dynamo/conversion/test_cumsum_aten.py index 4143401bd4..1c32be6dd6 100644 --- a/tests/py/dynamo/conversion/test_cumsum_aten.py +++ b/tests/py/dynamo/conversion/test_cumsum_aten.py @@ -24,6 +24,7 @@ def forward(self, x): self.run_test( Cumsum(), inputs, + make_refittable=False, ) @parameterized.expand( @@ -43,6 +44,7 @@ def forward(self, x): self.run_test( Cumsum(), inputs, + make_refittable=False, ) @parameterized.expand( @@ -63,6 +65,7 @@ def forward(self, x): self.run_test( Cumsum(), inputs, + make_refittable=False, ) @parameterized.expand( @@ -92,6 +95,7 @@ def forward(self, x): self.run_test_with_dynamic_shape( Cumsum(), inputs, + make_refittable=False, ) diff --git a/tests/py/dynamo/conversion/test_embedding_bag_aten.py b/tests/py/dynamo/conversion/test_embedding_bag_aten.py index 3fef3d70cf..6543ac2306 100644 --- a/tests/py/dynamo/conversion/test_embedding_bag_aten.py +++ b/tests/py/dynamo/conversion/test_embedding_bag_aten.py @@ -148,6 +148,7 @@ def forward(self, weight, indices): precision=weight.dtype, enable_passes=True, propagate_shapes=True, + make_refittable=False, ) @parameterized.expand( @@ -345,6 +346,7 @@ def forward(self, weight, indices, offsets): precision=weight.dtype, enable_passes=True, propagate_shapes=True, + make_refittable=False, ) @parameterized.expand( @@ -409,6 +411,7 @@ def forward(self, weight, indices, offsets): precision=weight.dtype, enable_passes=True, propagate_shapes=True, + make_refittable=False, ) @parameterized.expand( @@ -490,6 +493,7 @@ def forward(self, weights, indices, offsets, per_sample_weights=None): min_block_size=1, cache_built_engines=False, reuse_cached_engines=False, + make_refittable=False, ) # use the inputs with different shape to inference: if per_sample_weights is None: diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index 77a0f2809f..0f6fb05914 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -5,6 +5,7 @@ import numpy as np import pytest +import tensorrt as trt import torch import torch.nn.functional as F import torch_tensorrt as torchtrt @@ -24,8 +25,6 @@ from torch_tensorrt.logging import TRT_LOGGER from transformers import BertModel -import tensorrt as trt - assertions = unittest.TestCase() @@ -730,3 +729,56 @@ def forward(self, x): # Clean up model env torch._dynamo.reset() + + +@pytest.mark.unit +def test_refit_cumsum_fallback(): + + class net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 12, 3, padding=1) + self.fc1 = nn.Linear(12 * 16 * 16, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = torch.flatten(x, 1) + x = torch.cumsum(self.fc1(x), 1) + x = x**2 + return x + + model = net().eval().to("cuda") + inputs = [torch.randn((1, 3, 16, 16)).to("cuda")] + model(*inputs) + exp_program = torch.export.export(model, tuple(inputs)) + with torchtrt.logging.debug(): + trt_gm = torchtrt.dynamo.compile( + exp_program, + tuple(inputs), + enabled_precisions={torch.float}, + debug=True, + min_block_size=1, + make_refittable=True, + ) + + num_pyt_segments = len( + [1 for submod in list(trt_gm.named_children()) if "_run_on_gpu" in submod[0]] + ) + + # Number of pyt segments should be 1 (because of cumsum being non-refitable) + assertions.assertTrue( + num_pyt_segments == 1, + f"test_refit_cumsum_fallback test found {num_pyt_segments} pytorch segments but expected 1", + ) + + # Check the output + pyt_outputs, trt_outputs = exp_program.module()(*inputs), trt_gm(*inputs) + for pyt_output, trt_output in zip(pyt_outputs, trt_outputs): + assertions.assertTrue( + torch.allclose(pyt_output, trt_output, 1e-2, 1e-2), + "Refit Result is not correct. Refit failed", + ) + # Clean up model env + + torch._dynamo.reset()