-
Notifications
You must be signed in to change notification settings - Fork 270
Open
Labels
bugSomething isn't workingSomething isn't working
Description
🐛 Describe the bug
Currently, the MULTIHEAD_ATTENTION_OUTPUT ignore patterns for onnx and torch only work for "decomposed" versions of attention by matching against MATMUL and SOFTMAX nodes in particular arrangements.
This means that torch models using the fused torch.nn.functional.scaled_dot_product_attention operator and onnx models using the Attention node (opsets 23+) don't get matched.
The edge between MATMUL and SOFTMAX doesn't need to be matched, since it is already "hidden inside" the SDPA / Attention nodes. However, the other MATMUL input should correspond to the V (third) input of SDPA / Attention.
I am willing to look into contributing a fix for this, but I am not 100% sure if I can fully figure this out on my own.
Environment
nncf==2.18.0
torch==2.8.0
about-time==4.2.1
alive-progress==3.3.0
anyio==4.11.0
attrs==25.4.0
autograd==1.8.0
certifi==2025.11.12
click==8.3.0
cma==4.4.0
coloredlogs==15.0.1
contourpy==1.3.3
cycler==0.12.1
Deprecated==1.3.1
dill==0.4.0
filelock==3.20.0
flatbuffers==25.9.23
fonttools==4.60.1
fsspec==2025.10.0
graphemeu==0.7.2
h11==0.16.0
hf-xet==1.2.0
httpcore==1.0.9
httpx==0.28.1
huggingface_hub==1.1.4
humanfriendly==10.0
idna==3.11
Jinja2==3.1.6
joblib==1.5.2
jsonschema==4.25.1
jsonschema-specifications==2025.9.1
kiwisolver==1.4.9
markdown-it-py==4.0.0
MarkupSafe==3.0.3
matplotlib==3.10.7
mdurl==0.1.2
ml_dtypes==0.5.3
mpmath==1.3.0
natsort==8.4.0
networkx==3.4.2
ninja==1.13.0
nncf==2.18.0
numpy==2.2.6
nvidia-cublas-cu12==12.8.4.1
nvidia-cuda-cupti-cu12==12.8.90
nvidia-cuda-nvrtc-cu12==12.8.93
nvidia-cuda-runtime-cu12==12.8.90
nvidia-cudnn-cu12==9.10.2.21
nvidia-cufft-cu12==11.3.3.83
nvidia-cufile-cu12==1.13.1.3
nvidia-curand-cu12==10.3.9.90
nvidia-cusolver-cu12==11.7.3.90
nvidia-cusparse-cu12==12.5.8.93
nvidia-cusparselt-cu12==0.7.1
nvidia-nccl-cu12==2.27.3
nvidia-nvjitlink-cu12==12.8.93
nvidia-nvshmem-cu12==3.3.20
nvidia-nvtx-cu12==12.8.90
onnx==1.19.1
onnx-ir==0.1.12
onnxruntime==1.23.2
onnxscript @ git+https://github.com/ruro/onnxscript.git@ae22c2ff1f9816b3559f65b7019cd9f9ad4203ce
openvino-telemetry==2025.2.0
packaging==25.0
pandas==2.3.3
pillow==12.0.0
protobuf==6.33.1
psutil==7.1.3
pydot==3.0.4
Pygments==2.19.2
pymoo==0.6.1.5
pyparsing==3.2.5
python-dateutil==2.9.0.post0
pytz==2025.2
PyYAML==6.0.3
referencing==0.37.0
rich==14.2.0
rpds-py==0.29.0
safetensors==0.7.0
scikit-learn==1.7.2
scipy==1.16.3
setuptools==80.9.0
shellingham==1.5.4
six==1.17.0
sniffio==1.3.1
sympy==1.14.0
tabulate==0.9.0
threadpoolctl==3.6.0
timm==1.0.22
torch==2.8.0
torchvision==0.23.0
tqdm==4.67.1
triton==3.4.0
typer-slim==0.20.0
typing_extensions==4.15.0
tzdata==2025.2
wrapt==2.0.1
Additionally:
OS NixOS 25.11
Python 3.13.5
Install PyPI
RAM 32.00 GB
CPU 12th Gen Intel(R) Core(TM) i9-12900HK
CUDA 12.8
Minimal Reproducible Example
import torch
import nncf
import timm
sdpa = timm.layers.attention.Attention(1, 1)
input_sample = {
"x": torch.zeros(1, 1, 1),
}
sdpa = nncf.quantize(
sdpa,
calibration_dataset=nncf.Dataset([input_sample]),
model_type=nncf.ModelType.TRANSFORMER,
preset=nncf.QuantizationPreset.PERFORMANCE,
target_device=nncf.TargetDevice.CPU,
)Are you going to submit a PR?
- Yes I'd like to help by submitting a PR!
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working