Torch-TensorRT 2.0 #1826
Replies: 6 comments 3 replies
-
Comment on Recompilation + Dynamic BatchOne thing to add to the caching section is refactoring the Guards/Caching implementation so that differing shapes in a pre-specified (batch) dimension do not cause recompilation. For example, if a user has specified class Sample(nn.Module):
...
def forward(self, x):
if x.sum() > 2:
return torch.sum(x)
else:
return torch.sum(x**2)
model = torch_tensorrt.dynamo.compile(Sample(), ... [min_shape = (1,), opt_shape=(4,), max_shape=(8,)],...)
input_1 = torch.zeros(1)
input_2 = torch.zeros(4)
input_3 = torch.zeros(8)
# No Recompilation Should Occur
model(input_1)
model(input_2)
model(input_3) |
Beta Was this translation helpful? Give feedback.
-
This seemed very abstract in our conversation with PyT - do we have a clear way to represent the additional constraints of a TRT engine over Dynamo?
Most customers are very sensitive to Device memory consumption and often host too - FSCache is most analogous to how TRT is used today & seems reasonable, but a few thoughts / considerations.
Do we have an answer from Meta on preserving high level ops? We don't want Torch-TRT to turn into another MHA pattern matcher.
Is there any consideration for how something like CCD would fit in here? Ideally we leverage dynamo to do much of the selection process required in CCD. Additionally, Dynamic shapes, sources of dynamo overhead, export workflow, and potentially additional workflows like extracting or loading TRT engines should be considered. |
Beta Was this translation helpful? Give feedback.
-
Unifying dynamo export and compile workflows
Both
Partitioning, conversion etc should be shared by both Issues with PartitionerAdds constants as inputs, functional programming to the extreme, no state. Just inputs/outputs and function For a simple model like,
This will fail in the batch norm conversion phase because partitioner treats constants (weights of conv/bn) as placeholder tensors which get added to the graph https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/utils/fuser_utils.py#L142-L147 Question here : Should we handle ITensors in the converters (or) a transformation pass over partitioning (or) write a custom partitioner based off Note : This is not a problem with torch.compile (backend workflow) because there is no batch norm layer in the graph as batch_norm gets split into add/mul layers while dynamo.export has batch_norm layer explicitly.Probably Issues with TRT splitter
torch.exporttorch._dynamo.export doesn't do aot_autograd yet. From Meta's discussion, Naming of APIsCurrent prototype:
|
Beta Was this translation helpful? Give feedback.
-
Yeah
Yeah sure. We might need utils or other stuff so submodule is good. (similar to torchscript).
Yeah will do that. That would be great
So the
This sounds good. Yeah we can express
Agreed. My intention was to pick one (ideally partitioner). Since the partitioner had problems, I tried TRTSplitter as an option in this prototype to understand what the issue is. I think we should subclass partitioner to fix the above mentioned issue or maybe needed in the future to make some advanced partitioning heuristics. API namings seem good to me. However, one thing is not clear. Based on your comment, are you saying |
Beta Was this translation helpful? Give feedback.
-
|
Beta Was this translation helpful? Give feedback.
-
UX Goals
APIs
torch.compile
torch.compile is a JIT compiler. This should be taken in the sense that the users workflow would be to take a PyTorch Module
torch.export
torch_tensorrt.compile
Workflows
JIT Optimization
The idea of this workflow is that a user will deploy a boxed version of their model tied to Torch-TensorRT via
torch.compile
. When users call their model,torch.compile
will compile the graph and provide thetorch_tensorrt.dynamo.backend
with a set of example inputs, settings and the graph.Conceptually either dynamo or the backend will recognize when the current compiled module is invalid due to a change in constraints (typically input size) can recompile the target or call on a cache to pull up a previously compiled serialized version.
Compile and Deploy on the Same Machine
AOT Optimization
Compile And Deploy on Seperate Machines
Internals
Backends
ATen and AOTAutograd
https://github.com/pytorch/pytorch/blob/93d75568c7070942a59337dd83194c2fd5221adb/torch/_functorch/aot_autograd.py#L2837
Engine Cache
It seems like for
torch.compile
that there is a strong requirement for a engine cache, some sort of store of compiled TensorRT engines tied to an identifying hash calculated from the source graph and provided inputs. This cache should be able to short circuit the torch-tensorrt backend and deserialize and return the previously created engineThere are a couple methods we could think about for maintaining this cache
Implementation
FSCache
Write engines to disk in some sort of temp directory with a file system convention for locating and matching files.
Advantages
Disadvantages
Additional User Configurations required
MemCache
Hold serialized engines in host memory (i.e. dictionary) and deserialize
Advantages
Disadvantages
Additional User Configurations required
HotCache
As an addition to either of these cache options, we can include the ability to have a "Hot Cache" i.e. a number of engines which stay live and deserialized, the cost being additional VRAM and host memory usage.
Options for HotCache Rules
Saving and reloading caches
We need to come up with a format to store an load caches so that if in a future run, dynamo detects an identical graph we can load in the model
What causes a cache miss?
There are a number of reasons why a cache might be invalid.
Guards
Guards are the mechanism to detect if a subgraph is different than the target. Some of these guards are provided by dynamo. However we need to provide guards that are TensorRT specific. These may include changed weights, changed inputs etc.
Lowering
There are three classes of lowering in the dynamo backend. Decompositions, Subgraph Rewriting and Module Level Lowering
Decompositions
Decompositions are small functions which map an operator to a lowered form similar to unpack passes in the TorchScript frontend and serve to customize cases or reduced the opset that the converters need to handle.
Decompositions are run as part of the functorch.aot_autograd step
Subgraph Rewriting
Subgraph rewriting takes small repeating patterns and replaces them with one or many operations. See support for Linear/AddMM in TS for examples of what this looks like
Subgraph rewriting will be run post aot_autograd
Module Level Lowering
Module level passes identify submodules in graphs to perform aliasing or other high level operations.
Module level passes will run pre aot autograd
Ex.
Partitioning
Dynamo has a builtin capability partitioner that we have a prototype for:
This uses the default
CapabilityBasedPartitioner
. We would likely need to modify this to add support for features likemin_block_size
,torch_supported_ops
,torch_support_modules
Overview
In short, the three levels of Lowering work in conjunction with partitioning. First, we have the pre-tracing Module-Level Lowering for high-level modules. Then, we have during-tracing Decompositions for fusions, in-place ops, and operator simplification. Finally, we have post-tracing subgraph rewriting which can also assist with fusions, as well as other graph simplifications. This final, post-tracing pass can help with reducing segmentation in partitioning, since we can use post-tracing subgraph rewriting to replace operations with their prims-equivalents, which are lower-level and easier to implement converters for.
Conversion
FX Interpreter
https://github.com/pytorch/pytorch/blob/master/torch/fx/interpreter.py
Evaluation
There are some constants which we need at compile time produced by intermediate operations. The FX interpreter can execute these and store that data some where for converters to use.
Converters
Consolidate TRTNetwork, IR, in some context to pass around
Dynamo Symbolic Shapes + Shape Tensors
Dynamic shape cases
Sym Shape / Shape Tensor Interop
Runtime / Callable
Torch-TensorRT Legacy Runtime
Inductor has this function: https://github.com/pytorch/pytorch/blob/d5aa4cec578f40afd43cc0f96ba0d0abaf38b1f4/torch/_inductor/compile_fx.py#L514
Python Runtime
Beta Was this translation helpful? Give feedback.
All reactions