Skip to content

Commit

Permalink
fix: Address PR comments and update logging scheme
Browse files Browse the repository at this point in the history
- Fix test case failures
  • Loading branch information
gs-olive committed Dec 13, 2023
1 parent 635e26a commit d121545
Show file tree
Hide file tree
Showing 9 changed files with 181 additions and 115 deletions.
140 changes: 81 additions & 59 deletions py/torch_tensorrt/dynamo/_DryRunTracker.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging
import math
from dataclasses import dataclass, field
from typing import List, Tuple
from typing import Any, Dict, List

import torch
from torch_tensorrt.dynamo._settings import CompilationSettings

logger = logging.getLogger(__name__)

Expand All @@ -15,18 +15,18 @@ class PerSubgraphData:
Args:
subgraph_name (str): Name of the subgraph in the GraphModule
subgraph_op_count (int): Number of operations in the subgraph
subgraph_input_shapes (List[Tuple[int, ...]]): Shapes of input Tensors of the subgraph
subgraph_input_dtypes (List[torch.device]): Input data types of the subgraph
subgraph_output_shapes (List[Tuple[int, ...]]): Shapes of output Tensors of the subgraph
subgraph_output_dtypes (List[torch.device]): Output data types of the subgraph
subgraph_input_shapes (Any): Shapes of input Tensors of the subgraph
subgraph_input_dtypes (Any): Input data types of the subgraph
subgraph_output_shapes (Any): Shapes of output Tensors of the subgraph
subgraph_output_dtypes (Any): Output data types of the subgraph
"""

subgraph_name: str = ""
subgraph_op_count: int = 0
subgraph_input_shapes: List[Tuple[int, ...]] = field(default_factory=list)
subgraph_input_dtypes: List[torch.device] = field(default_factory=list)
subgraph_output_shapes: List[Tuple[int, ...]] = field(default_factory=list)
subgraph_output_dtypes: List[torch.device] = field(default_factory=list)
subgraph_input_shapes: Any = field(default_factory=list)
subgraph_input_dtypes: Any = field(default_factory=list)
subgraph_output_shapes: Any = field(default_factory=list)
subgraph_output_dtypes: Any = field(default_factory=list)


@dataclass
Expand All @@ -36,95 +36,86 @@ class DryRunTracker:
Args:
total_ops_in_graph (int): Total number of operators in graph
supported_ops_in_graph (int): Number of supported operators in graph
graph_input_shapes (List[Tuple[int, ...]]): Shapes of input Tensors of the graph
graph_input_dtypes (List[torch.device]): Input data types of the graph
graph_output_shapes (List[Tuple[int, ...]]): Shapes of output Tensors of the graph
graph_output_dtypes (List[torch.device]): Output data types of the graph
graph_input_shapes (Any): Shapes of input Tensors of the graph
graph_input_dtypes (Any): Input data types of the graph
graph_output_shapes (Any): Shapes of output Tensors of the graph
graph_output_dtypes (Any): Output data types of the graph
per_subgraph_data (List[PerSubgraphData]): Per-subgraph data, see above class
tensorrt_graph_count (int): Number of TensorRT engines to be generated
truncated_long_and_double (bool): Whether truncate_long_and_double was enabled
compilation_settings (CompilationSettings): User Compilation Settings
unsupported_ops (Dict[str, int]): Set of operators not supported in TRT
"""

total_ops_in_graph: int = 0
supported_ops_in_graph: int = 0
graph_input_shapes: List[Tuple[int, ...]] = field(default_factory=list)
graph_input_dtypes: List[torch.device] = field(default_factory=list)
graph_output_shapes: List[Tuple[int, ...]] = field(default_factory=list)
graph_output_dtypes: List[torch.device] = field(default_factory=list)
graph_input_shapes: Any = field(default_factory=list)
graph_input_dtypes: Any = field(default_factory=list)
graph_output_shapes: Any = field(default_factory=list)
graph_output_dtypes: Any = field(default_factory=list)
per_subgraph_data: List[PerSubgraphData] = field(default_factory=list)
tensorrt_graph_count: int = 0
truncated_long_and_double: bool = False
compilation_settings: CompilationSettings = field(
default_factory=CompilationSettings
)
unsupported_ops: Dict[str, int] = field(default_factory=dict)


