Skip to content

Commit

Permalink
feat: Add dryrun feature to Dynamo paths
Browse files Browse the repository at this point in the history
- Enables building of TRT engines with "dryrun" capabilities, meaning
all of the phases except conversion are run and verbose logs of the
graph structure and composition are printed for the user
- Improves general-purpose debug logging by printing dryrun stats to the
debug logs regardless of option specification
- Provides intuitive schematic of the graph engines, inputs, and code
path through the course of the graph
  • Loading branch information
gs-olive committed Dec 13, 2023
1 parent 8554782 commit 635e26a
Show file tree
Hide file tree
Showing 4 changed files with 294 additions and 9 deletions.
205 changes: 205 additions & 0 deletions py/torch_tensorrt/dynamo/_DryRunTracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import logging
import math
from dataclasses import dataclass, field
from typing import List, Tuple

import torch

logger = logging.getLogger(__name__)


@dataclass
class PerSubgraphData:
"""Class to track data on a per-subgraph level
Args:
subgraph_name (str): Name of the subgraph in the GraphModule
subgraph_op_count (int): Number of operations in the subgraph
subgraph_input_shapes (List[Tuple[int, ...]]): Shapes of input Tensors of the subgraph
subgraph_input_dtypes (List[torch.device]): Input data types of the subgraph
subgraph_output_shapes (List[Tuple[int, ...]]): Shapes of output Tensors of the subgraph
subgraph_output_dtypes (List[torch.device]): Output data types of the subgraph
"""

subgraph_name: str = ""
subgraph_op_count: int = 0
subgraph_input_shapes: List[Tuple[int, ...]] = field(default_factory=list)
subgraph_input_dtypes: List[torch.device] = field(default_factory=list)
subgraph_output_shapes: List[Tuple[int, ...]] = field(default_factory=list)
subgraph_output_dtypes: List[torch.device] = field(default_factory=list)


@dataclass
class DryRunTracker:
"""Class to track data on a graph-wide level
Args:
total_ops_in_graph (int): Total number of operators in graph
supported_ops_in_graph (int): Number of supported operators in graph
graph_input_shapes (List[Tuple[int, ...]]): Shapes of input Tensors of the graph
graph_input_dtypes (List[torch.device]): Input data types of the graph
graph_output_shapes (List[Tuple[int, ...]]): Shapes of output Tensors of the graph
graph_output_dtypes (List[torch.device]): Output data types of the graph
per_subgraph_data (List[PerSubgraphData]): Per-subgraph data, see above class
tensorrt_graph_count (int): Number of TensorRT engines to be generated
truncated_long_and_double (bool): Whether truncate_long_and_double was enabled
"""

total_ops_in_graph: int = 0
supported_ops_in_graph: int = 0
graph_input_shapes: List[Tuple[int, ...]] = field(default_factory=list)
graph_input_dtypes: List[torch.device] = field(default_factory=list)
graph_output_shapes: List[Tuple[int, ...]] = field(default_factory=list)
graph_output_dtypes: List[torch.device] = field(default_factory=list)
per_subgraph_data: List[PerSubgraphData] = field(default_factory=list)
tensorrt_graph_count: int = 0
truncated_long_and_double: bool = False


def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> None:
"""Displays statistics about the dryrun either to debug logs or info logs"""
# If user specified "dryrun=True", print to info logs, else debug
if dryrun_enabled:
dryrun_logger = logger.info
else:
dryrun_logger = logger.debug

formatted_stats = "\n"

# Print overall stats about the graph, operator counts, etc.
formatted_stats += "+" * 50 + " Dry-Run Results for Graph " + "+" * 50 + "\n"
formatted_stats += (
f"The graph consists of {dryrun_tracker.total_ops_in_graph} Total Operators, "
f"of which {dryrun_tracker.supported_ops_in_graph} operators are supported, "
f"{round(dryrun_tracker.supported_ops_in_graph*100/dryrun_tracker.total_ops_in_graph, 2)}% coverage\n"
)
formatted_stats += f"Long and double inputs were {'' if dryrun_tracker.truncated_long_and_double else 'not'} truncated (truncate_long_and_double={dryrun_tracker.truncated_long_and_double})\n"
formatted_stats += (
f"{dryrun_tracker.tensorrt_graph_count} TRT Engine(s) were generated\n"
)

assert len(dryrun_tracker.per_subgraph_data) == dryrun_tracker.tensorrt_graph_count

# Print schematic of the graph structure, as in:
#
# Inputs: [Tensor: (1, 3, 224, 224)@float32]
# ...
# TRT Engine #1: _run_on_acc_0
# Engine Inputs: [Tensor: (1, 3, 224, 224)@float32]
# Number of Operators in Engine: 1
# Engine Outputs: [Tensor: (1, 64, 112, 112)@float32]
# ...
# Outputs: [Tensor: (1, 1000)@float32]
#
formatted_stats += " " * 2 + "Graph Structure:\n\n"
formatted_stats += (
" " * 3
+ f"Inputs: [{input_formatter(dryrun_tracker.graph_input_shapes, dryrun_tracker.graph_input_dtypes)}]\n"
)

