Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] RuntimeError When Converting Model with Int64 Buffers in TensorRT 10 #3354

Open
korkland opened this issue Jan 12, 2025 · 7 comments
Labels
bug Something isn't working

Comments

@korkland
Copy link

korkland commented Jan 12, 2025

Bug Description

I am experiencing an issue when converting a model to a TensorRT engine from ts. The error occurs during the conversion process, specifically when I attempt to compile the model using torch_tensorrt.ts.convert_method_to_trt_engine. The error message is as follows:

RuntimeError: [Error thrown at core/compiler.cpp:319] Expected !dtype || dtype.value() != at::kLong to be true but got false
Cannot specify Int64 input for a model fully compiled in TRT

the script im using:

import torch_tensorrt
import torch

class DictToTupleWrapper(torch.nn.Module):
    def __init__(self, model, input_keys):
        super(DictToTupleWrapper, self).__init__()
        self.model = model
        self.input_keys = input_keys

    def forward(self, *inputs):
        # Convert inputs to dictionary using input_keys
        inputs_dict = {key: inputs[i] for i, key in enumerate(self.input_keys)}
        outputs_dict = self.model(inputs_dict)
        # Convert outputs dictionary to tuple
        return tuple(outputs_dict[key] for key in sorted(outputs_dict.keys()))

def create_inputs(inputs_dict):
    inputs_cuda = {k: v.detach().clone().cuda() for k, v in inputs_dict.items()}
    input_tuple = tuple(inputs_cuda[key] for key in inputs_cuda.keys())
    return input_tuple, list(inputs_cuda.keys())

def compile_torch_to_trt(traced_model, inputs_dict, output_path):
    cuda_model = traced_model.cuda().eval()  # Ensure the model is in eval mode and on GPU
    inputs, input_keys = create_inputs(inputs_dict)

    # Wrap the model to handle style inputs
    wrapped_model = DictToTupleWrapper(cuda_model, input_keys)
    original_outputs = wrapped_model(*inputs)

    jit_model = torch.jit.trace(wrapped_model, inputs, strict=False)
    model_trt = torch_tensorrt.compile(
        jit_model,
        inputs=inputs,
        ir="ts",
        truncate_long_and_double=True,
    )

    # Save the serialized model
    torch_tensorrt.ts.convert_method_to_trt_engine(
        jit_model,
        inputs=inputs,
        truncate_long_and_double=True,
        debug=True)

    return model_trt

torch_tensorrt.compile passed with following warnings, if it can help:

WARNING:torch_tensorrt.dynamo.conversion.aten_ops_converters:Unable to import quantization op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models
WARNING: [Torch-TensorRT] - For input getitem_7, found user specified input dtype as Float, however when inspecting the graph, the input type expected was inferred to be Long
The compiler is going to use the user setting Float
This conflict may cause an error at runtime due to partial compilation being enabled and therefore
compatibility with PyTorch's data type convention is required.
If you do indeed see errors at runtime either:
- Remove the dtype spec for getitem_7
- Disable partial compilation by setting require_full_compilation to True
WARNING: [Torch-TensorRT] - Dilation not used in Max pooling converter
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Int64 to Int32
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Int64 to Int32
WARNING: [Torch-TensorRT] - If indices include negative values, the exported graph will produce incorrect results.
WARNING: [Torch-TensorRT] - If indices include negative values, the exported graph will produce incorrect results.
WARNING: [Torch-TensorRT] - If indices include negative values, the exported graph will produce incorrect results.
WARNING: [Torch-TensorRT] - If indices include negative values, the exported graph will produce incorrect results.
WARNING: [Torch-TensorRT] - If indices include negative values, the exported graph will produce incorrect results.
WARNING: [Torch-TensorRT] - If indices include negative values, the exported graph will produce incorrect results.
WARNING: [Torch-TensorRT] - If indices include negative values, the exported graph will produce incorrect results.
WARNING: [Torch-TensorRT] - If indices include negative values, the exported graph will produce incorrect results.
WARNING: [Torch-TensorRT] - Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.

when im searching for int64 in the model, the only int64 i can find are num_batches_tracked in BatchNorm. and even after explicitly convert them im still having the error.
the error occurs even if i explicitly specified the inputs:

trt_engine = torch_tensorrt.convert_method_to_trt_engine(jit_model, inputs=[torch_tensorrt.Input((1,4,32,96),dtype=torch.float), torch_tensorrt.Input((1,2,32,96),dtype=torch.float), torch_tensorrt.Input((1, 3, 768, 1356),dtype=torch.float)], ir="ts", truncate_long_and_double=True)

Additional Information:

  • The model contains num_batches_tracked parameters in several BatchNorm layers, which are of type torch.int64.

  • I confirmed that I converted the relevant tensors (buffers and parameters) from torch.int64 to torch.int32 before tracing the model with torch.jit.trace.

- The issue occurs even though TensorRT 10 supports Int64 types, and truncate_long_and_double=True is set.

To Reproduce

unfortunately i cant share the model, due to company restrictions

Expected behavior

