Skip to content

Commit

Permalink
chore: fix docs for export (#2447)
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <[email protected]>
  • Loading branch information
peri044 authored Nov 9, 2023
1 parent da90d61 commit 504b39d
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 29 deletions.
7 changes: 2 additions & 5 deletions docsrc/dynamo/dynamo_export.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.. _dynamo_export:

Compiling ``ExportedPrograms`` with Torch-TensorRT
Compiling Exported Programs with Torch-TensorRT
=============================================
.. currentmodule:: torch_tensorrt.dynamo

Expand All @@ -9,8 +9,6 @@ Compiling ``ExportedPrograms`` with Torch-TensorRT
:undoc-members:
:show-inheritance:

Using the Torch-TensorRT Frontend for ``torch.export.ExportedPrograms``
--------------------------------------------------------
Pytorch 2.1 introduced ``torch.export`` APIs which
can export graphs from Pytorch programs into ``ExportedProgram`` objects. Torch-TensorRT dynamo
frontend compiles these ``ExportedProgram`` objects and optimizes them using TensorRT. Here's a simple
Expand Down Expand Up @@ -43,8 +41,7 @@ Some of the frequently used options are as follows:

The complete list of options can be found `here <https://github.com/pytorch/TensorRT/blob/123a486d6644a5bbeeec33e2f32257349acc0b8f/py/torch_tensorrt/dynamo/compile.py#L51-L77>`_

.. note:: We do not support INT precision currently in Dynamo. Support for this currently exists in
our Torchscript IR. We plan to implement similar support for dynamo in our next release.
.. note:: We do not support INT precision currently in Dynamo. Support for this currently exists in our Torchscript IR. We plan to implement similar support for dynamo in our next release.

Under the hood
--------------
Expand Down
45 changes: 22 additions & 23 deletions docsrc/user_guide/saving_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ The following code illustrates this approach.
import torch_tensorrt
model = MyModel().eval().cuda()
inputs = torch.randn((1, 3, 224, 224)).cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
trt_script_model = torch.jit.trace(trt_gm, inputs)
torch.jit.save(trt_script_model, "trt_model.ts")
trt_traced_model = torch.jit.trace(trt_gm, inputs)
torch.jit.save(trt_traced_model, "trt_model.ts")
# Later, you can load it and run inference
model = torch.jit.load("trt_model.ts").cuda()
model(inputs)
model(*inputs)
b) ExportedProgram
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand All @@ -50,40 +50,39 @@ b) ExportedProgram
import torch_tensorrt
model = MyModel().eval().cuda()
inputs = torch.randn((1, 3, 224, 224)).cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) # Output is a torch.fx.GraphModule
# Transform and create an exported program
trt_gm = torch_tensorrt.dynamo.export(trt_gm, inputs)
trt_exp_program = create_exported_program(trt_gm, call_spec, trt_gm.state_dict())
torch._export.save(trt_exp_program, "trt_model.ep")
trt_exp_program = torch_tensorrt.dynamo.export(trt_gm, inputs)
torch.export.save(trt_exp_program, "trt_model.ep")
# Later, you can load it and run inference
model = torch._export.load("trt_model.ep")
model(inputs)
model = torch.export.load("trt_model.ep")
model(*inputs)
`torch_tensorrt.dynamo.export` inlines the submodules within a GraphModule to their corresponding nodes and stiches all the nodes together.
This is needed as `torch._export` serialization cannot handle serializing and deserializing of submodules (`call_module` nodes).

NOTE: This way of saving the models using `ExportedProgram` is experimental. Here is a known issue : https://github.com/pytorch/TensorRT/issues/2341
.. note:: This way of saving the models using `ExportedProgram` is experimental. Here is a known issue : https://github.com/pytorch/TensorRT/issues/2341


Torchscript IR
-------------

In Torch-TensorRT 1.X versions, the primary way to compile and run inference with Torch-TensorRT is using Torchscript IR.
This behavior stays the same in 2.X versions as well.
In Torch-TensorRT 1.X versions, the primary way to compile and run inference with Torch-TensorRT is using Torchscript IR.
This behavior stays the same in 2.X versions as well.

.. code-block:: python
.. code-block:: python
import torch
import torch_tensorrt
import torch
import torch_tensorrt
model = MyModel().eval().cuda()
inputs = torch.randn((1, 3, 224, 224)).cuda()
trt_ts = torch_tensorrt.compile(model, ir="ts", inputs) # Output is a ScriptModule object
torch.jit.save(trt_ts, "trt_model.ts")
model = MyModel().eval().cuda()
inputs = [torch.randn((1, 3, 224, 224)).cuda()]
trt_ts = torch_tensorrt.compile(model, ir="ts", inputs) # Output is a ScriptModule object
torch.jit.save(trt_ts, "trt_model.ts")
# Later, you can load it and run inference
model = torch.jit.load("trt_model.ts").cuda()
model(inputs)
# Later, you can load it and run inference
model = torch.jit.load("trt_model.ts").cuda()
model(*inputs)
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def export(
return exp_program
else:
raise ValueError(
"Invalid ir : {ir} provided for serialization. Options include torchscript | exported_program"
f"Invalid ir : {ir} provided for serialization. Options include torchscript | exported_program"
)


Expand Down

0 comments on commit 504b39d

Please sign in to comment.