-
Notifications
You must be signed in to change notification settings - Fork 106
Closed
Description
Problem Description
Aiter FAv3 forward throws RuntimeError: invalid argument for fmha_fwd
when model is torch.Compiled and attention inputs are transposed from BHSD to BSHD format before attention function.
Error occurs only with Compile and if .transpose(1,2).contiguous()
is applied for q,k,v inputs. Also, error happens on gfx950, but not on gfx942 (although differences between eager vs. compile FAv3 fwd outputs is suspiciously high >5e-2
on gfx942).
Operating System
Ubuntu 24.04.2 LTS (Noble Numbat)
CPU
AMD EPYC 9575F 64-Core Processor
GPU
AMD Instinct MI355X
ROCm Version
ROCm 7.0.0rc20250820
ROCm Component
No response
Steps to Reproduce
Running pytest test_aiter_fa_compile.py
runs 6 tests:
test_aiter_attention[Eager-BSHD order]
Passestest_aiter_attention[Eager-BHSD order]
Passestest_aiter_attention[Compiled-BSHD order]
Passestest_aiter_attention[Compiled-BHSD order]
Fails on gfx950invalid argument for fmha_fwd
test_compare_eager_vs_compiled[BSHD order]
Passestest_compare_eager_vs_compiled[BHSD order]
Fails on gfx950invalid argument for fmha_fwd
test_aiter_fa_compile.py:
import pytest
import torch
import numpy as np
import aiter
class SimpleAiterAttentionModel(torch.nn.Module):
"""Simple model with just flash attention."""
def __init__(self, bhsd_order=False):
super().__init__()
self.flash_attn_func = aiter.flash_attn_func
self.bhsd_order = bhsd_order # If True, expects input in (B, H, S, D) order
def forward(self, q, k, v):
if self.bhsd_order:
# Convert from (B, H, S, D) to (B, S, H, D) for flash attention
q = q.transpose(1,2).contiguous()
k = k.transpose(1,2).contiguous()
v = v.transpose(1,2).contiguous()
output = self.flash_attn_func(
q, k, v,
causal=False,
deterministic=False,
return_lse=True # gfx950 requires this to use FAv3 fwd
)[0]
if self.bhsd_order:
# Convert back to (B, H, S, D)
output = output.transpose(1,2).contiguous()
return output
def get_test_data(batch_size=1, seq_len=32768, heads=16, dim=128, bhsd_order=False):
"""Generate test data for attention models."""
torch.manual_seed(42)
shape = (batch_size, seq_len, heads, dim)
if bhsd_order:
shape = (batch_size, heads, seq_len, dim)
def _generate_data():
return torch.randn(shape,
device='cuda',
dtype=torch.bfloat16,
requires_grad=False)
q = _generate_data()
k = _generate_data()
v = _generate_data()
return q, k, v
@pytest.mark.parametrize("bhsd_order", [False, True], ids=["BSHD order", "BHSD order"])
@pytest.mark.parametrize("use_compile", [False, True], ids=["Eager", "Compiled"])
def test_aiter_attention(bhsd_order, use_compile):
"""Test aiter flash attention with and without torch.compile."""
q, k, v = get_test_data(bhsd_order=bhsd_order)
# Create model and run forward pass
model = SimpleAiterAttentionModel(bhsd_order=bhsd_order).eval()
if use_compile:
model = torch.compile(model, mode='default')
with torch.inference_mode(True):
output = model(q, k, v)
# check for nans or infs or all zeros
assert not torch.isnan(output).any(), "Output contains NaNs"
assert not torch.isinf(output).any(), "Output contains Infs"
assert not torch.all(output == 0), "Output is all zeros"
@pytest.mark.parametrize("bhsd_order", [False, True], ids=["BSHD order", "BHSD order"])
def test_compare_eager_vs_compiled(bhsd_order):
"""Compare outputs from eager and compiled execution."""
q, k, v = get_test_data(bhsd_order=bhsd_order)
# Get outputs from eager and compiled versions
model = SimpleAiterAttentionModel(bhsd_order=bhsd_order).eval()
with torch.inference_mode(True):
eager_out = model(q.clone(), k.clone(), v.clone())
model_compiled = torch.compile(model)
with torch.inference_mode(True):
compiled_out = model_compiled(q.clone(), k.clone(), v.clone())
np.testing.assert_allclose(
eager_out.to(torch.float32).cpu().numpy(),
compiled_out.to(torch.float32).cpu().numpy(),
rtol=1e-1, # generous tolerances due to bfloat16
atol=5e-2,
err_msg=(
f"Outputs from eager and compiled model do not match. "
)
)
(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support
No response
Additional Information
Torch version: 2.7.1+rocmsdk20250821
Aiter version: 0.1.5.dev157+g01330f6c6
Metadata
Metadata
Assignees
Labels
No labels