Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/dynamo/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ Model Zoo
* :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`)
* :ref:`_torch_export_sam2`: Compiling SAM2 model using AOT workflow (`ir=dynamo`)
* :ref:`_torch_export_flux_dev`: Compiling FLUX.1-dev model using AOT workflow (`ir=dynamo`)
* :ref:`debugger_example`: Debugging Torch-TensorRT Compilation
75 changes: 75 additions & 0 deletions examples/dynamo/debugger_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
.. _debugger_example:

Debugging Torch-TensorRT Compilation
===================================================================

TensorRT conversion can perform many graph transformations and backend specific
optimizations that are sometimes hard to inspect. Torch-TensorRT provides a
Debugger utility to help visualize FX graphs around lowering passes, monitor
engine building, and capture profiling or TensorRT API traces.

In this example, we demonstrate how to:

1. Enable the Torch-TensorRT Debugger context
2. Capture and visualize FX graphs before and/or after specific lowering passes
3. Configure logging directory and verbosity
"""

import os
import tempfile

import numpy as np
import torch
import torch_tensorrt as torch_trt
import torchvision.models as models

temp_dir = os.path.join(tempfile.gettempdir(), "torch_tensorrt_debugger_example")

np.random.seed(0)
torch.manual_seed(0)
inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]


model = models.resnet18(pretrained=False).to("cuda").eval()
exp_program = torch.export.export(model, tuple(inputs))
enabled_precisions = {torch.float}
workspace_size = 20 << 30
min_block_size = 0
use_python_runtime = False
torch_executed_ops = {}

with torch_trt.dynamo.Debugger(
log_level="debug",
logging_dir=temp_dir,
engine_builder_monitor=False, # whether to monitor the engine building process
capture_fx_graph_after=[
"complex_graph_detection"
], # fx graph visualization after certain lowering pass
capture_fx_graph_before=[
"remove_detach"
], # fx graph visualization before certain lowering pass
):

trt_gm = torch_trt.dynamo.compile(
exp_program,
tuple(inputs),
use_python_runtime=use_python_runtime,
enabled_precisions=enabled_precisions,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
immutable_weights=False,
reuse_cached_engines=False,
)

trt_output = trt_gm(*inputs)


"""
The logging directory will contain the following files:
- /tmp/torch_tensorrt_debugger_example/
torch_tensorrt_logging.log
- /lowering_passes_visualization/
after_complex_graph_detection.svg
before_remove_detach.svg
"""
Loading