Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add _to_copy, operator.get and clone ATen converters #2161

Merged
merged 4 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")


class UnsupportedOperatorException(RuntimeError):
pass


class TRTInterpreterResult(NamedTuple):
engine: Any
input_names: Sequence[str]
Expand Down Expand Up @@ -301,7 +305,7 @@ def call_module(
converter = CONVERTERS.get(self._cur_node)

if not converter:
raise RuntimeError(
raise UnsupportedOperatorException(
f"Conversion of module of type {submod_type} not currently supported!"
)

Expand All @@ -312,7 +316,7 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
# TODO: Why is this stateful? We should be able to take in the inputs
converter = CONVERTERS.get(self._cur_node)
if not converter:
raise RuntimeError(
raise UnsupportedOperatorException(
f"Conversion of function {torch.typename(target)} not currently supported!"
)

Expand All @@ -324,7 +328,7 @@ def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
converter = CONVERTERS.get(self._cur_node)

if not converter:
raise RuntimeError(
raise UnsupportedOperatorException(
f"Conversion of method {target} not currently supported!"
)

Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ._TRTInterpreter import * # noqa: F403
from .aten_ops_converters import * # noqa: F403
from .conversion import * # noqa: F403
from .op_evaluators import * # noqa: F403
from .truncate_long_and_double import repair_long_or_double_inputs
65 changes: 60 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import tensorrt as trt
import torch
from torch.fx.node import Argument, Node, Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
Expand All @@ -12,8 +13,6 @@
from torch_tensorrt.fx.converters import acc_ops_converters
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

import tensorrt as trt

from .converter_registry import dynamo_tensorrt_converter

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -76,13 +75,13 @@ def aten_ops_div(
kwargs_new["input"].dtype == trt.int8 or kwargs_new["input"].dtype == trt.int32
):
kwargs_new["input"] = cast_trt_tensor(
network, kwargs_new["input"], trt.float32, name
network, kwargs_new["input"], trt.float32, name, target
)
elif isinstance(args[1], TRTTensor) and (
kwargs_new["other"].dtype == trt.int8 or kwargs_new["other"].dtype == trt.int32
):
kwargs_new["other"] = cast_trt_tensor(
network, kwargs_new["other"], trt.float32, name
network, kwargs_new["other"], trt.float32, name, target
)
rounding_mode = kwargs.get("rounding_mode")
if rounding_mode is None:
Expand All @@ -101,7 +100,7 @@ def aten_ops_div(
)


def embedding_param_validator(embedding_node: Node):
def embedding_param_validator(embedding_node: Node) -> bool:
scale_grad_by_freq = args_bounds_check(embedding_node.args, 3)
sparse = args_bounds_check(embedding_node.args, 4)

Expand Down Expand Up @@ -365,3 +364,59 @@ def aten_ops_permute(
args[0],
args[1],
)

gs-olive marked this conversation as resolved.
Show resolved Hide resolved

def to_copy_dtype_validator(to_copy_node: Node) -> bool:
gs-olive marked this conversation as resolved.
Show resolved Hide resolved
allowed_casts = {torch.float, torch.int32, torch.bool, torch.int8, torch.float16}

# Validate input node has convertible kwargs
if "dtype" in to_copy_node.kwargs:
if to_copy_node.kwargs["dtype"] in allowed_casts:
return True
else:
_LOGGER.debug(
f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}"
)
return False
else:
_LOGGER.debug(
f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}"
)
return False


@dynamo_tensorrt_converter(
torch.ops.aten._to_copy.default, capability_validator=to_copy_dtype_validator
)
def aten_ops_to_copy_dtype(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.cast.to_copy(
network,
target,
SourceIR.ATEN,
name,
args[0],
kwargs["dtype"],
)


@dynamo_tensorrt_converter(torch.ops.aten.clone.default)
def aten_ops_clone(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.cast.clone(
network,
target,
SourceIR.ATEN,
name,
args[0],
)
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/converter_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def dynamo_tensorrt_converter(
enabled: bool = True,
capability_validator: Optional[Callable[[Node], bool]] = None,
priority: ConverterPriority = ConverterPriority.STANDARD,
) -> Callable[[Any], Any]:
) -> Callable[[Any], Union[TRTTensor, Sequence[TRTTensor]]]:
"""Decorator for Dynamo TensorRT Converter

Registers the decorated function in the DYNAMO_ATEN_CONVERTERS registry
Expand Down
18 changes: 15 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import logging
import re
from typing import List
from typing import List, Optional

import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt.fx.converters.converter_utils import (
Frameworks,
unified_dtype_converter,
)
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor

from .._SourceIR import SourceIR
from .converter_registry import ConverterRegistry

_LOGGER: logging.Logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -71,24 +75,32 @@ def cast_trt_tensor(
input_val: TRTTensor,
dtype: TRTDataType,
name: str,
target: Target = "",
source_ir: Optional[SourceIR] = None,
) -> TRTTensor:
"""
Given a TRT Tensor, convert that Tensor to the specified dtype
Adds an Identity layer to the network which performs the conversion
Args:
network (TRTNetwork): A TensorRT network
input_val (TRTTensor): A TRT Tensor to cast to a new data type
dtype (TRTDataType): The TRTDataType to cast the input Tensor to
dtype (TRTDataType, torch.dtype, np.dtype): The data type to cast the input Tensor to
name (str): Name of the calling layer
target (Target): Target of calling node
source_ir (SourceIR): SourceIR of calling converter
Returns:
A TensorRT ITensor which has been casted to the specified dtype
"""
trt_dtype = unified_dtype_converter(dtype, Frameworks.TRT)

if input_val.dtype != trt_dtype:
source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN
target_str = ConverterRegistry.qualified_name_or_str(target)
target_name = f"{source_ir}_ops{('.' + target_str) if target_str else ''}"

identity_layer = network.add_identity(input_val)
identity_layer.set_output_type(0, trt_dtype)
identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - {name}"
identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - [{target_name}]-[{name}]"
return identity_layer.get_output(0)
else:
return input_val
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from . import (
activation,
cast,
condition,
elementwise,
embedding,
Expand Down
43 changes: 43 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import logging
from typing import Optional

from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor

LOGGER: logging.Logger = logging.getLogger(__name__)


def to_copy(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dtype: TRTDataType,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
f"to_copy received input {input} that is not a TensorRT ITensor"
)

casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir)
return casted_tensor


def clone(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
f"clone received input {input} that is not a TensorRT ITensor"
)

LOGGER.debug(f"Evaluating clone on object with name: {name}")

return input
11 changes: 7 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from typing import Any, Callable, Optional, Union

import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
Expand All @@ -15,8 +16,6 @@
from torch_tensorrt.fx.types import TRTElementWiseOp, TRTNetwork, TRTTensor
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter

import tensorrt as trt


def get_python_op_from_trt_elementwise_op(
trt_op: TRTElementWiseOp,
Expand Down Expand Up @@ -132,9 +131,13 @@ def convert_binary_elementwise(
trt_promoted_type = unified_dtype_converter(promoted_type, Frameworks.TRT)

if trt_promoted_type != lhs_val.dtype:
lhs_val = cast_trt_tensor(network, lhs_val, trt_promoted_type, name)
lhs_val = cast_trt_tensor(
network, lhs_val, trt_promoted_type, name, target, source_ir
)
if trt_promoted_type != rhs_val.dtype:
rhs_val = cast_trt_tensor(network, rhs_val, trt_promoted_type, name)
rhs_val = cast_trt_tensor(
network, rhs_val, trt_promoted_type, name, target, source_ir
)

# Check the limitation in the doc string.
if network.has_implicit_batch_dimension:
Expand Down
32 changes: 32 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/op_evaluators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import logging
import operator
from typing import Dict, Sequence, Tuple, Union

from torch.fx.node import Argument, Node, Target
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

from .converter_registry import ConverterRegistry, dynamo_tensorrt_converter

_LOGGER: logging.Logger = logging.getLogger(__name__)


def getitem_validator(getitem_node: Node) -> bool:
from torch_tensorrt.dynamo.conversion.converter_registry import DYNAMO_CONVERTERS

# Getitem nodes can only be converted if their parent node also can
return getitem_node.args[0] in DYNAMO_CONVERTERS


# TODO: Subsequent evaluators should be registered here with their own validators
@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator)
def generic_evaluator(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
_LOGGER.debug(
f"Evaluating {ConverterRegistry.qualified_name_or_str(target)} on object with name: {name}"
)
return target(*args)
24 changes: 24 additions & 0 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from typing import Any, Callable, Dict, Optional, Sequence

import torch
import torch_tensorrt
from torch_tensorrt._Device import Device
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import CompilationSettings
from torch_tensorrt.dynamo._defaults import PRECISION

from packaging import version

Expand Down Expand Up @@ -161,6 +163,28 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:
if settings.debug:
logger.setLevel(logging.DEBUG)

# TODO: Remove once Dynamo precisions refactoring is complete
if "enabled_precisions" in kwargs:
enabled_precisions = kwargs["enabled_precisions"]

if (
torch.float16 in enabled_precisions
or torch_tensorrt.dtype.half in enabled_precisions
):
settings.precision = torch.float16
elif (
torch.float32 in enabled_precisions
or torch_tensorrt.dtype.float in enabled_precisions
):
settings.precision = torch.float32
elif len(enabled_precisions) == 0:
logger.info(f"No precision specified, defaulting to {PRECISION}")
settings.precision = PRECISION
else:
raise ValueError(
f"Precision {enabled_precisions} not supported in the Dynamo Path"
)

# Parse input runtime specification
settings.use_python_runtime = use_python_runtime_parser(settings.use_python_runtime)

Expand Down
Loading
Loading