FlashBias: Fast Computation of Attention with Bias [Paper]
Attention with bias widely exists, such as relative position encoding in Swin Transformer, scientific prior encoding in AlphaFold (Nature 2024) and Pangu-Weather (Nature 2023). Surprisingly, despite the common use of attention with bias, its targeted efficiency optimization remains absent.
This paper presents FlashBias based on the low-rank compressed sensing theory, which can provide fast-exact computation for many widely used attention biases, enabling 1.5× speedup for Pairformer in AlphaFold 3, and over 2× speedup for attention with bias in vision and language models without loss of accuracy.
Figure 1. Overview of FlashBias.
Check ./flash_bias for Triton kernel and comparison among FlashBias, FlashAttention, torch-complie, SDPA and xFormers.
The following are some representative applications of FlashBias.
- Large language model with ALiBi bias: See ./1_Language_Model
- Swin Transformer V2 with relative position bias: See ./2_Vision_Transformer
- Transformer PDE Solver with spatial distance bias: See ./3_Neual_Solver
- AlphaFold 3 with pair representation bias: ./4_AlphaFold3
Figure 2. Three types of implementations for FlashBias and corresponding representative applications.
## Triton-based FlashBias, See `flash_bias_triton.py` for more details.
from flash_bias_triton import flash_bias_func
output_flash = flash_bias_func(q, k, v, q_bias, k_bias, None, False, 1 / np.sqrt(headdim))
## PyTorch-SDPA-based FlashBias
import torch
# Notably, the dimension of concat[q, q_bias] should be divided evenly by 8; otherwise, you cannot activate flashattention in the backend
output_flash_sdpa = torch.nn.functional.scaled_dot_product_attention(
query=torch.concat([q * softmax_scale, q_bias], dim=-1),
key=torch.concat([k, k_bias], dim=-1),
value=v,
attn_mask=None,
dropout_p=0.0,
scale=1,
is_causal=causal,
)
A significant memory and running time reduction compared to vanilla FlashAttention.
Figure 3. Efficiency comparison with vanilla FlashAttention.
If you find this repo useful, please cite our paper.
@inproceedings{wu2025flashbias,
title={FlashBias: Fast Computation of Attention with Bias},
author={Haixu Wu and Minghao Guo and Yuezhou Ma and Yuanxu Sun and Jianmin Wang and Wojciech Matusik and Mingsheng Long},
booktitle={Advances in Neural Information Processing Systems},
year={2025}
}
If you have any questions or want to use the code, please contact [email protected]
We appreciate the following GitHub repos a lot for their valuable code base or datasets:
https://github.com/Dao-AILab/flash-attention
https://github.com/meta-pytorch/attention-gym