def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) -> None:
"""Displays statistics about the dryrun either to debug logs or info logs"""
# If user specified "dryrun=True", print to info logs, else debug
if dryrun_enabled:
dryrun_logger = logger.info
else:
dryrun_logger = logger.debug

"""Displays statistics about the dryrun either to debug logs or stdout"""
formatted_stats = "\n"

# Print overall stats about the graph, operator counts, etc.
formatted_stats += "+" * 50 + " Dry-Run Results for Graph " + "+" * 50 + "\n"
formatted_stats += "+" * 50 + " Dry-Run Results for Graph " + "+" * 50 + "\n\n"
formatted_stats += (
f"The graph consists of {dryrun_tracker.total_ops_in_graph} Total Operators, "
f"of which {dryrun_tracker.supported_ops_in_graph} operators are supported, "
f"{round(dryrun_tracker.supported_ops_in_graph*100/dryrun_tracker.total_ops_in_graph, 2)}% coverage\n"
)
formatted_stats += f"Long and double inputs were {'' if dryrun_tracker.truncated_long_and_double else 'not'} truncated (truncate_long_and_double={dryrun_tracker.truncated_long_and_double})\n"
formatted_stats += (
f"{dryrun_tracker.tensorrt_graph_count} TRT Engine(s) were generated\n"
f"{round(dryrun_tracker.supported_ops_in_graph*100/dryrun_tracker.total_ops_in_graph, 2)}% coverage\n\n"
)
formatted_stats += f"The following ops are currently unsupported and set to run in Torch: {dryrun_tracker.unsupported_ops}\n\n"
formatted_stats += f"Compiled with: {dryrun_tracker.compilation_settings}\n\n"

assert len(dryrun_tracker.per_subgraph_data) == dryrun_tracker.tensorrt_graph_count

# Print schematic of the graph structure, as in:
#
# Inputs: [Tensor: (1, 3, 224, 224)@float32]
# Inputs: List[Tensor: (1, 3, 224, 224)@float32]
# ...
# TRT Engine #1: _run_on_acc_0
# Engine Inputs: [Tensor: (1, 3, 224, 224)@float32]
# Number of Operators in Engine: 1
# Engine Outputs: [Tensor: (1, 64, 112, 112)@float32]
# TRT Engine #1 - Submodule name: _run_on_acc_0
# Engine Inputs: List[Tensor: (1, 3, 224, 224)@float32]
# Number of Operators in Engine: 1
# Engine Outputs: Tensor: (1, 64, 112, 112)@float32
# ...
# Outputs: [Tensor: (1, 1000)@float32]
# Outputs: List[Tensor: (1, 1000)@float32]
#
formatted_stats += " " * 2 + "Graph Structure:\n\n"
formatted_stats += (
" " * 3
+ f"Inputs: [{input_formatter(dryrun_tracker.graph_input_shapes, dryrun_tracker.graph_input_dtypes)}]\n"
+ f"Inputs: {input_formatter(dryrun_tracker.graph_input_shapes, dryrun_tracker.graph_input_dtypes)}\n"
)

for i, trt_subgraph_data in enumerate(dryrun_tracker.per_subgraph_data):
assert len(trt_subgraph_data.subgraph_input_dtypes) == len(
trt_subgraph_data.subgraph_input_shapes
)
assert len(trt_subgraph_data.subgraph_output_dtypes) == len(
trt_subgraph_data.subgraph_output_shapes
)
formatted_stats += " " * 4 + "...\n"
formatted_stats += (
" " * 4 + f"TRT Engine #{i+1}: {trt_subgraph_data.subgraph_name}\n"
" " * 4
+ f"TRT Engine #{i+1} - Submodule name: {trt_subgraph_data.subgraph_name}\n"
)
formatted_stats += (
" " * 5
+ f"Engine Inputs: [{input_formatter(trt_subgraph_data.subgraph_input_shapes, trt_subgraph_data.subgraph_input_dtypes)}]\n"
+ f"Engine Inputs: {input_formatter(trt_subgraph_data.subgraph_input_shapes, trt_subgraph_data.subgraph_input_dtypes)}\n"
)
formatted_stats += (
" " * 5
+ f"Number of Operators in Engine: {trt_subgraph_data.subgraph_op_count}\n"
)
formatted_stats += (
" " * 5
+ f"Engine Outputs: [{input_formatter(trt_subgraph_data.subgraph_output_shapes, trt_subgraph_data.subgraph_output_dtypes)}]\n"
+ f"Engine Outputs: {input_formatter(trt_subgraph_data.subgraph_output_shapes, trt_subgraph_data.subgraph_output_dtypes)}\n"
)

