Skip to content

🐛 [Bug] nn.MultiheadAttention fails with Torch-TensorRT due to non-contiguous tensor before view() #3823

@LinzhouLi

Description

@LinzhouLi

Bug Description

When compiling a simple nn.MultiheadAttention module with Torch-TensorRT using the dynamo IR, I get a runtime error related to view() because the tensor returned by scaled_dot_product_attention is not contiguous.

Adding .contiguous() fixes the problem.

To Reproduce

Code sample

import torch
import torch_tensorrt

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.self_attn = torch.nn.MultiheadAttention(
            embed_dim=768, num_heads=8, kdim=768, vdim=768,
            dropout=0.0, bias=False, batch_first=True
        )
    
    def forward(self, x):
        att_out, _ = self.self_attn(x, x, x, need_weights=False)
        return att_out

model = Model()
model.cuda()
model.eval()

inputs = [torch.rand([1, 1024, 768], device='cuda', dtype=torch.float32)]
trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
torch_tensorrt.save(trt_gm, "tmp.ep", inputs=inputs)

Error

TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops
[09/17/2025-20:06:17] [TRT] [W] Functionality provided through tensorrt.plugin module is experimental.
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431] failed while attempting to run meta for aten.view.default
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431] Traceback (most recent call last):
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431]   File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2427, in _dispatch_impl
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431]     r = func(*args, **kwargs)
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431]         ^^^^^^^^^^^^^^^^^^^^^
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431]   File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_ops.py", line 756, in __call__
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431]     return self._op(*args, **kwargs)
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431]            ^^^^^^^^^^^^^^^^^^^^^^^^^
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431]   File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_refs/__init__.py", line 4671, in view
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431]     return _reshape_view_helper(a, *shape, allow_copy=False)
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431]   File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_refs/__init__.py", line 3800, in _reshape_view_helper
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431]     raise ValueError(msg)
E0917 20:06:18.037000 4088413 site-packages/torch/_subclasses/fake_tensor.py:2431] ValueError: Cannot view a tensor with shape torch.Size([1024, 1, 8, 96]) and strides (96, 786432, 98304, 1) as a tensor with shape (1024, 768)!
Traceback (most recent call last):
  File "/home/lilinzhou/code/Head/folder/issue.py", line 23, in <module>
    trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch_tensorrt/_compile.py", line 289, in compile
    trt_graph_module = dynamo_compile(
                       ^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch_tensorrt/dynamo/_compiler.py", line 682, in compile
    exported_program = exported_program.run_decompositions(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/export/exported_program.py", line 121, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/export/exported_program.py", line 1405, in run_decompositions
    return _decompose_exported_program(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/export/exported_program.py", line 872, in _decompose_exported_program
    ) = _decompose_and_get_gm_with_new_signature_constants(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/export/exported_program.py", line 491, in _decompose_and_get_gm_with_new_signature_constants
    aten_export_artifact = _export_to_aten_ir(
                           ^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/export/_trace.py", line 816, in _export_to_aten_ir
    gm, graph_signature = transform(aot_export_module)(
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1355, in aot_export_module
    fx_g, metadata, in_spec, out_spec = _aot_export_function(
                                        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1594, in _aot_export_function
    fx_g, meta = create_aot_dispatcher_function(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 570, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 671, in _create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 197, in inner
    flat_f_outs = f(*flat_f_args)
                  ^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
    tree_out = fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 899, in functional_call
    out = PropagateUnbackedSymInts(mod).run(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 171, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7183, in run_node
    result = super().run_node(n)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 240, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/fx/interpreter.py", line 320, in call_function
    return target(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_ops.py", line 756, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_subclasses/functional_tensor.py", line 525, in __torch_dispatch__
    outs_unwrapped = func._op_dk(
                     ^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/utils/_stats.py", line 27, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1282, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1823, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1384, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2427, in _dispatch_impl
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_ops.py", line 756, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_refs/__init__.py", line 4671, in view
    return _reshape_view_helper(a, *shape, allow_copy=False)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/lilinzhou/.conda/envs/myenv/lib/python3.12/site-packages/torch/_refs/__init__.py", line 3800, in _reshape_view_helper
    raise ValueError(msg)
ValueError: Cannot view a tensor with shape torch.Size([1024, 1, 8, 96]) and strides (96, 786432, 98304, 1) as a tensor with shape (1024, 768)!

While executing %view_6 : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%permute, [1024, 768]), kwargs = {})
GraphModule: class GraphModule(torch.nn.Module):
    def forward(self, x):
        x: "f32[1, 1024, 768][786432, 768, 1]"; 
    
        x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
        # No stacktrace found for following nodes
        self_attn_in_proj_weight: "f32[2304, 768][768, 1]" = self.self_attn.in_proj_weight
        self_attn_out_proj_weight: "f32[768, 768][768, 1]" = self.self_attn.out_proj.weight
        
         # File: /home/lilinzhou/code/Head/folder/issue.py:13 in forward, code: att_out, _ = self.self_attn(x, x, x, need_weights=False)
        transpose: "f32[1024, 1, 768][768, 786432, 1]" = torch.ops.aten.transpose.int(x, 1, 0);  x = None
        linear: "f32[1024, 1, 2304][2304, 2304, 1]" = torch.ops.aten.linear.default(transpose, self_attn_in_proj_weight);  transpose = self_attn_in_proj_weight = None
        unflatten: "f32[1024, 1, 3, 768][2304, 2304, 768, 1]" = torch.ops.aten.unflatten.int(linear, -1, [3, 768]);  linear = None
        unsqueeze: "f32[1, 1024, 1, 3, 768][2359296, 2304, 2304, 768, 1]" = torch.ops.aten.unsqueeze.default(unflatten, 0);  unflatten = None
        transpose_1: "f32[3, 1024, 1, 1, 768][768, 2304, 2304, 2359296, 1]" = torch.ops.aten.transpose.int(unsqueeze, 0, -2);  unsqueeze = None
        squeeze: "f32[3, 1024, 1, 768][768, 2304, 2304, 1]" = torch.ops.aten.squeeze.dim(transpose_1, -2);  transpose_1 = None
        contiguous: "f32[3, 1024, 1, 768][786432, 768, 768, 1]" = torch.ops.aten.contiguous.default(squeeze);  squeeze = None
        select: "f32[1024, 1, 768][768, 768, 1]" = torch.ops.aten.select.int(contiguous, 0, 0)
        select_1: "f32[1024, 1, 768][768, 768, 1]" = torch.ops.aten.select.int(contiguous, 0, 1)
        select_2: "f32[1024, 1, 768][768, 768, 1]" = torch.ops.aten.select.int(contiguous, 0, 2);  contiguous = None
        view: "f32[1024, 8, 96][768, 96, 1]" = torch.ops.aten.view.default(select, [1024, 8, 96]);  select = None
        transpose_2: "f32[8, 1024, 96][96, 768, 1]" = torch.ops.aten.transpose.int(view, 0, 1);  view = None
        view_1: "f32[1024, 8, 96][768, 96, 1]" = torch.ops.aten.view.default(select_1, [1024, 8, 96]);  select_1 = None
        transpose_3: "f32[8, 1024, 96][96, 768, 1]" = torch.ops.aten.transpose.int(view_1, 0, 1);  view_1 = None
        view_2: "f32[1024, 8, 96][768, 96, 1]" = torch.ops.aten.view.default(select_2, [1024, 8, 96]);  select_2 = None
        transpose_4: "f32[8, 1024, 96][96, 768, 1]" = torch.ops.aten.transpose.int(view_2, 0, 1);  view_2 = None
        view_3: "f32[1, 8, 1024, 96][768, 96, 768, 1]" = torch.ops.aten.view.default(transpose_2, [1, 8, 1024, 96]);  transpose_2 = None
        view_4: "f32[1, 8, 1024, 96][768, 96, 768, 1]" = torch.ops.aten.view.default(transpose_3, [1, 8, 1024, 96]);  transpose_3 = None
        view_5: "f32[1, 8, 1024, 96][768, 96, 768, 1]" = torch.ops.aten.view.default(transpose_4, [1, 8, 1024, 96]);  transpose_4 = None
        scaled_dot_product_attention: "f32[1, 8, 1024, 96][786432, 96, 768, 1]" = torch.ops.aten.scaled_dot_product_attention.default(view_3, view_4, view_5);  view_3 = view_4 = view_5 = None
        permute: "f32[1024, 1, 8, 96][768, 786432, 96, 1]" = torch.ops.aten.permute.default(scaled_dot_product_attention, [2, 0, 1, 3]);  scaled_dot_product_attention = None
        view_6: "f32[1024, 768][768, 1]" = torch.ops.aten.view.default(permute, [1024, 768]);  permute = None
        linear_1: "f32[1024, 768][768, 1]" = torch.ops.aten.linear.default(view_6, self_attn_out_proj_weight);  view_6 = self_attn_out_proj_weight = None
        view_7: "f32[1024, 1, 768][768, 768, 1]" = torch.ops.aten.view.default(linear_1, [1024, 1, 768]);  linear_1 = None
        transpose_5: "f32[1, 1024, 768][768, 768, 1]" = torch.ops.aten.transpose.int(view_7, 1, 0);  view_7 = None
        return pytree.tree_unflatten((transpose_5,), self._out_spec)
        

Original traceback:
  File "/home/lilinzhou/code/Head/folder/issue.py", line 13, in forward
    att_out, _ = self.self_attn(x, x, x, need_weights=False)

Workaround

If I manually patch the code by adding .contiguous() right after scaled_dot_product_attention(), the problem goes away:
https://github.com/pytorch/pytorch/blob/89a6dbe73af4ca64ee26f4e46219e163b827e698/torch/nn/functional.py#L6487-L6489

attn_output = scaled_dot_product_attention(
    q, k, v, attn_mask, dropout_p, is_causal
).contiguous()

Expected behavior

MultiheadAttention should work out of the box when compiled with Torch-TensorRT without requiring manual .contiguous() hacks.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.7.0+cu118
  • PyTorch Version (e.g. 1.0): 2.7.1+cu118
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Ubuntu 20.04
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives: building from archives
  • Python version: 3.12
  • CUDA version: 11.8
  • GPU models and configuration: RTX3090 (Driver Version 550.135)
  • Any other relevant information:

Additional context

It seems the tensor returned by scaled_dot_product_attention can be non-contiguous, but later .view() assumes contiguous memory layout. Adding .contiguous() fixes the issue.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions