Skip to content

Commit

Permalink
feat: Improve Logging in Dynamo (#2194)
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <[email protected]>
  • Loading branch information
peri044 authored Aug 16, 2023
1 parent b57d83e commit 08a2ee4
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 150 deletions.
32 changes: 17 additions & 15 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations

import logging
from enum import Enum
from typing import Any, Callable, List, Optional, Sequence, Set

import torch
import torch.fx
import torch_tensorrt.ts
from torch_tensorrt import logging
from torch_tensorrt._enums import dtype
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo.compile import compile as dynamo_compile
Expand All @@ -16,6 +16,13 @@
from torch_tensorrt.ts._compiler import compile as torchscript_compile
from typing_extensions import TypeGuard

logger = logging.getLogger(__name__)

__all__ = [
"compile",
"convert_method_to_trt_engine",
]


def _non_fx_input_interface(
inputs: Sequence[Input | torch.Tensor | InputTensorSpec],
Expand All @@ -30,7 +37,7 @@ def _fx_input_interface(


class _IRType(Enum):
"""Enum to set the minimum required logging level to print a message to stdout"""
"""Enum to determine the type of IR selected for model compilation"""

ts = 0
fx = 1
Expand All @@ -39,7 +46,7 @@ class _IRType(Enum):


class _ModuleType(Enum):
"""Enum to set the minimum required logging level to print a message to stdout"""
"""Enum to determine the type of model provided as input"""

nn = 0
ts = 1
Expand Down Expand Up @@ -81,14 +88,11 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
if ir == "default":
# Options are listed in order of preference
if module_is_fxable:
logging.log(
logging.Level.Info, "ir was set to default, using dynamo as ir"
)
logger.info("ir was set to default, using dynamo as ir")
return _IRType.dynamo
elif module_is_tsable:
logging.log(
logging.Level.Warning,
"Input graph is a Torchscript module but the ir provided is default (dynamo). Please set ir=torchscript to suppress the warning. Compiling the module with ir=torchscript",
logger.warning(
"Input graph is a Torchscript module but the ir provided is default (dynamo). Please set ir=torchscript to suppress the warning. Compiling the module with ir=torchscript"
)
return _IRType.ts
else:
Expand Down Expand Up @@ -151,9 +155,8 @@ def compile(
if target_ir == _IRType.ts:
ts_mod = module
if module_type == _ModuleType.nn:
logging.log(
logging.Level.Info,
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript",
logger.info(
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript"
)
ts_mod = torch.jit.script(module)
assert _non_fx_input_interface(input_list)
Expand Down Expand Up @@ -274,9 +277,8 @@ def convert_method_to_trt_engine(
if target_ir == _IRType.ts:
ts_mod = module
if module_type == _ModuleType.nn:
logging.log(
logging.Level.Info,
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript",
logger.info(
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript"
)
ts_mod = torch.jit.script(module)
return torch_tensorrt.ts.convert_method_to_trt_engine( # type: ignore[no-any-return]
Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/dynamo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import logging

from torch_tensorrt._utils import sanitized_torch_version

from packaging import version

logger = logging.getLogger(__name__)

if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
from ._settings import * # noqa: F403
from ._SourceIR import SourceIR # noqa: F403
Expand Down
7 changes: 5 additions & 2 deletions py/torch_tensorrt/dynamo/aten_tracer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import copy
import logging
import sys
from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
Expand All @@ -26,6 +27,8 @@

Value: TypeAlias = Union[Tuple["Value", ...], List["Value"], Dict[str, "Value"]]

logger = logging.getLogger(__name__)


class DynamoConfig:
"""
Expand Down Expand Up @@ -145,13 +148,13 @@ def trace(
]

fx_module, __package__ = dynamo_trace(model, inputs, True, "symbolic")
print(fx_module.graph)

for passes in passes_list:
pr: PassResult = passes(fx_module)
fx_module = pr.graph_module

fx_module(*inputs)

fx_module = run_const_fold(fx_module)
print(fx_module.graph)
logger.info("Post export graph : %s\n", fx_module.graph)
return fx_module
118 changes: 3 additions & 115 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@
import torch
import torch._dynamo as td
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
from torch_tensorrt.dynamo import CompilationSettings, partitioning
from torch_tensorrt.dynamo.conversion import (
convert_module,
repair_long_or_double_inputs,
)
from torch_tensorrt.dynamo import CompilationSettings
from torch_tensorrt.dynamo.compile import compile_module
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
Expand Down Expand Up @@ -69,7 +66,7 @@ def _pretraced_backend(
try:
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))

trt_compiled = _compile_module(
trt_compiled = compile_module(
gm,
sample_inputs,
settings=settings,
Expand All @@ -92,112 +89,3 @@ def _pretraced_backend(
+ "specify pass_through_build_failures=False."
)
raise


def _compile_module(
gm: torch.fx.GraphModule,
sample_inputs: Sequence[torch.Tensor],
settings: CompilationSettings = CompilationSettings(),
) -> torch.fx.GraphModule:
"""Compile a traced FX module
Includes: Partitioning + Conversion Phases
Args:
module: FX GraphModule to convert
inputs: Inputs to the module
settings: Compilation settings
Returns:
Compiled FX GraphModule
"""
# Check the number of supported operations in the graph
num_supported_ops, total_ops = partitioning.get_graph_converter_support(
gm, settings.debug, settings.torch_executed_ops
)

# If the number of supported operations is 0 or less than the block size, skip the subgraph
# TODO: Add condition to second expression below when require_full_compilation is added
if num_supported_ops == 0 or (num_supported_ops < settings.min_block_size):
logger.warning(
f"{num_supported_ops} supported operations detected in subgraph containing {total_ops} computational nodes. "
f"Skipping this subgraph, since min_block_size was detected to be {settings.min_block_size}"
)
return gm
else:
logger.debug(
f"Detected support for {num_supported_ops} operators out of {total_ops} in subgraph."
)

# Partition module into components that can be TRT-accelerated
fast_partitioner_failed = False

# If specified, try using the fast partitioner and fall back to the global one on failure
if settings.use_fast_partitioner:
try:
partitioned_module = partitioning.fast_partition(
gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
)
except torch.fx.passes.splitter_base.FxNetSplitterInternalError:
logger.error(
"Partitioning failed on the subgraph with fast partition. See trace above. "
+ "Retrying with global partition.",
exc_info=True,
)

fast_partitioner_failed = True
settings.use_fast_partitioner = False

if not settings.use_fast_partitioner:
partitioned_module = partitioning.global_partition(
gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
)

# Store TRT replicas of Torch subgraphs
trt_modules = {}

# Iterate over all components that can be accelerated
# Generate the corresponding TRT Module for those
for name, _ in partitioned_module.named_children():
# Criteria for a module to be convertible to TRT
if settings.use_fast_partitioner and "_run_on_acc" not in name:
continue

submodule = getattr(partitioned_module, name)

# Get submodule inputs
submodule_inputs = partitioning.get_submod_inputs(
partitioned_module, submodule, sample_inputs
)

assert submodule_inputs is not None
# Handle long/double inputs if requested by the user
if settings.truncate_long_and_double:
submodule_inputs = repair_long_or_double_inputs(
partitioned_module, submodule, submodule_inputs, name
)

# Create TRT Module from submodule
trt_mod = convert_module(
submodule,
submodule_inputs,
settings=settings,
name=name,
)

trt_modules[name] = trt_mod

# Replace all FX Modules with TRT Modules
for name, trt_mod in trt_modules.items():
setattr(partitioned_module, name, trt_mod)

# Reset settings object to user specification after fallback to global partitioning mode
if fast_partitioner_failed:
settings.use_fast_partitioner = True

return partitioned_module
Loading

0 comments on commit 08a2ee4

Please sign in to comment.