From af7719c3b14eb315ac3cc4c971e8c59d5a1467e8 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Fri, 11 Aug 2023 16:48:00 -0700 Subject: [PATCH] fix: Add generic evaluator function --- .../dynamo/conversion/__init__.py | 1 + .../dynamo/conversion/aten_ops_converters.py | 21 +--------- .../dynamo/conversion/converter_registry.py | 4 +- .../dynamo/conversion/impl/__init__.py | 1 - .../dynamo/conversion/impl/cast.py | 20 ++++++++++ .../dynamo/conversion/impl/evaluators.py | 40 ------------------- .../dynamo/conversion/op_evaluators.py | 32 +++++++++++++++ tests/py/dynamo/converters/test_casts.py | 30 ++++++++++++++ tests/py/dynamo/converters/test_evaluators.py | 30 -------------- tests/py/dynamo/models/test_models.py | 2 - 10 files changed, 86 insertions(+), 95 deletions(-) delete mode 100644 py/torch_tensorrt/dynamo/conversion/impl/evaluators.py create mode 100644 py/torch_tensorrt/dynamo/conversion/op_evaluators.py diff --git a/py/torch_tensorrt/dynamo/conversion/__init__.py b/py/torch_tensorrt/dynamo/conversion/__init__.py index 4536ff0e7b..9cbfff950e 100644 --- a/py/torch_tensorrt/dynamo/conversion/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/__init__.py @@ -1,4 +1,5 @@ from ._TRTInterpreter import * # noqa: F403 from .aten_ops_converters import * # noqa: F403 from .conversion import * # noqa: F403 +from .op_evaluators import * # noqa: F403 from .truncate_long_and_double import repair_long_or_double_inputs diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index f6b8dbcc78..07ef31c7de 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1,5 +1,4 @@ import logging -import operator from typing import Any, Dict, Optional, Sequence, Tuple, Union import tensorrt as trt @@ -421,24 +420,6 @@ def aten_ops_to_copy_dtype( ) -@dynamo_tensorrt_converter(operator.getitem) -def operator_getitem( - network: TRTNetwork, - target: Target, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - name: str, -) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.evaluators.getitem( - network, - target, - SourceIR.ATEN, - name, - args[0], - args[1], - ) - - @dynamo_tensorrt_converter(torch.ops.aten.clone.default) def aten_ops_clone( network: TRTNetwork, @@ -447,7 +428,7 @@ def aten_ops_clone( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.evaluators.clone( + return impl.cast.clone( network, target, SourceIR.ATEN, diff --git a/py/torch_tensorrt/dynamo/conversion/converter_registry.py b/py/torch_tensorrt/dynamo/conversion/converter_registry.py index db41420367..c9f279da29 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_registry.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_registry.py @@ -347,8 +347,8 @@ def unique_targets(self) -> Set[Target]: """Returns the set of unique converter targets stored across all registries""" return set.union(*[set(registry.keys()) for registry in self.registries]) - # TODO: Make this a static method since it does not need state - def qualified_name_or_str(self, target: Target) -> str: + @staticmethod + def qualified_name_or_str(target: Target) -> str: """Returns string representation of an FX Node target""" if isinstance(target, str): return target diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index 611dc630fa..8f7ab1badc 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -6,7 +6,6 @@ condition, elementwise, embedding, - evaluators, matmul, normalization, permutation, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cast.py b/py/torch_tensorrt/dynamo/conversion/impl/cast.py index 68899de766..0c55731169 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cast.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cast.py @@ -1,3 +1,4 @@ +import logging from typing import Optional from torch.fx.node import Target @@ -5,6 +6,8 @@ from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor +LOGGER: logging.Logger = logging.getLogger(__name__) + def to_copy( network: TRTNetwork, @@ -21,3 +24,20 @@ def to_copy( casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir) return casted_tensor + + +def clone( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, +) -> TRTTensor: + if not isinstance(input, TRTTensor): + raise RuntimeError( + f"clone received input {input} that is not a TensorRT ITensor" + ) + + LOGGER.debug(f"Evaluating clone on object with name: {name}") + + return input diff --git a/py/torch_tensorrt/dynamo/conversion/impl/evaluators.py b/py/torch_tensorrt/dynamo/conversion/impl/evaluators.py deleted file mode 100644 index cb61fb6158..0000000000 --- a/py/torch_tensorrt/dynamo/conversion/impl/evaluators.py +++ /dev/null @@ -1,40 +0,0 @@ -import logging -import operator -from typing import Optional, Sequence - -from torch.fx.node import Target -from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor - -LOGGER: logging.Logger = logging.getLogger(__name__) - - -def getitem( - network: TRTNetwork, - target: Target, - source_ir: Optional[SourceIR], - name: str, - input: Sequence[TRTTensor], - index: int, -) -> TRTTensor: - LOGGER.debug(f"Evaluating getitem on object with name: {name}") - - # Directly index the input sequence and return the value - return operator.getitem(input, index) - - -def clone( - network: TRTNetwork, - target: Target, - source_ir: Optional[SourceIR], - name: str, - input: TRTTensor, -) -> TRTTensor: - if not isinstance(input, TRTTensor): - raise RuntimeError( - f"clone received input {input} that is not a TensorRT ITensor" - ) - - LOGGER.debug(f"Evaluating clone on object with name: {name}") - - return input diff --git a/py/torch_tensorrt/dynamo/conversion/op_evaluators.py b/py/torch_tensorrt/dynamo/conversion/op_evaluators.py new file mode 100644 index 0000000000..a546e34305 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/op_evaluators.py @@ -0,0 +1,32 @@ +import logging +import operator +from typing import Dict, Sequence, Tuple, Union + +from torch.fx.node import Argument, Node, Target +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + +from .converter_registry import ConverterRegistry, dynamo_tensorrt_converter + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +def getitem_validator(getitem_node: Node) -> bool: + from torch_tensorrt.dynamo.conversion.converter_registry import DYNAMO_CONVERTERS + + # Getitem nodes can only be converted if their parent node also can + return getitem_node.args[0] in DYNAMO_CONVERTERS + + +# TODO: Subsequent evaluators should be registered here with their own validators +@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator) +def generic_evaluator( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + _LOGGER.debug( + f"Evaluating {ConverterRegistry.qualified_name_or_str(target)} on object with name: {name}" + ) + return target(*args) diff --git a/tests/py/dynamo/converters/test_casts.py b/tests/py/dynamo/converters/test_casts.py index 4bb05ef463..3a4fd65610 100644 --- a/tests/py/dynamo/converters/test_casts.py +++ b/tests/py/dynamo/converters/test_casts.py @@ -5,6 +5,36 @@ from torch_tensorrt.dynamo.conversion import UnsupportedOperatorException +class TestCloneConverter(DispatchTestCase): + def test_clone_contiguous(self): + class Clone(nn.Module): + def forward(self, x): + y = torch.clone(x, memory_format=torch.contiguous_format) + return y + 1 + + inputs = [torch.randn((1, 3, 10))] + self.run_test( + Clone(), + inputs, + expected_ops={torch.ops.aten.clone.default}, + disable_passes=True, + ) + + def test_clone_regular(self): + class Clone(nn.Module): + def forward(self, x): + y = torch.clone(x) + return y + 1 + + inputs = [torch.randn((8, 2, 10))] + self.run_test( + Clone(), + inputs, + expected_ops={torch.ops.aten.clone.default}, + disable_passes=True, + ) + + class TestToCopyConverter(DispatchTestCase): def test_to_copy_half(self): class ToCopyHalf(nn.Module): diff --git a/tests/py/dynamo/converters/test_evaluators.py b/tests/py/dynamo/converters/test_evaluators.py index cf42009495..64dd303727 100644 --- a/tests/py/dynamo/converters/test_evaluators.py +++ b/tests/py/dynamo/converters/test_evaluators.py @@ -7,36 +7,6 @@ from torch.testing._internal.common_utils import run_tests -class TestCloneConverter(DispatchTestCase): - def test_clone_contiguous(self): - class Clone(nn.Module): - def forward(self, x): - y = torch.clone(x, memory_format=torch.contiguous_format) - return y + 1 - - inputs = [torch.randn((1, 3, 10))] - self.run_test( - Clone(), - inputs, - expected_ops={torch.ops.aten.clone.default}, - disable_passes=True, - ) - - def test_clone_regular(self): - class Clone(nn.Module): - def forward(self, x): - y = torch.clone(x) - return y + 1 - - inputs = [torch.randn((8, 2, 10))] - self.run_test( - Clone(), - inputs, - expected_ops={torch.ops.aten.clone.default}, - disable_passes=True, - ) - - # TODO: Switch this test back to self.run_test once an implementation exists # for a converter that returns a list, such as aten.split @unittest.skip("Pending aten.split converter. Currently tested by E2E") diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index c8f730e2e6..50d7fcbbd9 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -27,7 +27,6 @@ def test_resnet18(ir): "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, - "min_block_size": 10, "ir": "torch_compile", } @@ -176,7 +175,6 @@ def test_resnet18_half(ir): "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, - "min_block_size": 10, "ir": "torch_compile", }