diff --git a/py/torch_tensorrt/dynamo/_DryRunTracker.py b/py/torch_tensorrt/dynamo/_DryRunTracker.py index 85cc23165a..031fce2e73 100644 --- a/py/torch_tensorrt/dynamo/_DryRunTracker.py +++ b/py/torch_tensorrt/dynamo/_DryRunTracker.py @@ -1,9 +1,14 @@ import logging import math +import operator +import os from dataclasses import dataclass, field -from typing import Any, Dict, List +from typing import Any, Dict, List, Union +import torch from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterRegistry +from torch_tensorrt.dynamo.conversion.converter_utils import get_node_name logger = logging.getLogger(__name__) @@ -44,6 +49,7 @@ class DryRunTracker: tensorrt_graph_count (int): Number of TensorRT engines to be generated compilation_settings (CompilationSettings): User Compilation Settings unsupported_ops (Dict[str, int]): Set of operators not supported in TRT + to_run_in_torch (List[str]): Set of nodes to run in Torch """ total_ops_in_graph: int = 0 @@ -58,9 +64,12 @@ class DryRunTracker: default_factory=CompilationSettings ) unsupported_ops: Dict[str, int] = field(default_factory=dict) + to_run_in_torch: List[str] = field(default_factory=list) -def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> None: +def dryrun_stats_display( + dryrun_tracker: DryRunTracker, dryrun_enabled: Union[bool, str] +) -> None: """Displays statistics about the dryrun either to debug logs or stdout""" formatted_stats = "\n" @@ -71,7 +80,19 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> f"of which {dryrun_tracker.supported_ops_in_graph} operators are supported, " f"{round(dryrun_tracker.supported_ops_in_graph*100/dryrun_tracker.total_ops_in_graph, 2)}% coverage\n\n" ) - formatted_stats += f"The following ops are currently unsupported and set to run in Torch: {dryrun_tracker.unsupported_ops}\n\n" + if dryrun_tracker.unsupported_ops: + parsed_ops = "\n".join( + [f"{str(k)}: {str(v)}" for k, v in dryrun_tracker.unsupported_ops.items()] + ) + formatted_stats += f"The following ops are currently unsupported or excluded from conversion, and are listed with their op-count in the graph:\n {parsed_ops}\n\n" + + if dryrun_tracker.to_run_in_torch: + formatted_nodes = "\n".join(dryrun_tracker.to_run_in_torch) + formatted_stats += ( + f"The following nodes are currently set to run in Torch:\n{formatted_nodes}\n" + "Note: Some of the above nodes may be supported, but were not included in a TRT graph by the partitioner\n\n" + ) + formatted_stats += f"Compiled with: {dryrun_tracker.compilation_settings}\n\n" assert len(dryrun_tracker.per_subgraph_data) == dryrun_tracker.tensorrt_graph_count @@ -184,8 +205,17 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> ) # If user specified "dryrun=True", print to stdout, else debug + # If user specified a filepath, save the output to the path as well if dryrun_enabled: print(formatted_stats) + if isinstance(dryrun_enabled, str): + if os.path.exists(dryrun_enabled): + logger.warning( + f"File already exists at path {dryrun_enabled}, not saving dryrun output" + ) + else: + with open(dryrun_enabled, "w+") as f: + f.write(formatted_stats) else: logger.debug(formatted_stats) @@ -225,3 +255,23 @@ def input_formatter_helper(shapes: Any, dtypes: Any) -> str: ) return input_formatter_helper(shapes, dtypes)[:-2] + + +def parse_non_trt_nodes(graph_module: torch.fx.GraphModule) -> List[str]: + """Parses call_function and call_method nodes from a GraphModule + Excludes getitem nodes + + Returns a string representation of the nodes + """ + to_run_in_torch = [] + for node in graph_module.graph.nodes: + # getitem nodes are excluded since they are a Tensor-collection op + if ( + node.op in ("call_function", "call_method") + and node.target != operator.getitem + ): + to_run_in_torch.append( + f"Node: {ConverterRegistry.qualified_name_or_str(node.target)}, " + f"with layer location: {get_node_name(node)}" + ) + return to_run_in_torch diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 95e943db4d..0ba80d3615 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -33,6 +33,7 @@ DryRunTracker, PerSubgraphData, dryrun_stats_display, + parse_non_trt_nodes, ) from torch_tensorrt.dynamo.conversion import ( CompilationSettings, @@ -296,6 +297,10 @@ def compile_module( dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators + # The global partitioner leaves non-TRT nodes as-is + if not settings.use_fast_partitioner: + dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(partitioned_module)) + # Store TRT replicas of Torch subgraphs trt_modules = {} # Iterate over all components that can be accelerated @@ -304,6 +309,7 @@ def compile_module( submodule = getattr(partitioned_module, name) # Criteria for a module to be convertible to TRT if settings.use_fast_partitioner and "_run_on_acc" not in name: + dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(submodule)) continue subgraph_data = PerSubgraphData() diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 540300eb48..267eca95f5 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Optional, Set +from typing import Optional, Set, Union import torch from torch_tensorrt._Device import Device @@ -47,8 +47,9 @@ class CompilationSettings: device (Device): GPU to compile the model on require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT. Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path - dryrun (bool): Toggle "Dryrun" mode, which runs everything through partitioning, short of conversion to - TRT Engines. Prints detailed logs of the graph structure and nature of partitioning + dryrun (Union[bool, str]): Toggle "Dryrun" mode, which runs everything through partitioning, short of conversion to + TRT Engines. Prints detailed logs of the graph structure and nature of partitioning. Optionally saves the + ouptut to a file if a string path is specified """ precision: torch.dtype = PRECISION @@ -66,4 +67,4 @@ class CompilationSettings: enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS device: Device = field(default_factory=default_device) require_full_compilation: bool = REQUIRE_FULL_COMPILATION - dryrun: bool = DRYRUN + dryrun: Union[bool, str] = DRYRUN diff --git a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py index 4cc1ba9401..8eceefe0bf 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py @@ -42,7 +42,7 @@ def is_node_supported( node_name = ConverterRegistry.qualified_name_or_str(node.target) if ( - node in CONVERTERS or (node.op == "get_attr" and "constant" in node_name) + node in CONVERTERS or (node.op == "get_attr") ) and node_name not in self.torch_executed_ops: # If node is a proper, supported computational node, store the operator if not node.is_impure(): diff --git a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py index f078d888d3..2c6674da74 100644 --- a/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py +++ b/py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py @@ -150,7 +150,7 @@ def is_node_supported( node_name = ConverterRegistry.qualified_name_or_str(node.target) if ( - node in CONVERTERS or (node.op == "get_attr" and "constant" in node_name) + node in CONVERTERS or (node.op == "get_attr") ) and node_name not in self.torch_executed_ops: # If node is a proper, supported computational node, store the operator if not node.is_impure():