Skip to content

MULTIHEAD_ATTENTION_OUTPUT ignored patterns don't match "proper" SDPA / Attention #3750

@ruro

Description

@ruro

🐛 Describe the bug

Currently, the MULTIHEAD_ATTENTION_OUTPUT ignore patterns for onnx and torch only work for "decomposed" versions of attention by matching against MATMUL and SOFTMAX nodes in particular arrangements.

This means that torch models using the fused torch.nn.functional.scaled_dot_product_attention operator and onnx models using the Attention node (opsets 23+) don't get matched.

The edge between MATMUL and SOFTMAX doesn't need to be matched, since it is already "hidden inside" the SDPA / Attention nodes. However, the other MATMUL input should correspond to the V (third) input of SDPA / Attention.

I am willing to look into contributing a fix for this, but I am not 100% sure if I can fully figure this out on my own.

Environment

nncf==2.18.0
torch==2.8.0
about-time==4.2.1
alive-progress==3.3.0
anyio==4.11.0
attrs==25.4.0
autograd==1.8.0
certifi==2025.11.12
click==8.3.0
cma==4.4.0
coloredlogs==15.0.1
contourpy==1.3.3
cycler==0.12.1
Deprecated==1.3.1
dill==0.4.0
filelock==3.20.0
flatbuffers==25.9.23
fonttools==4.60.1
fsspec==2025.10.0
graphemeu==0.7.2
h11==0.16.0
hf-xet==1.2.0
httpcore==1.0.9
httpx==0.28.1
huggingface_hub==1.1.4
humanfriendly==10.0
idna==3.11
Jinja2==3.1.6
joblib==1.5.2
jsonschema==4.25.1
jsonschema-specifications==2025.9.1
kiwisolver==1.4.9
markdown-it-py==4.0.0
MarkupSafe==3.0.3
matplotlib==3.10.7
mdurl==0.1.2
ml_dtypes==0.5.3
mpmath==1.3.0
natsort==8.4.0
networkx==3.4.2
ninja==1.13.0
nncf==2.18.0
numpy==2.2.6
nvidia-cublas-cu12==12.8.4.1
nvidia-cuda-cupti-cu12==12.8.90
nvidia-cuda-nvrtc-cu12==12.8.93
nvidia-cuda-runtime-cu12==12.8.90
nvidia-cudnn-cu12==9.10.2.21
nvidia-cufft-cu12==11.3.3.83
nvidia-cufile-cu12==1.13.1.3
nvidia-curand-cu12==10.3.9.90
nvidia-cusolver-cu12==11.7.3.90
nvidia-cusparse-cu12==12.5.8.93
nvidia-cusparselt-cu12==0.7.1
nvidia-nccl-cu12==2.27.3
nvidia-nvjitlink-cu12==12.8.93
nvidia-nvshmem-cu12==3.3.20
nvidia-nvtx-cu12==12.8.90
onnx==1.19.1
onnx-ir==0.1.12
onnxruntime==1.23.2
onnxscript @ git+https://github.com/ruro/onnxscript.git@ae22c2ff1f9816b3559f65b7019cd9f9ad4203ce
openvino-telemetry==2025.2.0
packaging==25.0
pandas==2.3.3
pillow==12.0.0
protobuf==6.33.1
psutil==7.1.3
pydot==3.0.4
Pygments==2.19.2
pymoo==0.6.1.5
pyparsing==3.2.5
python-dateutil==2.9.0.post0
pytz==2025.2
PyYAML==6.0.3
referencing==0.37.0
rich==14.2.0
rpds-py==0.29.0
safetensors==0.7.0
scikit-learn==1.7.2
scipy==1.16.3
setuptools==80.9.0
shellingham==1.5.4
six==1.17.0
sniffio==1.3.1
sympy==1.14.0
tabulate==0.9.0
threadpoolctl==3.6.0
timm==1.0.22
torch==2.8.0
torchvision==0.23.0
tqdm==4.67.1
triton==3.4.0
typer-slim==0.20.0
typing_extensions==4.15.0
tzdata==2025.2
wrapt==2.0.1
Additionally:
OS                  NixOS 25.11
Python              3.13.5
Install             PyPI
RAM                 32.00 GB
CPU                 12th Gen Intel(R) Core(TM) i9-12900HK
CUDA                12.8

Minimal Reproducible Example

import torch
import nncf
import timm

sdpa = timm.layers.attention.Attention(1, 1)
input_sample = {
    "x": torch.zeros(1, 1, 1),
}

sdpa = nncf.quantize(
    sdpa,
    calibration_dataset=nncf.Dataset([input_sample]),
    model_type=nncf.ModelType.TRANSFORMER,
    preset=nncf.QuantizationPreset.PERFORMANCE,
    target_device=nncf.TargetDevice.CPU,
)

Are you going to submit a PR?

  • Yes I'd like to help by submitting a PR!

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions