diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 6cb3a30abb..e559fbe461 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -70,13 +70,13 @@ def aten_ops_div( kwargs_new["input"].dtype == trt.int8 or kwargs_new["input"].dtype == trt.int32 ): kwargs_new["input"] = cast_trt_tensor( - network, kwargs_new["input"], trt.float32, name + network, kwargs_new["input"], trt.float32, name, target ) elif isinstance(args[1], TRTTensor) and ( kwargs_new["other"].dtype == trt.int8 or kwargs_new["other"].dtype == trt.int32 ): kwargs_new["other"] = cast_trt_tensor( - network, kwargs_new["other"], trt.float32, name + network, kwargs_new["other"], trt.float32, name, target ) rounding_mode = kwargs.get("rounding_mode") if rounding_mode is None: @@ -377,3 +377,77 @@ def aten_ops_permute( args[0], args[1], ) + + +def to_copy_dtype_validator(to_copy_node: Node): + allowed_casts = {torch.float, torch.int32, torch.bool, torch.int8, torch.float16} + + # Validate input node has convertible kwargs + if "dtype" in to_copy_node.kwargs: + if to_copy_node.kwargs["dtype"] in allowed_casts: + return True + else: + _LOGGER.debug( + f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}" + ) + return False + else: + _LOGGER.debug( + f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}" + ) + return False + + +@dynamo_tensorrt_converter( + torch.ops.aten._to_copy.default, capability_validator=to_copy_dtype_validator +) +def aten_ops_to_copy_dtype( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.cast.to_copy( + network, + target, + SourceIR.ATEN, + name, + args[0], + kwargs["dtype"], + ) + + +@dynamo_tensorrt_converter(operator.getitem) +def operator_getitem( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.evaluators.getitem( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.clone.default) +def aten_ops_clone( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.evaluators.clone( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 584e15b263..5e39eb5cf7 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -1,5 +1,7 @@ import torch +from torch.fx.node import _get_qualified_name + from torch_tensorrt.fx.types import ( TRTDataType, TRTNetwork, @@ -12,7 +14,9 @@ ) import tensorrt as trt -from typing import List +from typing import List, Optional, Union + +from .._SourceIR import SourceIR def dynamic_unsupported(node: torch.fx.Node) -> bool: @@ -49,6 +53,8 @@ def cast_trt_tensor( input_val: TRTTensor, dtype: TRTDataType, name: str, + target: Union[torch.fx.Target, str] = "", + source_ir: Optional[SourceIR] = None, ) -> TRTTensor: """ Given a TRT Tensor, convert that Tensor to the specified dtype @@ -56,7 +62,7 @@ def cast_trt_tensor( Args: network (TRTNetwork): A TensorRT network input_val (TRTTensor): A TRT Tensor to cast to a new data type - dtype (TRTDataType): The TRTDataType to cast the input Tensor to + dtype (TRTDataType, torch.dtype, np.dtype): The data type to cast the input Tensor to name (str): Name of the calling layer Returns: A TensorRT ITensor which has been casted to the specified dtype @@ -64,9 +70,16 @@ def cast_trt_tensor( trt_dtype = unified_dtype_converter(dtype, Frameworks.TRT) if input_val.dtype != trt_dtype: + source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN + target_name = ( + f"{source_ir}_ops{'.' + target}" + if (isinstance(target, str) and target) + else f"{source_ir}_ops.{_get_qualified_name(target)}" + ) + identity_layer = network.add_identity(input_val) identity_layer.set_output_type(0, trt_dtype) - identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - {name}" + identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} -{name}-[{target_name}]-[{name}]" return identity_layer.get_output(0) else: return input_val diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index db6e405978..470cecfeba 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -12,3 +12,5 @@ from . import squeeze from . import unsqueeze from . import permutation +from . import cast +from . import evaluators diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cast.py b/py/torch_tensorrt/dynamo/conversion/impl/cast.py new file mode 100644 index 0000000000..6f9d0dbefa --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/cast.py @@ -0,0 +1,28 @@ +from typing import Optional +from torch.fx.node import Target + +from torch_tensorrt.dynamo.conversion import SourceIR +from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor + +from torch_tensorrt.fx.types import ( + TRTNetwork, + TRTTensor, + TRTDataType, +) + + +def to_copy( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dtype: TRTDataType, +) -> TRTTensor: + if not isinstance(input, TRTTensor): + raise RuntimeError( + f"to_copy received input {input} that is not a TensorRT ITensor" + ) + + casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir) + return casted_tensor diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index a8e4067493..0d6683c28a 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -137,9 +137,13 @@ def convert_binary_elementwise( trt_promoted_type = unified_dtype_converter(promoted_type, Frameworks.TRT) if trt_promoted_type != lhs_val.dtype: - lhs_val = cast_trt_tensor(network, lhs_val, trt_promoted_type, name) + lhs_val = cast_trt_tensor( + network, lhs_val, trt_promoted_type, name, target, source_ir + ) if trt_promoted_type != rhs_val.dtype: - rhs_val = cast_trt_tensor(network, rhs_val, trt_promoted_type, name) + rhs_val = cast_trt_tensor( + network, rhs_val, trt_promoted_type, name, target, source_ir + ) # Check the limitation in the doc string. if network.has_implicit_batch_dimension: diff --git a/py/torch_tensorrt/dynamo/conversion/impl/evaluators.py b/py/torch_tensorrt/dynamo/conversion/impl/evaluators.py new file mode 100644 index 0000000000..3dc30356ea --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/evaluators.py @@ -0,0 +1,45 @@ +import operator +import logging +from typing import Optional, Sequence +from torch.fx.node import Target + +from torch_tensorrt.dynamo.conversion import SourceIR + +from torch_tensorrt.fx.types import ( + TRTNetwork, + TRTTensor, +) + + +LOGGER: logging.Logger = logging.getLogger(__name__) + + +def getitem( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: Sequence[TRTTensor], + index: int, +) -> TRTTensor: + LOGGER.debug(f"Evaluating getitem on object with name: {name}") + + # Directly index the input sequence and return the value + return operator.getitem(input, index) + + +def clone( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, +) -> TRTTensor: + if not isinstance(input, TRTTensor): + raise RuntimeError( + f"clone received input {input} that is not a TensorRT ITensor" + ) + + LOGGER.debug(f"Evaluating clone on object with name: {name}") + + return input diff --git a/py/torch_tensorrt/dynamo/conversion/trt_interpreter.py b/py/torch_tensorrt/dynamo/conversion/trt_interpreter.py index 4293fb65eb..730690ea9d 100644 --- a/py/torch_tensorrt/dynamo/conversion/trt_interpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/trt_interpreter.py @@ -29,6 +29,10 @@ ] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER") +class UnsupportedOperatorException(RuntimeError): + pass + + class TRTInterpreterResult(NamedTuple): engine: Any input_names: Sequence[str] @@ -288,7 +292,7 @@ def call_module(self, target, args, kwargs): converter = CONVERTERS.get(self._cur_node) if not converter: - raise RuntimeError( + raise UnsupportedOperatorException( f"Conversion of module of type {submod_type} not currently supported!" ) @@ -298,7 +302,7 @@ def call_module(self, target, args, kwargs): def call_function(self, target, args, kwargs): converter = CONVERTERS.get(self._cur_node) if not converter: - raise RuntimeError( + raise UnsupportedOperatorException( f"Conversion of function {torch.typename(target)} not currently supported!" ) @@ -310,7 +314,7 @@ def call_method(self, target, args, kwargs): converter = CONVERTERS.get(self._cur_node) if not converter: - raise RuntimeError( + raise UnsupportedOperatorException( f"Conversion of method {target} not currently supported!" ) diff --git a/py/torch_tensorrt/dynamo/test_utils.py b/py/torch_tensorrt/dynamo/test_utils.py index a3d742c70a..67be779135 100644 --- a/py/torch_tensorrt/dynamo/test_utils.py +++ b/py/torch_tensorrt/dynamo/test_utils.py @@ -217,6 +217,7 @@ def generate_graph( expected_ops: Set[Callable], unexpected_ops: Optional[Set[Callable]] = None, customized_passes: List[Callable] = None, + disable_passes: bool = False, ): # Torchdynamo+aot proxytensor tracer # Below are common passes @@ -234,6 +235,10 @@ def generate_graph( # Combine with customized passes specific to any model if customized_passes: passes_list.extend(customized_passes) + + if disable_passes: + passes_list = [] + fx_module, _ = aten_tracer.trace(mod, original_inputs) for passes in passes_list: pr: PassResult = passes(fx_module) @@ -261,9 +266,17 @@ def run_test( atol=1e-03, precision=torch.float, check_dtype=True, + disable_passes=False, ): mod.eval() - mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None) + mod = self.generate_graph( + mod, + inputs, + expected_ops, + unexpected_ops, + None, + disable_passes=disable_passes, + ) if apply_passes is not None: pass_tracer = chain_passes(*apply_passes) @@ -293,10 +306,18 @@ def run_test_with_dynamic_shape( unexpected_ops=None, rtol=1e-03, atol=1e-03, + disable_passes=False, ): mod.eval() inputs = [spec.example_tensor("opt_shape") for spec in input_specs] - mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None) + mod = self.generate_graph( + mod, + inputs, + expected_ops, + unexpected_ops, + None, + disable_passes=disable_passes, + ) interp = TRTInterpreter( mod, diff --git a/tests/py/dynamo/converters/test_casts.py b/tests/py/dynamo/converters/test_casts.py new file mode 100644 index 0000000000..b36ae3fbfb --- /dev/null +++ b/tests/py/dynamo/converters/test_casts.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt.dynamo.conversion.trt_interpreter import ( + UnsupportedOperatorException, +) + + +class TestToCopyConverter(DispatchTestCase): + def test_to_copy_half(self): + class ToCopyHalf(nn.Module): + def forward(self, x): + y = x.to(dtype=torch.half) + return y + + inputs = [torch.rand((1, 3, 10))] + self.run_test( + ToCopyHalf(), + inputs, + expected_ops={torch.ops.aten._to_copy.default}, + precision=torch.half, + disable_passes=True, + ) + + def test_to_copy_float(self): + class ToCopyFloat(nn.Module): + def forward(self, x): + y = x.to(dtype=torch.float) + return y + + inputs = [torch.rand((1, 3, 10)).half()] + self.run_test( + ToCopyFloat(), + inputs, + expected_ops={torch.ops.aten._to_copy.default}, + precision=torch.float, + disable_passes=True, + ) + + def test_to_copy_unsupported(self): + class ToCopy64Bit(nn.Module): + def forward(self, x): + y = x.to(dtype=torch.int64) + return y + + inputs = [torch.randn((1, 3, 10)).int()] + + with self.assertRaises(UnsupportedOperatorException): + self.run_test( + ToCopy64Bit(), + inputs, + expected_ops={torch.ops.aten._to_copy.default}, + disable_passes=True, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_evaluators.py b/tests/py/dynamo/converters/test_evaluators.py new file mode 100644 index 0000000000..6cbf986e8c --- /dev/null +++ b/tests/py/dynamo/converters/test_evaluators.py @@ -0,0 +1,66 @@ +import operator +import unittest +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase + + +class TestCloneConverter(DispatchTestCase): + def test_clone_contiguous(self): + class Clone(nn.Module): + def forward(self, x): + y = torch.clone(x, memory_format=torch.contiguous_format) + return y + 1 + + inputs = [torch.randn((1, 3, 10))] + self.run_test( + Clone(), + inputs, + expected_ops={torch.ops.aten.clone.default}, + disable_passes=True, + ) + + def test_clone_regular(self): + class Clone(nn.Module): + def forward(self, x): + y = torch.clone(x) + return y + 1 + + inputs = [torch.randn((8, 2, 10))] + self.run_test( + Clone(), + inputs, + expected_ops={torch.ops.aten.clone.default}, + disable_passes=True, + ) + + +# TODO: Switch this test back to self.run_test once an implementation exists +# for a converter that returns a list, such as aten.split +@unittest.skip("Pending aten.split converter. Currently tested by E2E") +class TestGetItemConverter(DispatchTestCase): + def test_getitem(self): + class GetItem(nn.Module): + def forward(self, x): + lis = torch.split(x, 5) + b = operator.getitem(lis, 0) + c = operator.getitem(lis, 1) + d = b + c + return d + + inputs = [ + torch.randn((3, 3, 10)), + torch.randn((3, 3, 10)), + torch.randn((3, 3, 10)), + ] + self.run_test( + GetItem(), + inputs, + expected_ops={operator.getitem}, + disable_passes=True, + ) + + +if __name__ == "__main__": + run_tests()