for i, trt_subgraph_data in enumerate(dryrun_tracker.per_subgraph_data):
assert len(trt_subgraph_data.subgraph_input_dtypes) == len(
trt_subgraph_data.subgraph_input_shapes
)
assert len(trt_subgraph_data.subgraph_output_dtypes) == len(
trt_subgraph_data.subgraph_output_shapes
)
formatted_stats += " " * 4 + "...\n"
formatted_stats += (
" " * 4 + f"TRT Engine #{i+1}: {trt_subgraph_data.subgraph_name}\n"
)
formatted_stats += (
" " * 5
+ f"Engine Inputs: [{input_formatter(trt_subgraph_data.subgraph_input_shapes, trt_subgraph_data.subgraph_input_dtypes)}]\n"
)
formatted_stats += (
" " * 5
+ f"Number of Operators in Engine: {trt_subgraph_data.subgraph_op_count}\n"
)
formatted_stats += (
" " * 5
+ f"Engine Outputs: [{input_formatter(trt_subgraph_data.subgraph_output_shapes, trt_subgraph_data.subgraph_output_dtypes)}]\n"
)

formatted_stats += " " * 4 + "...\n"
formatted_stats += (
" " * 3
+ f"Outputs: [{input_formatter(dryrun_tracker.graph_output_shapes, dryrun_tracker.graph_output_dtypes)}]\n"
)

# Print aggregate statistics about the graph structure, including recommended "min_block_size" options
if dryrun_tracker.tensorrt_graph_count > 0:
min_ops_in_an_engine = min(
trt_subgraph.subgraph_op_count
for trt_subgraph in dryrun_tracker.per_subgraph_data
)
avg_ops_per_engine = (
sum(
trt_subgraph.subgraph_op_count
for trt_subgraph in dryrun_tracker.per_subgraph_data
)
/ dryrun_tracker.tensorrt_graph_count
)
avg_ops_per_engine = round(avg_ops_per_engine, 2)
most_ops_in_an_engine = max(
trt_subgraph.subgraph_op_count
for trt_subgraph in dryrun_tracker.per_subgraph_data
)

formatted_stats += "\n" + " " * 2 + "-" * 25 + " Aggregate Stats " + "-" * 25
formatted_stats += (
"\n\n"
+ " " * 3
+ "Average Number of Operators per TRT Engine: "
+ f"{avg_ops_per_engine}"
)

formatted_stats += (
"\n"
+ " " * 3
+ "Most Operators in a TRT Engine: "
+ f"{most_ops_in_an_engine}"
)

formatted_stats += "\n\n" + " " * 2 + "*" * 10 + " Recommendations " + "*" * 10
formatted_stats += (
"\n\n"
+ " " * 3
+ "- For minimal graph segmentation, select min_block_size="
+ f"{most_ops_in_an_engine} which would generate "
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= most_ops_in_an_engine])} TRT engines"
)
if math.ceil(avg_ops_per_engine) != most_ops_in_an_engine:
formatted_stats += (
"\n"
+ " " * 3
+ "- For moderate graph segmentation, select min_block_size="
+ f"{math.ceil(avg_ops_per_engine)} which would generate "
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= math.ceil(avg_ops_per_engine)])} TRT engines"
)

formatted_stats += (
"\n"
+ " " * 3
+ "- The current level of graph segmentation is equivalent to selecting min_block_size="
+ f"{min_ops_in_an_engine} which generates "
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= min_ops_in_an_engine])} TRT engines"
)
else:
formatted_stats += (
"\n"
+ " " * 2
+ "Aggregate stats not available since no TRT Engines were generated."
)

dryrun_logger(formatted_stats)


def input_formatter(shapes: List[Tuple[int, ...]], dtypes: List[torch.dtype]) -> str:
"""Format shapes and dtypes of input Tensors into a readable string"""
formatted_str = ", "

for shape, dtype in zip(shapes, dtypes):
formatted_str += f"Tensor: {shape}@{str(dtype)[6:]}, "

return formatted_str[2:-2]
93 changes: 84 additions & 9 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Any, List, Optional, Sequence, Set, Tuple, Union

import torch
import torch_tensorrt
from torch.export import ExportedProgram
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum
Expand All @@ -20,6 +19,7 @@
DLA_GLOBAL_DRAM_SIZE,
DLA_LOCAL_DRAM_SIZE,
DLA_SRAM_SIZE,
DRYRUN,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
ENGINE_CAPABILITY,
MAX_AUX_STREAMS,
Expand All @@ -37,6 +37,11 @@
VERSION_COMPATIBLE,
WORKSPACE_SIZE,
)
from torch_tensorrt.dynamo._DryRunTracker import (
DryRunTracker,
PerSubgraphData,
dryrun_stats_display,
)
from torch_tensorrt.dynamo.conversion import (
CompilationSettings,
convert_module,
Expand All @@ -51,6 +56,8 @@
to_torch_tensorrt_device,
)

