From 6a69c6ad2abbd4e3149fcdaaff83e7bc2abf2fd9 Mon Sep 17 00:00:00 2001 From: George S <113141689+gs-olive@users.noreply.github.com> Date: Thu, 17 Aug 2023 12:06:10 -0700 Subject: [PATCH] feat: Add `_to_copy`, `operator.get` and `clone` ATen converters (#2161) --- .../dynamo/conversion/_TRTInterpreter.py | 10 ++- .../dynamo/conversion/__init__.py | 1 + .../dynamo/conversion/aten_ops_converters.py | 65 ++++++++++++-- .../dynamo/conversion/converter_registry.py | 2 +- .../dynamo/conversion/converter_utils.py | 18 +++- .../dynamo/conversion/impl/__init__.py | 1 + .../dynamo/conversion/impl/cast.py | 43 +++++++++ .../conversion/impl/elementwise/base.py | 11 ++- .../dynamo/conversion/op_evaluators.py | 32 +++++++ py/torch_tensorrt/dynamo/utils.py | 24 +++++ tests/py/dynamo/converters/harness.py | 44 +++++++--- tests/py/dynamo/converters/test_casts.py | 87 +++++++++++++++++++ tests/py/dynamo/converters/test_evaluators.py | 37 ++++++++ tests/py/dynamo/models/test_models.py | 17 ++-- 14 files changed, 352 insertions(+), 40 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/cast.py create mode 100644 py/torch_tensorrt/dynamo/conversion/op_evaluators.py create mode 100644 tests/py/dynamo/converters/test_casts.py create mode 100644 tests/py/dynamo/converters/test_evaluators.py diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 5338f36876..29485a919b 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -27,6 +27,10 @@ ] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER") +class UnsupportedOperatorException(RuntimeError): + pass + + class TRTInterpreterResult(NamedTuple): engine: Any input_names: Sequence[str] @@ -301,7 +305,7 @@ def call_module( converter = CONVERTERS.get(self._cur_node) if not converter: - raise RuntimeError( + raise UnsupportedOperatorException( f"Conversion of module of type {submod_type} not currently supported!" ) @@ -312,7 +316,7 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any: # TODO: Why is this stateful? We should be able to take in the inputs converter = CONVERTERS.get(self._cur_node) if not converter: - raise RuntimeError( + raise UnsupportedOperatorException( f"Conversion of function {torch.typename(target)} not currently supported!" ) @@ -324,7 +328,7 @@ def call_method(self, target: str, args: Any, kwargs: Any) -> Any: converter = CONVERTERS.get(self._cur_node) if not converter: - raise RuntimeError( + raise UnsupportedOperatorException( f"Conversion of method {target} not currently supported!" ) diff --git a/py/torch_tensorrt/dynamo/conversion/__init__.py b/py/torch_tensorrt/dynamo/conversion/__init__.py index 4536ff0e7b..9cbfff950e 100644 --- a/py/torch_tensorrt/dynamo/conversion/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/__init__.py @@ -1,4 +1,5 @@ from ._TRTInterpreter import * # noqa: F403 from .aten_ops_converters import * # noqa: F403 from .conversion import * # noqa: F403 +from .op_evaluators import * # noqa: F403 from .truncate_long_and_double import repair_long_or_double_inputs diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 0ef7266624..75a7782354 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1,6 +1,7 @@ import logging from typing import Any, Dict, Optional, Sequence, Tuple, Union +import tensorrt as trt import torch from torch.fx.node import Argument, Node, Target from torch_tensorrt.dynamo._SourceIR import SourceIR @@ -12,8 +13,6 @@ from torch_tensorrt.fx.converters import acc_ops_converters from torch_tensorrt.fx.types import TRTNetwork, TRTTensor -import tensorrt as trt - from .converter_registry import dynamo_tensorrt_converter _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -76,13 +75,13 @@ def aten_ops_div( kwargs_new["input"].dtype == trt.int8 or kwargs_new["input"].dtype == trt.int32 ): kwargs_new["input"] = cast_trt_tensor( - network, kwargs_new["input"], trt.float32, name + network, kwargs_new["input"], trt.float32, name, target ) elif isinstance(args[1], TRTTensor) and ( kwargs_new["other"].dtype == trt.int8 or kwargs_new["other"].dtype == trt.int32 ): kwargs_new["other"] = cast_trt_tensor( - network, kwargs_new["other"], trt.float32, name + network, kwargs_new["other"], trt.float32, name, target ) rounding_mode = kwargs.get("rounding_mode") if rounding_mode is None: @@ -101,7 +100,7 @@ def aten_ops_div( ) -def embedding_param_validator(embedding_node: Node): +def embedding_param_validator(embedding_node: Node) -> bool: scale_grad_by_freq = args_bounds_check(embedding_node.args, 3) sparse = args_bounds_check(embedding_node.args, 4) @@ -365,3 +364,59 @@ def aten_ops_permute( args[0], args[1], ) + + +def to_copy_dtype_validator(to_copy_node: Node) -> bool: + allowed_casts = {torch.float, torch.int32, torch.bool, torch.int8, torch.float16} + + # Validate input node has convertible kwargs + if "dtype" in to_copy_node.kwargs: + if to_copy_node.kwargs["dtype"] in allowed_casts: + return True + else: + _LOGGER.debug( + f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}" + ) + return False + else: + _LOGGER.debug( + f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}" + ) + return False + + +@dynamo_tensorrt_converter( + torch.ops.aten._to_copy.default, capability_validator=to_copy_dtype_validator +) +def aten_ops_to_copy_dtype( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.cast.to_copy( + network, + target, + SourceIR.ATEN, + name, + args[0], + kwargs["dtype"], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.clone.default) +def aten_ops_clone( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.cast.clone( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_registry.py b/py/torch_tensorrt/dynamo/conversion/converter_registry.py index 493773bbde..b09bf61418 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_registry.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_registry.py @@ -66,7 +66,7 @@ def dynamo_tensorrt_converter( enabled: bool = True, capability_validator: Optional[Callable[[Node], bool]] = None, priority: ConverterPriority = ConverterPriority.STANDARD, -) -> Callable[[Any], Any]: +) -> Callable[[Any], Union[TRTTensor, Sequence[TRTTensor]]]: """Decorator for Dynamo TensorRT Converter Registers the decorated function in the DYNAMO_ATEN_CONVERTERS registry diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 3d32b25f63..44bc8b9445 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -1,15 +1,19 @@ import logging import re -from typing import List +from typing import List, Optional import tensorrt as trt import torch +from torch.fx.node import Target from torch_tensorrt.fx.converters.converter_utils import ( Frameworks, unified_dtype_converter, ) from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor +from .._SourceIR import SourceIR +from .converter_registry import ConverterRegistry + _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -71,6 +75,8 @@ def cast_trt_tensor( input_val: TRTTensor, dtype: TRTDataType, name: str, + target: Target = "", + source_ir: Optional[SourceIR] = None, ) -> TRTTensor: """ Given a TRT Tensor, convert that Tensor to the specified dtype @@ -78,17 +84,23 @@ def cast_trt_tensor( Args: network (TRTNetwork): A TensorRT network input_val (TRTTensor): A TRT Tensor to cast to a new data type - dtype (TRTDataType): The TRTDataType to cast the input Tensor to + dtype (TRTDataType, torch.dtype, np.dtype): The data type to cast the input Tensor to name (str): Name of the calling layer + target (Target): Target of calling node + source_ir (SourceIR): SourceIR of calling converter Returns: A TensorRT ITensor which has been casted to the specified dtype """ trt_dtype = unified_dtype_converter(dtype, Frameworks.TRT) if input_val.dtype != trt_dtype: + source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN + target_str = ConverterRegistry.qualified_name_or_str(target) + target_name = f"{source_ir}_ops{('.' + target_str) if target_str else ''}" + identity_layer = network.add_identity(input_val) identity_layer.set_output_type(0, trt_dtype) - identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - {name}" + identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - [{target_name}]-[{name}]" return identity_layer.get_output(0) else: return input_val diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index b402240b84..8f7ab1badc 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -2,6 +2,7 @@ from . import ( activation, + cast, condition, elementwise, embedding, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cast.py b/py/torch_tensorrt/dynamo/conversion/impl/cast.py new file mode 100644 index 0000000000..0c55731169 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/cast.py @@ -0,0 +1,43 @@ +import logging +from typing import Optional + +from torch.fx.node import Target +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor +from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor + +LOGGER: logging.Logger = logging.getLogger(__name__) + + +def to_copy( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dtype: TRTDataType, +) -> TRTTensor: + if not isinstance(input, TRTTensor): + raise RuntimeError( + f"to_copy received input {input} that is not a TensorRT ITensor" + ) + + casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir) + return casted_tensor + + +def clone( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, +) -> TRTTensor: + if not isinstance(input, TRTTensor): + raise RuntimeError( + f"clone received input {input} that is not a TensorRT ITensor" + ) + + LOGGER.debug(f"Evaluating clone on object with name: {name}") + + return input diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index c4cc744aa9..9ae7859fdc 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -2,6 +2,7 @@ import warnings from typing import Any, Callable, Optional, Union +import tensorrt as trt import torch from torch.fx.node import Target from torch_tensorrt.dynamo._SourceIR import SourceIR @@ -15,8 +16,6 @@ from torch_tensorrt.fx.types import TRTElementWiseOp, TRTNetwork, TRTTensor from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter -import tensorrt as trt - def get_python_op_from_trt_elementwise_op( trt_op: TRTElementWiseOp, @@ -132,9 +131,13 @@ def convert_binary_elementwise( trt_promoted_type = unified_dtype_converter(promoted_type, Frameworks.TRT) if trt_promoted_type != lhs_val.dtype: - lhs_val = cast_trt_tensor(network, lhs_val, trt_promoted_type, name) + lhs_val = cast_trt_tensor( + network, lhs_val, trt_promoted_type, name, target, source_ir + ) if trt_promoted_type != rhs_val.dtype: - rhs_val = cast_trt_tensor(network, rhs_val, trt_promoted_type, name) + rhs_val = cast_trt_tensor( + network, rhs_val, trt_promoted_type, name, target, source_ir + ) # Check the limitation in the doc string. if network.has_implicit_batch_dimension: diff --git a/py/torch_tensorrt/dynamo/conversion/op_evaluators.py b/py/torch_tensorrt/dynamo/conversion/op_evaluators.py new file mode 100644 index 0000000000..a546e34305 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/op_evaluators.py @@ -0,0 +1,32 @@ +import logging +import operator +from typing import Dict, Sequence, Tuple, Union + +from torch.fx.node import Argument, Node, Target +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + +from .converter_registry import ConverterRegistry, dynamo_tensorrt_converter + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +def getitem_validator(getitem_node: Node) -> bool: + from torch_tensorrt.dynamo.conversion.converter_registry import DYNAMO_CONVERTERS + + # Getitem nodes can only be converted if their parent node also can + return getitem_node.args[0] in DYNAMO_CONVERTERS + + +# TODO: Subsequent evaluators should be registered here with their own validators +@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator) +def generic_evaluator( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + _LOGGER.debug( + f"Evaluating {ConverterRegistry.qualified_name_or_str(target)} on object with name: {name}" + ) + return target(*args) diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index cec328e84f..398af78788 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -5,9 +5,11 @@ from typing import Any, Callable, Dict, Optional, Sequence import torch +import torch_tensorrt from torch_tensorrt._Device import Device from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import CompilationSettings +from torch_tensorrt.dynamo._defaults import PRECISION from packaging import version @@ -161,6 +163,28 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings: if settings.debug: logger.setLevel(logging.DEBUG) + # TODO: Remove once Dynamo precisions refactoring is complete + if "enabled_precisions" in kwargs: + enabled_precisions = kwargs["enabled_precisions"] + + if ( + torch.float16 in enabled_precisions + or torch_tensorrt.dtype.half in enabled_precisions + ): + settings.precision = torch.float16 + elif ( + torch.float32 in enabled_precisions + or torch_tensorrt.dtype.float in enabled_precisions + ): + settings.precision = torch.float32 + elif len(enabled_precisions) == 0: + logger.info(f"No precision specified, defaulting to {PRECISION}") + settings.precision = PRECISION + else: + raise ValueError( + f"Precision {enabled_precisions} not supported in the Dynamo Path" + ) + # Parse input runtime specification settings.use_python_runtime = use_python_runtime_parser(settings.use_python_runtime) diff --git a/tests/py/dynamo/converters/harness.py b/tests/py/dynamo/converters/harness.py index 5634e37a30..f6ff25fb77 100644 --- a/tests/py/dynamo/converters/harness.py +++ b/tests/py/dynamo/converters/harness.py @@ -1,11 +1,17 @@ +import logging import time import unittest -import torch -import logging from typing import Callable, List, Optional, Set, Tuple -from torch.testing._internal.common_utils import TestCase +import torch import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer +from torch.fx.passes.infra.pass_base import PassResult +from torch.testing._internal.common_utils import TestCase +from torch_tensorrt import Input + +# Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry +from torch_tensorrt.dynamo.conversion import TRTInterpreter +from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule from torch_tensorrt.fx.passes.lower_basic_pass_aten import ( compose_bmm, compose_chunk, @@ -18,15 +24,8 @@ replace_transpose_mm_op_with_linear, run_const_fold, ) -from torch.fx.passes.infra.pass_base import PassResult from torch_tensorrt.fx.passes.pass_utils import chain_passes -# Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry -from torch_tensorrt.dynamo.conversion import TRTInterpreter -from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule -from torch_tensorrt import Input - - _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -217,6 +216,7 @@ def generate_graph( expected_ops: Set[Callable], unexpected_ops: Optional[Set[Callable]] = None, customized_passes: List[Callable] = None, + disable_passes: bool = False, ): # Torchdynamo+aot proxytensor tracer # Below are common passes @@ -234,6 +234,10 @@ def generate_graph( # Combine with customized passes specific to any model if customized_passes: passes_list.extend(customized_passes) + + if disable_passes: + passes_list = [] + fx_module, _ = aten_tracer.trace(mod, original_inputs) for passes in passes_list: pr: PassResult = passes(fx_module) @@ -261,9 +265,17 @@ def run_test( atol=1e-03, precision=torch.float, check_dtype=True, + disable_passes=False, ): mod.eval() - mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None) + mod = self.generate_graph( + mod, + inputs, + expected_ops, + unexpected_ops, + None, + disable_passes=disable_passes, + ) if apply_passes is not None: pass_tracer = chain_passes(*apply_passes) @@ -293,10 +305,18 @@ def run_test_with_dynamic_shape( unexpected_ops=None, rtol=1e-03, atol=1e-03, + disable_passes=False, ): mod.eval() inputs = [spec.example_tensor("opt_shape") for spec in input_specs] - mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None) + mod = self.generate_graph( + mod, + inputs, + expected_ops, + unexpected_ops, + None, + disable_passes=disable_passes, + ) interp = TRTInterpreter( mod, diff --git a/tests/py/dynamo/converters/test_casts.py b/tests/py/dynamo/converters/test_casts.py new file mode 100644 index 0000000000..3a4fd65610 --- /dev/null +++ b/tests/py/dynamo/converters/test_casts.py @@ -0,0 +1,87 @@ +import torch +import torch.nn as nn +from harness import DispatchTestCase +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.conversion import UnsupportedOperatorException + + +class TestCloneConverter(DispatchTestCase): + def test_clone_contiguous(self): + class Clone(nn.Module): + def forward(self, x): + y = torch.clone(x, memory_format=torch.contiguous_format) + return y + 1 + + inputs = [torch.randn((1, 3, 10))] + self.run_test( + Clone(), + inputs, + expected_ops={torch.ops.aten.clone.default}, + disable_passes=True, + ) + + def test_clone_regular(self): + class Clone(nn.Module): + def forward(self, x): + y = torch.clone(x) + return y + 1 + + inputs = [torch.randn((8, 2, 10))] + self.run_test( + Clone(), + inputs, + expected_ops={torch.ops.aten.clone.default}, + disable_passes=True, + ) + + +class TestToCopyConverter(DispatchTestCase): + def test_to_copy_half(self): + class ToCopyHalf(nn.Module): + def forward(self, x): + y = x.to(dtype=torch.half) + return y + + inputs = [torch.rand((1, 3, 10))] + self.run_test( + ToCopyHalf(), + inputs, + expected_ops={torch.ops.aten._to_copy.default}, + precision=torch.half, + disable_passes=True, + ) + + def test_to_copy_float(self): + class ToCopyFloat(nn.Module): + def forward(self, x): + y = x.to(dtype=torch.float) + return y + + inputs = [torch.rand((1, 3, 10)).half()] + self.run_test( + ToCopyFloat(), + inputs, + expected_ops={torch.ops.aten._to_copy.default}, + precision=torch.float, + disable_passes=True, + ) + + def test_to_copy_unsupported(self): + class ToCopy64Bit(nn.Module): + def forward(self, x): + y = x.to(dtype=torch.int64) + return y + + inputs = [torch.randn((1, 3, 10)).int()] + + with self.assertRaises(UnsupportedOperatorException): + self.run_test( + ToCopy64Bit(), + inputs, + expected_ops={torch.ops.aten._to_copy.default}, + disable_passes=True, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_evaluators.py b/tests/py/dynamo/converters/test_evaluators.py new file mode 100644 index 0000000000..64dd303727 --- /dev/null +++ b/tests/py/dynamo/converters/test_evaluators.py @@ -0,0 +1,37 @@ +import operator +import unittest + +import torch +import torch.nn as nn +from harness import DispatchTestCase +from torch.testing._internal.common_utils import run_tests + + +# TODO: Switch this test back to self.run_test once an implementation exists +# for a converter that returns a list, such as aten.split +@unittest.skip("Pending aten.split converter. Currently tested by E2E") +class TestGetItemConverter(DispatchTestCase): + def test_getitem(self): + class GetItem(nn.Module): + def forward(self, x): + lis = torch.split(x, 5) + b = operator.getitem(lis, 0) + c = operator.getitem(lis, 1) + d = b + c + return d + + inputs = [ + torch.randn((3, 3, 10)), + torch.randn((3, 3, 10)), + torch.randn((3, 3, 10)), + ] + self.run_test( + GetItem(), + inputs, + expected_ops={operator.getitem}, + disable_passes=True, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index 0fdfcb3fd0..50d7fcbbd9 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -1,18 +1,13 @@ -import torch -import timm -import pytest import unittest +import pytest +import timm +import torch import torch_tensorrt as torchtrt import torchvision.models as models - +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity from transformers import BertModel -from torch_tensorrt.dynamo.utils import ( - COSINE_THRESHOLD, - cosine_similarity, -) - assertions = unittest.TestCase() @@ -32,7 +27,6 @@ def test_resnet18(ir): "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, - "min_block_size": 10, "ir": "torch_compile", } @@ -143,7 +137,7 @@ def test_bert_base_uncased(ir): "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, - "min_block_size": 10, + "min_block_size": 15, "ir": "torch_compile", } trt_mod = torchtrt.compile(model, **compile_spec) @@ -181,7 +175,6 @@ def test_resnet18_half(ir): "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, - "min_block_size": 10, "ir": "torch_compile", }