Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: cherry pick of Refit fixes #3166

Merged
merged 1 commit into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,11 +314,9 @@ def compile_module(
dryrun_tracker = DryRunTracker()
if sample_kwarg_inputs is None:
sample_kwarg_inputs = {}
# Assume converters support dynamic shapes and disable validation
CONVERTERS.set_dynamic_shape_support(settings.assume_dynamic_shape_support)

# Set torch-executed ops
CONVERTERS.set_disallowed_targets(settings.torch_executed_ops)
# Configure user compilation settings to converters.
CONVERTERS.set_compilation_settings(settings)

# Check the number of supported operations in the graph
num_supported_ops, total_ops = partitioning.get_graph_converter_support(
Expand Down Expand Up @@ -670,8 +668,8 @@ def convert_exported_program_to_serialized_trt_engine(
settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)

# Assume converters support dynamic shapes and disable validation
CONVERTERS.set_dynamic_shape_support(settings.assume_dynamic_shape_support)
# Configure user compilation settings to converters.
CONVERTERS.set_compilation_settings(settings)

try:
interpreter_result = interpret_module_to_result(
Expand Down
3 changes: 1 addition & 2 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, List, Optional, Sequence, Tuple

import numpy as np
import tensorrt as trt
import torch
from torch.export import ExportedProgram
from torch_tensorrt._enums import dtype
Expand Down Expand Up @@ -42,8 +43,6 @@
)
from torch_tensorrt.logging import TRT_LOGGER

import tensorrt as trt

logger = logging.getLogger(__name__)


Expand Down
40 changes: 27 additions & 13 deletions py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch import SymBool, SymFloat, SymInt
from torch._ops import OpOverloadPacket
from torch.fx.node import Argument, Node, Target, _get_qualified_name
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.converter_registry import CONVERTERS as FX_CONVERTERS

Expand Down Expand Up @@ -82,7 +83,9 @@ class ConverterSupport:
"""

converter_implementation: ConverterImplSignature
capability_validator: Callable[[Node], bool] = field(default=lambda node: True)
capability_validator: Callable[[Node, CompilationSettings], bool] = field(
default=lambda node, compilation_settings: True
)
supports_dynamic_shapes: bool = False


Expand Down Expand Up @@ -112,18 +115,20 @@ def has_dynamic_shapes_in_args(

def has_static_shapes_in_args(
arg_positions_to_check: Optional[List[int]] = None,
) -> Callable[[torch.fx.Node], bool]:
) -> Callable[[torch.fx.Node, CompilationSettings], bool]:
"""Returns True if a node has static inputs in node.args at specified positions"""
_has_static_shapes = lambda node, arg_positions_to_check: not _has_dynamic_shapes(
node, arg_positions_to_check
_has_static_shapes = lambda node, compilation_settings, arg_positions_to_check: not _has_dynamic_shapes(
node, compilation_settings, arg_positions_to_check
)
return functools.partial(
_has_static_shapes, arg_positions_to_check=arg_positions_to_check
)


def _has_dynamic_shapes(
node: torch.fx.Node, arg_positions_to_check: Optional[List[int]] = None
node: torch.fx.Node,
compilation_settings: CompilationSettings = None,
arg_positions_to_check: Optional[List[int]] = None,
) -> bool:
# Validate that none of the inputs to the node have Dynamic shapes
assert isinstance(
Expand Down Expand Up @@ -188,7 +193,7 @@ def dynamo_tensorrt_converter(
key: Target,
*,
enabled: bool = True,
capability_validator: Optional[Callable[[Node], bool]] = None,
capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None,
priority: ConverterPriority = ConverterPriority.STANDARD,
supports_dynamic_shapes: bool = False,
) -> Callable[[ConverterImplSignature], ConverterImplSignature]:
Expand Down Expand Up @@ -297,7 +302,6 @@ def __init__(
],
registry_names: Optional[Sequence[str]] = None,
registry_calling_conventions: Optional[Sequence[CallingConvention]] = None,
assume_dynamic_shape_support: bool = False,
):
# Copy reference to each dictionary object into attribute list
self.registries = list(registries)
Expand All @@ -318,12 +322,16 @@ def __init__(
CallingConvention.CTX for _ in range(len(self.registries))
]

self.compilation_settings: CompilationSettings = None
self.disallowed_targets: Collection[Target] = set()
self.assume_dynamic_shape_support = assume_dynamic_shape_support
self.validate_invariants()

def set_dynamic_shape_support(self, assume_dynamic_shape_support: bool) -> None:
self.assume_dynamic_shape_support = assume_dynamic_shape_support
def set_compilation_settings(
self, compilation_settings: CompilationSettings
) -> None:
self.compilation_settings = compilation_settings
# set torch executed ops as disallowed targets
self.set_disallowed_targets(compilation_settings.torch_executed_ops)

def set_disallowed_targets(self, torch_executed_ops: Collection[Target]) -> None:
self.disallowed_targets = torch_executed_ops
Expand Down Expand Up @@ -412,7 +420,11 @@ def __getitem__(

self.validate_invariants()
key = node.target

assume_dynamic_shape_support = False
if self.compilation_settings:
assume_dynamic_shape_support = (
self.compilation_settings.assume_dynamic_shape_support
)
if (
key in self.disallowed_targets
or self.qualified_name_or_str(key) in self.disallowed_targets
Expand All @@ -436,8 +448,10 @@ def __getitem__(
# 2) Assume dynamic_shape support is True
# 3) Node only has static shaped inputs
# 4) Node has dynamic inputs and the converter has supports_dynamic_shapes=True
if candidate.capability_validator(node) and (
self.assume_dynamic_shape_support
if candidate.capability_validator(
node, self.compilation_settings
) and (
assume_dynamic_shape_support
or not node_has_dynamic_shapes(node)
or candidate.supports_dynamic_shapes
):
Expand Down
8 changes: 6 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)

import numpy as np
import tensorrt as trt
import torch
import torch.fx
from torch.fx.node import _get_qualified_name
Expand All @@ -43,7 +44,6 @@
from torch_tensorrt.fx.observer import Observer
from torch_tensorrt.logging import TRT_LOGGER

import tensorrt as trt
from packaging import version

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -89,6 +89,11 @@ def __init__(
self.builder.create_network(flag), compilation_settings
)

self.compilation_settings = compilation_settings
if not CONVERTERS.compilation_settings:
# Configure user compilation settings to converters.
CONVERTERS.set_compilation_settings(compilation_settings)

assert TRTInterpreter._all_precisions_supported(
compilation_settings.enabled_precisions
), f"Attempted to enable kernel precisions that are not supported (got: {compilation_settings.enabled_precisions}, support: {_defaults.SUPPORTED_KERNEL_PRECISIONS})"
Expand Down Expand Up @@ -117,7 +122,6 @@ def __init__(
self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
dict()
)
self.compilation_settings = compilation_settings

# Data types for TRT Module output Tensors
self.output_dtypes = (
Expand Down
66 changes: 47 additions & 19 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import torch
from torch.fx.node import Argument, Node, Target
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
Expand Down Expand Up @@ -48,7 +49,7 @@ def get_ir(target: Target) -> SourceIR:
return SourceIR.UNKNOWN


def one_user_validator(node: Node) -> bool:
def one_user_validator(node: Node, settings: CompilationSettings = None) -> bool:
# Validate only one user, which is a getitem node that accesses the first element in the list
return (
len(node.users) == 1
Expand Down Expand Up @@ -270,7 +271,11 @@ def aten_ops_embedding(
)


def embedding_bag_validator(node: Node) -> bool:
def embedding_bag_validator(node: Node, settings: CompilationSettings = None) -> bool:
# Embedding bag op is not refitable
if settings.make_refittable:
return False

if not one_user_validator(node):
return False
meta = node.args[1].meta
Expand Down Expand Up @@ -416,7 +421,7 @@ def aten_ops_symsize_int(
return impl.shape.shape(ctx, target, SourceIR.ATEN, name, args[0], args[1])


def index_dtype_validator(node: Node) -> bool:
def index_dtype_validator(node: Node, settings: CompilationSettings = None) -> bool:
index = node.args[1]
for ind in index:
if ind is not None:
Expand Down Expand Up @@ -837,7 +842,7 @@ def aten_ops_select(
)


def index_put_validator(node: Node) -> bool:
def index_put_validator(node: Node, settings: CompilationSettings = None) -> bool:
if args_bounds_check(node.args, 3, False): # Check if accumulate is valid
_LOGGER.debug("We do not support accumulate=True for aten.index_put operation")
accumulate_valid = False
Expand Down Expand Up @@ -924,7 +929,18 @@ def aten_ops_slice(
)


@dynamo_tensorrt_converter(torch.ops.aten.cumsum.default, supports_dynamic_shapes=True)
def refit_validator(node: Node, settings: CompilationSettings = None) -> bool:
# cumsum op is not refitable
if settings and settings.make_refittable:
return False
return True


@dynamo_tensorrt_converter(
torch.ops.aten.cumsum.default,
capability_validator=refit_validator,
supports_dynamic_shapes=True,
)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down Expand Up @@ -970,7 +986,7 @@ def aten_ops_tile(
)


def zero_output_validator(node: Node) -> bool:
def zero_output_validator(node: Node, settings: CompilationSettings = None) -> bool:
if 0 in node.args[1]:
_LOGGER.debug(
f"We do not support output tensor {node.args[1]} tensors with zero-sized dimensions for this operation."
Expand Down Expand Up @@ -1027,7 +1043,9 @@ def aten_ops_permute(
)


def to_copy_dtype_validator(placeholder_only: bool) -> Callable[[Node], bool]:
def to_copy_dtype_validator(
placeholder_only: bool, settings: CompilationSettings = None
) -> Callable[[Node, CompilationSettings], bool]:
"""Return validator for to_copy node with placeholder restrictions"""

def validate_dtype(to_copy_node: Node) -> bool:
Expand Down Expand Up @@ -1059,7 +1077,7 @@ def validate_dtype(to_copy_node: Node) -> bool:
)
return False

def validator(to_copy_node: Node) -> bool:
def validator(to_copy_node: Node, settings: CompilationSettings = None) -> bool:
"""Returns true if the to_copy node can be converted to TRT
and the placeholder restriction is satisfied
"""
Expand All @@ -1074,7 +1092,9 @@ def validator(to_copy_node: Node) -> bool:

@dynamo_tensorrt_converter(
torch.ops.aten.clone.default,
capability_validator=lambda node: not is_only_operator_on_placeholder(node),
capability_validator=lambda node, settings: not is_only_operator_on_placeholder(
node, settings
),
supports_dynamic_shapes=True,
)
@dynamo_tensorrt_converter(
Expand Down Expand Up @@ -2128,7 +2148,7 @@ def aten_ops_logical_xor(
)


def bitwise_type_validator(node: Node) -> bool:
def bitwise_type_validator(node: Node, settings: CompilationSettings = None) -> bool:
supported_type = [torch.bool, bool]

tensor_targets = [
Expand Down Expand Up @@ -2271,7 +2291,9 @@ def aten_ops_bitwise_xor(
)


def bitwise_not_type_validator(node: Node) -> bool:
def bitwise_not_type_validator(
node: Node, settings: CompilationSettings = None
) -> bool:
val = node.args[0]
val_meta = val.meta.get("tensor_meta")

Expand Down Expand Up @@ -2453,7 +2475,7 @@ def aten_ops_le(
)


def conv_param_validator(conv_node: Node) -> bool:
def conv_param_validator(conv_node: Node, settings: CompilationSettings = None) -> bool:
return conv_node.args[7] in ([0], [0, 0], [0, 0, 0])


Expand Down Expand Up @@ -2549,7 +2571,9 @@ def aten_ops_cdist_forward(
)


def avg_pool_param_validator(pool_node: Node) -> bool:
def avg_pool_param_validator(
pool_node: Node, settings: CompilationSettings = None
) -> bool:
ceil_mode = args_bounds_check(pool_node.args, 4, False)
divisor_override = args_bounds_check(pool_node.args, 6)

Expand Down Expand Up @@ -2665,12 +2689,12 @@ def aten_ops_adaptive_avg_poolNd(
)


def topk_validator(node: Node) -> bool:
def topk_validator(node: Node, settings: CompilationSettings = None) -> bool:
k = node.args[1]
return topk_sort_validator(k)


def sort_validator(node: Node) -> bool:
def sort_validator(node: Node, settings: CompilationSettings = None) -> bool:
meta_data = node.args[0].meta.get("tensor_meta")
if meta_data is None:
return False
Expand All @@ -2692,7 +2716,9 @@ def topk_sort_validator(k: int) -> bool:
return True


def max_pool_param_validator(pool_node: Node) -> bool:
def max_pool_param_validator(
pool_node: Node, settings: CompilationSettings = None
) -> bool:
dilation = args_bounds_check(pool_node.args, 4, 1)
ceil_mode = args_bounds_check(pool_node.args, 5, False)

Expand Down Expand Up @@ -2746,7 +2772,7 @@ def aten_ops_max_pool(
)


def attention_validator(node: Node) -> bool:
def attention_validator(node: Node, settings: CompilationSettings = None) -> bool:
# Currently, `attn_mask` is not supported
return args_bounds_check(node.args, 3) is None

Expand Down Expand Up @@ -3637,7 +3663,7 @@ def aten_ops_flip(
)


def zero_diag_size_validator(node: Node) -> bool:
def zero_diag_size_validator(node: Node, settings: CompilationSettings = None) -> bool:
meta = node.args[0].meta.get("tensor_meta")
if meta:
input_shape = meta.shape
Expand Down Expand Up @@ -3765,7 +3791,9 @@ def aten_ops_index_select(
)


def dropout_inference_validator(node: Node) -> bool:
def dropout_inference_validator(
node: Node, settings: CompilationSettings = None
) -> bool:
train_mode = args_bounds_check(node.args, 2, None)
if train_mode is False:
return True
Expand Down
Loading
Loading