Skip to content

Commit

Permalink
feat: Add _to_copy, operator.get and clone ATen converters (#2161)
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive authored Aug 17, 2023
1 parent 06e544e commit 6a69c6a
Show file tree
Hide file tree
Showing 14 changed files with 352 additions and 40 deletions.
10 changes: 7 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")


class UnsupportedOperatorException(RuntimeError):
pass


class TRTInterpreterResult(NamedTuple):
engine: Any
input_names: Sequence[str]
Expand Down Expand Up @@ -301,7 +305,7 @@ def call_module(
converter = CONVERTERS.get(self._cur_node)

if not converter:
raise RuntimeError(
raise UnsupportedOperatorException(
f"Conversion of module of type {submod_type} not currently supported!"
)

Expand All @@ -312,7 +316,7 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
# TODO: Why is this stateful? We should be able to take in the inputs
converter = CONVERTERS.get(self._cur_node)
if not converter:
raise RuntimeError(
raise UnsupportedOperatorException(
f"Conversion of function {torch.typename(target)} not currently supported!"
)

Expand All @@ -324,7 +328,7 @@ def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
converter = CONVERTERS.get(self._cur_node)

if not converter:
raise RuntimeError(
raise UnsupportedOperatorException(
f"Conversion of method {target} not currently supported!"
)

Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ._TRTInterpreter import * # noqa: F403
from .aten_ops_converters import * # noqa: F403
from .conversion import * # noqa: F403
from .op_evaluators import * # noqa: F403
from .truncate_long_and_double import repair_long_or_double_inputs
65 changes: 60 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import tensorrt as trt
import torch
from torch.fx.node import Argument, Node, Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
Expand All @@ -12,8 +13,6 @@
from torch_tensorrt.fx.converters import acc_ops_converters
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

import tensorrt as trt

from .converter_registry import dynamo_tensorrt_converter

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -76,13 +75,13 @@ def aten_ops_div(
kwargs_new["input"].dtype == trt.int8 or kwargs_new["input"].dtype == trt.int32
):
kwargs_new["input"] = cast_trt_tensor(
network, kwargs_new["input"], trt.float32, name
network, kwargs_new["input"], trt.float32, name, target
)
elif isinstance(args[1], TRTTensor) and (
kwargs_new["other"].dtype == trt.int8 or kwargs_new["other"].dtype == trt.int32
):
kwargs_new["other"] = cast_trt_tensor(
network, kwargs_new["other"], trt.float32, name
network, kwargs_new["other"], trt.float32, name, target
)
rounding_mode = kwargs.get("rounding_mode")
if rounding_mode is None:
Expand All @@ -101,7 +100,7 @@ def aten_ops_div(
)


def embedding_param_validator(embedding_node: Node):
def embedding_param_validator(embedding_node: Node) -> bool:
scale_grad_by_freq = args_bounds_check(embedding_node.args, 3)
sparse = args_bounds_check(embedding_node.args, 4)

Expand Down Expand Up @@ -365,3 +364,59 @@ def aten_ops_permute(
args[0],
args[1],
)


def to_copy_dtype_validator(to_copy_node: Node) -> bool:
allowed_casts = {torch.float, torch.int32, torch.bool, torch.int8, torch.float16}

# Validate input node has convertible kwargs
if "dtype" in to_copy_node.kwargs:
if to_copy_node.kwargs["dtype"] in allowed_casts:
return True
else:
_LOGGER.debug(
f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}"
)
return False
else:
_LOGGER.debug(
f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}"
)
return False


@dynamo_tensorrt_converter(
torch.ops.aten._to_copy.default, capability_validator=to_copy_dtype_validator
)
def aten_ops_to_copy_dtype(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.cast.to_copy(
network,
target,
SourceIR.ATEN,
name,
args[0],
kwargs["dtype"],
)


@dynamo_tensorrt_converter(torch.ops.aten.clone.default)
def aten_ops_clone(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.cast.clone(
network,
target,
SourceIR.ATEN,
name,
args[0],
)
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/converter_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def dynamo_tensorrt_converter(
enabled: bool = True,
capability_validator: Optional[Callable[[Node], bool]] = None,
priority: ConverterPriority = ConverterPriority.STANDARD,
) -> Callable[[Any], Any]:
) -> Callable[[Any], Union[TRTTensor, Sequence[TRTTensor]]]:
"""Decorator for Dynamo TensorRT Converter
Registers the decorated function in the DYNAMO_ATEN_CONVERTERS registry
Expand Down
18 changes: 15 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import logging
import re
from typing import List
from typing import List, Optional

import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt.fx.converters.converter_utils import (
Frameworks,
unified_dtype_converter,
)
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor

from .._SourceIR import SourceIR
from .converter_registry import ConverterRegistry

_LOGGER: logging.Logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -71,24 +75,32 @@ def cast_trt_tensor(
input_val: TRTTensor,
dtype: TRTDataType,
name: str,
target: Target = "",
source_ir: Optional[SourceIR] = None,
) -> TRTTensor:
"""
Given a TRT Tensor, convert that Tensor to the specified dtype
Adds an Identity layer to the network which performs the conversion
Args:
network (TRTNetwork): A TensorRT network
input_val (TRTTensor): A TRT Tensor to cast to a new data type
dtype (TRTDataType): The TRTDataType to cast the input Tensor to
dtype (TRTDataType, torch.dtype, np.dtype): The data type to cast the input Tensor to
name (str): Name of the calling layer
target (Target): Target of calling node
source_ir (SourceIR): SourceIR of calling converter
Returns:
A TensorRT ITensor which has been casted to the specified dtype
"""
trt_dtype = unified_dtype_converter(dtype, Frameworks.TRT)

if input_val.dtype != trt_dtype:
source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN
target_str = ConverterRegistry.qualified_name_or_str(target)
target_name = f"{source_ir}_ops{('.' + target_str) if target_str else ''}"

identity_layer = network.add_identity(input_val)
identity_layer.set_output_type(0, trt_dtype)
identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - {name}"
identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - [{target_name}]-[{name}]"
return identity_layer.get_output(0)
else:
return input_val
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from . import (
activation,
cast,
condition,
elementwise,
embedding,
Expand Down
43 changes: 43 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import logging
from typing import Optional

from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor

LOGGER: logging.Logger = logging.getLogger(__name__)


def to_copy(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dtype: TRTDataType,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
f"to_copy received input {input} that is not a TensorRT ITensor"
)

casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir)
return casted_tensor


def clone(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
f"clone received input {input} that is not a TensorRT ITensor"
)

LOGGER.debug(f"Evaluating clone on object with name: {name}")

return input
11 changes: 7 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from typing import Any, Callable, Optional, Union

import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
Expand All @@ -15,8 +16,6 @@
from torch_tensorrt.fx.types import TRTElementWiseOp, TRTNetwork, TRTTensor
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter

import tensorrt as trt


def get_python_op_from_trt_elementwise_op(
trt_op: TRTElementWiseOp,
Expand Down Expand Up @@ -132,9 +131,13 @@ def convert_binary_elementwise(
trt_promoted_type = unified_dtype_converter(promoted_type, Frameworks.TRT)

if trt_promoted_type != lhs_val.dtype:
lhs_val = cast_trt_tensor(network, lhs_val, trt_promoted_type, name)
lhs_val = cast_trt_tensor(
network, lhs_val, trt_promoted_type, name, target, source_ir
)
if trt_promoted_type != rhs_val.dtype:
rhs_val = cast_trt_tensor(network, rhs_val, trt_promoted_type, name)
rhs_val = cast_trt_tensor(
network, rhs_val, trt_promoted_type, name, target, source_ir
)

# Check the limitation in the doc string.
if network.has_implicit_batch_dimension:
Expand Down
32 changes: 32 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/op_evaluators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import logging
import operator
from typing import Dict, Sequence, Tuple, Union

from torch.fx.node import Argument, Node, Target
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

from .converter_registry import ConverterRegistry, dynamo_tensorrt_converter

_LOGGER: logging.Logger = logging.getLogger(__name__)


def getitem_validator(getitem_node: Node) -> bool:
from torch_tensorrt.dynamo.conversion.converter_registry import DYNAMO_CONVERTERS

# Getitem nodes can only be converted if their parent node also can
return getitem_node.args[0] in DYNAMO_CONVERTERS


# TODO: Subsequent evaluators should be registered here with their own validators
@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator)
def generic_evaluator(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
_LOGGER.debug(
f"Evaluating {ConverterRegistry.qualified_name_or_str(target)} on object with name: {name}"
)
return target(*args)
24 changes: 24 additions & 0 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from typing import Any, Callable, Dict, Optional, Sequence

import torch
import torch_tensorrt
from torch_tensorrt._Device import Device
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import CompilationSettings
from torch_tensorrt.dynamo._defaults import PRECISION

from packaging import version

Expand Down Expand Up @@ -161,6 +163,28 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:
if settings.debug:
logger.setLevel(logging.DEBUG)

# TODO: Remove once Dynamo precisions refactoring is complete
if "enabled_precisions" in kwargs:
enabled_precisions = kwargs["enabled_precisions"]

if (
torch.float16 in enabled_precisions
or torch_tensorrt.dtype.half in enabled_precisions
):
settings.precision = torch.float16
elif (
torch.float32 in enabled_precisions
or torch_tensorrt.dtype.float in enabled_precisions
):
settings.precision = torch.float32
elif len(enabled_precisions) == 0:
logger.info(f"No precision specified, defaulting to {PRECISION}")
settings.precision = PRECISION
else:
raise ValueError(
f"Precision {enabled_precisions} not supported in the Dynamo Path"
)

# Parse input runtime specification
settings.use_python_runtime = use_python_runtime_parser(settings.use_python_runtime)

Expand Down
Loading

0 comments on commit 6a69c6a

Please sign in to comment.