formatted_stats += " " * 4 + "...\n"
formatted_stats += (
" " * 3
+ f"Outputs: [{input_formatter(dryrun_tracker.graph_output_shapes, dryrun_tracker.graph_output_dtypes)}]\n"
+ f"Outputs: {input_formatter(dryrun_tracker.graph_output_shapes, dryrun_tracker.graph_output_dtypes)}\n"
)

# Print aggregate statistics about the graph structure, including recommended "min_block_size" options
Expand Down Expand Up @@ -167,23 +158,23 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) ->
+ " " * 3
+ "- For minimal graph segmentation, select min_block_size="
+ f"{most_ops_in_an_engine} which would generate "
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= most_ops_in_an_engine])} TRT engines"
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= most_ops_in_an_engine])} TRT engine(s)"
)
if math.ceil(avg_ops_per_engine) != most_ops_in_an_engine:
formatted_stats += (
"\n"
+ " " * 3
+ "- For moderate graph segmentation, select min_block_size="
+ f"{math.ceil(avg_ops_per_engine)} which would generate "
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= math.ceil(avg_ops_per_engine)])} TRT engines"
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= math.ceil(avg_ops_per_engine)])} TRT engine(s)"
)

formatted_stats += (
"\n"
+ " " * 3
+ "- The current level of graph segmentation is equivalent to selecting min_block_size="
+ f"{min_ops_in_an_engine} which generates "
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= min_ops_in_an_engine])} TRT engines"
+ f"{len([1 for trt_subgraph in dryrun_tracker.per_subgraph_data if trt_subgraph.subgraph_op_count >= min_ops_in_an_engine])} TRT engine(s)"
)
else:
formatted_stats += (
Expand All @@ -192,14 +183,45 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) ->
+ "Aggregate stats not available since no TRT Engines were generated."
)

dryrun_logger(formatted_stats)
# If user specified "dryrun=True", print to stdout, else debug
if dryrun_enabled:
print(formatted_stats)
else:
logger.debug(formatted_stats)


def input_formatter(shapes: List[Tuple[int, ...]], dtypes: List[torch.dtype]) -> str:
def input_formatter(shapes: Any, dtypes: Any) -> str:
"""Format shapes and dtypes of input Tensors into a readable string"""
formatted_str = ", "

for shape, dtype in zip(shapes, dtypes):
formatted_str += f"Tensor: {shape}@{str(dtype)[6:]}, "
def input_formatter_helper(shapes: Any, dtypes: Any) -> str:
"""Helper for input formatter"""
# Base case - single shape, single dtype
if isinstance(shapes, tuple) and all(isinstance(elt, int) for elt in shapes):
return f"Tensor: {shapes}@{str(dtypes)[6:]}, "

# Shapes is a sequence
elif isinstance(shapes, (list, tuple)):
formatted_str = "List[" if isinstance(shapes, list) else "Tuple("
for shape, dtype in zip(shapes, dtypes):
formatted_str += input_formatter_helper(shape, dtype)
formatted_str = formatted_str[:-2] + (
"], " if isinstance(shapes, list) else "), "
)
return formatted_str

# Shapes is a dictionary
elif isinstance(shapes, dict):
formatted_str = "Dict{"

for key, shape in shapes.items():
formatted_str += input_formatter_helper(shape, dtypes[key])

formatted_str = formatted_str[:-2] + "}, "
return formatted_str

else:
raise ValueError(
f"Invalid input type {type(shapes)} encountered in parse_complex_tensor_structs parsing."
)

