From 635e26a88f5dbbcc7edba2673f78e3f96d5c6137 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Thu, 9 Nov 2023 10:57:08 -0800 Subject: [PATCH 1/4] feat: Add dryrun feature to Dynamo paths - Enables building of TRT engines with "dryrun" capabilities, meaning all of the phases except conversion are run and verbose logs of the graph structure and composition are printed for the user - Improves general-purpose debug logging by printing dryrun stats to the debug logs regardless of option specification - Provides intuitive schematic of the graph engines, inputs, and code path through the course of the graph --- py/torch_tensorrt/dynamo/_DryRunTracker.py | 205 +++++++++++++++++++++ py/torch_tensorrt/dynamo/_compiler.py | 93 +++++++++- py/torch_tensorrt/dynamo/_defaults.py | 1 + py/torch_tensorrt/dynamo/_settings.py | 4 + 4 files changed, 294 insertions(+), 9 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/_DryRunTracker.py diff --git a/py/torch_tensorrt/dynamo/_DryRunTracker.py b/py/torch_tensorrt/dynamo/_DryRunTracker.py new file mode 100644 index 0000000000..625db8753d --- /dev/null +++ b/py/torch_tensorrt/dynamo/_DryRunTracker.py @@ -0,0 +1,205 @@ +import logging +import math +from dataclasses import dataclass, field +from typing import List, Tuple + +import torch + +logger = logging.getLogger(__name__) + + +@dataclass +class PerSubgraphData: + """Class to track data on a per-subgraph level + + Args: + subgraph_name (str): Name of the subgraph in the GraphModule + subgraph_op_count (int): Number of operations in the subgraph + subgraph_input_shapes (List[Tuple[int, ...]]): Shapes of input Tensors of the subgraph + subgraph_input_dtypes (List[torch.device]): Input data types of the subgraph + subgraph_output_shapes (List[Tuple[int, ...]]): Shapes of output Tensors of the subgraph + subgraph_output_dtypes (List[torch.device]): Output data types of the subgraph + """ + + subgraph_name: str = "" + subgraph_op_count: int = 0 + subgraph_input_shapes: List[Tuple[int, ...]] = field(default_factory=list) + subgraph_input_dtypes: List[torch.device] = field(default_factory=list) + subgraph_output_shapes: List[Tuple[int, ...]] = field(default_factory=list) + subgraph_output_dtypes: List[torch.device] = field(default_factory=list) + + +@dataclass +class DryRunTracker: + """Class to track data on a graph-wide level + + Args: + total_ops_in_graph (int): Total number of operators in graph + supported_ops_in_graph (int): Number of supported operators in graph + graph_input_shapes (List[Tuple[int, ...]]): Shapes of input Tensors of the graph + graph_input_dtypes (List[torch.device]): Input data types of the graph + graph_output_shapes (List[Tuple[int, ...]]): Shapes of output Tensors of the graph + graph_output_dtypes (List[torch.device]): Output data types of the graph + per_subgraph_data (List[PerSubgraphData]): Per-subgraph data, see above class + tensorrt_graph_count (int): Number of TensorRT engines to be generated + truncated_long_and_double (bool): Whether truncate_long_and_double was enabled + """ + + total_ops_in_graph: int = 0 + supported_ops_in_graph: int = 0 + graph_input_shapes: List[Tuple[int, ...]] = field(default_factory=list) + graph_input_dtypes: List[torch.device] = field(default_factory=list) + graph_output_shapes: List[Tuple[int, ...]] = field(default_factory=list) + graph_output_dtypes: List[torch.device] = field(default_factory=list) + per_subgraph_data: List[PerSubgraphData] = field(default_factory=list) + tensorrt_graph_count: int = 0 + truncated_long_and_double: bool = False + + +def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> None: + """Displays statistics about the dryrun either to debug logs or info logs""" + # If user specified "dryrun=True", print to info logs, else debug + if dryrun_enabled: + dryrun_logger = logger.info + else: + dryrun_logger = logger.debug + + formatted_stats = "\n" + + # Print overall stats about the graph, operator counts, etc. + formatted_stats += "+" * 50 + " Dry-Run Results for Graph " + "+" * 50 + "\n" + formatted_stats += ( + f"The graph consists of {dryrun_tracker.total_ops_in_graph} Total Operators, " + f"of which {dryrun_tracker.supported_ops_in_graph} operators are supported, " + f"{round(dryrun_tracker.supported_ops_in_graph*100/dryrun_tracker.total_ops_in_graph, 2)}% coverage\n" + ) + formatted_stats += f"Long and double inputs were {'' if dryrun_tracker.truncated_long_and_double else 'not'} truncated (truncate_long_and_double={dryrun_tracker.truncated_long_and_double})\n" + formatted_stats += ( + f"{dryrun_tracker.tensorrt_graph_count} TRT Engine(s) were generated\n" + ) + + assert len(dryrun_tracker.per_subgraph_data) == dryrun_tracker.tensorrt_graph_count + + # Print schematic of the graph structure, as in: + # + # Inputs: [Tensor: (1, 3, 224, 224)@float32] + # ... + # TRT Engine #1: _run_on_acc_0 + # Engine Inputs: [Tensor: (1, 3, 224, 224)@float32] + # Number of Operators in Engine: 1 + # Engine Outputs: [Tensor: (1, 64, 112, 112)@float32] + # ... + # Outputs: [Tensor: (1, 1000)@float32] + # + formatted_stats += " " * 2 + "Graph Structure:\n\n" + formatted_stats += ( + " " * 3 + + f"Inputs: [{input_formatter(dryrun_tracker.graph_input_shapes, dryrun_tracker.graph_input_dtypes)}]\n" + ) + + for i, trt_subgraph_data in enumerate(dryrun_tracker.per_subgraph_data): + assert len(trt_subgraph_data.subgraph_input_dtypes) == len( + trt_subgraph_data.subgraph_input_shapes + ) + assert len(trt_subgraph_data.subgraph_output_dtypes) == len( + trt_subgraph_data.subgraph_output_shapes + ) + formatted_stats += " " * 4 + "...\n" + formatted_stats += ( + " " * 4 + f"TRT Engine #{i+1}: {trt_subgraph_data.subgraph_name}\n" + ) + formatted_stats += ( + " " * 5 + + f"Engine Inputs: [{input_formatter(trt_subgraph_data.subgraph_input_shapes, trt_subgraph_data.subgraph_input_dtypes)}]\n" + ) + formatted_stats += ( + " " * 5 + + f"Number of Operators in Engine: {trt_subgraph_data.subgraph_op_count}\n" + ) + formatted_stats += ( + " " * 5 + + f"Engine Outputs: [{input_formatter(trt_subgraph_data.subgraph_output_shapes, trt_subgraph_data.subgraph_output_dtypes)}]\n" + ) + + formatted_stats += " " * 4 + "...\n" + formatted_stats += ( + " " * 3 + + f"Outputs: [{input_formatter(dryrun_tracker.graph_output_shapes, dryrun_tracker.graph_output_dtypes)}]\n" + ) + + # Print aggregate statistics about the graph structure, including recommended "min_block_size" options + if dryrun_tracker.tensorrt_graph_count > 0: + min_ops_in_an_engine = min( + trt_subgraph.subgraph_op_count + for trt_subgraph in dryrun_tracker.per_subgraph_data + ) + avg_ops_per_engine = ( + sum( + trt_subgraph.subgraph_op_count + for trt_subgraph in dryrun_tracker.per_subgraph_data + ) + / dryrun_tracker.tensorrt_graph_count + ) + avg_ops_per_engine = round(avg_ops_per_engine, 2) + most_ops_in_an_engine = max( + trt_subgraph.subgraph_op_count + for trt_subgraph in dryrun_tracker.per_subgraph_data + ) + + formatted_stats += "\n" + " " * 2 + "-" * 25 + " Aggregate Stats " + "-" * 25 + formatted_stats += ( + "\n\n" + + " " * 3 + + "Average Number of Operators per TRT Engine: " + + f"{avg_ops_per_engine}" + ) + + formatted_stats += ( + "\n" + + " " * 3 + + "Most Operators in a TRT Engine: " + + f"{most_ops_in_an_engine}" + ) + + formatted_stats += "\n\n" + " " * 2 + "*" * 10 + " Recommendations " + "*" * 10 + formatted_stats += ( + "\n\n" + + " " * 3 + + "- For minimal graph segmentation, select min_block_size=" + + f"{most_ops_in_an_engine} which would generate " + + f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= most_ops_in_an_engine])} TRT engines" + ) + if math.ceil(avg_ops_per_engine) != most_ops_in_an_engine: + formatted_stats += ( + "\n" + + " " * 3 + + "- For moderate graph segmentation, select min_block_size=" + + f"{math.ceil(avg_ops_per_engine)} which would generate " + + f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= math.ceil(avg_ops_per_engine)])} TRT engines" + ) + + formatted_stats += ( + "\n" + + " " * 3 + + "- The current level of graph segmentation is equivalent to selecting min_block_size=" + + f"{min_ops_in_an_engine} which generates " + + f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= min_ops_in_an_engine])} TRT engines" + ) + else: + formatted_stats += ( + "\n" + + " " * 2 + + "Aggregate stats not available since no TRT Engines were generated." + ) + + dryrun_logger(formatted_stats) + + +def input_formatter(shapes: List[Tuple[int, ...]], dtypes: List[torch.dtype]) -> str: + """Format shapes and dtypes of input Tensors into a readable string""" + formatted_str = ", " + + for shape, dtype in zip(shapes, dtypes): + formatted_str += f"Tensor: {shape}@{str(dtype)[6:]}, " + + return formatted_str[2:-2] diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 796e0690f3..78942086a0 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -5,7 +5,6 @@ from typing import Any, List, Optional, Sequence, Set, Tuple, Union import torch -import torch_tensorrt from torch.export import ExportedProgram from torch_tensorrt._Device import Device from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum @@ -20,6 +19,7 @@ DLA_GLOBAL_DRAM_SIZE, DLA_LOCAL_DRAM_SIZE, DLA_SRAM_SIZE, + DRYRUN, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, ENGINE_CAPABILITY, MAX_AUX_STREAMS, @@ -37,6 +37,11 @@ VERSION_COMPATIBLE, WORKSPACE_SIZE, ) +from torch_tensorrt.dynamo._DryRunTracker import ( + DryRunTracker, + PerSubgraphData, + dryrun_stats_display, +) from torch_tensorrt.dynamo.conversion import ( CompilationSettings, convert_module, @@ -51,6 +56,8 @@ to_torch_tensorrt_device, ) +import torch_tensorrt + logger = logging.getLogger(__name__) @@ -84,6 +91,7 @@ def compile( use_python_runtime: bool = USE_PYTHON_RUNTIME, use_fast_partitioner: bool = USE_FAST_PARTITIONER, enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS, + dryrun: bool = DRYRUN, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile a TorchScript module for NVIDIA GPUs using TensorRT @@ -140,6 +148,7 @@ def compile( use_python_runtime: (bool): Return a graph using a pure Python runtime, reduces options for serialization use_fast_partitioner: (bool): Use the adjacency based partitioning scheme instead of the global partitioner. Adjacency partitioning is faster but may not be optiminal. Use the global paritioner (``False``) if looking for best performance enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT. + dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -215,6 +224,7 @@ def compile( "dla_sram_size": dla_sram_size, "dla_local_dram_size": dla_local_dram_size, "dla_global_dram_size": dla_global_dram_size, + "dryrun": dryrun, } settings = CompilationSettings(**compilation_options) @@ -238,15 +248,32 @@ def compile_module( Returns: Compiled FX GraphModule """ + dryrun_tracker = DryRunTracker() # Check the number of supported operations in the graph num_supported_ops, total_ops = partitioning.get_graph_converter_support( gm, settings.debug, settings.torch_executed_ops ) + dryrun_tracker.total_ops_in_graph = total_ops + dryrun_tracker.supported_ops_in_graph = num_supported_ops + dryrun_tracker.graph_input_shapes = [ + tuple(input_.shape) for input_ in sample_inputs + ] + dryrun_tracker.graph_input_dtypes = [input_.torch_dtype for input_ in sample_inputs] + dryrun_tracker.truncated_long_and_double = settings.truncate_long_and_double + + if settings.dryrun and settings.min_block_size > 1: + logger.info( + "It is recommended to run `dryrun` mode with `min_block_size=1`, " + "for the most thorough analysis" + ) + # If the number of supported operations is 0 or less than the block size, skip the subgraph # TODO: Add condition to second expression below when require_full_compilation is added - if num_supported_ops == 0 or (num_supported_ops < settings.min_block_size): + if num_supported_ops == 0 or ( + num_supported_ops < settings.min_block_size and not settings.dryrun + ): logger.warning( f"{num_supported_ops} supported operations detected in subgraph containing {total_ops} computational nodes. " f"Skipping this subgraph, since min_block_size was detected to be {settings.min_block_size}" @@ -297,6 +324,16 @@ def compile_module( if settings.use_fast_partitioner and "_run_on_acc" not in name: continue + subgraph_data = PerSubgraphData() + subgraph_data.subgraph_name = name + subgraph_data.subgraph_op_count = len( + [ + node + for node in submodule.graph.nodes + if node.op in ("call_function", "call_method", "call_module") + ] + ) + # Get the submodule inputs for min, opt, max shapes of the graph inputs submodule_inputs = partitioning.get_submod_inputs( partitioned_module, @@ -323,15 +360,51 @@ def compile_module( name, ) - # Create TRT engines from submodule - trt_module = convert_module( - submodule, - submodule_inputs, - settings=settings, - name=name, + subgraph_data.subgraph_input_dtypes = [ + submodule_input.torch_dtype for submodule_input in submodule_inputs + ] + subgraph_data.subgraph_input_shapes = [ + tuple(submodule_input.shape) for submodule_input in submodule_inputs + ] + + submodule_outputs = submodule( + *get_torch_inputs(submodule_inputs, to_torch_device(settings.device)) ) + if not isinstance(submodule_outputs, (list, tuple)): + submodule_outputs = [submodule_outputs] - trt_modules[name] = trt_module + subgraph_data.subgraph_output_dtypes = [ + submodule_output.dtype for submodule_output in submodule_outputs + ] + subgraph_data.subgraph_output_shapes = [ + tuple(submodule_output.shape) for submodule_output in submodule_outputs + ] + + dryrun_tracker.tensorrt_graph_count += 1 + dryrun_tracker.per_subgraph_data.append(subgraph_data) + + # Create TRT engines from submodule + if not settings.dryrun: + trt_module = convert_module( + submodule, + submodule_inputs, + settings=settings, + name=name, + ) + + trt_modules[name] = trt_module + + sample_outputs = gm( + *get_torch_inputs(sample_inputs, to_torch_device(settings.device)) + ) + + if not isinstance(sample_outputs, (list, tuple)): + sample_outputs = [sample_outputs] + + dryrun_tracker.graph_output_shapes = [ + tuple(output_.shape) for output_ in sample_outputs + ] + dryrun_tracker.graph_output_dtypes = [output_.dtype for output_ in sample_outputs] # Replace all FX Modules with TRT Modules for name, trt_module in trt_modules.items(): @@ -341,4 +414,6 @@ def compile_module( if fast_partitioner_failed: settings.use_fast_partitioner = True + dryrun_stats_display(dryrun_tracker, settings.dryrun) + return partitioned_module diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 4ec872fb1b..4afabe60eb 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -24,6 +24,7 @@ ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False REFIT = False REQUIRE_FULL_COMPILATION = False +DRYRUN = False def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index cd58c9547f..00c3d95a0e 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -10,6 +10,7 @@ DLA_GLOBAL_DRAM_SIZE, DLA_LOCAL_DRAM_SIZE, DLA_SRAM_SIZE, + DRYRUN, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, ENGINE_CAPABILITY, MAX_AUX_STREAMS, @@ -63,6 +64,8 @@ class CompilationSettings: dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer. dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution + dryrun (bool): Toggle "Dryrun" mode, which runs everything through partitioning, short of conversion to + TRT Engines. Prints detailed logs of the graph structure and nature of partitioning """ precision: torch.dtype = PRECISION @@ -88,3 +91,4 @@ class CompilationSettings: dla_sram_size: int = DLA_SRAM_SIZE dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE + dryrun: bool = DRYRUN From d1215453eee67d470854fabb0b557cde3fc2bdfc Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Thu, 9 Nov 2023 21:37:10 -0800 Subject: [PATCH 2/4] fix: Address PR comments and update logging scheme - Fix test case failures --- py/torch_tensorrt/dynamo/_DryRunTracker.py | 140 ++++++++++-------- py/torch_tensorrt/dynamo/_compiler.py | 55 +++---- .../partitioning/_adjacency_partitioner.py | 6 +- .../partitioning/_global_partitioner.py | 8 +- py/torch_tensorrt/dynamo/utils.py | 60 ++++++-- .../dynamo/backend/test_backend_compiler.py | 7 +- .../partitioning/test_fast_partitioning.py | 8 +- .../partitioning/test_global_partitioning.py | 8 +- tests/py/dynamo/testing_utilities.py | 4 +- 9 files changed, 181 insertions(+), 115 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_DryRunTracker.py b/py/torch_tensorrt/dynamo/_DryRunTracker.py index 625db8753d..85cc23165a 100644 --- a/py/torch_tensorrt/dynamo/_DryRunTracker.py +++ b/py/torch_tensorrt/dynamo/_DryRunTracker.py @@ -1,9 +1,9 @@ import logging import math from dataclasses import dataclass, field -from typing import List, Tuple +from typing import Any, Dict, List -import torch +from torch_tensorrt.dynamo._settings import CompilationSettings logger = logging.getLogger(__name__) @@ -15,18 +15,18 @@ class PerSubgraphData: Args: subgraph_name (str): Name of the subgraph in the GraphModule subgraph_op_count (int): Number of operations in the subgraph - subgraph_input_shapes (List[Tuple[int, ...]]): Shapes of input Tensors of the subgraph - subgraph_input_dtypes (List[torch.device]): Input data types of the subgraph - subgraph_output_shapes (List[Tuple[int, ...]]): Shapes of output Tensors of the subgraph - subgraph_output_dtypes (List[torch.device]): Output data types of the subgraph + subgraph_input_shapes (Any): Shapes of input Tensors of the subgraph + subgraph_input_dtypes (Any): Input data types of the subgraph + subgraph_output_shapes (Any): Shapes of output Tensors of the subgraph + subgraph_output_dtypes (Any): Output data types of the subgraph """ subgraph_name: str = "" subgraph_op_count: int = 0 - subgraph_input_shapes: List[Tuple[int, ...]] = field(default_factory=list) - subgraph_input_dtypes: List[torch.device] = field(default_factory=list) - subgraph_output_shapes: List[Tuple[int, ...]] = field(default_factory=list) - subgraph_output_dtypes: List[torch.device] = field(default_factory=list) + subgraph_input_shapes: Any = field(default_factory=list) + subgraph_input_dtypes: Any = field(default_factory=list) + subgraph_output_shapes: Any = field(default_factory=list) + subgraph_output_dtypes: Any = field(default_factory=list) @dataclass @@ -36,81 +36,72 @@ class DryRunTracker: Args: total_ops_in_graph (int): Total number of operators in graph supported_ops_in_graph (int): Number of supported operators in graph - graph_input_shapes (List[Tuple[int, ...]]): Shapes of input Tensors of the graph - graph_input_dtypes (List[torch.device]): Input data types of the graph - graph_output_shapes (List[Tuple[int, ...]]): Shapes of output Tensors of the graph - graph_output_dtypes (List[torch.device]): Output data types of the graph + graph_input_shapes (Any): Shapes of input Tensors of the graph + graph_input_dtypes (Any): Input data types of the graph + graph_output_shapes (Any): Shapes of output Tensors of the graph + graph_output_dtypes (Any): Output data types of the graph per_subgraph_data (List[PerSubgraphData]): Per-subgraph data, see above class tensorrt_graph_count (int): Number of TensorRT engines to be generated - truncated_long_and_double (bool): Whether truncate_long_and_double was enabled + compilation_settings (CompilationSettings): User Compilation Settings + unsupported_ops (Dict[str, int]): Set of operators not supported in TRT """ total_ops_in_graph: int = 0 supported_ops_in_graph: int = 0 - graph_input_shapes: List[Tuple[int, ...]] = field(default_factory=list) - graph_input_dtypes: List[torch.device] = field(default_factory=list) - graph_output_shapes: List[Tuple[int, ...]] = field(default_factory=list) - graph_output_dtypes: List[torch.device] = field(default_factory=list) + graph_input_shapes: Any = field(default_factory=list) + graph_input_dtypes: Any = field(default_factory=list) + graph_output_shapes: Any = field(default_factory=list) + graph_output_dtypes: Any = field(default_factory=list) per_subgraph_data: List[PerSubgraphData] = field(default_factory=list) tensorrt_graph_count: int = 0 - truncated_long_and_double: bool = False + compilation_settings: CompilationSettings = field( + default_factory=CompilationSettings + ) + unsupported_ops: Dict[str, int] = field(default_factory=dict) def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> None: - """Displays statistics about the dryrun either to debug logs or info logs""" - # If user specified "dryrun=True", print to info logs, else debug - if dryrun_enabled: - dryrun_logger = logger.info - else: - dryrun_logger = logger.debug - + """Displays statistics about the dryrun either to debug logs or stdout""" formatted_stats = "\n" # Print overall stats about the graph, operator counts, etc. - formatted_stats += "+" * 50 + " Dry-Run Results for Graph " + "+" * 50 + "\n" + formatted_stats += "+" * 50 + " Dry-Run Results for Graph " + "+" * 50 + "\n\n" formatted_stats += ( f"The graph consists of {dryrun_tracker.total_ops_in_graph} Total Operators, " f"of which {dryrun_tracker.supported_ops_in_graph} operators are supported, " - f"{round(dryrun_tracker.supported_ops_in_graph*100/dryrun_tracker.total_ops_in_graph, 2)}% coverage\n" - ) - formatted_stats += f"Long and double inputs were {'' if dryrun_tracker.truncated_long_and_double else 'not'} truncated (truncate_long_and_double={dryrun_tracker.truncated_long_and_double})\n" - formatted_stats += ( - f"{dryrun_tracker.tensorrt_graph_count} TRT Engine(s) were generated\n" + f"{round(dryrun_tracker.supported_ops_in_graph*100/dryrun_tracker.total_ops_in_graph, 2)}% coverage\n\n" ) + formatted_stats += f"The following ops are currently unsupported and set to run in Torch: {dryrun_tracker.unsupported_ops}\n\n" + formatted_stats += f"Compiled with: {dryrun_tracker.compilation_settings}\n\n" assert len(dryrun_tracker.per_subgraph_data) == dryrun_tracker.tensorrt_graph_count # Print schematic of the graph structure, as in: # - # Inputs: [Tensor: (1, 3, 224, 224)@float32] + # Inputs: List[Tensor: (1, 3, 224, 224)@float32] # ... - # TRT Engine #1: _run_on_acc_0 - # Engine Inputs: [Tensor: (1, 3, 224, 224)@float32] - # Number of Operators in Engine: 1 - # Engine Outputs: [Tensor: (1, 64, 112, 112)@float32] + # TRT Engine #1 - Submodule name: _run_on_acc_0 + # Engine Inputs: List[Tensor: (1, 3, 224, 224)@float32] + # Number of Operators in Engine: 1 + # Engine Outputs: Tensor: (1, 64, 112, 112)@float32 # ... - # Outputs: [Tensor: (1, 1000)@float32] + # Outputs: List[Tensor: (1, 1000)@float32] # formatted_stats += " " * 2 + "Graph Structure:\n\n" formatted_stats += ( " " * 3 - + f"Inputs: [{input_formatter(dryrun_tracker.graph_input_shapes, dryrun_tracker.graph_input_dtypes)}]\n" + + f"Inputs: {input_formatter(dryrun_tracker.graph_input_shapes, dryrun_tracker.graph_input_dtypes)}\n" ) for i, trt_subgraph_data in enumerate(dryrun_tracker.per_subgraph_data): - assert len(trt_subgraph_data.subgraph_input_dtypes) == len( - trt_subgraph_data.subgraph_input_shapes - ) - assert len(trt_subgraph_data.subgraph_output_dtypes) == len( - trt_subgraph_data.subgraph_output_shapes - ) formatted_stats += " " * 4 + "...\n" formatted_stats += ( - " " * 4 + f"TRT Engine #{i+1}: {trt_subgraph_data.subgraph_name}\n" + " " * 4 + + f"TRT Engine #{i+1} - Submodule name: {trt_subgraph_data.subgraph_name}\n" ) formatted_stats += ( " " * 5 - + f"Engine Inputs: [{input_formatter(trt_subgraph_data.subgraph_input_shapes, trt_subgraph_data.subgraph_input_dtypes)}]\n" + + f"Engine Inputs: {input_formatter(trt_subgraph_data.subgraph_input_shapes, trt_subgraph_data.subgraph_input_dtypes)}\n" ) formatted_stats += ( " " * 5 @@ -118,13 +109,13 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> ) formatted_stats += ( " " * 5 - + f"Engine Outputs: [{input_formatter(trt_subgraph_data.subgraph_output_shapes, trt_subgraph_data.subgraph_output_dtypes)}]\n" + + f"Engine Outputs: {input_formatter(trt_subgraph_data.subgraph_output_shapes, trt_subgraph_data.subgraph_output_dtypes)}\n" ) formatted_stats += " " * 4 + "...\n" formatted_stats += ( " " * 3 - + f"Outputs: [{input_formatter(dryrun_tracker.graph_output_shapes, dryrun_tracker.graph_output_dtypes)}]\n" + + f"Outputs: {input_formatter(dryrun_tracker.graph_output_shapes, dryrun_tracker.graph_output_dtypes)}\n" ) # Print aggregate statistics about the graph structure, including recommended "min_block_size" options @@ -167,7 +158,7 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> + " " * 3 + "- For minimal graph segmentation, select min_block_size=" + f"{most_ops_in_an_engine} which would generate " - + f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= most_ops_in_an_engine])} TRT engines" + + f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= most_ops_in_an_engine])} TRT engine(s)" ) if math.ceil(avg_ops_per_engine) != most_ops_in_an_engine: formatted_stats += ( @@ -175,7 +166,7 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> + " " * 3 + "- For moderate graph segmentation, select min_block_size=" + f"{math.ceil(avg_ops_per_engine)} which would generate " - + f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= math.ceil(avg_ops_per_engine)])} TRT engines" + + f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= math.ceil(avg_ops_per_engine)])} TRT engine(s)" ) formatted_stats += ( @@ -183,7 +174,7 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> + " " * 3 + "- The current level of graph segmentation is equivalent to selecting min_block_size=" + f"{min_ops_in_an_engine} which generates " - + f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= min_ops_in_an_engine])} TRT engines" + + f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= min_ops_in_an_engine])} TRT engine(s)" ) else: formatted_stats += ( @@ -192,14 +183,45 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> + "Aggregate stats not available since no TRT Engines were generated." ) - dryrun_logger(formatted_stats) + # If user specified "dryrun=True", print to stdout, else debug + if dryrun_enabled: + print(formatted_stats) + else: + logger.debug(formatted_stats) -def input_formatter(shapes: List[Tuple[int, ...]], dtypes: List[torch.dtype]) -> str: +def input_formatter(shapes: Any, dtypes: Any) -> str: """Format shapes and dtypes of input Tensors into a readable string""" - formatted_str = ", " - for shape, dtype in zip(shapes, dtypes): - formatted_str += f"Tensor: {shape}@{str(dtype)[6:]}, " + def input_formatter_helper(shapes: Any, dtypes: Any) -> str: + """Helper for input formatter""" + # Base case - single shape, single dtype + if isinstance(shapes, tuple) and all(isinstance(elt, int) for elt in shapes): + return f"Tensor: {shapes}@{str(dtypes)[6:]}, " + + # Shapes is a sequence + elif isinstance(shapes, (list, tuple)): + formatted_str = "List[" if isinstance(shapes, list) else "Tuple(" + for shape, dtype in zip(shapes, dtypes): + formatted_str += input_formatter_helper(shape, dtype) + formatted_str = formatted_str[:-2] + ( + "], " if isinstance(shapes, list) else "), " + ) + return formatted_str + + # Shapes is a dictionary + elif isinstance(shapes, dict): + formatted_str = "Dict{" + + for key, shape in shapes.items(): + formatted_str += input_formatter_helper(shape, dtypes[key]) + + formatted_str = formatted_str[:-2] + "}, " + return formatted_str + + else: + raise ValueError( + f"Invalid input type {type(shapes)} encountered in parse_complex_tensor_structs parsing." + ) - return formatted_str[2:-2] + return input_formatter_helper(shapes, dtypes)[:-2] diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 78942086a0..d382f26db6 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -50,6 +50,7 @@ from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions from torch_tensorrt.dynamo.utils import ( get_torch_inputs, + parse_complex_tensor_structs, prepare_inputs, set_log_level, to_torch_device, @@ -257,11 +258,13 @@ def compile_module( dryrun_tracker.total_ops_in_graph = total_ops dryrun_tracker.supported_ops_in_graph = num_supported_ops - dryrun_tracker.graph_input_shapes = [ - tuple(input_.shape) for input_ in sample_inputs - ] - dryrun_tracker.graph_input_dtypes = [input_.torch_dtype for input_ in sample_inputs] - dryrun_tracker.truncated_long_and_double = settings.truncate_long_and_double + dryrun_tracker.graph_input_shapes = parse_complex_tensor_structs( + sample_inputs, "shape", tuple + ) + dryrun_tracker.graph_input_dtypes = parse_complex_tensor_structs( + sample_inputs, "torch_dtype" + ) + dryrun_tracker.compilation_settings = settings if settings.dryrun and settings.min_block_size > 1: logger.info( @@ -290,7 +293,7 @@ def compile_module( # If specified, try using the fast partitioner and fall back to the global one on failure if settings.use_fast_partitioner: try: - partitioned_module = partitioning.fast_partition( + partitioned_module, supported_ops = partitioning.fast_partition( gm, verbose=settings.debug, min_block_size=settings.min_block_size, @@ -307,13 +310,15 @@ def compile_module( settings.use_fast_partitioner = False if not settings.use_fast_partitioner: - partitioned_module = partitioning.global_partition( + partitioned_module, supported_ops = partitioning.global_partition( gm, verbose=settings.debug, min_block_size=settings.min_block_size, torch_executed_ops=settings.torch_executed_ops, ) + dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators + # Store TRT replicas of Torch subgraphs trt_modules = {} # Iterate over all components that can be accelerated @@ -360,25 +365,23 @@ def compile_module( name, ) - subgraph_data.subgraph_input_dtypes = [ - submodule_input.torch_dtype for submodule_input in submodule_inputs - ] - subgraph_data.subgraph_input_shapes = [ - tuple(submodule_input.shape) for submodule_input in submodule_inputs - ] + subgraph_data.subgraph_input_shapes = parse_complex_tensor_structs( + submodule_inputs, "shape", tuple + ) + subgraph_data.subgraph_input_dtypes = parse_complex_tensor_structs( + submodule_inputs, "torch_dtype" + ) submodule_outputs = submodule( *get_torch_inputs(submodule_inputs, to_torch_device(settings.device)) ) - if not isinstance(submodule_outputs, (list, tuple)): - submodule_outputs = [submodule_outputs] - subgraph_data.subgraph_output_dtypes = [ - submodule_output.dtype for submodule_output in submodule_outputs - ] - subgraph_data.subgraph_output_shapes = [ - tuple(submodule_output.shape) for submodule_output in submodule_outputs - ] + subgraph_data.subgraph_output_shapes = parse_complex_tensor_structs( + submodule_outputs, "shape", tuple + ) + subgraph_data.subgraph_output_dtypes = parse_complex_tensor_structs( + submodule_outputs, "dtype" + ) dryrun_tracker.tensorrt_graph_count += 1 dryrun_tracker.per_subgraph_data.append(subgraph_data) @@ -401,10 +404,12 @@ def compile_module( if not isinstance(sample_outputs, (list, tuple)): sample_outputs = [sample_outputs] - dryrun_tracker.graph_output_shapes = [ - tuple(output_.shape) for output_ in sample_outputs - ] - dryrun_tracker.graph_output_dtypes = [output_.dtype for output_ in sample_outputs] + dryrun_tracker.graph_output_shapes = parse_complex_tensor_structs( + sample_outputs, "shape", tuple + ) + dryrun_tracker.graph_output_dtypes = parse_complex_tensor_structs( + sample_outputs, "dtype" + ) # Replace all FX Modules with TRT Modules for name, trt_module in trt_modules.items(): diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index 5bdbb8919b..5ec5293474 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -248,7 +248,7 @@ def partition( min_block_size: int = MIN_BLOCK_SIZE, torch_executed_ops: Collection[Target] = set(), require_full_compilation: bool = REQUIRE_FULL_COMPILATION, -) -> torch.fx.GraphModule: +) -> Tuple[torch.fx.GraphModule, OpSupportTester]: """Partition an FX GraphModule with aten ops into TRT engines Partitioning is based on converter operator support @@ -259,7 +259,7 @@ def partition( torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage require_full_compilation: Require that all computational operators be run in TRT Returns: - torch.fx.GraphModule + torch.fx.GraphModule, OpSupportTester """ # Ensure graph is clean prior to partitioning gm.graph.eliminate_dead_code() @@ -280,4 +280,4 @@ def partition( if verbose: supported_ops.print_support_overview(partitioner.num_trt_accelerated_subgraphs) - return partitioned_graph + return partitioned_graph, supported_ops diff --git a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py index 092bdabfd0..4c8efb234e 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py @@ -1,5 +1,5 @@ import logging -from typing import Collection, Dict, List, Mapping, Optional, Sequence, Set +from typing import Collection, Dict, List, Mapping, Optional, Sequence, Set, Tuple import torch from torch.fx.graph_module import GraphModule @@ -203,7 +203,7 @@ def partition( min_block_size: int = MIN_BLOCK_SIZE, torch_executed_ops: Optional[Set[str]] = None, require_full_compilation: bool = REQUIRE_FULL_COMPILATION, -) -> torch.fx.GraphModule: +) -> Tuple[torch.fx.GraphModule, TorchTensorRTOperatorSupport]: """Partition an FX GraphModule with aten ops into TRT engines Partitioning is based on converter operator support @@ -214,7 +214,7 @@ def partition( torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage require_full_compilation: Whether to require that all operators be run in TRT Returns: - torch.fx.GraphModule + torch.fx.GraphModule, TorchTensorRTOperatorSupport """ supported_ops = TorchTensorRTOperatorSupport( torch_executed_ops=torch_executed_ops @@ -236,4 +236,4 @@ def partition( if verbose: supported_ops.print_support_overview(len(partitions)) - return fused_graph + return fused_graph, supported_ops diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 26de1fcb27..22590fe73d 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -5,12 +5,12 @@ from typing import Any, Callable, Dict, Optional, Sequence, Union import torch -import torch_tensorrt from torch_tensorrt._Device import Device from torch_tensorrt._Input import Input from torch_tensorrt.dynamo._defaults import PRECISION from torch_tensorrt.dynamo._settings import CompilationSettings +import torch_tensorrt from packaging import version logger = logging.getLogger(__name__) @@ -114,7 +114,7 @@ def prepare_inputs( inputs, disable_memory_format_check=disable_memory_format_check ) - elif isinstance(inputs, list): + elif isinstance(inputs, (list, tuple)): torchtrt_input_list = [] for input_obj in inputs: torchtrt_input = prepare_inputs( @@ -122,24 +122,62 @@ def prepare_inputs( ) torchtrt_input_list.append(torchtrt_input) - return torchtrt_input_list + return ( + torchtrt_input_list + if isinstance(inputs, list) + else tuple(torchtrt_input_list) + ) - elif isinstance(inputs, tuple): - torchtrt_inputs_tup = [] - for input_obj in inputs: + elif isinstance(inputs, dict): + torchtrt_inputs_dict: Dict[Any, Any] = dict() + + for key, input_obj in inputs.items(): torchtrt_input = prepare_inputs( input_obj, disable_memory_format_check=disable_memory_format_check ) - torchtrt_inputs_tup.append(torchtrt_input) + torchtrt_inputs_dict[key] = torchtrt_input + + return torchtrt_inputs_dict - return tuple(torchtrt_inputs_tup) + else: + raise ValueError( + f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. " + + "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}" + ) + + +def parse_complex_tensor_structs( + inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any], + attribute_to_extract: str, + apply_fn: Callable[[Any], Any] = lambda x: x, +) -> Any: + """Parses complex structures of Tensors and returns a mirrored structure + Extracts key attributes of each singular element, while reconstructing the struct + Optionally applies a function to each attribute before returning + """ + if isinstance(inputs, (torch.Tensor, Input)): + return apply_fn(getattr(inputs, attribute_to_extract, None)) + + elif isinstance(inputs, (list, tuple)): + torchtrt_input_list = [] + for input_obj in inputs: + torchtrt_input = parse_complex_tensor_structs( + input_obj, attribute_to_extract, apply_fn + ) + torchtrt_input_list.append(torchtrt_input) + + return ( + torchtrt_input_list + if isinstance(inputs, list) + else tuple(torchtrt_input_list) + ) elif isinstance(inputs, dict): torchtrt_inputs_dict: Dict[Any, Any] = dict() for key, input_obj in inputs.items(): - torchtrt_input = prepare_inputs( - input_obj, disable_memory_format_check=disable_memory_format_check + torchtrt_input = parse_complex_tensor_structs( + input_obj, attribute_to_extract, apply_fn ) torchtrt_inputs_dict[key] = torchtrt_input @@ -147,7 +185,7 @@ def prepare_inputs( else: raise ValueError( - f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. " + f"Invalid input type {type(inputs)} encountered in parse_complex_tensor_structs parsing. " + "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}" ) diff --git a/tests/py/dynamo/backend/test_backend_compiler.py b/tests/py/dynamo/backend/test_backend_compiler.py index 0038412c30..a958d03120 100644 --- a/tests/py/dynamo/backend/test_backend_compiler.py +++ b/tests/py/dynamo/backend/test_backend_compiler.py @@ -1,10 +1,11 @@ from copy import deepcopy import torch -import torch_tensorrt from torch.testing._internal.common_utils import TestCase, run_tests from torch_tensorrt.dynamo.partitioning import fast_partition +import torch_tensorrt + from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing @@ -19,7 +20,7 @@ def forward(self, x, y): return torch.mean(out, dim=1) fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) - partitioned_graph = fast_partition(deepcopy(fx_graph), min_block_size=3) + partitioned_graph, _ = fast_partition(deepcopy(fx_graph), min_block_size=3) self.assertEquals( len( @@ -198,7 +199,7 @@ def forward(self, x, y): ) fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) - partitioned_graph = fast_partition(deepcopy(fx_graph), min_block_size=3) + partitioned_graph, _ = fast_partition(deepcopy(fx_graph), min_block_size=3) self.assertEquals( len(list(partitioned_graph.named_children())), diff --git a/tests/py/dynamo/partitioning/test_fast_partitioning.py b/tests/py/dynamo/partitioning/test_fast_partitioning.py index a271df1e72..2ff5433b22 100644 --- a/tests/py/dynamo/partitioning/test_fast_partitioning.py +++ b/tests/py/dynamo/partitioning/test_fast_partitioning.py @@ -18,7 +18,7 @@ def forward(self, x, y): return torch.ops.aten.add.Tensor(x, y) fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp()) - partitioned_graph = partitioning.fast_partition(deepcopy(fx_graph)) + partitioned_graph, _ = partitioning.fast_partition(deepcopy(fx_graph)) self.assertEquals( len( [ @@ -40,7 +40,7 @@ def forward(self, x, y): return torch.ops.aten.add.Tensor(x, y) fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp()) - partitioned_graph = partitioning.fast_partition( + partitioned_graph, _ = partitioning.fast_partition( deepcopy(fx_graph), require_full_compilation=True ) self.assertEquals( @@ -68,7 +68,7 @@ def forward(self, x, y): return pow_ fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) - partitioned_graph = partitioning.fast_partition( + partitioned_graph, _ = partitioning.fast_partition( deepcopy(fx_graph), min_block_size=2 ) self.assertEquals( @@ -97,7 +97,7 @@ def forward(self, x, y): return pow_ fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) - partitioned_graph = partitioning.fast_partition( + partitioned_graph, _ = partitioning.fast_partition( deepcopy(fx_graph), min_block_size=2 ) self.assertEquals( diff --git a/tests/py/dynamo/partitioning/test_global_partitioning.py b/tests/py/dynamo/partitioning/test_global_partitioning.py index 6701b33ccf..51d93a40f2 100644 --- a/tests/py/dynamo/partitioning/test_global_partitioning.py +++ b/tests/py/dynamo/partitioning/test_global_partitioning.py @@ -18,7 +18,7 @@ def forward(self, x, y): return torch.ops.aten.add.Tensor(x, y) fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp()) - partitioned_graph = partitioning.global_partition(deepcopy(fx_graph)) + partitioned_graph, _ = partitioning.global_partition(deepcopy(fx_graph)) self.assertEquals( len(list(partitioned_graph.named_children())), 0, @@ -34,7 +34,7 @@ def forward(self, x, y): return torch.ops.aten.add.Tensor(x, y) fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp()) - partitioned_graph = partitioning.global_partition( + partitioned_graph, _ = partitioning.global_partition( deepcopy(fx_graph), require_full_compilation=True ) self.assertEquals( @@ -56,7 +56,7 @@ def forward(self, x, y): return pow_ fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) - partitioned_graph = partitioning.global_partition( + partitioned_graph, _ = partitioning.global_partition( deepcopy(fx_graph), min_block_size=2 ) self.assertEquals( @@ -79,7 +79,7 @@ def forward(self, x, y): return pow_ fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) - partitioned_graph = partitioning.global_partition( + partitioned_graph, _ = partitioning.global_partition( deepcopy(fx_graph), min_block_size=2 ) self.assertEquals( diff --git a/tests/py/dynamo/testing_utilities.py b/tests/py/dynamo/testing_utilities.py index 9ec0fcf58e..742b9fc1a3 100644 --- a/tests/py/dynamo/testing_utilities.py +++ b/tests/py/dynamo/testing_utilities.py @@ -70,13 +70,13 @@ def compile_module_testing( ) -> torch.fx.GraphModule: """Helper compiler exclusively for testing""" if use_fast_partitioner: - partitioned_module = partitioning.fast_partition( + partitioned_module, _ = partitioning.fast_partition( gm, min_block_size=min_block_size, torch_executed_ops=torch_executed_ops, ) else: - partitioned_module = partitioning.global_partition( + partitioned_module, _ = partitioning.global_partition( gm, min_block_size=min_block_size, torch_executed_ops=torch_executed_ops, From 212019b1122d2c8f9ecaa37988443de83018aedf Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Tue, 14 Nov 2023 21:19:42 -0800 Subject: [PATCH 3/4] feat: Add optional filepath to save - Add detailed layer information for excluded ops --- py/torch_tensorrt/dynamo/_DryRunTracker.py | 56 ++++++++++++++++++++-- py/torch_tensorrt/dynamo/_compiler.py | 6 +++ py/torch_tensorrt/dynamo/_settings.py | 9 ++-- 3 files changed, 64 insertions(+), 7 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_DryRunTracker.py b/py/torch_tensorrt/dynamo/_DryRunTracker.py index 85cc23165a..031fce2e73 100644 --- a/py/torch_tensorrt/dynamo/_DryRunTracker.py +++ b/py/torch_tensorrt/dynamo/_DryRunTracker.py @@ -1,9 +1,14 @@ import logging import math +import operator +import os from dataclasses import dataclass, field -from typing import Any, Dict, List +from typing import Any, Dict, List, Union +import torch from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterRegistry +from torch_tensorrt.dynamo.conversion.converter_utils import get_node_name logger = logging.getLogger(__name__) @@ -44,6 +49,7 @@ class DryRunTracker: tensorrt_graph_count (int): Number of TensorRT engines to be generated compilation_settings (CompilationSettings): User Compilation Settings unsupported_ops (Dict[str, int]): Set of operators not supported in TRT + to_run_in_torch (List[str]): Set of nodes to run in Torch """ total_ops_in_graph: int = 0 @@ -58,9 +64,12 @@ class DryRunTracker: default_factory=CompilationSettings ) unsupported_ops: Dict[str, int] = field(default_factory=dict) + to_run_in_torch: List[str] = field(default_factory=list) -def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> None: +def dryrun_stats_display( + dryrun_tracker: DryRunTracker, dryrun_enabled: Union[bool, str] +) -> None: """Displays statistics about the dryrun either to debug logs or stdout""" formatted_stats = "\n" @@ -71,7 +80,19 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> f"of which {dryrun_tracker.supported_ops_in_graph} operators are supported, " f"{round(dryrun_tracker.supported_ops_in_graph*100/dryrun_tracker.total_ops_in_graph, 2)}% coverage\n\n" ) - formatted_stats += f"The following ops are currently unsupported and set to run in Torch: {dryrun_tracker.unsupported_ops}\n\n" + if dryrun_tracker.unsupported_ops: + parsed_ops = "\n".join( + [f"{str(k)}: {str(v)}" for k, v in dryrun_tracker.unsupported_ops.items()] + ) + formatted_stats += f"The following ops are currently unsupported or excluded from conversion, and are listed with their op-count in the graph:\n {parsed_ops}\n\n" + + if dryrun_tracker.to_run_in_torch: + formatted_nodes = "\n".join(dryrun_tracker.to_run_in_torch) + formatted_stats += ( + f"The following nodes are currently set to run in Torch:\n{formatted_nodes}\n" + "Note: Some of the above nodes may be supported, but were not included in a TRT graph by the partitioner\n\n" + ) + formatted_stats += f"Compiled with: {dryrun_tracker.compilation_settings}\n\n" assert len(dryrun_tracker.per_subgraph_data) == dryrun_tracker.tensorrt_graph_count @@ -184,8 +205,17 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> ) # If user specified "dryrun=True", print to stdout, else debug + # If user specified a filepath, save the output to the path as well if dryrun_enabled: print(formatted_stats) + if isinstance(dryrun_enabled, str): + if os.path.exists(dryrun_enabled): + logger.warning( + f"File already exists at path {dryrun_enabled}, not saving dryrun output" + ) + else: + with open(dryrun_enabled, "w+") as f: + f.write(formatted_stats) else: logger.debug(formatted_stats) @@ -225,3 +255,23 @@ def input_formatter_helper(shapes: Any, dtypes: Any) -> str: ) return input_formatter_helper(shapes, dtypes)[:-2] + + +def parse_non_trt_nodes(graph_module: torch.fx.GraphModule) -> List[str]: + """Parses call_function and call_method nodes from a GraphModule + Excludes getitem nodes + + Returns a string representation of the nodes + """ + to_run_in_torch = [] + for node in graph_module.graph.nodes: + # getitem nodes are excluded since they are a Tensor-collection op + if ( + node.op in ("call_function", "call_method") + and node.target != operator.getitem + ): + to_run_in_torch.append( + f"Node: {ConverterRegistry.qualified_name_or_str(node.target)}, " + f"with layer location: {get_node_name(node)}" + ) + return to_run_in_torch diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index d382f26db6..23e32e2b65 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -41,6 +41,7 @@ DryRunTracker, PerSubgraphData, dryrun_stats_display, + parse_non_trt_nodes, ) from torch_tensorrt.dynamo.conversion import ( CompilationSettings, @@ -319,6 +320,10 @@ def compile_module( dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators + # The global partitioner leaves non-TRT nodes as-is + if not settings.use_fast_partitioner: + dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(partitioned_module)) + # Store TRT replicas of Torch subgraphs trt_modules = {} # Iterate over all components that can be accelerated @@ -327,6 +332,7 @@ def compile_module( submodule = getattr(partitioned_module, name) # Criteria for a module to be convertible to TRT if settings.use_fast_partitioner and "_run_on_acc" not in name: + dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(submodule)) continue subgraph_data = PerSubgraphData() diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 00c3d95a0e..60990bda99 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Optional, Set +from typing import Optional, Set, Union import torch from tensorrt import EngineCapability @@ -64,8 +64,9 @@ class CompilationSettings: dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer. dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution - dryrun (bool): Toggle "Dryrun" mode, which runs everything through partitioning, short of conversion to - TRT Engines. Prints detailed logs of the graph structure and nature of partitioning + dryrun (Union[bool, str]): Toggle "Dryrun" mode, which runs everything through partitioning, short of conversion to + TRT Engines. Prints detailed logs of the graph structure and nature of partitioning. Optionally saves the + ouptut to a file if a string path is specified """ precision: torch.dtype = PRECISION @@ -91,4 +92,4 @@ class CompilationSettings: dla_sram_size: int = DLA_SRAM_SIZE dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE - dryrun: bool = DRYRUN + dryrun: Union[bool, str] = DRYRUN From 3883062f8a2d710ca1f5c179e5a340344929a1c3 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Thu, 7 Dec 2023 16:16:32 -0800 Subject: [PATCH 4/4] fix: Add support for Dynamic Shapes --- py/torch_tensorrt/dynamo/_DryRunTracker.py | 15 +++++++++++++++ py/torch_tensorrt/dynamo/_compiler.py | 12 ++++++++---- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_DryRunTracker.py b/py/torch_tensorrt/dynamo/_DryRunTracker.py index 031fce2e73..46d99ffe31 100644 --- a/py/torch_tensorrt/dynamo/_DryRunTracker.py +++ b/py/torch_tensorrt/dynamo/_DryRunTracker.py @@ -229,6 +229,21 @@ def input_formatter_helper(shapes: Any, dtypes: Any) -> str: if isinstance(shapes, tuple) and all(isinstance(elt, int) for elt in shapes): return f"Tensor: {shapes}@{str(dtypes)[6:]}, " + # Base case - dynamic shape, single dtype + elif ( + isinstance(shapes, dict) + and len(shapes) == 3 + and all( + ( + isinstance(shape, tuple) + and all(isinstance(elt, int) for elt in shape) + and k in ("min_shape", "opt_shape", "max_shape") + ) + for k, shape in shapes.items() + ) + ): + return f"Tensor: {shapes}@{str(dtypes)[6:]}, " + # Shapes is a sequence elif isinstance(shapes, (list, tuple)): formatted_str = "List[" if isinstance(shapes, list) else "Tuple(" diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 23e32e2b65..ac7a323545 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -260,7 +260,7 @@ def compile_module( dryrun_tracker.total_ops_in_graph = total_ops dryrun_tracker.supported_ops_in_graph = num_supported_ops dryrun_tracker.graph_input_shapes = parse_complex_tensor_structs( - sample_inputs, "shape", tuple + sample_inputs, "shape", lambda x: dict(x) if isinstance(x, dict) else tuple(x) ) dryrun_tracker.graph_input_dtypes = parse_complex_tensor_structs( sample_inputs, "torch_dtype" @@ -372,7 +372,9 @@ def compile_module( ) subgraph_data.subgraph_input_shapes = parse_complex_tensor_structs( - submodule_inputs, "shape", tuple + submodule_inputs, + "shape", + lambda x: dict(x) if isinstance(x, dict) else tuple(x), ) subgraph_data.subgraph_input_dtypes = parse_complex_tensor_structs( submodule_inputs, "torch_dtype" @@ -383,7 +385,9 @@ def compile_module( ) subgraph_data.subgraph_output_shapes = parse_complex_tensor_structs( - submodule_outputs, "shape", tuple + submodule_outputs, + "shape", + lambda x: dict(x) if isinstance(x, dict) else tuple(x), ) subgraph_data.subgraph_output_dtypes = parse_complex_tensor_structs( submodule_outputs, "dtype" @@ -411,7 +415,7 @@ def compile_module( sample_outputs = [sample_outputs] dryrun_tracker.graph_output_shapes = parse_complex_tensor_structs( - sample_outputs, "shape", tuple + sample_outputs, "shape", lambda x: dict(x) if isinstance(x, dict) else tuple(x) ) dryrun_tracker.graph_output_dtypes = parse_complex_tensor_structs( sample_outputs, "dtype"