Skip to content

Commit

Permalink
fix: Add generic evaluator function
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive committed Aug 17, 2023
1 parent 26d1051 commit fbc3a7e
Show file tree
Hide file tree
Showing 10 changed files with 89 additions and 100 deletions.
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ._TRTInterpreter import * # noqa: F403
from .aten_ops_converters import * # noqa: F403
from .conversion import * # noqa: F403
from .op_evaluators import * # noqa: F403
from .truncate_long_and_double import repair_long_or_double_inputs
21 changes: 1 addition & 20 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import operator
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import tensorrt as trt
Expand Down Expand Up @@ -406,24 +405,6 @@ def aten_ops_to_copy_dtype(
)


@dynamo_tensorrt_converter(operator.getitem)
def operator_getitem(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.evaluators.getitem(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.clone.default)
def aten_ops_clone(
network: TRTNetwork,
Expand All @@ -432,7 +413,7 @@ def aten_ops_clone(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.evaluators.clone(
return impl.cast.clone(
network,
target,
SourceIR.ATEN,
Expand Down
12 changes: 5 additions & 7 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@

import tensorrt as trt
import torch
from torch.fx.node import Target, _get_qualified_name
from torch.fx.node import Target
from torch_tensorrt.fx.converters.converter_utils import (
Frameworks,
unified_dtype_converter,
)
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor

from .._SourceIR import SourceIR
from .converter_registry import ConverterRegistry

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -94,15 +95,12 @@ def cast_trt_tensor(

if input_val.dtype != trt_dtype:
source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN
target_name = (
f"{source_ir}_ops{'.' + target if target else ''}"
if (isinstance(target, str))
else f"{source_ir}_ops.{_get_qualified_name(target)}"
)
target_str = ConverterRegistry.qualified_name_or_str(target)
target_name = f"{source_ir}_ops{'.' + target_str if target_str else ''}"

identity_layer = network.add_identity(input_val)
identity_layer.set_output_type(0, trt_dtype)
identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} -{name}-[{target_name}]-[{name}]"
identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - [{target_name}]-[{name}]"
return identity_layer.get_output(0)
else:
return input_val
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
condition,
elementwise,
embedding,
evaluators,
matmul,
normalization,
permutation,
Expand Down
20 changes: 20 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/cast.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import logging
from typing import Optional

from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor

LOGGER: logging.Logger = logging.getLogger(__name__)


def to_copy(
network: TRTNetwork,
Expand All @@ -21,3 +24,20 @@ def to_copy(

casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir)
return casted_tensor


def clone(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
f"clone received input {input} that is not a TensorRT ITensor"
)

LOGGER.debug(f"Evaluating clone on object with name: {name}")

return input
40 changes: 0 additions & 40 deletions py/torch_tensorrt/dynamo/conversion/impl/evaluators.py

This file was deleted.

32 changes: 32 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/op_evaluators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import logging
import operator
from typing import Dict, Sequence, Tuple, Union

from torch.fx.node import Argument, Node, Target
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

from .converter_registry import ConverterRegistry, dynamo_tensorrt_converter

_LOGGER: logging.Logger = logging.getLogger(__name__)


def getitem_validator(getitem_node: Node) -> bool:
from torch_tensorrt.dynamo.conversion.converter_registry import DYNAMO_CONVERTERS

# Getitem nodes can only be converted if their parent node also can
return getitem_node.args[0] in DYNAMO_CONVERTERS


# TODO: Subsequent evaluators should be registered here with their own validators
@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator)
def generic_evaluator(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
_LOGGER.debug(
f"Evaluating {ConverterRegistry.qualified_name_or_str(target)} on object with name: {name}"
)
return target(*args)
30 changes: 30 additions & 0 deletions tests/py/dynamo/converters/test_casts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,36 @@
from torch_tensorrt.dynamo.conversion import UnsupportedOperatorException


class TestCloneConverter(DispatchTestCase):
def test_clone_contiguous(self):
class Clone(nn.Module):
def forward(self, x):
y = torch.clone(x, memory_format=torch.contiguous_format)
return y + 1

inputs = [torch.randn((1, 3, 10))]
self.run_test(
Clone(),
inputs,
expected_ops={torch.ops.aten.clone.default},
disable_passes=True,
)

def test_clone_regular(self):
class Clone(nn.Module):
def forward(self, x):
y = torch.clone(x)
return y + 1

inputs = [torch.randn((8, 2, 10))]
self.run_test(
Clone(),
inputs,
expected_ops={torch.ops.aten.clone.default},
disable_passes=True,
)


class TestToCopyConverter(DispatchTestCase):
def test_to_copy_half(self):
class ToCopyHalf(nn.Module):
Expand Down
30 changes: 0 additions & 30 deletions tests/py/dynamo/converters/test_evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,6 @@
from torch.testing._internal.common_utils import run_tests


class TestCloneConverter(DispatchTestCase):
def test_clone_contiguous(self):
class Clone(nn.Module):
def forward(self, x):
y = torch.clone(x, memory_format=torch.contiguous_format)
return y + 1

inputs = [torch.randn((1, 3, 10))]
self.run_test(
Clone(),
inputs,
expected_ops={torch.ops.aten.clone.default},
disable_passes=True,
)

def test_clone_regular(self):
class Clone(nn.Module):
def forward(self, x):
y = torch.clone(x)
return y + 1

inputs = [torch.randn((8, 2, 10))]
self.run_test(
Clone(),
inputs,
expected_ops={torch.ops.aten.clone.default},
disable_passes=True,
)


# TODO: Switch this test back to self.run_test once an implementation exists
# for a converter that returns a list, such as aten.split
@unittest.skip("Pending aten.split converter. Currently tested by E2E")
Expand Down
2 changes: 0 additions & 2 deletions tests/py/dynamo/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def test_resnet18(ir):
"ir": ir,
"pass_through_build_failures": True,
"optimization_level": 1,
"min_block_size": 10,
"ir": "torch_compile",
}

Expand Down Expand Up @@ -176,7 +175,6 @@ def test_resnet18_half(ir):
"ir": ir,
"pass_through_build_failures": True,
"optimization_level": 1,
"min_block_size": 10,
"ir": "torch_compile",
}

Expand Down

0 comments on commit fbc3a7e

Please sign in to comment.