diff --git a/py/torch_tensorrt/dynamo/_DryRunTracker.py b/py/torch_tensorrt/dynamo/_DryRunTracker.py new file mode 100644 index 0000000000..46d99ffe31 --- /dev/null +++ b/py/torch_tensorrt/dynamo/_DryRunTracker.py @@ -0,0 +1,292 @@ +import logging +import math +import operator +import os +from dataclasses import dataclass, field +from typing import Any, Dict, List, Union + +import torch +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterRegistry +from torch_tensorrt.dynamo.conversion.converter_utils import get_node_name + +logger = logging.getLogger(__name__) + + +@dataclass +class PerSubgraphData: + """Class to track data on a per-subgraph level + + Args: + subgraph_name (str): Name of the subgraph in the GraphModule + subgraph_op_count (int): Number of operations in the subgraph + subgraph_input_shapes (Any): Shapes of input Tensors of the subgraph + subgraph_input_dtypes (Any): Input data types of the subgraph + subgraph_output_shapes (Any): Shapes of output Tensors of the subgraph + subgraph_output_dtypes (Any): Output data types of the subgraph + """ + + subgraph_name: str = "" + subgraph_op_count: int = 0 + subgraph_input_shapes: Any = field(default_factory=list) + subgraph_input_dtypes: Any = field(default_factory=list) + subgraph_output_shapes: Any = field(default_factory=list) + subgraph_output_dtypes: Any = field(default_factory=list) + + +@dataclass +class DryRunTracker: + """Class to track data on a graph-wide level + + Args: + total_ops_in_graph (int): Total number of operators in graph + supported_ops_in_graph (int): Number of supported operators in graph + graph_input_shapes (Any): Shapes of input Tensors of the graph + graph_input_dtypes (Any): Input data types of the graph + graph_output_shapes (Any): Shapes of output Tensors of the graph + graph_output_dtypes (Any): Output data types of the graph + per_subgraph_data (List[PerSubgraphData]): Per-subgraph data, see above class + tensorrt_graph_count (int): Number of TensorRT engines to be generated + compilation_settings (CompilationSettings): User Compilation Settings + unsupported_ops (Dict[str, int]): Set of operators not supported in TRT + to_run_in_torch (List[str]): Set of nodes to run in Torch + """ + + total_ops_in_graph: int = 0 + supported_ops_in_graph: int = 0 + graph_input_shapes: Any = field(default_factory=list) + graph_input_dtypes: Any = field(default_factory=list) + graph_output_shapes: Any = field(default_factory=list) + graph_output_dtypes: Any = field(default_factory=list) + per_subgraph_data: List[PerSubgraphData] = field(default_factory=list) + tensorrt_graph_count: int = 0 + compilation_settings: CompilationSettings = field( + default_factory=CompilationSettings + ) + unsupported_ops: Dict[str, int] = field(default_factory=dict) + to_run_in_torch: List[str] = field(default_factory=list) + + +def dryrun_stats_display( + dryrun_tracker: DryRunTracker, dryrun_enabled: Union[bool, str] +) -> None: + """Displays statistics about the dryrun either to debug logs or stdout""" + formatted_stats = "\n" + + # Print overall stats about the graph, operator counts, etc. + formatted_stats += "+" * 50 + " Dry-Run Results for Graph " + "+" * 50 + "\n\n" + formatted_stats += ( + f"The graph consists of {dryrun_tracker.total_ops_in_graph} Total Operators, " + f"of which {dryrun_tracker.supported_ops_in_graph} operators are supported, " + f"{round(dryrun_tracker.supported_ops_in_graph*100/dryrun_tracker.total_ops_in_graph, 2)}% coverage\n\n" + ) + if dryrun_tracker.unsupported_ops: + parsed_ops = "\n".join( + [f"{str(k)}: {str(v)}" for k, v in dryrun_tracker.unsupported_ops.items()] + ) + formatted_stats += f"The following ops are currently unsupported or excluded from conversion, and are listed with their op-count in the graph:\n {parsed_ops}\n\n" + + if dryrun_tracker.to_run_in_torch: + formatted_nodes = "\n".join(dryrun_tracker.to_run_in_torch) + formatted_stats += ( + f"The following nodes are currently set to run in Torch:\n{formatted_nodes}\n" + "Note: Some of the above nodes may be supported, but were not included in a TRT graph by the partitioner\n\n" + ) + + formatted_stats += f"Compiled with: {dryrun_tracker.compilation_settings}\n\n" + + assert len(dryrun_tracker.per_subgraph_data) == dryrun_tracker.tensorrt_graph_count + + # Print schematic of the graph structure, as in: + # + # Inputs: List[Tensor: (1, 3, 224, 224)@float32] + # ... + # TRT Engine #1 - Submodule name: _run_on_acc_0 + # Engine Inputs: List[Tensor: (1, 3, 224, 224)@float32] + # Number of Operators in Engine: 1 + # Engine Outputs: Tensor: (1, 64, 112, 112)@float32 + # ... + # Outputs: List[Tensor: (1, 1000)@float32] + # + formatted_stats += " " * 2 + "Graph Structure:\n\n" + formatted_stats += ( + " " * 3 + + f"Inputs: {input_formatter(dryrun_tracker.graph_input_shapes, dryrun_tracker.graph_input_dtypes)}\n" + ) + + for i, trt_subgraph_data in enumerate(dryrun_tracker.per_subgraph_data): + formatted_stats += " " * 4 + "...\n" + formatted_stats += ( + " " * 4 + + f"TRT Engine #{i+1} - Submodule name: {trt_subgraph_data.subgraph_name}\n" + ) + formatted_stats += ( + " " * 5 + + f"Engine Inputs: {input_formatter(trt_subgraph_data.subgraph_input_shapes, trt_subgraph_data.subgraph_input_dtypes)}\n" + ) + formatted_stats += ( + " " * 5 + + f"Number of Operators in Engine: {trt_subgraph_data.subgraph_op_count}\n" + ) + formatted_stats += ( + " " * 5 + + f"Engine Outputs: {input_formatter(trt_subgraph_data.subgraph_output_shapes, trt_subgraph_data.subgraph_output_dtypes)}\n" + ) + + formatted_stats += " " * 4 + "...\n" + formatted_stats += ( + " " * 3 + + f"Outputs: {input_formatter(dryrun_tracker.graph_output_shapes, dryrun_tracker.graph_output_dtypes)}\n" + ) + + # Print aggregate statistics about the graph structure, including recommended "min_block_size" options + if dryrun_tracker.tensorrt_graph_count > 0: + min_ops_in_an_engine = min( + trt_subgraph.subgraph_op_count + for trt_subgraph in dryrun_tracker.per_subgraph_data + ) + avg_ops_per_engine = ( + sum( + trt_subgraph.subgraph_op_count + for trt_subgraph in dryrun_tracker.per_subgraph_data + ) + / dryrun_tracker.tensorrt_graph_count + ) + avg_ops_per_engine = round(avg_ops_per_engine, 2) + most_ops_in_an_engine = max( + trt_subgraph.subgraph_op_count + for trt_subgraph in dryrun_tracker.per_subgraph_data + ) + + formatted_stats += "\n" + " " * 2 + "-" * 25 + " Aggregate Stats " + "-" * 25 + formatted_stats += ( + "\n\n" + + " " * 3 + + "Average Number of Operators per TRT Engine: " + + f"{avg_ops_per_engine}" + ) + + formatted_stats += ( + "\n" + + " " * 3 + + "Most Operators in a TRT Engine: " + + f"{most_ops_in_an_engine}" + ) + + formatted_stats += "\n\n" + " " * 2 + "*" * 10 + " Recommendations " + "*" * 10 + formatted_stats += ( + "\n\n" + + " " * 3 + + "- For minimal graph segmentation, select min_block_size=" + + f"{most_ops_in_an_engine} which would generate " + + f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= most_ops_in_an_engine])} TRT engine(s)" + ) + if math.ceil(avg_ops_per_engine) != most_ops_in_an_engine: + formatted_stats += ( + "\n" + + " " * 3 + + "- For moderate graph segmentation, select min_block_size=" + + f"{math.ceil(avg_ops_per_engine)} which would generate " + + f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= math.ceil(avg_ops_per_engine)])} TRT engine(s)" + ) + + formatted_stats += ( + "\n" + + " " * 3 + + "- The current level of graph segmentation is equivalent to selecting min_block_size=" + + f"{min_ops_in_an_engine} which generates " + + f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= min_ops_in_an_engine])} TRT engine(s)" + ) + else: + formatted_stats += ( + "\n" + + " " * 2 + + "Aggregate stats not available since no TRT Engines were generated." + ) + + # If user specified "dryrun=True", print to stdout, else debug + # If user specified a filepath, save the output to the path as well + if dryrun_enabled: + print(formatted_stats) + if isinstance(dryrun_enabled, str): + if os.path.exists(dryrun_enabled): + logger.warning( + f"File already exists at path {dryrun_enabled}, not saving dryrun output" + ) + else: + with open(dryrun_enabled, "w+") as f: + f.write(formatted_stats) + else: + logger.debug(formatted_stats) + + +def input_formatter(shapes: Any, dtypes: Any) -> str: + """Format shapes and dtypes of input Tensors into a readable string""" + + def input_formatter_helper(shapes: Any, dtypes: Any) -> str: + """Helper for input formatter""" + # Base case - single shape, single dtype + if isinstance(shapes, tuple) and all(isinstance(elt, int) for elt in shapes): + return f"Tensor: {shapes}@{str(dtypes)[6:]}, " + + # Base case - dynamic shape, single dtype + elif ( + isinstance(shapes, dict) + and len(shapes) == 3 + and all( + ( + isinstance(shape, tuple) + and all(isinstance(elt, int) for elt in shape) + and k in ("min_shape", "opt_shape", "max_shape") + ) + for k, shape in shapes.items() + ) + ): + return f"Tensor: {shapes}@{str(dtypes)[6:]}, " + + # Shapes is a sequence + elif isinstance(shapes, (list, tuple)): + formatted_str = "List[" if isinstance(shapes, list) else "Tuple(" + for shape, dtype in zip(shapes, dtypes): + formatted_str += input_formatter_helper(shape, dtype) + formatted_str = formatted_str[:-2] + ( + "], " if isinstance(shapes, list) else "), " + ) + return formatted_str + + # Shapes is a dictionary + elif isinstance(shapes, dict): + formatted_str = "Dict{" + + for key, shape in shapes.items(): + formatted_str += input_formatter_helper(shape, dtypes[key]) + + formatted_str = formatted_str[:-2] + "}, " + return formatted_str + + else: + raise ValueError( + f"Invalid input type {type(shapes)} encountered in parse_complex_tensor_structs parsing." + ) + + return input_formatter_helper(shapes, dtypes)[:-2] + + +def parse_non_trt_nodes(graph_module: torch.fx.GraphModule) -> List[str]: + """Parses call_function and call_method nodes from a GraphModule + Excludes getitem nodes + + Returns a string representation of the nodes + """ + to_run_in_torch = [] + for node in graph_module.graph.nodes: + # getitem nodes are excluded since they are a Tensor-collection op + if ( + node.op in ("call_function", "call_method") + and node.target != operator.getitem + ): + to_run_in_torch.append( + f"Node: {ConverterRegistry.qualified_name_or_str(node.target)}, " + f"with layer location: {get_node_name(node)}" + ) + return to_run_in_torch diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 796e0690f3..ac7a323545 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -5,7 +5,6 @@ from typing import Any, List, Optional, Sequence, Set, Tuple, Union import torch -import torch_tensorrt from torch.export import ExportedProgram from torch_tensorrt._Device import Device from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum @@ -20,6 +19,7 @@ DLA_GLOBAL_DRAM_SIZE, DLA_LOCAL_DRAM_SIZE, DLA_SRAM_SIZE, + DRYRUN, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, ENGINE_CAPABILITY, MAX_AUX_STREAMS, @@ -37,6 +37,12 @@ VERSION_COMPATIBLE, WORKSPACE_SIZE, ) +from torch_tensorrt.dynamo._DryRunTracker import ( + DryRunTracker, + PerSubgraphData, + dryrun_stats_display, + parse_non_trt_nodes, +) from torch_tensorrt.dynamo.conversion import ( CompilationSettings, convert_module, @@ -45,12 +51,15 @@ from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions from torch_tensorrt.dynamo.utils import ( get_torch_inputs, + parse_complex_tensor_structs, prepare_inputs, set_log_level, to_torch_device, to_torch_tensorrt_device, ) +import torch_tensorrt + logger = logging.getLogger(__name__) @@ -84,6 +93,7 @@ def compile( use_python_runtime: bool = USE_PYTHON_RUNTIME, use_fast_partitioner: bool = USE_FAST_PARTITIONER, enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS, + dryrun: bool = DRYRUN, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile a TorchScript module for NVIDIA GPUs using TensorRT @@ -140,6 +150,7 @@ def compile( use_python_runtime: (bool): Return a graph using a pure Python runtime, reduces options for serialization use_fast_partitioner: (bool): Use the adjacency based partitioning scheme instead of the global partitioner. Adjacency partitioning is faster but may not be optiminal. Use the global paritioner (``False``) if looking for best performance enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT. + dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -215,6 +226,7 @@ def compile( "dla_sram_size": dla_sram_size, "dla_local_dram_size": dla_local_dram_size, "dla_global_dram_size": dla_global_dram_size, + "dryrun": dryrun, } settings = CompilationSettings(**compilation_options) @@ -238,15 +250,34 @@ def compile_module( Returns: Compiled FX GraphModule """ + dryrun_tracker = DryRunTracker() # Check the number of supported operations in the graph num_supported_ops, total_ops = partitioning.get_graph_converter_support( gm, settings.debug, settings.torch_executed_ops ) + dryrun_tracker.total_ops_in_graph = total_ops + dryrun_tracker.supported_ops_in_graph = num_supported_ops + dryrun_tracker.graph_input_shapes = parse_complex_tensor_structs( + sample_inputs, "shape", lambda x: dict(x) if isinstance(x, dict) else tuple(x) + ) + dryrun_tracker.graph_input_dtypes = parse_complex_tensor_structs( + sample_inputs, "torch_dtype" + ) + dryrun_tracker.compilation_settings = settings + + if settings.dryrun and settings.min_block_size > 1: + logger.info( + "It is recommended to run `dryrun` mode with `min_block_size=1`, " + "for the most thorough analysis" + ) + # If the number of supported operations is 0 or less than the block size, skip the subgraph # TODO: Add condition to second expression below when require_full_compilation is added - if num_supported_ops == 0 or (num_supported_ops < settings.min_block_size): + if num_supported_ops == 0 or ( + num_supported_ops < settings.min_block_size and not settings.dryrun + ): logger.warning( f"{num_supported_ops} supported operations detected in subgraph containing {total_ops} computational nodes. " f"Skipping this subgraph, since min_block_size was detected to be {settings.min_block_size}" @@ -263,7 +294,7 @@ def compile_module( # If specified, try using the fast partitioner and fall back to the global one on failure if settings.use_fast_partitioner: try: - partitioned_module = partitioning.fast_partition( + partitioned_module, supported_ops = partitioning.fast_partition( gm, verbose=settings.debug, min_block_size=settings.min_block_size, @@ -280,13 +311,19 @@ def compile_module( settings.use_fast_partitioner = False if not settings.use_fast_partitioner: - partitioned_module = partitioning.global_partition( + partitioned_module, supported_ops = partitioning.global_partition( gm, verbose=settings.debug, min_block_size=settings.min_block_size, torch_executed_ops=settings.torch_executed_ops, ) + dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators + + # The global partitioner leaves non-TRT nodes as-is + if not settings.use_fast_partitioner: + dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(partitioned_module)) + # Store TRT replicas of Torch subgraphs trt_modules = {} # Iterate over all components that can be accelerated @@ -295,8 +332,19 @@ def compile_module( submodule = getattr(partitioned_module, name) # Criteria for a module to be convertible to TRT if settings.use_fast_partitioner and "_run_on_acc" not in name: + dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(submodule)) continue + subgraph_data = PerSubgraphData() + subgraph_data.subgraph_name = name + subgraph_data.subgraph_op_count = len( + [ + node + for node in submodule.graph.nodes + if node.op in ("call_function", "call_method", "call_module") + ] + ) + # Get the submodule inputs for min, opt, max shapes of the graph inputs submodule_inputs = partitioning.get_submod_inputs( partitioned_module, @@ -323,15 +371,55 @@ def compile_module( name, ) - # Create TRT engines from submodule - trt_module = convert_module( - submodule, + subgraph_data.subgraph_input_shapes = parse_complex_tensor_structs( submodule_inputs, - settings=settings, - name=name, + "shape", + lambda x: dict(x) if isinstance(x, dict) else tuple(x), + ) + subgraph_data.subgraph_input_dtypes = parse_complex_tensor_structs( + submodule_inputs, "torch_dtype" ) - trt_modules[name] = trt_module + submodule_outputs = submodule( + *get_torch_inputs(submodule_inputs, to_torch_device(settings.device)) + ) + + subgraph_data.subgraph_output_shapes = parse_complex_tensor_structs( + submodule_outputs, + "shape", + lambda x: dict(x) if isinstance(x, dict) else tuple(x), + ) + subgraph_data.subgraph_output_dtypes = parse_complex_tensor_structs( + submodule_outputs, "dtype" + ) + + dryrun_tracker.tensorrt_graph_count += 1 + dryrun_tracker.per_subgraph_data.append(subgraph_data) + + # Create TRT engines from submodule + if not settings.dryrun: + trt_module = convert_module( + submodule, + submodule_inputs, + settings=settings, + name=name, + ) + + trt_modules[name] = trt_module + + sample_outputs = gm( + *get_torch_inputs(sample_inputs, to_torch_device(settings.device)) + ) + + if not isinstance(sample_outputs, (list, tuple)): + sample_outputs = [sample_outputs] + + dryrun_tracker.graph_output_shapes = parse_complex_tensor_structs( + sample_outputs, "shape", lambda x: dict(x) if isinstance(x, dict) else tuple(x) + ) + dryrun_tracker.graph_output_dtypes = parse_complex_tensor_structs( + sample_outputs, "dtype" + ) # Replace all FX Modules with TRT Modules for name, trt_module in trt_modules.items(): @@ -341,4 +429,6 @@ def compile_module( if fast_partitioner_failed: settings.use_fast_partitioner = True + dryrun_stats_display(dryrun_tracker, settings.dryrun) + return partitioned_module diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 4ec872fb1b..4afabe60eb 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -24,6 +24,7 @@ ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False REFIT = False REQUIRE_FULL_COMPILATION = False +DRYRUN = False def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index cd58c9547f..60990bda99 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Optional, Set +from typing import Optional, Set, Union import torch from tensorrt import EngineCapability @@ -10,6 +10,7 @@ DLA_GLOBAL_DRAM_SIZE, DLA_LOCAL_DRAM_SIZE, DLA_SRAM_SIZE, + DRYRUN, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, ENGINE_CAPABILITY, MAX_AUX_STREAMS, @@ -63,6 +64,9 @@ class CompilationSettings: dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer. dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution + dryrun (Union[bool, str]): Toggle "Dryrun" mode, which runs everything through partitioning, short of conversion to + TRT Engines. Prints detailed logs of the graph structure and nature of partitioning. Optionally saves the + ouptut to a file if a string path is specified """ precision: torch.dtype = PRECISION @@ -88,3 +92,4 @@ class CompilationSettings: dla_sram_size: int = DLA_SRAM_SIZE dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE + dryrun: Union[bool, str] = DRYRUN diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index 5bdbb8919b..5ec5293474 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -248,7 +248,7 @@ def partition( min_block_size: int = MIN_BLOCK_SIZE, torch_executed_ops: Collection[Target] = set(), require_full_compilation: bool = REQUIRE_FULL_COMPILATION, -) -> torch.fx.GraphModule: +) -> Tuple[torch.fx.GraphModule, OpSupportTester]: """Partition an FX GraphModule with aten ops into TRT engines Partitioning is based on converter operator support @@ -259,7 +259,7 @@ def partition( torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage require_full_compilation: Require that all computational operators be run in TRT Returns: - torch.fx.GraphModule + torch.fx.GraphModule, OpSupportTester """ # Ensure graph is clean prior to partitioning gm.graph.eliminate_dead_code() @@ -280,4 +280,4 @@ def partition( if verbose: supported_ops.print_support_overview(partitioner.num_trt_accelerated_subgraphs) - return partitioned_graph + return partitioned_graph, supported_ops diff --git a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py index 092bdabfd0..4c8efb234e 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py @@ -1,5 +1,5 @@ import logging -from typing import Collection, Dict, List, Mapping, Optional, Sequence, Set +from typing import Collection, Dict, List, Mapping, Optional, Sequence, Set, Tuple import torch from torch.fx.graph_module import GraphModule @@ -203,7 +203,7 @@ def partition( min_block_size: int = MIN_BLOCK_SIZE, torch_executed_ops: Optional[Set[str]] = None, require_full_compilation: bool = REQUIRE_FULL_COMPILATION, -) -> torch.fx.GraphModule: +) -> Tuple[torch.fx.GraphModule, TorchTensorRTOperatorSupport]: """Partition an FX GraphModule with aten ops into TRT engines Partitioning is based on converter operator support @@ -214,7 +214,7 @@ def partition( torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage require_full_compilation: Whether to require that all operators be run in TRT Returns: - torch.fx.GraphModule + torch.fx.GraphModule, TorchTensorRTOperatorSupport """ supported_ops = TorchTensorRTOperatorSupport( torch_executed_ops=torch_executed_ops @@ -236,4 +236,4 @@ def partition( if verbose: supported_ops.print_support_overview(len(partitions)) - return fused_graph + return fused_graph, supported_ops diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 26de1fcb27..22590fe73d 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -5,12 +5,12 @@ from typing import Any, Callable, Dict, Optional, Sequence, Union import torch -import torch_tensorrt from torch_tensorrt._Device import Device from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._defaults import PRECISION from torch_tensorrt.dynamo._settings import CompilationSettings +import torch_tensorrt from packaging import version logger = logging.getLogger(__name__) @@ -114,7 +114,7 @@ def prepare_inputs( inputs, disable_memory_format_check=disable_memory_format_check ) - elif isinstance(inputs, list): + elif isinstance(inputs, (list, tuple)): torchtrt_input_list = [] for input_obj in inputs: torchtrt_input = prepare_inputs( @@ -122,24 +122,62 @@ def prepare_inputs( ) torchtrt_input_list.append(torchtrt_input) - return torchtrt_input_list + return ( + torchtrt_input_list + if isinstance(inputs, list) + else tuple(torchtrt_input_list) + ) - elif isinstance(inputs, tuple): - torchtrt_inputs_tup = [] - for input_obj in inputs: + elif isinstance(inputs, dict): + torchtrt_inputs_dict: Dict[Any, Any] = dict() + + for key, input_obj in inputs.items(): torchtrt_input = prepare_inputs( input_obj, disable_memory_format_check=disable_memory_format_check ) - torchtrt_inputs_tup.append(torchtrt_input) + torchtrt_inputs_dict[key] = torchtrt_input + + return torchtrt_inputs_dict - return tuple(torchtrt_inputs_tup) + else: + raise ValueError( + f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. " + + "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}" + ) + + +def parse_complex_tensor_structs( + inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any], + attribute_to_extract: str, + apply_fn: Callable[[Any], Any] = lambda x: x, +) -> Any: + """Parses complex structures of Tensors and returns a mirrored structure + Extracts key attributes of each singular element, while reconstructing the struct + Optionally applies a function to each attribute before returning + """ + if isinstance(inputs, (torch.Tensor, Input)): + return apply_fn(getattr(inputs, attribute_to_extract, None)) + + elif isinstance(inputs, (list, tuple)): + torchtrt_input_list = [] + for input_obj in inputs: + torchtrt_input = parse_complex_tensor_structs( + input_obj, attribute_to_extract, apply_fn + ) + torchtrt_input_list.append(torchtrt_input) + + return ( + torchtrt_input_list + if isinstance(inputs, list) + else tuple(torchtrt_input_list) + ) elif isinstance(inputs, dict): torchtrt_inputs_dict: Dict[Any, Any] = dict() for key, input_obj in inputs.items(): - torchtrt_input = prepare_inputs( - input_obj, disable_memory_format_check=disable_memory_format_check + torchtrt_input = parse_complex_tensor_structs( + input_obj, attribute_to_extract, apply_fn ) torchtrt_inputs_dict[key] = torchtrt_input @@ -147,7 +185,7 @@ def prepare_inputs( else: raise ValueError( - f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. " + f"Invalid input type {type(inputs)} encountered in parse_complex_tensor_structs parsing. " + "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}" ) diff --git a/tests/py/dynamo/backend/test_backend_compiler.py b/tests/py/dynamo/backend/test_backend_compiler.py index 0038412c30..a958d03120 100644 --- a/tests/py/dynamo/backend/test_backend_compiler.py +++ b/tests/py/dynamo/backend/test_backend_compiler.py @@ -1,10 +1,11 @@ from copy import deepcopy import torch -import torch_tensorrt from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt.dynamo.partitioning import fast_partition +import torch_tensorrt + from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing @@ -19,7 +20,7 @@ def forward(self, x, y): return torch.mean(out, dim=1) fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) - partitioned_graph = fast_partition(deepcopy(fx_graph), min_block_size=3) + partitioned_graph, _ = fast_partition(deepcopy(fx_graph), min_block_size=3) self.assertEquals( len( @@ -198,7 +199,7 @@ def forward(self, x, y): ) fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) - partitioned_graph = fast_partition(deepcopy(fx_graph), min_block_size=3) + partitioned_graph, _ = fast_partition(deepcopy(fx_graph), min_block_size=3) self.assertEquals( len(list(partitioned_graph.named_children())), diff --git a/tests/py/dynamo/partitioning/test_fast_partitioning.py b/tests/py/dynamo/partitioning/test_fast_partitioning.py index a271df1e72..2ff5433b22 100644 --- a/tests/py/dynamo/partitioning/test_fast_partitioning.py +++ b/tests/py/dynamo/partitioning/test_fast_partitioning.py @@ -18,7 +18,7 @@ def forward(self, x, y): return torch.ops.aten.add.Tensor(x, y) fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp()) - partitioned_graph = partitioning.fast_partition(deepcopy(fx_graph)) + partitioned_graph, _ = partitioning.fast_partition(deepcopy(fx_graph)) self.assertEquals( len( [ @@ -40,7 +40,7 @@ def forward(self, x, y): return torch.ops.aten.add.Tensor(x, y) fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp()) - partitioned_graph = partitioning.fast_partition( + partitioned_graph, _ = partitioning.fast_partition( deepcopy(fx_graph), require_full_compilation=True ) self.assertEquals( @@ -68,7 +68,7 @@ def forward(self, x, y): return pow_ fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) - partitioned_graph = partitioning.fast_partition( + partitioned_graph, _ = partitioning.fast_partition( deepcopy(fx_graph), min_block_size=2 ) self.assertEquals( @@ -97,7 +97,7 @@ def forward(self, x, y): return pow_ fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) - partitioned_graph = partitioning.fast_partition( + partitioned_graph, _ = partitioning.fast_partition( deepcopy(fx_graph), min_block_size=2 ) self.assertEquals( diff --git a/tests/py/dynamo/partitioning/test_global_partitioning.py b/tests/py/dynamo/partitioning/test_global_partitioning.py index 6701b33ccf..51d93a40f2 100644 --- a/tests/py/dynamo/partitioning/test_global_partitioning.py +++ b/tests/py/dynamo/partitioning/test_global_partitioning.py @@ -18,7 +18,7 @@ def forward(self, x, y): return torch.ops.aten.add.Tensor(x, y) fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp()) - partitioned_graph = partitioning.global_partition(deepcopy(fx_graph)) + partitioned_graph, _ = partitioning.global_partition(deepcopy(fx_graph)) self.assertEquals( len(list(partitioned_graph.named_children())), 0, @@ -34,7 +34,7 @@ def forward(self, x, y): return torch.ops.aten.add.Tensor(x, y) fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp()) - partitioned_graph = partitioning.global_partition( + partitioned_graph, _ = partitioning.global_partition( deepcopy(fx_graph), require_full_compilation=True ) self.assertEquals( @@ -56,7 +56,7 @@ def forward(self, x, y): return pow_ fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) - partitioned_graph = partitioning.global_partition( + partitioned_graph, _ = partitioning.global_partition( deepcopy(fx_graph), min_block_size=2 ) self.assertEquals( @@ -79,7 +79,7 @@ def forward(self, x, y): return pow_ fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) - partitioned_graph = partitioning.global_partition( + partitioned_graph, _ = partitioning.global_partition( deepcopy(fx_graph), min_block_size=2 ) self.assertEquals( diff --git a/tests/py/dynamo/testing_utilities.py b/tests/py/dynamo/testing_utilities.py index 9ec0fcf58e..742b9fc1a3 100644 --- a/tests/py/dynamo/testing_utilities.py +++ b/tests/py/dynamo/testing_utilities.py @@ -70,13 +70,13 @@ def compile_module_testing( ) -> torch.fx.GraphModule: """Helper compiler exclusively for testing""" if use_fast_partitioner: - partitioned_module = partitioning.fast_partition( + partitioned_module, _ = partitioning.fast_partition( gm, min_block_size=min_block_size, torch_executed_ops=torch_executed_ops, ) else: - partitioned_module = partitioning.global_partition( + partitioned_module, _ = partitioning.global_partition( gm, min_block_size=min_block_size, torch_executed_ops=torch_executed_ops,