Skip to content

Commit

Permalink
feat: Add _to_copy, operator.get and clone
Browse files Browse the repository at this point in the history
- Add ATen converters for key operators in the pipeline of multiple
models
- Add robust testing and patch issues in interpreter
- Add evaluator and casting utilities to the converter utils
  • Loading branch information
gs-olive committed Aug 2, 2023
1 parent ce06f6e commit e331363
Show file tree
Hide file tree
Showing 10 changed files with 335 additions and 14 deletions.
85 changes: 81 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import operator
from typing import Dict, Sequence, Tuple, Union
import torch
import tensorrt as trt
Expand All @@ -9,8 +10,10 @@
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
from torch_tensorrt.dynamo.conversion.converter_utils import cast_int_int_div_trt_tensor
from torch_tensorrt.dynamo.conversion.converter_utils import (
cast_trt_tensor,
cast_int_int_div_trt_tensor,
)

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -70,13 +73,13 @@ def aten_ops_div(
kwargs_new["input"].dtype == trt.int8 or kwargs_new["input"].dtype == trt.int32
):
kwargs_new["input"] = cast_trt_tensor(
network, kwargs_new["input"], trt.float32, name
network, kwargs_new["input"], trt.float32, name, target
)
elif isinstance(args[1], TRTTensor) and (
kwargs_new["other"].dtype == trt.int8 or kwargs_new["other"].dtype == trt.int32
):
kwargs_new["other"] = cast_trt_tensor(
network, kwargs_new["other"], trt.float32, name
network, kwargs_new["other"], trt.float32, name, target
)
rounding_mode = kwargs.get("rounding_mode")
if rounding_mode is None:
Expand Down Expand Up @@ -377,3 +380,77 @@ def aten_ops_permute(
args[0],
args[1],
)


def to_copy_dtype_validator(to_copy_node: Node):
allowed_casts = {torch.float, torch.int32, torch.bool, torch.int8, torch.float16}

# Validate input node has convertible kwargs
if "dtype" in to_copy_node.kwargs:
if to_copy_node.kwargs["dtype"] in allowed_casts:
return True
else:
_LOGGER.debug(
f"_to_copy converter rejected node {to_copy_node} with dtype {to_copy_node.kwargs['dtype']}"
)
return False
else:
_LOGGER.debug(
f"_to_copy converter rejected node {to_copy_node} with kwargs {to_copy_node.kwargs}"
)
return False


