-
Notifications
You must be signed in to change notification settings - Fork 357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🐛 [Bug] RuntimeError When Converting Model with Int64 Buffers in TensorRT 10 #3354
Comments
Since you are using |
I've tried using the Dynamo backend, and while it works, I can't find a way to pass a dictionary to the model. If I pass it as arg_inputs, the dictionary gets converted into a list of its keys. For example, {x: tensor(...), y: tensor(...)} becomes ['x', 'y'], which seems buggy to me, though I'm not entirely sure. This behavior can be traced to _compiler.py (version 2.5.0), line 622:
Here, prepare_inputs returns the original dictionary, but list() converts it into a list of keys. On the other hand, if I pass the dictionary as kwarg_inputs, it expects the forward signature to match the dictionary's fields, as demonstrated in test_export_kwargs_serde.py. I would appreciate clarification on whether passing a dictionary to the model for conversion using trt.dynamo.convert_exported_program_to_serialized_trt_engine is supported. Additionally, in case passing a dictionary isn't supported, how can I assign the keys as names to the input tensors? For example, assigning names like "image" or "grid" instead of defaulting to "input_1", "input_2", etc., would make deployment much easier by allowing inputs to be matched to their intended names. I would appreciate any guidance on this. |
if you are using |
Ok and in the case of trt.dynamo.convert_exported_program_to_serialized_trt_engine how can i assign names to the tensors? When we are using onnx conversion the names are assigned. |
Input names are set via the target names of the input nodes in the FX graph, this comes from dynamo.
|
Thanks for your responses! However, I'm still unclear about how to set names for inputs and outputs. In the following code example, the input names are correctly derived from the forward signature, but the output remains as "output0". Could you explain how, in this flow, I can set the output name to be something like "z"? For example, while I attempt to assign "z" as the output name here:
I get an error (Given invalid tensor name: z) because the output tensor name defaults to "output0". Here's the complete code for reference:
I’d greatly appreciate your help in understanding how to correctly name the output tensor. Thank you! |
We don't support a method to set your own custom names through any method other than setting the names of parameters to the forward function, we just inherent those names directly through dynamo. |
Bug Description
I am experiencing an issue when converting a model to a TensorRT engine from ts. The error occurs during the conversion process, specifically when I attempt to compile the model using torch_tensorrt.ts.convert_method_to_trt_engine. The error message is as follows:
the script im using:
torch_tensorrt.compile passed with following warnings, if it can help:
when im searching for int64 in the model, the only int64 i can find are num_batches_tracked in BatchNorm. and even after explicitly convert them im still having the error.
the error occurs even if i explicitly specified the inputs:
Additional Information:
The model contains num_batches_tracked parameters in several BatchNorm layers, which are of type torch.int64.
I confirmed that I converted the relevant tensors (buffers and parameters) from torch.int64 to torch.int32 before tracing the model with torch.jit.trace.
- The issue occurs even though TensorRT 10 supports Int64 types, and truncate_long_and_double=True is set.
To Reproduce
unfortunately i cant share the model, due to company restrictions
Expected behavior
The model should successfully convert to a TensorRT engine without errors related to Int64 types, and the conversion should work with Int32 or Float32 inputs.
Environment
conda
,pip
,libtorch
, source): pipThe text was updated successfully, but these errors were encountered: