-
Notifications
You must be signed in to change notification settings - Fork 370
Description
Bug Description
When compiling a simple nn.MultiheadAttention module with Torch-TensorRT using the dynamo IR, I get a runtime error related to view() because the tensor returned by scaled_dot_product_attention is not contiguous.
Adding .contiguous() fixes the problem.
To Reproduce
Code sample
import torch
import torch_tensorrt
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.self_attn = torch.nn.MultiheadAttention(
embed_dim=768, num_heads=8, kdim=768, vdim=768,
dropout=0.0, bias=False, batch_first=True
)
def forward(self, x):
att_out, _ = self.self_attn(x, x, x, need_weights=False)
return att_out
model = Model()
model.cuda()
model.eval()
inputs = [torch.rand([1, 1024, 768], device='cuda', dtype=torch.float32)]
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
torch_tensorrt.save(trt_gm, "tmp.ep", inputs=inputs)
Error
TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops
[09/17/2025-20:06:17] [TRT] [W] Functionality provided through tensorrt.plugin module is experimental.
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431] failed while attempting to run meta for aten.view.default
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431] Traceback (most recent call last):
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431] File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2427, in _dispatch_impl
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431] r = func(*args, **kwargs)
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431] ^^^^^^^^^^^^^^^^^^^^^
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431] File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_ops.py", line 756, in __call__
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431] return self._op(*args, **kwargs)
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431] ^^^^^^^^^^^^^^^^^^^^^^^^^
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431] File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_refs/__init__.py", line 4671, in view
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431] return _reshape_view_helper(a, *shape, allow_copy=False)
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431] File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_refs/__init__.py", line 3800, in _reshape_view_helper
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431] raise ValueError(msg)
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431] ValueError: Cannot view a tensor with shape torch.Size([1024, 1, 8, 96]) and strides (96, 786432, 98304, 1) as a tensor with shape (1024, 768)!
Traceback (most recent call last):
File "/home/lilinzhou/code/Head/folder/issue.py", line 23, in <module>
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch_tensorrt/_compile.py", line 289, in compile
trt_graph_module = dynamo_compile(
^^^^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch_tensorrt/dynamo/_compiler.py", line 682, in compile
exported_program = exported_program.run_decompositions(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/export/exported_program.py", line 121, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/export/exported_program.py", line 1405, in run_decompositions
return _decompose_exported_program(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/export/exported_program.py", line 872, in _decompose_exported_program
) = _decompose_and_get_gm_with_new_signature_constants(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/export/exported_program.py", line 491, in _decompose_and_get_gm_with_new_signature_constants
aten_export_artifact = _export_to_aten_ir(
^^^^^^^^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/export/_trace.py", line 816, in _export_to_aten_ir
gm, graph_signature = transform(aot_export_module)(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1355, in aot_export_module
fx_g, metadata, in_spec, out_spec = _aot_export_function(
^^^^^^^^^^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1594, in _aot_export_function
fx_g, meta = create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 570, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 671, in _create_aot_dispatcher_function
fw_metadata = run_functionalized_fw_and_collect_metadata(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 197, in inner
flat_f_outs = f(*flat_f_args)
^^^^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
tree_out = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 899, in functional_call
out = PropagateUnbackedSymInts(mod).run(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 171, in run
self.env[node] = self.run_node(node)
^^^^^^^^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7183, in run_node
result = super().run_node(n)
^^^^^^^^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 240, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 320, in call_function
return target(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_ops.py", line 756, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_subclasses/functional_tensor.py", line 525, in __torch_dispatch__
outs_unwrapped = func._op_dk(
^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/utils/_stats.py", line 27, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1282, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1823, in dispatch
return self._cached_dispatch_impl(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1384, in _cached_dispatch_impl
output = self._dispatch_impl(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2427, in _dispatch_impl
r = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_ops.py", line 756, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_refs/__init__.py", line 4671, in view
return _reshape_view_helper(a, *shape, allow_copy=False)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_refs/__init__.py", line 3800, in _reshape_view_helper
raise ValueError(msg)
ValueError: Cannot view a tensor with shape torch.Size([1024, 1, 8, 96]) and strides (96, 786432, 98304, 1) as a tensor with shape (1024, 768)!
While executing %view_6 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%permute, [1024, 768]), kwargs = {})
GraphModule: class GraphModule(torch.nn.Module):
def forward(self, x):
x: "f32[1, 1024, 768][786432, 768, 1]";
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
# No stacktrace found for following nodes
self_attn_in_proj_weight: "f32[2304, 768][768, 1]" = self.self_attn.in_proj_weight
self_attn_out_proj_weight: "f32[768, 768][768, 1]" = self.self_attn.out_proj.weight
# File: /home/lilinzhou/code/Head/folder/issue.py:13 in forward, code: att_out, _ = self.self_attn(x, x, x, need_weights=False)
transpose: "f32[1024, 1, 768][768, 786432, 1]" = torch.ops.aten.transpose.int(x, 1, 0); x = None
linear: "f32[1024, 1, 2304][2304, 2304, 1]" = torch.ops.aten.linear.default(transpose, self_attn_in_proj_weight); transpose = self_attn_in_proj_weight = None
unflatten: "f32[1024, 1, 3, 768][2304, 2304, 768, 1]" = torch.ops.aten.unflatten.int(linear, -1, [3, 768]); linear = None
unsqueeze: "f32[1, 1024, 1, 3, 768][2359296, 2304, 2304, 768, 1]" = torch.ops.aten.unsqueeze.default(unflatten, 0); unflatten = None
transpose_1: "f32[3, 1024, 1, 1, 768][768, 2304, 2304, 2359296, 1]" = torch.ops.aten.transpose.int(unsqueeze, 0, -2); unsqueeze = None
squeeze: "f32[3, 1024, 1, 768][768, 2304, 2304, 1]" = torch.ops.aten.squeeze.dim(transpose_1, -2); transpose_1 = None
contiguous: "f32[3, 1024, 1, 768][786432, 768, 768, 1]" = torch.ops.aten.contiguous.default(squeeze); squeeze = None
select: "f32[1024, 1, 768][768, 768, 1]" = torch.ops.aten.select.int(contiguous, 0, 0)
select_1: "f32[1024, 1, 768][768, 768, 1]" = torch.ops.aten.select.int(contiguous, 0, 1)
select_2: "f32[1024, 1, 768][768, 768, 1]" = torch.ops.aten.select.int(contiguous, 0, 2); contiguous = None
view: "f32[1024, 8, 96][768, 96, 1]" = torch.ops.aten.view.default(select, [1024, 8, 96]); select = None
transpose_2: "f32[8, 1024, 96][96, 768, 1]" = torch.ops.aten.transpose.int(view, 0, 1); view = None
view_1: "f32[1024, 8, 96][768, 96, 1]" = torch.ops.aten.view.default(select_1, [1024, 8, 96]); select_1 = None
transpose_3: "f32[8, 1024, 96][96, 768, 1]" = torch.ops.aten.transpose.int(view_1, 0, 1); view_1 = None
view_2: "f32[1024, 8, 96][768, 96, 1]" = torch.ops.aten.view.default(select_2, [1024, 8, 96]); select_2 = None
transpose_4: "f32[8, 1024, 96][96, 768, 1]" = torch.ops.aten.transpose.int(view_2, 0, 1); view_2 = None
view_3: "f32[1, 8, 1024, 96][768, 96, 768, 1]" = torch.ops.aten.view.default(transpose_2, [1, 8, 1024, 96]); transpose_2 = None
view_4: "f32[1, 8, 1024, 96][768, 96, 768, 1]" = torch.ops.aten.view.default(transpose_3, [1, 8, 1024, 96]); transpose_3 = None
view_5: "f32[1, 8, 1024, 96][768, 96, 768, 1]" = torch.ops.aten.view.default(transpose_4, [1, 8, 1024, 96]); transpose_4 = None
scaled_dot_product_attention: "f32[1, 8, 1024, 96][786432, 96, 768, 1]" = torch.ops.aten.scaled_dot_product_attention.default(view_3, view_4, view_5); view_3 = view_4 = view_5 = None
permute: "f32[1024, 1, 8, 96][768, 786432, 96, 1]" = torch.ops.aten.permute.default(scaled_dot_product_attention, [2, 0, 1, 3]); scaled_dot_product_attention = None
view_6: "f32[1024, 768][768, 1]" = torch.ops.aten.view.default(permute, [1024, 768]); permute = None
linear_1: "f32[1024, 768][768, 1]" = torch.ops.aten.linear.default(view_6, self_attn_out_proj_weight); view_6 = self_attn_out_proj_weight = None
view_7: "f32[1024, 1, 768][768, 768, 1]" = torch.ops.aten.view.default(linear_1, [1024, 1, 768]); linear_1 = None
transpose_5: "f32[1, 1024, 768][768, 768, 1]" = torch.ops.aten.transpose.int(view_7, 1, 0); view_7 = None
return pytree.tree_unflatten((transpose_5,), self._out_spec)
Original traceback:
File "/home/lilinzhou/code/Head/folder/issue.py", line 13, in forward
att_out, _ = self.self_attn(x, x, x, need_weights=False)
Workaround
If I manually patch the code by adding .contiguous() right after scaled_dot_product_attention()
, the problem goes away:
https://github.com/pytorch/pytorch/blob/89a6dbe73af4ca64ee26f4e46219e163b827e698/torch/nn/functional.py#L6487-L6489
attn_output = scaled_dot_product_attention(
q, k, v, attn_mask, dropout_p, is_causal
).contiguous()
Expected behavior
MultiheadAttention
should work out of the box when compiled with Torch-TensorRT without requiring manual .contiguous()
hacks.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): 2.7.0+cu118
- PyTorch Version (e.g. 1.0): 2.7.1+cu118
- CPU Architecture: x86_64
- OS (e.g., Linux): Ubuntu 20.04
- How you installed PyTorch (
conda
,pip
,libtorch
, source): pip - Build command you used (if compiling from source):
- Are you using local sources or building from archives: building from archives
- Python version: 3.12
- CUDA version: 11.8
- GPU models and configuration: RTX3090 (Driver Version 550.135)
- Any other relevant information:
Additional context
It seems the tensor returned by scaled_dot_product_attention
can be non-contiguous, but later .view()
assumes contiguous memory layout. Adding .contiguous()
fixes the issue.