The model should successfully convert to a TensorRT engine without errors related to Int64 types, and the conversion should work with Int32 or Float32 inputs.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.4.0
  • PyTorch Version (e.g. 1.0): 2.4.0+cu118
  • CPU Architecture: X86
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.10
  • CUDA version:11.8
  • GPU models and configuration: A40
  • Any other relevant information:
@korkland korkland added the bug Something isn't working label Jan 12, 2025
@narendasan
Copy link
Collaborator

Since you are using torch.jit.trace my recommendation would be to try the dynamo fronted which likely solves this issue and supports native int64. If you still want to use TorchScript as the deployment format, you can trace the resulting compiled program with torch.jit.trace or torch_tensorrt.save. You might not need the DictWrapper as well

@korkland
Copy link
Author

korkland commented Jan 15, 2025

Since you are using torch.jit.trace my recommendation would be to try the dynamo fronted which likely solves this issue and supports native int64. If you still want to use TorchScript as the deployment format, you can trace the resulting compiled program with torch.jit.trace or torch_tensorrt.save. You might not need the DictWrapper as well

I've tried using the Dynamo backend, and while it works, I can't find a way to pass a dictionary to the model. If I pass it as arg_inputs, the dictionary gets converted into a list of its keys. For example, {x: tensor(...), y: tensor(...)} becomes ['x', 'y'], which seems buggy to me, though I'm not entirely sure. This behavior can be traced to _compiler.py (version 2.5.0), line 622:

arg_input_list = list(prepare_inputs(arg_inputs))

Here, prepare_inputs returns the original dictionary, but list() converts it into a list of keys.

On the other hand, if I pass the dictionary as kwarg_inputs, it expects the forward signature to match the dictionary's fields, as demonstrated in test_export_kwargs_serde.py.

I would appreciate clarification on whether passing a dictionary to the model for conversion using trt.dynamo.convert_exported_program_to_serialized_trt_engine is supported.

Additionally, in case passing a dictionary isn't supported, how can I assign the keys as names to the input tensors? For example, assigning names like "image" or "grid" instead of defaulting to "input_1", "input_2", etc., would make deployment much easier by allowing inputs to be matched to their intended names.

I would appreciate any guidance on this.

@narendasan
Copy link
Collaborator

if you are using trt.dynamo.convert_exported_program_to_serialized_trt_engine then you must use a flat set of inputs since there will be no python / pytorch intermediary to format the data for the TensorRT engine. The output of torchtrt.dynamo.compile will by a PyTorch module that uses the same forward as the source module and flattens the data for the engine for you, that is the case where arg_inputs, kwarg_inputs will be respected.

@korkland
Copy link
Author

if you are using trt.dynamo.convert_exported_program_to_serialized_trt_engine then you must use a flat set of inputs since there will be no python / pytorch intermediary to format the data for the TensorRT engine. The output of torchtrt.dynamo.compile will by a PyTorch module that uses the same forward as the source module and flattens the data for the engine for you, that is the case where arg_inputs, kwarg_inputs will be respected.

Ok and in the case of trt.dynamo.convert_exported_program_to_serialized_trt_engine how can i assign names to the tensors? When we are using onnx conversion the names are assigned.

@narendasan
Copy link
Collaborator

Input names are set via the target names of the input nodes in the FX graph, this comes from dynamo.

@korkland
Copy link
Author

Thanks for your responses! However, I'm still unclear about how to set names for inputs and outputs. In the following code example, the input names are correctly derived from the forward signature, but the output remains as "output0".

Could you explain how, in this flow, I can set the output name to be something like "z"? For example, while I attempt to assign "z" as the output name here:

py_trt_module = PythonTorchTensorRTModule(serialized_engine, list(dict_input.keys()), ["z"])

I get an error (Given invalid tensor name: z) because the output tensor name defaults to "output0". Here's the complete code for reference:

import torch
import torch_tensorrt as trt
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity

# Step 1: Define your model
class MyModel(torch.nn.Module):
    def forward(self, x, y):
        z = x + y
        return z

model = MyModel().cuda().eval()

# Step 2: Prepare inputs with explicit names
input_1 = torch.randn(1, 3, 224, 224, device="cuda").detach().clone()
input_2 = torch.randn(1, 3, 224, 224, device="cuda").detach().clone()
dict_input = {"x": input_1, "y": input_2}

# Step 3: Export the model
exported_program = torch.export.export(model, args=(), kwargs=dict_input, strict=False)

# Step 4: Convert to TensorRT engine
serialized_engine = trt.dynamo.convert_exported_program_to_serialized_trt_engine(
    exported_program,
    arg_inputs=(),
    kwarg_inputs=dict_input,
    enabled_precisions={torch.float32, torch.float16},
)

# Error (Given invalid tensor name: z) - output name is output0
py_trt_module = PythonTorchTensorRTModule(serialized_engine, list(dict_input.keys()), ["z"])
trt_output = py_trt_module(input_1, input_2)

# Step 5: Save the engine
with open("named_inputs_model_trt.engine", "wb") as f:
    f.write(serialized_engine)

I’d greatly appreciate your help in understanding how to correctly name the output tensor. Thank you!

@narendasan
Copy link
Collaborator

We don't support a method to set your own custom names through any method other than setting the names of parameters to the forward function, we just inherent those names directly through dynamo.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants