Skip to content

Commit cee4914

Browse files
Implemented basic Mutable torch tensorrt module pipeline (#2981)
Signed-off-by: Naren Dasan <[email protected]> Co-authored-by: Naren Dasan <[email protected]>
1 parent d7f68c4 commit cee4914

File tree

8 files changed

+1247
-2
lines changed

8 files changed

+1247
-2
lines changed

docsrc/py_api/torch_tensorrt.rst

+3
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ Functions
3232

3333
Classes
3434
---------
35+
.. autoclass:: MutableTorchTensorRTModule
36+
:members:
37+
:special-members: __init__
3538

3639
.. autoclass:: Input
3740
:members:

examples/dynamo/README.rst

+1
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ a number of ways you can leverage this backend to accelerate inference.
1313
* :ref:`torch_export_cudagraphs`: Using the Cudagraphs integration with `ir="dynamo"`
1414
* :ref:`custom_kernel_plugins`: Creating a plugin to use a custom kernel inside TensorRT engines
1515
* :ref:`refit_engine_example`: Refitting a compiled TensorRT Graph Module with updated weights
16+
* :ref:`mutable_torchtrt_module_example`: Compile, use, and modify TensorRT Graph Module with MutableTorchTensorRTModule
1617
* :ref:`vgg16_fp8_ptq`: Compiling a VGG16 model with FP8 and PTQ using ``torch.compile``
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
"""
2+
.. _mutable_torchtrt_module_example:
3+
4+
Mutable Torch TensorRT Module
5+
===================================================================
6+
7+
We are going to demonstrate how we can easily use Mutable Torch TensorRT Module to compile, interact, and modify the TensorRT Graph Module.
8+
9+
Compiling a Torch-TensorRT module is straightforward, but modifying the compiled module can be challenging, especially when it comes to maintaining the state and connection between the PyTorch module and the corresponding Torch-TensorRT module.
10+
In Ahead-of-Time (AoT) scenarios, integrating Torch TensorRT with complex pipelines, such as the Hugging Face Stable Diffusion pipeline, becomes even more difficult.
11+
The Mutable Torch TensorRT Module is designed to address these challenges, making interaction with the Torch-TensorRT module easier than ever.
12+
13+
In this tutorial, we are going to walk through
14+
1. Sample workflow of Mutable Torch TensorRT Module with ResNet 18
15+
2. Save a Mutable Torch TensorRT Module
16+
3. Integration with Huggingface pipeline in LoRA use case
17+
"""
18+
19+
import numpy as np
20+
import torch
21+
import torch_tensorrt as torch_trt
22+
import torchvision.models as models
23+
24+
np.random.seed(5)
25+
torch.manual_seed(5)
26+
inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]
27+
28+
# %%
29+
# Initialize the Mutable Torch TensorRT Module with settings.
30+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
31+
settings = {
32+
"use_python": False,
33+
"enabled_precisions": {torch.float32},
34+
"make_refitable": True,
35+
}
36+
37+
model = models.resnet18(pretrained=False).eval().to("cuda")
38+
mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings)
39+
# You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module.
40+
mutable_module(*inputs)
41+
42+
# %%
43+
# Make modifications to the mutable module.
44+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
45+
46+
# %%
47+
# Making changes to mutable module can trigger refit or re-compilation. For example, loading a different state_dict and setting new weight values will trigger refit, and adding a module to the model will trigger re-compilation.
48+
model2 = models.resnet18(pretrained=True).eval().to("cuda")
49+
mutable_module.load_state_dict(model2.state_dict())
50+
51+
52+
# Check the output
53+
# The refit happens while you call the mutable module again.
54+
expected_outputs, refitted_outputs = model2(*inputs), mutable_module(*inputs)
55+
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
56+
assert torch.allclose(
57+
expected_output, refitted_output, 1e-2, 1e-2
58+
), "Refit Result is not correct. Refit failed"
59+
60+
print("Refit successfully!")
61+
62+
# %%
63+
# Saving Mutable Torch TensorRT Module
64+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
65+
66+
# Currently, saving is only enabled for C++ runtime, not python runtime.
67+
torch_trt.MutableTorchTensorRTModule.save(mutable_module, "mutable_module.pkl")
68+
reload = torch_trt.MutableTorchTensorRTModule.load("mutable_module.pkl")
69+
70+
# %%
71+
# Stable Diffusion with Huggingface
72+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
73+
74+
# The LoRA checkpoint is from https://civitai.com/models/12597/moxin
75+
76+
from diffusers import DiffusionPipeline
77+
78+
with torch.no_grad():
79+
settings = {
80+
"use_python_runtime": True,
81+
"enabled_precisions": {torch.float16},
82+
"debug": True,
83+
"make_refitable": True,
84+
}
85+
86+
model_id = "runwayml/stable-diffusion-v1-5"
87+
device = "cuda:0"
88+
89+
prompt = "house in forest, shuimobysim, wuchangshuo, best quality"
90+
negative = "(worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, out of focus, cloudy, (watermark:2),"
91+
92+
pipe = DiffusionPipeline.from_pretrained(
93+
model_id, revision="fp16", torch_dtype=torch.float16
94+
)
95+
pipe.to(device)
96+
97+
# The only extra line you need
98+
pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings)
99+
100+
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
101+
image.save("./without_LoRA_mutable.jpg")
102+
103+
# Standard Huggingface LoRA loading procedure
104+
pipe.load_lora_weights(
105+
"stablediffusionapi/load_lora_embeddings",
106+
weight_name="moxin.safetensors",
107+
adapter_name="lora1",
108+
)
109+
pipe.set_adapters(["lora1"], adapter_weights=[1])
110+
pipe.fuse_lora()
111+
pipe.unload_lora_weights()
112+
113+
# Refit triggered
114+
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
115+
image.save("./with_LoRA_mutable.jpg")

py/torch_tensorrt/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,6 @@ def _register_with_torch() -> None:
134134
from torch_tensorrt import dynamo # noqa: F401
135135

136136
from torch_tensorrt._compile import * # noqa: F403
137+
from torch_tensorrt.dynamo.runtime._MutableTorchTensorRTModule import (
138+
MutableTorchTensorRTModule,
139+
)

py/torch_tensorrt/_compile.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -534,4 +534,4 @@ def save(
534534
exp_program = torch.export.export(
535535
module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False
536536
)
537-
torch.export.save(exp_program, file_path)
537+
torch.export.save(exp_program, file_path)

0 commit comments

Comments
 (0)