Skip to content

Commit

Permalink
fix: Add generic evaluator function
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive committed Aug 14, 2023
1 parent b7d9d5a commit 4283f43
Show file tree
Hide file tree
Showing 10 changed files with 86 additions and 95 deletions.
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ._TRTInterpreter import * # noqa: F403
from .aten_ops_converters import * # noqa: F403
from .conversion import * # noqa: F403
from .op_evaluators import * # noqa: F403
from .truncate_long_and_double import repair_long_or_double_inputs
21 changes: 1 addition & 20 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import operator
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import tensorrt as trt
Expand Down Expand Up @@ -421,24 +420,6 @@ def aten_ops_to_copy_dtype(
)


@dynamo_tensorrt_converter(operator.getitem)
def operator_getitem(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.evaluators.getitem(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.clone.default)
def aten_ops_clone(
network: TRTNetwork,
Expand All @@ -447,7 +428,7 @@ def aten_ops_clone(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.evaluators.clone(
return impl.cast.clone(
network,
target,
SourceIR.ATEN,
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/converter_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,8 @@ def unique_targets(self) -> Set[Target]:
"""Returns the set of unique converter targets stored across all registries"""
return set.union(*[set(registry.keys()) for registry in self.registries])

# TODO: Make this a static method since it does not need state
def qualified_name_or_str(self, target: Target) -> str:
@staticmethod
def qualified_name_or_str(target: Target) -> str:
"""Returns string representation of an FX Node target"""
if isinstance(target, str):
return target
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
condition,
elementwise,
embedding,
evaluators,
matmul,
normalization,
permutation,
Expand Down
20 changes: 20 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/cast.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import logging
from typing import Optional

from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor

LOGGER: logging.Logger = logging.getLogger(__name__)


def to_copy(
network: TRTNetwork,
Expand All @@ -21,3 +24,20 @@ def to_copy(

casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir)
return casted_tensor


def clone(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
f"clone received input {input} that is not a TensorRT ITensor"
)

LOGGER.debug(f"Evaluating clone on object with name: {name}")

return input
40 changes: 0 additions & 40 deletions py/torch_tensorrt/dynamo/conversion/impl/evaluators.py

This file was deleted.

32 changes: 32 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/op_evaluators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import logging
import operator
from typing import Dict, Sequence, Tuple, Union

from torch.fx.node import Argument, Node, Target
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

from .converter_registry import ConverterRegistry, dynamo_tensorrt_converter

_LOGGER: logging.Logger = logging.getLogger(__name__)


def getitem_validator(getitem_node: Node) -> bool:
from torch_tensorrt.dynamo.conversion.converter_registry import DYNAMO_CONVERTERS

# Getitem nodes can only be converted if their parent node also can
return getitem_node.args[0] in DYNAMO_CONVERTERS


# TODO: Subsequent evaluators should be registered here with their own validators
@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator)
def generic_evaluator(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
_LOGGER.debug(
f"Evaluating {ConverterRegistry.qualified_name_or_str(target)} on object with name: {name}"
)
return target(*args, **kwargs)
30 changes: 30 additions & 0 deletions tests/py/dynamo/converters/test_casts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,36 @@
from torch_tensorrt.dynamo.conversion import UnsupportedOperatorException


class TestCloneConverter(DispatchTestCase):
def test_clone_contiguous(self):
class Clone(nn.Module):
def forward(self, x):
y = torch.clone(x, memory_format=torch.contiguous_format)
return y + 1

inputs = [torch.randn((1, 3, 10))]
self.run_test(
Clone(),
inputs,
expected_ops={torch.ops.aten.clone.default},
disable_passes=True,
)

def test_clone_regular(self):
class Clone(nn.Module):
def forward(self, x):
y = torch.clone(x)
return y + 1

inputs = [torch.randn((8, 2, 10))]
self.run_test(
Clone(),
inputs,
expected_ops={torch.ops.aten.clone.default},
disable_passes=True,
)


class TestToCopyConverter(DispatchTestCase):
def test_to_copy_half(self):
class ToCopyHalf(nn.Module):
Expand Down
30 changes: 0 additions & 30 deletions tests/py/dynamo/converters/test_evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,6 @@
from torch.testing._internal.common_utils import run_tests


class TestCloneConverter(DispatchTestCase):
def test_clone_contiguous(self):
class Clone(nn.Module):
def forward(self, x):
y = torch.clone(x, memory_format=torch.contiguous_format)
return y + 1

inputs = [torch.randn((1, 3, 10))]
self.run_test(
Clone(),
inputs,
expected_ops={torch.ops.aten.clone.default},
disable_passes=True,
)

def test_clone_regular(self):
class Clone(nn.Module):
def forward(self, x):
y = torch.clone(x)
return y + 1

inputs = [torch.randn((8, 2, 10))]
self.run_test(
Clone(),
inputs,
expected_ops={torch.ops.aten.clone.default},
disable_passes=True,
)


# TODO: Switch this test back to self.run_test once an implementation exists
# for a converter that returns a list, such as aten.split
@unittest.skip("Pending aten.split converter. Currently tested by E2E")
Expand Down
2 changes: 0 additions & 2 deletions tests/py/dynamo/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def test_resnet18(ir):
"ir": ir,
"pass_through_build_failures": True,
"optimization_level": 1,
"min_block_size": 10,
"ir": "torch_compile",
}

Expand Down Expand Up @@ -176,7 +175,6 @@ def test_resnet18_half(ir):
"ir": ir,
"pass_through_build_failures": True,
"optimization_level": 1,
"min_block_size": 10,
"ir": "torch_compile",
}

Expand Down

0 comments on commit 4283f43

Please sign in to comment.