import torch_tensorrt

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -84,6 +91,7 @@ def compile(
use_python_runtime: bool = USE_PYTHON_RUNTIME,
use_fast_partitioner: bool = USE_FAST_PARTITIONER,
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
dryrun: bool = DRYRUN,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -140,6 +148,7 @@ def compile(
use_python_runtime: (bool): Return a graph using a pure Python runtime, reduces options for serialization
use_fast_partitioner: (bool): Use the adjacency based partitioning scheme instead of the global partitioner. Adjacency partitioning is faster but may not be optiminal. Use the global paritioner (``False``) if looking for best performance
enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT.
dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -215,6 +224,7 @@ def compile(
"dla_sram_size": dla_sram_size,
"dla_local_dram_size": dla_local_dram_size,
"dla_global_dram_size": dla_global_dram_size,
"dryrun": dryrun,
}

settings = CompilationSettings(**compilation_options)
Expand All @@ -238,15 +248,32 @@ def compile_module(
Returns:
Compiled FX GraphModule
"""
dryrun_tracker = DryRunTracker()

# Check the number of supported operations in the graph
num_supported_ops, total_ops = partitioning.get_graph_converter_support(
gm, settings.debug, settings.torch_executed_ops
)

dryrun_tracker.total_ops_in_graph = total_ops
dryrun_tracker.supported_ops_in_graph = num_supported_ops
dryrun_tracker.graph_input_shapes = [
tuple(input_.shape) for input_ in sample_inputs
]
dryrun_tracker.graph_input_dtypes = [input_.torch_dtype for input_ in sample_inputs]
dryrun_tracker.truncated_long_and_double = settings.truncate_long_and_double

if settings.dryrun and settings.min_block_size > 1:
logger.info(
"It is recommended to run `dryrun` mode with `min_block_size=1`, "
"for the most thorough analysis"
)

# If the number of supported operations is 0 or less than the block size, skip the subgraph
# TODO: Add condition to second expression below when require_full_compilation is added
if num_supported_ops == 0 or (num_supported_ops < settings.min_block_size):
if num_supported_ops == 0 or (
num_supported_ops < settings.min_block_size and not settings.dryrun
):
logger.warning(
f"{num_supported_ops} supported operations detected in subgraph containing {total_ops} computational nodes. "
f"Skipping this subgraph, since min_block_size was detected to be {settings.min_block_size}"
Expand Down Expand Up @@ -297,6 +324,16 @@ def compile_module(
if settings.use_fast_partitioner and "_run_on_acc" not in name:
continue

subgraph_data = PerSubgraphData()
subgraph_data.subgraph_name = name
subgraph_data.subgraph_op_count = len(
[
node
for node in submodule.graph.nodes
if node.op in ("call_function", "call_method", "call_module")
]
)

# Get the submodule inputs for min, opt, max shapes of the graph inputs
submodule_inputs = partitioning.get_submod_inputs(
partitioned_module,
Expand All @@ -323,15 +360,51 @@ def compile_module(
name,
)

# Create TRT engines from submodule
trt_module = convert_module(
submodule,
submodule_inputs,
settings=settings,
name=name,
subgraph_data.subgraph_input_dtypes = [
submodule_input.torch_dtype for submodule_input in submodule_inputs
]
subgraph_data.subgraph_input_shapes = [
tuple(submodule_input.shape) for submodule_input in submodule_inputs
]

submodule_outputs = submodule(
*get_torch_inputs(submodule_inputs, to_torch_device(settings.device))
)
if not isinstance(submodule_outputs, (list, tuple)):
submodule_outputs = [submodule_outputs]

trt_modules[name] = trt_module
subgraph_data.subgraph_output_dtypes = [
submodule_output.dtype for submodule_output in submodule_outputs
]
subgraph_data.subgraph_output_shapes = [
tuple(submodule_output.shape) for submodule_output in submodule_outputs
]

dryrun_tracker.tensorrt_graph_count += 1
dryrun_tracker.per_subgraph_data.append(subgraph_data)

# Create TRT engines from submodule
if not settings.dryrun:
trt_module = convert_module(
submodule,
submodule_inputs,
settings=settings,
name=name,
)

trt_modules[name] = trt_module

sample_outputs = gm(
*get_torch_inputs(sample_inputs, to_torch_device(settings.device))
)

if not isinstance(sample_outputs, (list, tuple)):
sample_outputs = [sample_outputs]

dryrun_tracker.graph_output_shapes = [
tuple(output_.shape) for output_ in sample_outputs
]
dryrun_tracker.graph_output_dtypes = [output_.dtype for output_ in sample_outputs]

# Replace all FX Modules with TRT Modules
for name, trt_module in trt_modules.items():
Expand All @@ -341,4 +414,6 @@ def compile_module(
if fast_partitioner_failed:
settings.use_fast_partitioner = True

dryrun_stats_display(dryrun_tracker, settings.dryrun)

return partitioned_module
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
REFIT = False
REQUIRE_FULL_COMPILATION = False
DRYRUN = False


def default_device() -> Device:
Expand Down
Loading

0 comments on commit 635e26a

Please sign in to comment.