Skip to content

Commit

Permalink
feat: Add optional filepath to save
Browse files Browse the repository at this point in the history
- Add detailed layer information for excluded ops
  • Loading branch information
gs-olive committed Dec 7, 2023
1 parent c847dc5 commit 998e560
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 9 deletions.
56 changes: 53 additions & 3 deletions py/torch_tensorrt/dynamo/_DryRunTracker.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import logging
import math
import operator
import os
from dataclasses import dataclass, field
from typing import Any, Dict, List
from typing import Any, Dict, List, Union

import torch
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterRegistry
from torch_tensorrt.dynamo.conversion.converter_utils import get_node_name

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -44,6 +49,7 @@ class DryRunTracker:
tensorrt_graph_count (int): Number of TensorRT engines to be generated
compilation_settings (CompilationSettings): User Compilation Settings
unsupported_ops (Dict[str, int]): Set of operators not supported in TRT
to_run_in_torch (List[str]): Set of nodes to run in Torch
"""

total_ops_in_graph: int = 0
Expand All @@ -58,9 +64,12 @@ class DryRunTracker:
default_factory=CompilationSettings
)
unsupported_ops: Dict[str, int] = field(default_factory=dict)
to_run_in_torch: List[str] = field(default_factory=list)


def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> None:
def dryrun_stats_display(
dryrun_tracker: DryRunTracker, dryrun_enabled: Union[bool, str]
) -> None:
"""Displays statistics about the dryrun either to debug logs or stdout"""
formatted_stats = "\n"

Expand All @@ -71,7 +80,19 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) ->
f"of which {dryrun_tracker.supported_ops_in_graph} operators are supported, "
f"{round(dryrun_tracker.supported_ops_in_graph*100/dryrun_tracker.total_ops_in_graph, 2)}% coverage\n\n"
)
formatted_stats += f"The following ops are currently unsupported and set to run in Torch: {dryrun_tracker.unsupported_ops}\n\n"
if dryrun_tracker.unsupported_ops:
parsed_ops = "\n".join(
[f"{str(k)}: {str(v)}" for k, v in dryrun_tracker.unsupported_ops.items()]
)
formatted_stats += f"The following ops are currently unsupported or excluded from conversion, and are listed with their op-count in the graph:\n {parsed_ops}\n\n"

if dryrun_tracker.to_run_in_torch:
formatted_nodes = "\n".join(dryrun_tracker.to_run_in_torch)
formatted_stats += (
f"The following nodes are currently set to run in Torch:\n{formatted_nodes}\n"
"Note: Some of the above nodes may be supported, but were not included in a TRT graph by the partitioner\n\n"
)

formatted_stats += f"Compiled with: {dryrun_tracker.compilation_settings}\n\n"

assert len(dryrun_tracker.per_subgraph_data) == dryrun_tracker.tensorrt_graph_count
Expand Down Expand Up @@ -184,8 +205,17 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) ->
)

# If user specified "dryrun=True", print to stdout, else debug
# If user specified a filepath, save the output to the path as well
if dryrun_enabled:
print(formatted_stats)
if isinstance(dryrun_enabled, str):
if os.path.exists(dryrun_enabled):
logger.warning(
f"File already exists at path {dryrun_enabled}, not saving dryrun output"
)
else:
with open(dryrun_enabled, "w+") as f:
f.write(formatted_stats)
else:
logger.debug(formatted_stats)

Expand Down Expand Up @@ -225,3 +255,23 @@ def input_formatter_helper(shapes: Any, dtypes: Any) -> str:
)

return input_formatter_helper(shapes, dtypes)[:-2]


def parse_non_trt_nodes(graph_module: torch.fx.GraphModule) -> List[str]:
"""Parses call_function and call_method nodes from a GraphModule
Excludes getitem nodes
Returns a string representation of the nodes
"""
to_run_in_torch = []
for node in graph_module.graph.nodes:
# getitem nodes are excluded since they are a Tensor-collection op
if (
node.op in ("call_function", "call_method")
and node.target != operator.getitem
):
to_run_in_torch.append(
f"Node: {ConverterRegistry.qualified_name_or_str(node.target)}, "
f"with layer location: {get_node_name(node)}"
)
return to_run_in_torch
6 changes: 6 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
DryRunTracker,
PerSubgraphData,
dryrun_stats_display,
parse_non_trt_nodes,
)
from torch_tensorrt.dynamo.conversion import (
CompilationSettings,
Expand Down Expand Up @@ -303,6 +304,10 @@ def compile_module(

dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators

# The global partitioner leaves non-TRT nodes as-is
if not settings.use_fast_partitioner:
dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(partitioned_module))

# Store TRT replicas of Torch subgraphs
trt_modules = {}
# Iterate over all components that can be accelerated
Expand All @@ -311,6 +316,7 @@ def compile_module(
submodule = getattr(partitioned_module, name)
# Criteria for a module to be convertible to TRT
if settings.use_fast_partitioner and "_run_on_acc" not in name:
dryrun_tracker.to_run_in_torch.extend(parse_non_trt_nodes(submodule))
continue

subgraph_data = PerSubgraphData()
Expand Down
9 changes: 5 additions & 4 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Optional, Set
from typing import Optional, Set, Union

import torch
from torch_tensorrt._Device import Device
Expand Down Expand Up @@ -47,8 +47,9 @@ class CompilationSettings:
device (Device): GPU to compile the model on
require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT.
Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path
dryrun (bool): Toggle "Dryrun" mode, which runs everything through partitioning, short of conversion to
TRT Engines. Prints detailed logs of the graph structure and nature of partitioning
dryrun (Union[bool, str]): Toggle "Dryrun" mode, which runs everything through partitioning, short of conversion to
TRT Engines. Prints detailed logs of the graph structure and nature of partitioning. Optionally saves the
ouptut to a file if a string path is specified
"""

precision: torch.dtype = PRECISION
Expand All @@ -66,4 +67,4 @@ class CompilationSettings:
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
device: Device = field(default_factory=default_device)
require_full_compilation: bool = REQUIRE_FULL_COMPILATION
dryrun: bool = DRYRUN
dryrun: Union[bool, str] = DRYRUN
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def is_node_supported(
node_name = ConverterRegistry.qualified_name_or_str(node.target)

if (
node in CONVERTERS or (node.op == "get_attr" and "constant" in node_name)
node in CONVERTERS or (node.op == "get_attr")
) and node_name not in self.torch_executed_ops:
# If node is a proper, supported computational node, store the operator
if not node.is_impure():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def is_node_supported(
node_name = ConverterRegistry.qualified_name_or_str(node.target)

if (
node in CONVERTERS or (node.op == "get_attr" and "constant" in node_name)
node in CONVERTERS or (node.op == "get_attr")
) and node_name not in self.torch_executed_ops:
# If node is a proper, supported computational node, store the operator
if not node.is_impure():
Expand Down

0 comments on commit 998e560

Please sign in to comment.