-
Notifications
You must be signed in to change notification settings - Fork 376
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Bug Description
I have compiled a model and it works as expected, it produces results that are very similar to the original PyTorch model.
I am trying to save the compiled model, but it fails with several exceptions, that look like it cannot serialize input tensors. I get the following output:
<...>compile logs<...>
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: ((mean, sym_size_int_81)) | Outputs: (output: ))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:05.828617
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Not found cached TRT engines. Start building engine.
INFO:torch_tensorrt [TensorRT Conversion Context]:Global timing cache in use. Profiling results in this builder pass will be stored.
INFO:torch_tensorrt [TensorRT Conversion Context]:Detected 1 inputs and 2 output network tensors.
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Host Persistent Memory: 202960
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Device Persistent Memory: 0
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Scratch Memory: 125829120
INFO:torch_tensorrt [TensorRT Conversion Context]:[BlockAssignment] Started assigning block shifts. This will take 145 steps to complete.
INFO:torch_tensorrt [TensorRT Conversion Context]:[BlockAssignment] Algorithm ShiftNTopDown took 1.44263ms to assign 4 blocks to 145 nodes requiring 163577856 bytes.
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Activation Memory: 163577856
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Weights Memory: 1686827264
INFO:torch_tensorrt [TensorRT Conversion Context]:Engine generation completed in 65.3723 seconds.
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 1588 MiB, GPU 1609 MiB
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage during Engine building and serialization: CPU: 11880 MiB
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:01:06.079919
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 1699159860 bytes of Memory
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 7756 bytes of code generator cache.
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 3043944 bytes of compilation cache.
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 947 timing cache entries
INFO:torch_tensorrt [TensorRT Conversion Context]:The logger passed into createInferBuilder differs from one already provided for an existing builder, runtime, or refitter. Uses of the global logger, returned by nvinfer1::getLogger(), will return the existing value.
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init CUDA: CPU +0, GPU +0, now: CPU 2519, GPU 9453 (MiB)
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init builder kernel library: CPU +2086, GPU +382, now: CPU 4605, GPU 9835 (MiB)
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node trunk.head.global_pool/as_strided [as_strided] (Inputs: () | Outputs: (as_strided: (s0, 3072, 1, 1)@torch.float32))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node trunk.head.norm/permute_168 [aten.permute.default] (Inputs: (as_strided: (s0, 3072, 1, 1)@torch.float32, [0, 2, 3, 1]) | Outputs: (permute_168: (s0, 1, 1, 3072)@torch.float32))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node trunk_head_norm_weight [trunk.head.norm.weight] (Inputs: () | Outputs: (trunk_head_norm_weight: (3072,)@float32))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node trunk_head_norm_bias [trunk.head.norm.bias] (Inputs: () | Outputs: (trunk_head_norm_bias: (3072,)@float32))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node trunk.head.norm/native_layer_norm_44 [aten.native_layer_norm.default] (Inputs: (permute_168: (s0, 1, 1, 3072)@torch.float32, [3072], trunk_head_norm_weight: (3072,)@float32, trunk_head_norm_bias: (3072,)@float32, 1e-05) | Outputs: (native_layer_norm_44: ((s0, 1, 1, 3072)@torch.float32, (s0, 1, 1, 1)@torch.float32, (s0, 1, 1, 1)@torch.float32)))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node trunk.head.norm/getitem_132 [<built-in function getitem>] (Inputs: (native_layer_norm_44: ((s0, 1, 1, 3072)@torch.float32, (s0, 1, 1, 1)@torch.float32, (s0, 1, 1, 1)@torch.float32), 0) | Outputs: (getitem_132: (s0, 1, 1, 3072)@torch.float32))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node trunk.head.norm/permute_169 [aten.permute.default] (Inputs: (getitem_132: (s0, 1, 1, 3072)@torch.float32, [0, 3, 1, 2]) | Outputs: (permute_169: (s0, 3072, 1, 1)@torch.float32))
WARNING:torch_tensorrt.dynamo.conversion.converter_utils:Detected unparsable type in node formatting: <class 'torch.SymInt'>
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node sym_size_int_81 [sym_size_int_81] (Inputs: () | Outputs: (sym_size_int_81: ))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node trunk.head.flatten/reshape_default_160 [aten.reshape.default] (Inputs: (permute_169: (s0, 3072, 1, 1)@torch.float32, [sym_size_int_81, 3072]) | Outputs: (reshape_default_160: (s0, 3072)@torch.float32))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node trunk.head.drop/clone_80 [aten.clone.default] (Inputs: (reshape_default_160: (s0, 3072)@torch.float32) | Outputs: (clone_80: (s0, 3072)@torch.float32))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node head.drop/clone_81 [aten.clone.default] (Inputs: (clone_80: (s0, 3072)@torch.float32) | Outputs: (clone_81: (s0, 3072)@torch.float32))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node _frozen_param120 [_frozen_param120] (Inputs: () | Outputs: (_frozen_param120: (3072, 1024)@float32))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node head.proj/mm [aten.mm.default] (Inputs: (clone_81: (s0, 3072)@torch.float32, _frozen_param120: (3072, 1024)@float32) | Outputs: (mm: (s0, 1024)@torch.float32))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: (mm: (s0, 1024)@torch.float32) | Outputs: (output: ))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.010138
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Not found cached TRT engines. Start building engine.
INFO:torch_tensorrt [TensorRT Conversion Context]:Global timing cache in use. Profiling results in this builder pass will be stored.
INFO:torch_tensorrt [TensorRT Conversion Context]:Detected 2 inputs and 1 output network tensors.
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Host Persistent Memory: 32
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Device Persistent Memory: 0
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Scratch Memory: 24576
INFO:torch_tensorrt [TensorRT Conversion Context]:[BlockAssignment] Started assigning block shifts. This will take 2 steps to complete.
INFO:torch_tensorrt [TensorRT Conversion Context]:[BlockAssignment] Algorithm ShiftNTopDown took 0.015751ms to assign 2 blocks to 2 nodes requiring 49152 bytes.
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Activation Memory: 49152
INFO:torch_tensorrt [TensorRT Conversion Context]:Total Weights Memory: 6303744
INFO:torch_tensorrt [TensorRT Conversion Context]:Engine generation completed in 3.72328 seconds.
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage of TRT CPU/GPU memory allocators: CPU 1588 MiB, GPU 1788 MiB
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageStats] Peak memory usage during Engine building and serialization: CPU: 12529 MiB
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:03.727700
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 6447340 bytes of Memory
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 8637 bytes of code generator cache.
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 3164135 bytes of compilation cache.
INFO:torch_tensorrt [TensorRT Conversion Context]:Serialized 950 timing cache entries
WARNING:py.warnings:/proj/.venv/lib/python3.11/site-packages/torch_tensorrt/dynamo/_exporter.py:387: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
engine_node = gm.graph.get_attr(engine_name)
WARNING:py.warnings:/proj/.venv/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_0_engine target _run_on_acc_0_engine _run_on_acc_0_engine of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
WARNING:py.warnings:/proj/.venv/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_2_engine target _run_on_acc_2_engine _run_on_acc_2_engine of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
Traceback (most recent call last):
File "/proj/.venv/lib/python3.11/site-packages/torch/_export/serde/serialize.py", line 1314, in serialize_graph
getattr(self, f"handle_{node.op}")(node)
File "/proj/.venv/lib/python3.11/site-packages/torch/_export/serde/serialize.py", line 521, in handle_call_function
inputs=self.serialize_inputs(node.target, node.args, node.kwargs),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/proj/.venv/lib/python3.11/site-packages/torch/_export/serde/serialize.py", line 663, in serialize_inputs
arg=self.serialize_input(args[i], schema_arg.type),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/proj/.venv/lib/python3.11/site-packages/torch/_export/serde/serialize.py", line 893, in serialize_input
raise SerializeError(
torch._export.serde.serialize.SerializeError: Unsupported list/tuple argument type: [<class 'torch.fx.node.Node'>, <class 'int'>, <class 'int'>, <class 'int'>]
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/proj/test_algos/trt/torch_tensorrt_compile_save.py", line 33, in <module>
torch_tensorrt.save(model, "saved_model.ep", inputs=[test_batch])
File "/proj/.venv/lib/python3.11/site-packages/torch_tensorrt/_compile.py", line 529, in save
torch.export.save(exp_program, file_path)
File "/proj/.venv/lib/python3.11/site-packages/torch/export/__init__.py", line 341, in save
artifact: SerializedArtifact = serialize(ep, opset_version)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/proj/.venv/lib/python3.11/site-packages/torch/_export/serde/serialize.py", line 2374, in serialize
serialized_program = ExportedProgramSerializer(opset_version).serialize(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/proj/.venv/lib/python3.11/site-packages/torch/_export/serde/serialize.py", line 1373, in serialize
serialized_graph_module = gm_serializer.serialize(exported_program.graph_module)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/proj/.venv/lib/python3.11/site-packages/torch/_export/serde/serialize.py", line 1344, in serialize
graph = self.serialize_graph(graph_module)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/proj/.venv/lib/python3.11/site-packages/torch/_export/serde/serialize.py", line 1316, in serialize_graph
raise SerializeError(
torch._export.serde.serialize.SerializeError: Failed serializing node as_strided in graph: %as_strided : [num_users=1] = call_function[target=torch.ops.aten.as_strided.default](args = (%getitem, [%getitem_1, 3072, 1, 1], [3072, 1, 3072, 3072]), kwargs = {})
Original exception Traceback (most recent call last):
File "/proj/.venv/lib/python3.11/site-packages/torch/_export/serde/serialize.py", line 1314, in serialize_graph
getattr(self, f"handle_{node.op}")(node)
File "/proj/.venv/lib/python3.11/site-packages/torch/_export/serde/serialize.py", line 521, in handle_call_function
inputs=self.serialize_inputs(node.target, node.args, node.kwargs),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/proj/.venv/lib/python3.11/site-packages/torch/_export/serde/serialize.py", line 663, in serialize_inputs
arg=self.serialize_input(args[i], schema_arg.type),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/proj/.venv/lib/python3.11/site-packages/torch/_export/serde/serialize.py", line 893, in serialize_input
raise SerializeError(
torch._export.serde.serialize.SerializeError: Unsupported list/tuple argument type: [<class 'torch.fx.node.Node'>, <class 'int'>, <class 'int'>, <class 'int'>]To Reproduce
Steps to reproduce the behavior:
- Use python 3.11
- pip install torch==2.5.1 torch_tensorrt==2.5.0 open-clip-torch==2.29.0
- Run the snippet below
import logging
import torch
import torch_tensorrt
import open_clip
logging.basicConfig(level=logging.INFO)
torch.set_float32_matmul_precision('high')
torch.serialization.DEFAULT_PROTOCOL = 4 # Default is 2, it fails to save such large models (4.8GB), see https://github.com/pytorch/TensorRT/issues/3294
device = "cuda:0"
model, _, preprocess = open_clip.create_model_and_transforms("convnext_xxlarge", pretrained="laion2b_s34b_b82k_augreg_soup")
model = model.visual.to(device).eval()
image_size = model.image_size
input_shape = (1, 3, *image_size)
test_batch = torch.randn(*input_shape, dtype=torch.float32).to(device)
trt_input = torch_tensorrt.Input(
min_shape=input_shape,
opt_shape=(4, 3, *image_size),
max_shape=(4, 3, *image_size),
dtype=torch.float32,
name="input",
)
model = torch_tensorrt.compile(
model,
ir="dynamo",
inputs=[trt_input],
enabled_precisions={torch_tensorrt.dtype.half},
)
model(test_batch)
torch_tensorrt.save(model, "saved_model.ep", inputs=[test_batch])Expected behavior
I expected to see a saved optimized model at saved_model.ep, so I do not have to recompile it at every run of a program.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): 2.5.0
- PyTorch Version (e.g. 1.0): 2.5.1
- CPU Architecture: x86
- OS (e.g., Linux): Linux Ubuntu 22.04
- How you installed PyTorch (
conda,pip,libtorch, source):uv add torch - Build command you used (if compiling from source): n/a
- Are you using local sources or building from archives: n/a
- Python version: 3.11
- CUDA version: 12.6
- GPU models and configuration: RTX 3060
- Any other relevant information:
Additional context
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working