return formatted_str[2:-2]
return input_formatter_helper(shapes, dtypes)[:-2]
55 changes: 30 additions & 25 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
from torch_tensorrt.dynamo.utils import (
get_torch_inputs,
parse_complex_tensor_structs,
prepare_inputs,
set_log_level,
to_torch_device,
Expand Down Expand Up @@ -257,11 +258,13 @@ def compile_module(

dryrun_tracker.total_ops_in_graph = total_ops
dryrun_tracker.supported_ops_in_graph = num_supported_ops
dryrun_tracker.graph_input_shapes = [
tuple(input_.shape) for input_ in sample_inputs
]
dryrun_tracker.graph_input_dtypes = [input_.torch_dtype for input_ in sample_inputs]
dryrun_tracker.truncated_long_and_double = settings.truncate_long_and_double
dryrun_tracker.graph_input_shapes = parse_complex_tensor_structs(
sample_inputs, "shape", tuple
)
dryrun_tracker.graph_input_dtypes = parse_complex_tensor_structs(
sample_inputs, "torch_dtype"
)
dryrun_tracker.compilation_settings = settings

if settings.dryrun and settings.min_block_size > 1:
logger.info(
Expand Down Expand Up @@ -290,7 +293,7 @@ def compile_module(
# If specified, try using the fast partitioner and fall back to the global one on failure
if settings.use_fast_partitioner:
try:
partitioned_module = partitioning.fast_partition(
partitioned_module, supported_ops = partitioning.fast_partition(
gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
Expand All @@ -307,13 +310,15 @@ def compile_module(
settings.use_fast_partitioner = False

if not settings.use_fast_partitioner:
partitioned_module = partitioning.global_partition(
partitioned_module, supported_ops = partitioning.global_partition(
gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
)

dryrun_tracker.unsupported_ops = supported_ops.unsupported_operators

# Store TRT replicas of Torch subgraphs
trt_modules = {}
# Iterate over all components that can be accelerated
Expand Down Expand Up @@ -360,25 +365,23 @@ def compile_module(
name,
)

subgraph_data.subgraph_input_dtypes = [
submodule_input.torch_dtype for submodule_input in submodule_inputs
]
subgraph_data.subgraph_input_shapes = [
tuple(submodule_input.shape) for submodule_input in submodule_inputs
]
subgraph_data.subgraph_input_shapes = parse_complex_tensor_structs(
submodule_inputs, "shape", tuple
)
subgraph_data.subgraph_input_dtypes = parse_complex_tensor_structs(
submodule_inputs, "torch_dtype"
)

submodule_outputs = submodule(
*get_torch_inputs(submodule_inputs, to_torch_device(settings.device))
)
if not isinstance(submodule_outputs, (list, tuple)):
submodule_outputs = [submodule_outputs]

subgraph_data.subgraph_output_dtypes = [
submodule_output.dtype for submodule_output in submodule_outputs
]
subgraph_data.subgraph_output_shapes = [
tuple(submodule_output.shape) for submodule_output in submodule_outputs
]
subgraph_data.subgraph_output_shapes = parse_complex_tensor_structs(
submodule_outputs, "shape", tuple
)
subgraph_data.subgraph_output_dtypes = parse_complex_tensor_structs(
submodule_outputs, "dtype"
)

dryrun_tracker.tensorrt_graph_count += 1
dryrun_tracker.per_subgraph_data.append(subgraph_data)
Expand All @@ -401,10 +404,12 @@ def compile_module(
if not isinstance(sample_outputs, (list, tuple)):
sample_outputs = [sample_outputs]

dryrun_tracker.graph_output_shapes = [
tuple(output_.shape) for output_ in sample_outputs
]
dryrun_tracker.graph_output_dtypes = [output_.dtype for output_ in sample_outputs]
dryrun_tracker.graph_output_shapes = parse_complex_tensor_structs(
sample_outputs, "shape", tuple
)
dryrun_tracker.graph_output_dtypes = parse_complex_tensor_structs(
sample_outputs, "dtype"
)

# Replace all FX Modules with TRT Modules
for name, trt_module in trt_modules.items():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def partition(
min_block_size: int = MIN_BLOCK_SIZE,
torch_executed_ops: Collection[Target] = set(),
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
) -> torch.fx.GraphModule:
) -> Tuple[torch.fx.GraphModule, OpSupportTester]:
"""Partition an FX GraphModule with aten ops into TRT engines
Partitioning is based on converter operator support
Expand All @@ -259,7 +259,7 @@ def partition(
torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage
require_full_compilation: Require that all computational operators be run in TRT
Returns:
torch.fx.GraphModule
torch.fx.GraphModule, OpSupportTester
"""
# Ensure graph is clean prior to partitioning
gm.graph.eliminate_dead_code()
Expand All @@ -280,4 +280,4 @@ def partition(
if verbose:
supported_ops.print_support_overview(partitioner.num_trt_accelerated_subgraphs)

return partitioned_graph
return partitioned_graph, supported_ops
Loading

0 comments on commit d121545

Please sign in to comment.