diff --git a/py/torch_tensorrt/dynamo/_DryRunTracker.py b/py/torch_tensorrt/dynamo/_DryRunTracker.py new file mode 100644 index 0000000000..625db8753d --- /dev/null +++ b/py/torch_tensorrt/dynamo/_DryRunTracker.py @@ -0,0 +1,205 @@ +import logging +import math +from dataclasses import dataclass, field +from typing import List, Tuple + +import torch + +logger = logging.getLogger(__name__) + + +@dataclass +class PerSubgraphData: + """Class to track data on a per-subgraph level + + Args: + subgraph_name (str): Name of the subgraph in the GraphModule + subgraph_op_count (int): Number of operations in the subgraph + subgraph_input_shapes (List[Tuple[int, ...]]): Shapes of input Tensors of the subgraph + subgraph_input_dtypes (List[torch.device]): Input data types of the subgraph + subgraph_output_shapes (List[Tuple[int, ...]]): Shapes of output Tensors of the subgraph + subgraph_output_dtypes (List[torch.device]): Output data types of the subgraph + """ + + subgraph_name: str = "" + subgraph_op_count: int = 0 + subgraph_input_shapes: List[Tuple[int, ...]] = field(default_factory=list) + subgraph_input_dtypes: List[torch.device] = field(default_factory=list) + subgraph_output_shapes: List[Tuple[int, ...]] = field(default_factory=list) + subgraph_output_dtypes: List[torch.device] = field(default_factory=list) + + +@dataclass +class DryRunTracker: + """Class to track data on a graph-wide level + + Args: + total_ops_in_graph (int): Total number of operators in graph + supported_ops_in_graph (int): Number of supported operators in graph + graph_input_shapes (List[Tuple[int, ...]]): Shapes of input Tensors of the graph + graph_input_dtypes (List[torch.device]): Input data types of the graph + graph_output_shapes (List[Tuple[int, ...]]): Shapes of output Tensors of the graph + graph_output_dtypes (List[torch.device]): Output data types of the graph + per_subgraph_data (List[PerSubgraphData]): Per-subgraph data, see above class + tensorrt_graph_count (int): Number of TensorRT engines to be generated + truncated_long_and_double (bool): Whether truncate_long_and_double was enabled + """ + + total_ops_in_graph: int = 0 + supported_ops_in_graph: int = 0 + graph_input_shapes: List[Tuple[int, ...]] = field(default_factory=list) + graph_input_dtypes: List[torch.device] = field(default_factory=list) + graph_output_shapes: List[Tuple[int, ...]] = field(default_factory=list) + graph_output_dtypes: List[torch.device] = field(default_factory=list) + per_subgraph_data: List[PerSubgraphData] = field(default_factory=list) + tensorrt_graph_count: int = 0 + truncated_long_and_double: bool = False + + +def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> None: + """Displays statistics about the dryrun either to debug logs or info logs""" + # If user specified "dryrun=True", print to info logs, else debug + if dryrun_enabled: + dryrun_logger = logger.info + else: + dryrun_logger = logger.debug + + formatted_stats = "\n" + + # Print overall stats about the graph, operator counts, etc. + formatted_stats += "+" * 50 + " Dry-Run Results for Graph " + "+" * 50 + "\n" + formatted_stats += ( + f"The graph consists of {dryrun_tracker.total_ops_in_graph} Total Operators, " + f"of which {dryrun_tracker.supported_ops_in_graph} operators are supported, " + f"{round(dryrun_tracker.supported_ops_in_graph*100/dryrun_tracker.total_ops_in_graph, 2)}% coverage\n" + ) + formatted_stats += f"Long and double inputs were {'' if dryrun_tracker.truncated_long_and_double else 'not'} truncated (truncate_long_and_double={dryrun_tracker.truncated_long_and_double})\n" + formatted_stats += ( + f"{dryrun_tracker.tensorrt_graph_count} TRT Engine(s) were generated\n" + ) + + assert len(dryrun_tracker.per_subgraph_data) == dryrun_tracker.tensorrt_graph_count + + # Print schematic of the graph structure, as in: + # + # Inputs: [Tensor: (1, 3, 224, 224)@float32] + # ... + # TRT Engine #1: _run_on_acc_0 + # Engine Inputs: [Tensor: (1, 3, 224, 224)@float32] + # Number of Operators in Engine: 1 + # Engine Outputs: [Tensor: (1, 64, 112, 112)@float32] + # ... + # Outputs: [Tensor: (1, 1000)@float32] + # + formatted_stats += " " * 2 + "Graph Structure:\n\n" + formatted_stats += ( + " " * 3 + + f"Inputs: [{input_formatter(dryrun_tracker.graph_input_shapes, dryrun_tracker.graph_input_dtypes)}]\n" + ) + + for i, trt_subgraph_data in enumerate(dryrun_tracker.per_subgraph_data): + assert len(trt_subgraph_data.subgraph_input_dtypes) == len( + trt_subgraph_data.subgraph_input_shapes + ) + assert len(trt_subgraph_data.subgraph_output_dtypes) == len( + trt_subgraph_data.subgraph_output_shapes + ) + formatted_stats += " " * 4 + "...\n" + formatted_stats += ( + " " * 4 + f"TRT Engine #{i+1}: {trt_subgraph_data.subgraph_name}\n" + ) + formatted_stats += ( + " " * 5 + + f"Engine Inputs: [{input_formatter(trt_subgraph_data.subgraph_input_shapes, trt_subgraph_data.subgraph_input_dtypes)}]\n" + ) + formatted_stats += ( + " " * 5 + + f"Number of Operators in Engine: {trt_subgraph_data.subgraph_op_count}\n" + ) + formatted_stats += ( + " " * 5 + + f"Engine Outputs: [{input_formatter(trt_subgraph_data.subgraph_output_shapes, trt_subgraph_data.subgraph_output_dtypes)}]\n" + ) + + formatted_stats += " " * 4 + "...\n" + formatted_stats += ( + " " * 3 + + f"Outputs: [{input_formatter(dryrun_tracker.graph_output_shapes, dryrun_tracker.graph_output_dtypes)}]\n" + ) + + # Print aggregate statistics about the graph structure, including recommended "min_block_size" options + if dryrun_tracker.tensorrt_graph_count > 0: + min_ops_in_an_engine = min( + trt_subgraph.subgraph_op_count + for trt_subgraph in dryrun_tracker.per_subgraph_data + ) + avg_ops_per_engine = ( + sum( + trt_subgraph.subgraph_op_count + for trt_subgraph in dryrun_tracker.per_subgraph_data + ) + / dryrun_tracker.tensorrt_graph_count + ) + avg_ops_per_engine = round(avg_ops_per_engine, 2) + most_ops_in_an_engine = max( + trt_subgraph.subgraph_op_count + for trt_subgraph in dryrun_tracker.per_subgraph_data + ) + + formatted_stats += "\n" + " " * 2 + "-" * 25 + " Aggregate Stats " + "-" * 25 + formatted_stats += ( + "\n\n" + + " " * 3 + + "Average Number of Operators per TRT Engine: " + + f"{avg_ops_per_engine}" + ) + + formatted_stats += ( + "\n" + + " " * 3 + + "Most Operators in a TRT Engine: " + + f"{most_ops_in_an_engine}" + ) + + formatted_stats += "\n\n" + " " * 2 + "*" * 10 + " Recommendations " + "*" * 10 + formatted_stats += ( + "\n\n" + + " " * 3 + + "- For minimal graph segmentation, select min_block_size=" + + f"{most_ops_in_an_engine} which would generate " + + f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= most_ops_in_an_engine])} TRT engines" + ) + if math.ceil(avg_ops_per_engine) != most_ops_in_an_engine: + formatted_stats += ( + "\n" + + " " * 3 + + "- For moderate graph segmentation, select min_block_size=" + + f"{math.ceil(avg_ops_per_engine)} which would generate " + + f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= math.ceil(avg_ops_per_engine)])} TRT engines" + ) + + formatted_stats += ( + "\n" + + " " * 3 + + "- The current level of graph segmentation is equivalent to selecting min_block_size=" + + f"{min_ops_in_an_engine} which generates " + + f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= min_ops_in_an_engine])} TRT engines" + ) + else: + formatted_stats += ( + "\n" + + " " * 2 + + "Aggregate stats not available since no TRT Engines were generated." + ) + + dryrun_logger(formatted_stats) + + +def input_formatter(shapes: List[Tuple[int, ...]], dtypes: List[torch.dtype]) -> str: + """Format shapes and dtypes of input Tensors into a readable string""" + formatted_str = ", " + + for shape, dtype in zip(shapes, dtypes): + formatted_str += f"Tensor: {shape}@{str(dtype)[6:]}, " + + return formatted_str[2:-2] diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index d31be8a413..7a543c3803 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -5,7 +5,6 @@ from typing import Any, List, Optional, Sequence, Set, Tuple, Union import torch -import torch_tensorrt from torch.export import ExportedProgram from torch_tensorrt._Device import Device from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum @@ -16,6 +15,7 @@ from torch_tensorrt.dynamo._defaults import ( DEBUG, DEVICE, + DRYRUN, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, MAX_AUX_STREAMS, MIN_BLOCK_SIZE, @@ -29,6 +29,11 @@ VERSION_COMPATIBLE, WORKSPACE_SIZE, ) +from torch_tensorrt.dynamo._DryRunTracker import ( + DryRunTracker, + PerSubgraphData, + dryrun_stats_display, +) from torch_tensorrt.dynamo.conversion import ( CompilationSettings, convert_module, @@ -43,6 +48,8 @@ to_torch_tensorrt_device, ) +import torch_tensorrt + logger = logging.getLogger(__name__) @@ -75,6 +82,7 @@ def compile( use_python_runtime: bool = USE_PYTHON_RUNTIME, use_fast_partitioner: bool = USE_FAST_PARTITIONER, enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS, + dryrun: bool = DRYRUN, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile a TorchScript module for NVIDIA GPUs using TensorRT @@ -131,6 +139,7 @@ def compile( use_python_runtime: (bool): Return a graph using a pure Python runtime, reduces options for serialization use_fast_partitioner: (bool): Use the adjacency based partitioning scheme instead of the global partitioner. Adjacency partitioning is faster but may not be optiminal. Use the global paritioner (``False``) if looking for best performance enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT. + dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -192,6 +201,7 @@ def compile( "use_fast_partitioner": use_fast_partitioner, "enable_experimental_decompositions": enable_experimental_decompositions, "require_full_compilation": require_full_compilation, + "dryrun": dryrun, } settings = CompilationSettings(**compilation_options) @@ -215,15 +225,32 @@ def compile_module( Returns: Compiled FX GraphModule """ + dryrun_tracker = DryRunTracker() # Check the number of supported operations in the graph num_supported_ops, total_ops = partitioning.get_graph_converter_support( gm, settings.debug, settings.torch_executed_ops ) + dryrun_tracker.total_ops_in_graph = total_ops + dryrun_tracker.supported_ops_in_graph = num_supported_ops + dryrun_tracker.graph_input_shapes = [ + tuple(input_.shape) for input_ in sample_inputs + ] + dryrun_tracker.graph_input_dtypes = [input_.torch_dtype for input_ in sample_inputs] + dryrun_tracker.truncated_long_and_double = settings.truncate_long_and_double + + if settings.dryrun and settings.min_block_size > 1: + logger.info( + "It is recommended to run `dryrun` mode with `min_block_size=1`, " + "for the most thorough analysis" + ) + # If the number of supported operations is 0 or less than the block size, skip the subgraph # TODO: Add condition to second expression below when require_full_compilation is added - if num_supported_ops == 0 or (num_supported_ops < settings.min_block_size): + if num_supported_ops == 0 or ( + num_supported_ops < settings.min_block_size and not settings.dryrun + ): logger.warning( f"{num_supported_ops} supported operations detected in subgraph containing {total_ops} computational nodes. " f"Skipping this subgraph, since min_block_size was detected to be {settings.min_block_size}" @@ -274,6 +301,16 @@ def compile_module( if settings.use_fast_partitioner and "_run_on_acc" not in name: continue + subgraph_data = PerSubgraphData() + subgraph_data.subgraph_name = name + subgraph_data.subgraph_op_count = len( + [ + node + for node in submodule.graph.nodes + if node.op in ("call_function", "call_method", "call_module") + ] + ) + # Get the submodule inputs for min, opt, max shapes of the graph inputs submodule_inputs = partitioning.get_submod_inputs( partitioned_module, @@ -300,15 +337,51 @@ def compile_module( name, ) - # Create TRT engines from submodule - trt_module = convert_module( - submodule, - submodule_inputs, - settings=settings, - name=name, + subgraph_data.subgraph_input_dtypes = [ + submodule_input.torch_dtype for submodule_input in submodule_inputs + ] + subgraph_data.subgraph_input_shapes = [ + tuple(submodule_input.shape) for submodule_input in submodule_inputs + ] + + submodule_outputs = submodule( + *get_torch_inputs(submodule_inputs, to_torch_device(settings.device)) ) + if not isinstance(submodule_outputs, (list, tuple)): + submodule_outputs = [submodule_outputs] - trt_modules[name] = trt_module + subgraph_data.subgraph_output_dtypes = [ + submodule_output.dtype for submodule_output in submodule_outputs + ] + subgraph_data.subgraph_output_shapes = [ + tuple(submodule_output.shape) for submodule_output in submodule_outputs + ] + + dryrun_tracker.tensorrt_graph_count += 1 + dryrun_tracker.per_subgraph_data.append(subgraph_data) + + # Create TRT engines from submodule + if not settings.dryrun: + trt_module = convert_module( + submodule, + submodule_inputs, + settings=settings, + name=name, + ) + + trt_modules[name] = trt_module + + sample_outputs = gm( + *get_torch_inputs(sample_inputs, to_torch_device(settings.device)) + ) + + if not isinstance(sample_outputs, (list, tuple)): + sample_outputs = [sample_outputs] + + dryrun_tracker.graph_output_shapes = [ + tuple(output_.shape) for output_ in sample_outputs + ] + dryrun_tracker.graph_output_dtypes = [output_.dtype for output_ in sample_outputs] # Replace all FX Modules with TRT Modules for name, trt_module in trt_modules.items(): @@ -318,4 +391,6 @@ def compile_module( if fast_partitioner_failed: settings.use_fast_partitioner = True + dryrun_stats_display(dryrun_tracker, settings.dryrun) + return partitioned_module diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 103b5f7792..367abe790d 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -15,6 +15,7 @@ USE_FAST_PARTITIONER = True ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False REQUIRE_FULL_COMPILATION = False +DRYRUN = False def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index c9f4534cb8..540300eb48 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -5,6 +5,7 @@ from torch_tensorrt._Device import Device from torch_tensorrt.dynamo._defaults import ( DEBUG, + DRYRUN, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, MAX_AUX_STREAMS, MIN_BLOCK_SIZE, @@ -46,6 +47,8 @@ class CompilationSettings: device (Device): GPU to compile the model on require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT. Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path + dryrun (bool): Toggle "Dryrun" mode, which runs everything through partitioning, short of conversion to + TRT Engines. Prints detailed logs of the graph structure and nature of partitioning """ precision: torch.dtype = PRECISION @@ -63,3 +66,4 @@ class CompilationSettings: enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS device: Device = field(default_factory=default_device) require_full_compilation: bool = REQUIRE_FULL_COMPILATION + dryrun: bool = DRYRUN