From 53378e954394f495d343bf41868274b064d41796 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Fri, 30 Jun 2023 16:43:56 -0700 Subject: [PATCH] feat: Data Structure update for Dynamo Registry - Add custom class overriding default Dictionary class to access converters from various registries - Add new dictionary type `Dict[Target, Sequence[ConverterSupport]]` as well as ConverterSupport class which stores a converter and its validation implementation - Add unified `DYNAMO_CONVERTERS` dictionary which coalesces both the FX and Dynamo converter dictionaries and acts as a single unified dictionary - Streamline dictionary accesses via get/contains accessors - Add priority converter decorator enum to prioritize user-provided converters and name argument checking "capability validation" to clarify utility - Add boilerplate `no_dynamic` converter capability validator for easy use in specifying converters as not-able to handle dynamic shapes --- .../dynamo/backend/lowering/_partition.py | 11 +- .../dynamo/common_utils/converter_utils.py | 30 ++ .../dynamo/converter_registry.py | 312 +++++++++++++++++- .../dynamo/fx_ts_compat/fx2trt.py | 15 +- 4 files changed, 352 insertions(+), 16 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/common_utils/converter_utils.py diff --git a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py index e3308cbe03..63d39568be 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, List, Optional, Sequence, Set +from typing import Callable, Dict, List, Optional, Sequence, Set import torch @@ -55,6 +55,10 @@ def __init__( ) self.min_block_size = min_block_size + logger.debug( + "Initialized Capability-Based Partitioner with available Converters:\n" + + f"{CONVERTERS.display_all_available_converters()}" + ) def propose_partitions(self) -> List[Partition]: # Propose partitions using the default, then refine the results @@ -123,10 +127,7 @@ def is_node_supported( else node.target ) - if ( - node.target in CONVERTERS.keys() - and node_name not in self.torch_executed_ops - ): + if node in CONVERTERS and node_name not in self.torch_executed_ops: # If node is a proper, supported computational node, store the operator if not node.is_impure(): self.supported_operators.add(node_name) diff --git a/py/torch_tensorrt/dynamo/common_utils/converter_utils.py b/py/torch_tensorrt/dynamo/common_utils/converter_utils.py new file mode 100644 index 0000000000..fc816e90fa --- /dev/null +++ b/py/torch_tensorrt/dynamo/common_utils/converter_utils.py @@ -0,0 +1,30 @@ +import torch + + +def dynamic_unsupported(node: torch.fx.Node) -> bool: + # Validate that none of the inputs to the node have Dynamic shapes + assert isinstance( + node, torch.fx.Node + ), "Inputs to validator functions must be FX Nodes" + + # Check node value itself + if node.meta["val"]._has_symbolic_sizes_strides: + return False + + # Check node arguments individually + if any( + arg.meta["val"]._has_symbolic_sizes_strides + for arg in node.args + if isinstance(arg, torch.fx.Node) + ): + return False + + # Check node keyword arguments individually + if any( + kwarg.meta["val"]._has_symbolic_sizes_strides + for kwarg in node.kwargs.values() + if isinstance(kwarg, torch.fx.Node) + ): + return False + + return True diff --git a/py/torch_tensorrt/dynamo/converter_registry.py b/py/torch_tensorrt/dynamo/converter_registry.py index 1a9ee03970..bc84361abb 100644 --- a/py/torch_tensorrt/dynamo/converter_registry.py +++ b/py/torch_tensorrt/dynamo/converter_registry.py @@ -1,23 +1,327 @@ -from typing import Any, Callable, Dict +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, Optional, Sequence, Union +from enum import Enum, auto -from torch.fx.node import Target +from torch.fx.node import Target, Node, _get_qualified_name from torch_tensorrt.fx.converter_registry import CONVERTERS -DYNAMO_CONVERTERS: Dict[Target, Any] = dict(CONVERTERS) + +class ConverterPriority(Enum): + """Enum to set a converter's priority in the registry""" + + STANDARD = auto() + HIGH = auto() + + +@dataclass(frozen=True) +class ConverterSupport: + """Class representing a converter implementation and support function + + Args: + converter_implementation: Function which converts said node to a TRT equivalent + capability_validator: Function which takes in a Node and returns a bool indicating + whether that node can be supported by its companion converter. Note that + this function must not modify the node or its graph + """ + + converter_implementation: Callable + capability_validator: Callable[[Node], bool] = field(default=lambda node: True) + + +# Dictionary representing Dynamo aten-only converters +# Each converter maps to a sequence of at least one ConverterSupport object(s) +DYNAMO_ATEN_CONVERTERS: Dict[Target, Sequence[ConverterSupport]] = {} def dynamo_tensorrt_converter( key: Target, enabled: bool = True, + capability_validator: Optional[Callable[[Node], bool]] = None, + priority: ConverterPriority = ConverterPriority.STANDARD, ) -> Callable[[Any], Any]: + """Decorator for Dynamo TensorRT Converter + + Registers the decorated function in the DYNAMO_ATEN_CONVERTERS registry + + Args: + key: Node target for which the converter is implemented for + (for example, torch.ops.add.Tensor) + enabled: Whether the converter should be enabled/cached or not + capability_validator: Function which evaluates whether a node is valid for conversion + by the decorated converter. See ConverterSupport for more details. + Defaults to None, implying the capability_validator function is always true - + this means all nodes of "key" kind can be supported by this converter + priority: Converter's level of priority relative to other converters with the + same target + Returns: + The converter being decorated + """ + def register_converter(converter): - DYNAMO_CONVERTERS[key] = converter + """Helper function to register the converter, then return it""" + assert callable(converter), "Converter function must be callable" + + # If no capability_validator function is specified, use the default function - always return true + if capability_validator is None: + converter_support = ConverterSupport(converter_implementation=converter) + else: + assert callable( + capability_validator + ), "Argument checking function must be callable" + converter_support = ConverterSupport( + converter_implementation=converter, + capability_validator=capability_validator, + ) + + # If a converter for this operator already exists, append the new converter to the list + # Otherwise, start a new list + if key in DYNAMO_ATEN_CONVERTERS: + # High priority converters are inserted at the front of the list, + # so they can be checked first by the registry + if priority is ConverterPriority.HIGH: + DYNAMO_ATEN_CONVERTERS[key].insert(0, converter_support) + else: + DYNAMO_ATEN_CONVERTERS[key].append(converter_support) + else: + DYNAMO_ATEN_CONVERTERS[key] = [converter_support] + return converter def disable_converter(converter): return converter + # Select whether to cache/enable the converter if enabled: return register_converter else: return disable_converter + + +class ConverterRegistry: + """Registry for storing multiple converter dictionaries + + Capable of storing dictionaries with the following signature: + Dict[Target, Union[Callable, Sequence[ConverterSupport]]] + + Also able to validate converter implementations against user-provided + argument-checking functions + + Args: + registries: List of dictionaries representing converter registries. + The order of the provided dictionaries is the order in which they + will be traversed. This is only significant when using non-validated + methods. + """ + + def __init__( + self, + registries: Sequence[Dict[Target, Union[Callable, Sequence[ConverterSupport]]]], + registry_names: Optional[Sequence[str]] = None, + ): + # Copy reference to each dictionary object into attribute list + self.registries = [registry for registry in registries] + + if registry_names is not None: + assert len(self.registries) == len(registry_names) + self.registry_names = [name for name in registry_names] + else: + self.registry_names = [ + f"Registry {i + 1}" for i in range(len(self.registries)) + ] + + self.validate_invariants() + + def validate_invariants(self): + """Validates the invariants required of the dictionaries in the registries + + Raises AssertionError if any invariants have been violated + """ + # All registries must be dictionaries + assert all(isinstance(elt, dict) for elt in self.registries) + + # Every dictionary in the registry must have one of two signatures: + # Dict[Target, Callable] or Dict[Target, Sequence[ConverterSupport]] + # Where, for the latter, the sequence must be non-empty + for registry in self.registries: + for converters in registry.values(): + if isinstance(converters, (list, tuple)): + assert ( + all(isinstance(c, ConverterSupport) for c in converters) + and len(converters) > 0 + ) + else: + assert callable(converters), "Converter function must be callable" + + def __getitem_without_validation__(self, key: Target): + """Get the first-found converter in any registry + + Searches all registries in order and returns the first converter encountered + """ + if isinstance(key, Node): + raise KeyError( + "Unvalidated accesses to the Converter registry can only be " + + "made with node targets. Try accessing the registry with node.target" + ) + + self.validate_invariants() + + # Iterate over all registries and return the first converter found + for registry in self.registries: + if key in registry: + converters = registry[key] + + if isinstance(converters, (list, tuple)): + return converters[0].converter_implementation + else: + return converters + + raise KeyError(f"None of the converter registries have an entry for {key}") + + def __getitem__(self, node: Node): + """Get the first-found validated converter in any registry + + Searches all registries in order and returns the first converter + which passes validation on the input node + """ + if not isinstance(node, Node): + raise KeyError( + "Validated accesses to the Converter registry can only be " + + "made with node inputs. Try accessing the registry with a node " + + "or use get_unvalidated to access without node validation." + ) + + self.validate_invariants() + key = node.target + + # Iterate over all registries, validating the converter on the input node + # If no capability_validator function is found, assume full coverage + for registry in self.registries: + if key in registry: + converters = registry[key] + + if isinstance(converters, (list, tuple)): + for candidate in converters: + if candidate.capability_validator(node): + return candidate.converter_implementation + else: + return converters + + raise KeyError( + f"None of the converter registries have a validated entry for {key}, with node {node}" + ) + + def keys(self): + """Get all unique targets across all dictionaries""" + return self.unique_targets() + + def get_unvalidated(self, key: Target, value=None): + """Get unvalidated converter for input target with a default return""" + try: + return self.__getitem_without_validation__(key) + except KeyError: + return value + + def get(self, node: Node, value=None): + """Get validated converter for input node with a default return""" + try: + return self.__getitem__(node) + except KeyError: + return value + + def __contains__(self, key: Union[Target, Node]): + """Check whether a converter for an input node or target exists""" + try: + # Attempt to access the item in the registry + if isinstance(key, Node): + self.__getitem__(key) + else: + self.__getitem_without_validation__(key) + + return True + except KeyError: + return False + + def get_all_converters_with_target( + self, key: Target, return_registry_info: bool = False + ): + """Get all converters across all registries for the target + + Returns a list of all converterts having the specified target + """ + self.validate_invariants() + converters_with_target = [] + + # Store count of number of registered converters per registry + if return_registry_info: + registry_data = {name: 0 for name in self.registry_names} + + for index, registry in enumerate(self.registries): + if key in registry: + converters = registry[key] + + if isinstance(converters, (list, tuple)): + converters_with_target.extend( + [c.converter_implementation for c in converters] + ) + # Add converter count to registry name storage + if return_registry_info: + registry_data[self.registry_names[index]] += len(converters) + else: + converters_with_target.append(converters) + # Add converter count to registry name storage + if return_registry_info: + registry_data[self.registry_names[index]] += 1 + + if return_registry_info: + return converters_with_target, registry_data + else: + return converters_with_target + + def __setitem__(self, key, value): + raise AssertionError( + f"Do not set registry members directly through the ConverterRegistry object. " + + f"Attempted to set {key}: {value} via direct assignment to ConverterRegistry." + ) + + def __delitem__(self, key): + raise AssertionError( + f"Do not delete registry members directly through the ConverterRegistry object. " + + f"Attempted to delete {key} via direct del on ConverterRegistry." + ) + + def __len__(self): + """Returns the sum of lengths of all registries stored""" + return sum(len(registry) for registry in self.registries) + + def unique_targets(self): + """Returns the set of unique converter targets stored across all registries""" + return set.union(*[set(registry.keys()) for registry in self.registries]) + + def qualified_name_or_str(self, target: Target) -> str: + """Returns string representation of an FX Node target""" + if isinstance(target, str): + return target + else: + return _get_qualified_name(target) + + def display_all_available_converters(self) -> str: + """Returns a string with all converters and their source, separated by newlines""" + available_converters = "Available converters in ATen registries with counts:\n" + + for target in sorted( + self.unique_targets(), key=lambda target: self.qualified_name_or_str(target) + ): + _, registry_data = self.get_all_converters_with_target( + target, return_registry_info=True + ) + available_converters += f"Node: {self.qualified_name_or_str(target)} - Registry Presence Counts: {registry_data}\n" + + return available_converters + + +# Initialize dynamo converter registry with the FX and Dynamo aten registries +# Note the Dynamo registry is listed first, for precedence +DYNAMO_CONVERTERS: ConverterRegistry = ConverterRegistry( + [DYNAMO_ATEN_CONVERTERS, CONVERTERS], + ["Dynamo ATen Converters Registry", "FX ATen Converters Registry"], +) diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py index 03499aaf85..76dbd08531 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py @@ -10,7 +10,6 @@ import tensorrt as trt import torch import torch.fx -from torch._ops import OpOverload from torch.fx.node import _get_qualified_name from torch.fx.passes.shape_prop import TensorMetadata @@ -71,6 +70,7 @@ def __init__( self.input_specs_iter = 0 self.validate_input_specs() self._cur_node_name: Optional[str] = None + self._cur_node: Optional[torch.fx.Node] = None self._input_names: List[str] = [] self._output_names: List[str] = [] self._itensor_to_tensor_meta: Dict[ @@ -141,14 +141,14 @@ def validate_conversion(self): missing_converter = set() for node in self.module.graph.nodes: - if node.op == "call_function" and not CONVERTERS.get(node.target): + if node.op == "call_function" and not CONVERTERS.get(node): missing_converter.add(f"{node.op} {_get_qualified_name(node.target)}") - elif node.op == "call_method" and not CONVERTERS.get(node.target): + elif node.op == "call_method" and not CONVERTERS.get(node): missing_converter.add(f"{node.op} torch.Tensor.{node.target}") elif node.op == "call_module": submod = self.fetch_attr(node.target) submod_type = getattr(submod, "_base_class_origin", type(submod)) - if not CONVERTERS.get(submod_type): + if not CONVERTERS.get(node): missing_converter.add(f"{node.op} {torch.typename(submod_type)}") return missing_converter @@ -293,6 +293,7 @@ def run( def run_node(self, n): self._cur_node_name = str(n) + self._cur_node = n # add "_itensor_to_tensor_meta" kwargs = dict(n.kwargs) kwargs["_itensor_to_tensor_meta"] = self._itensor_to_tensor_meta @@ -336,7 +337,7 @@ def call_module(self, target, args, kwargs): assert isinstance(target, str) submod = self.fetch_attr(target) submod_type = getattr(submod, "_base_class_origin", type(submod)) - converter = CONVERTERS.get(submod_type) + converter = CONVERTERS.get(self._cur_node) if not converter: raise RuntimeError( @@ -347,7 +348,7 @@ def call_module(self, target, args, kwargs): return converter(self.network, submod, args, kwargs, self._cur_node_name) def call_function(self, target, args, kwargs): - converter = CONVERTERS.get(target) + converter = CONVERTERS.get(self._cur_node) if not converter: raise RuntimeError( f"Conversion of function {torch.typename(target)} not currently supported!" @@ -358,7 +359,7 @@ def call_function(self, target, args, kwargs): def call_method(self, target, args, kwargs): assert isinstance(target, str) - converter = CONVERTERS.get(target) + converter = CONVERTERS.get(self._cur_node) if not converter: raise RuntimeError(