@dynamo_tensorrt_converter(
torch.ops.aten._to_copy.default, capability_validator=to_copy_dtype_validator
)
def aten_ops_to_copy_dtype(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.cast.to_copy(
network,
target,
SourceIR.ATEN,
name,
args[0],
kwargs["dtype"],
)


@dynamo_tensorrt_converter(operator.getitem)
def operator_getitem(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.evaluators.getitem(
network,
target,
SourceIR.ATEN,
name,
args[0],
args[1],
)


@dynamo_tensorrt_converter(torch.ops.aten.clone.default)
def aten_ops_clone(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.evaluators.clone(
network,
target,
SourceIR.ATEN,
name,
args[0],
)
21 changes: 18 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch

from torch.fx.node import _get_qualified_name, Target

from torch_tensorrt.fx.types import (
TRTDataType,
TRTNetwork,
Expand All @@ -12,7 +14,9 @@
)

import tensorrt as trt
from typing import List
from typing import List, Optional

from .._SourceIR import SourceIR


def dynamic_unsupported(node: torch.fx.Node) -> bool:
Expand Down Expand Up @@ -49,24 +53,35 @@ def cast_trt_tensor(
input_val: TRTTensor,
dtype: TRTDataType,
name: str,
target: Target = "",
source_ir: Optional[SourceIR] = None,
) -> TRTTensor:
"""
Given a TRT Tensor, convert that Tensor to the specified dtype
Adds an Identity layer to the network which performs the conversion
Args:
network (TRTNetwork): A TensorRT network
input_val (TRTTensor): A TRT Tensor to cast to a new data type
dtype (TRTDataType): The TRTDataType to cast the input Tensor to
dtype (TRTDataType, torch.dtype, np.dtype): The data type to cast the input Tensor to
name (str): Name of the calling layer
target (Target): Target of calling node
source_ir (SourceIR): SourceIR of calling converter
Returns:
A TensorRT ITensor which has been casted to the specified dtype
"""
trt_dtype = unified_dtype_converter(dtype, Frameworks.TRT)

if input_val.dtype != trt_dtype:
source_ir = source_ir if source_ir is not None else SourceIR.UNKNOWN
target_name = (
f"{source_ir}_ops{'.' + target if target else ''}"
if (isinstance(target, str))
else f"{source_ir}_ops.{_get_qualified_name(target)}"
)

identity_layer = network.add_identity(input_val)
identity_layer.set_output_type(0, trt_dtype)
identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - {name}"
identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} -{name}-[{target_name}]-[{name}]"
return identity_layer.get_output(0)
else:
return input_val
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@
from . import squeeze
from . import unsqueeze
from . import permutation
from . import cast
from . import evaluators
28 changes: 28 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/cast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Optional
from torch.fx.node import Target

from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor

from torch_tensorrt.fx.types import (
TRTNetwork,
TRTTensor,
TRTDataType,
)


def to_copy(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
dtype: TRTDataType,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
f"to_copy received input {input} that is not a TensorRT ITensor"
)

casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir)
return casted_tensor
8 changes: 6 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,13 @@ def convert_binary_elementwise(
trt_promoted_type = unified_dtype_converter(promoted_type, Frameworks.TRT)

if trt_promoted_type != lhs_val.dtype:
lhs_val = cast_trt_tensor(network, lhs_val, trt_promoted_type, name)
lhs_val = cast_trt_tensor(
network, lhs_val, trt_promoted_type, name, target, source_ir
)
if trt_promoted_type != rhs_val.dtype:
rhs_val = cast_trt_tensor(network, rhs_val, trt_promoted_type, name)
rhs_val = cast_trt_tensor(
network, rhs_val, trt_promoted_type, name, target, source_ir
)

# Check the limitation in the doc string.
if network.has_implicit_batch_dimension:
Expand Down
45 changes: 45 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/evaluators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import operator
import logging
from typing import Optional, Sequence
from torch.fx.node import Target

from torch_tensorrt.dynamo._SourceIR import SourceIR

from torch_tensorrt.fx.types import (
TRTNetwork,
TRTTensor,
)


LOGGER: logging.Logger = logging.getLogger(__name__)


def getitem(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: Sequence[TRTTensor],
index: int,
) -> TRTTensor:
LOGGER.debug(f"Evaluating getitem on object with name: {name}")

# Directly index the input sequence and return the value
return operator.getitem(input, index)


def clone(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
) -> TRTTensor:
if not isinstance(input, TRTTensor):
raise RuntimeError(
f"clone received input {input} that is not a TensorRT ITensor"
)

LOGGER.debug(f"Evaluating clone on object with name: {name}")

return input
10 changes: 7 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/trt_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")


class UnsupportedOperatorException(RuntimeError):
pass


class TRTInterpreterResult(NamedTuple):
engine: Any
input_names: Sequence[str]
Expand Down Expand Up @@ -288,7 +292,7 @@ def call_module(self, target, args, kwargs):
converter = CONVERTERS.get(self._cur_node)

if not converter:
raise RuntimeError(
raise UnsupportedOperatorException(
f"Conversion of module of type {submod_type} not currently supported!"
)

Expand All @@ -298,7 +302,7 @@ def call_module(self, target, args, kwargs):
def call_function(self, target, args, kwargs):
converter = CONVERTERS.get(self._cur_node)
if not converter:
raise RuntimeError(
raise UnsupportedOperatorException(
f"Conversion of function {torch.typename(target)} not currently supported!"
)

Expand All @@ -310,7 +314,7 @@ def call_method(self, target, args, kwargs):
converter = CONVERTERS.get(self._cur_node)

if not converter:
raise RuntimeError(
raise UnsupportedOperatorException(
f"Conversion of method {target} not currently supported!"
)

Expand Down
25 changes: 23 additions & 2 deletions py/torch_tensorrt/dynamo/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def generate_graph(
expected_ops: Set[Callable],
unexpected_ops: Optional[Set[Callable]] = None,
customized_passes: List[Callable] = None,
disable_passes: bool = False,
):
# Torchdynamo+aot proxytensor tracer
# Below are common passes
Expand All @@ -234,6 +235,10 @@ def generate_graph(
# Combine with customized passes specific to any model
if customized_passes:
passes_list.extend(customized_passes)

if disable_passes:
passes_list = []

fx_module, _ = aten_tracer.trace(mod, original_inputs)
for passes in passes_list:
pr: PassResult = passes(fx_module)
Expand Down Expand Up @@ -261,9 +266,17 @@ def run_test(
atol=1e-03,
precision=torch.float,
check_dtype=True,
disable_passes=False,
):
mod.eval()
mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None)
mod = self.generate_graph(
mod,
inputs,
expected_ops,
unexpected_ops,
None,
disable_passes=disable_passes,
)

if apply_passes is not None:
pass_tracer = chain_passes(*apply_passes)
Expand Down Expand Up @@ -293,10 +306,18 @@ def run_test_with_dynamic_shape(
unexpected_ops=None,
rtol=1e-03,
atol=1e-03,
disable_passes=False,
):
mod.eval()
inputs = [spec.example_tensor("opt_shape") for spec in input_specs]
mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None)
mod = self.generate_graph(
mod,
inputs,
expected_ops,
unexpected_ops,
None,
disable_passes=disable_passes,
)

interp = TRTInterpreter(
mod,
Expand Down
Loading

0 comments on commit e331363

Please sign in to comment.