这个项目提供了一个基准测试脚本,用于比较不同 Attention 实现的性能表现:
- 原始 Attention (Vanilla Attention)
- PyTorch 内置的 Flash Attention
- Flash Attention 库实现 (flash-attn)
- 支持不同的批次大小(batch size)和序列长度(sequence length)
- 测量计算时间和内存使用情况
- 自动生成性能对比图表
- 详细的性能指标输出
pip install torch pip install flash-attn pip install matplotlib pip install numpy
直接运行脚本
脚本会自动:
- 检测可用的 Attention 实现
- 运行基准测试
- 生成性能对比图表 (attention_benchmark_full.png)
- 序列长度: [512, 1024, 2048]
- 批次大小: [8, 16, 32, 64]
- 头维度(head dimension): 64
- 注意力头数(number of heads): 8
- GPU: 4090
- CUDA Version: 12.1
- PyTorch Version: 2.1.0
原始的flash库,比torch集成的flash attention快。 flash attention可以极大的加速attention的计算,减少显存占用。
- 测试前请确保 GPU 资源充足且无其他重要任务运行
- 首次运行时会进行预热,可能需要一些时间
- 如果显存不足,可以调整序列长度和批次大小的测试范围
