Skip to content

Commit d5246f9

Browse files
committed
chore: updates
1 parent c4f8945 commit d5246f9

File tree

4 files changed

+50
-32
lines changed

4 files changed

+50
-32
lines changed

docsrc/index.rst

+2
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ Tutorials
118118
tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2
119119
tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion
120120
tutorials/_rendered_examples/dynamo/mutable_torchtrt_module_example
121+
tutorials/_rendered_examples/dynamo/torch_export_gpt2
122+
tutorials/_rendered_examples/dynamo/torch_export_llama2
121123

122124
Python API Documentation
123125
------------------------

examples/dynamo/README.rst

+24-7
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,36 @@
11
.. _torch_compile:
22

3-
Dynamo / ``torch.compile``
4-
----------------------------
3+
Torch-TensorRT Examples
4+
====================================
55

6-
Torch-TensorRT provides a backend for the new ``torch.compile`` API released in PyTorch 2.0. In the following examples we describe
7-
a number of ways you can leverage this backend to accelerate inference.
6+
Please refer to the following examples which demonstrate the usage of different features of Torch-TensorRT. We also provide
7+
examples of Torch-TensorRT compilation of select computer vision and language models.
88

9-
* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile``
10-
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile``
9+
Dependencies
10+
------------------------------------
11+
12+
Please install the following external depencies (assuming you already have `torch_tensorrt` installed)
13+
14+
.. code-block:: python
15+
16+
pip install -r requirements.txt
17+
18+
19+
Compiler Features
20+
------------------------------------
1121
* :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API
12-
* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile``
1322
* :ref:`torch_export_cudagraphs`: Using the Cudagraphs integration with `ir="dynamo"`
1423
* :ref:`custom_kernel_plugins`: Creating a plugin to use a custom kernel inside TensorRT engines
1524
* :ref:`refit_engine_example`: Refitting a compiled TensorRT Graph Module with updated weights
1625
* :ref:`mutable_torchtrt_module_example`: Compile, use, and modify TensorRT Graph Module with MutableTorchTensorRTModule
1726
* :ref:`vgg16_fp8_ptq`: Compiling a VGG16 model with FP8 and PTQ using ``torch.compile``
1827
* :ref:`engine_caching_example`: Utilizing engine caching to speed up compilation times
1928
* :ref:`engine_caching_bert_example`: Demonstrating engine caching on BERT
29+
30+
Model Zoo
31+
------------------------------------
32+
* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile``
33+
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile``
34+
* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile``
35+
* :ref:`_torch_export_gpt2`: Compiling a GPT2 model using AOT workflow (`ir=dynamo`)
36+
* :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`)

examples/dynamo/requirements.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
cupy==13.1.0
2-
torch>=2.4.0.dev20240503+cu121
3-
torch-tensorrt>=2.4.0.dev20240503+cu121
42
triton==2.3.0
3+
diffusers==0.30.3
4+
transformers==4.44.2

examples/dynamo/torch_compile_gpt2.py

+22-23
Original file line numberDiff line numberDiff line change
@@ -53,30 +53,29 @@
5353
# Compilation with `torch.compile` using tensorrt backend and generate TensorRT outputs
5454
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
5555

56-
with torch_tensorrt.logging.debug():
57-
# Compile the model and mark the input sequence length to be dynamic
58-
torch._dynamo.mark_dynamic(input_ids, 1, min=2, max=1023)
59-
model.forward = torch.compile(
60-
model.forward,
61-
backend="tensorrt",
62-
dynamic=None,
63-
options={
64-
"enabled_precisions": {torch.float32},
65-
"disable_tf32": True,
66-
"min_block_size": 1,
67-
"debug": True,
68-
},
69-
)
56+
# Compile the model and mark the input sequence length to be dynamic
57+
torch._dynamo.mark_dynamic(input_ids, 1, min=2, max=1023)
58+
model.forward = torch.compile(
59+
model.forward,
60+
backend="tensorrt",
61+
dynamic=None,
62+
options={
63+
"enabled_precisions": {torch.float32},
64+
"disable_tf32": True,
65+
"min_block_size": 1,
66+
"debug": True,
67+
},
68+
)
7069

71-
# Auto-regressive generation loop for greedy decoding using TensorRT model
72-
# The first token generation compiles the model using TensorRT and the second token
73-
# encounters recompilation
74-
trt_gen_tokens = model.generate(
75-
inputs=input_ids,
76-
max_length=MAX_TOKENS,
77-
use_cache=False,
78-
pad_token_id=tokenizer.eos_token_id,
79-
)
70+
# Auto-regressive generation loop for greedy decoding using TensorRT model
71+
# The first token generation compiles the model using TensorRT and the second token
72+
# encounters recompilation
73+
trt_gen_tokens = model.generate(
74+
inputs=input_ids,
75+
max_length=MAX_TOKENS,
76+
use_cache=False,
77+
pad_token_id=tokenizer.eos_token_id,
78+
)
8079

8180
# %%
8281
# Decode the output sentences of PyTorch and TensorRT

0 commit comments

Comments
 (0)