Skip to content

[Issue]: invalid argument for fmha_fwd with torch.Compile in gfx950 #912

@jjuvonen-amd

Description

@jjuvonen-amd

Problem Description

Aiter FAv3 forward throws RuntimeError: invalid argument for fmha_fwd when model is torch.Compiled and attention inputs are transposed from BHSD to BSHD format before attention function.

Error occurs only with Compile and if .transpose(1,2).contiguous() is applied for q,k,v inputs. Also, error happens on gfx950, but not on gfx942 (although differences between eager vs. compile FAv3 fwd outputs is suspiciously high >5e-2 on gfx942).

Operating System

Ubuntu 24.04.2 LTS (Noble Numbat)

CPU

AMD EPYC 9575F 64-Core Processor

GPU

AMD Instinct MI355X

ROCm Version

ROCm 7.0.0rc20250820

ROCm Component

No response

Steps to Reproduce

Running pytest test_aiter_fa_compile.py runs 6 tests:

  • test_aiter_attention[Eager-BSHD order] Passes
  • test_aiter_attention[Eager-BHSD order] Passes
  • test_aiter_attention[Compiled-BSHD order] Passes
  • test_aiter_attention[Compiled-BHSD order] Fails on gfx950 invalid argument for fmha_fwd
  • test_compare_eager_vs_compiled[BSHD order] Passes
  • test_compare_eager_vs_compiled[BHSD order] Fails on gfx950 invalid argument for fmha_fwd

test_aiter_fa_compile.py:

import pytest
import torch
import numpy as np
import aiter

class SimpleAiterAttentionModel(torch.nn.Module):
    """Simple model with just flash attention."""
    
    def __init__(self, bhsd_order=False):
        super().__init__()
        self.flash_attn_func = aiter.flash_attn_func
        self.bhsd_order = bhsd_order # If True, expects input in (B, H, S, D) order
            
    def forward(self, q, k, v):
        if self.bhsd_order:
            # Convert from (B, H, S, D) to (B, S, H, D) for flash attention
            q = q.transpose(1,2).contiguous()
            k = k.transpose(1,2).contiguous()
            v = v.transpose(1,2).contiguous()

        output = self.flash_attn_func(
            q, k, v, 
            causal=False, 
            deterministic=False,
            return_lse=True # gfx950 requires this to use FAv3 fwd
        )[0]
        
        if self.bhsd_order:
            # Convert back to (B, H, S, D)
            output = output.transpose(1,2).contiguous()
        
        return output

def get_test_data(batch_size=1, seq_len=32768, heads=16, dim=128, bhsd_order=False):
    """Generate test data for attention models."""
    torch.manual_seed(42)

    shape = (batch_size, seq_len, heads, dim)
    if bhsd_order:
        shape = (batch_size, heads, seq_len, dim)
    
    def _generate_data():
        return torch.randn(shape, 
                           device='cuda', 
                           dtype=torch.bfloat16, 
                           requires_grad=False)
    q = _generate_data()
    k = _generate_data()
    v = _generate_data()
    return q, k, v

@pytest.mark.parametrize("bhsd_order", [False, True], ids=["BSHD order", "BHSD order"])
@pytest.mark.parametrize("use_compile", [False, True], ids=["Eager", "Compiled"])
def test_aiter_attention(bhsd_order, use_compile):
    """Test aiter flash attention with and without torch.compile."""
    q, k, v = get_test_data(bhsd_order=bhsd_order)
    
    # Create model and run forward pass
    model = SimpleAiterAttentionModel(bhsd_order=bhsd_order).eval()
    
    if use_compile:
        model = torch.compile(model, mode='default')

    with torch.inference_mode(True):
        output = model(q, k, v)
    
    # check for nans or infs or all zeros
    assert not torch.isnan(output).any(), "Output contains NaNs"
    assert not torch.isinf(output).any(), "Output contains Infs"
    assert not torch.all(output == 0), "Output is all zeros"

@pytest.mark.parametrize("bhsd_order", [False, True], ids=["BSHD order", "BHSD order"])
def test_compare_eager_vs_compiled(bhsd_order):
    """Compare outputs from eager and compiled execution."""
    q, k, v = get_test_data(bhsd_order=bhsd_order)
    
    # Get outputs from eager and compiled versions
    model = SimpleAiterAttentionModel(bhsd_order=bhsd_order).eval()
    with torch.inference_mode(True):
        eager_out = model(q.clone(), k.clone(), v.clone())

    model_compiled = torch.compile(model)
    with torch.inference_mode(True):
        compiled_out = model_compiled(q.clone(), k.clone(), v.clone())

    np.testing.assert_allclose(
        eager_out.to(torch.float32).cpu().numpy(),
        compiled_out.to(torch.float32).cpu().numpy(),
        rtol=1e-1, # generous tolerances due to bfloat16
        atol=5e-2,
        err_msg=(
            f"Outputs from eager and compiled model do not match. "
        )
    )

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

Torch version: 2.7.1+rocmsdk20250821
Aiter version: 0.1.5.dev157+g01330f6c6

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions