Skip to content

RoversCode/flash_attention_benchmark

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 

Repository files navigation

Attention Mechanism Benchmark

这个项目提供了一个基准测试脚本,用于比较不同 Attention 实现的性能表现:

  • 原始 Attention (Vanilla Attention)
  • PyTorch 内置的 Flash Attention
  • Flash Attention 库实现 (flash-attn)

功能特点

  • 支持不同的批次大小(batch size)和序列长度(sequence length)
  • 测量计算时间和内存使用情况
  • 自动生成性能对比图表
  • 详细的性能指标输出

环境要求

pip install torch pip install flash-attn pip install matplotlib pip install numpy

使用方法

直接运行脚本

脚本会自动:

  1. 检测可用的 Attention 实现
  2. 运行基准测试
  3. 生成性能对比图表 (attention_benchmark_full.png)

测试参数

  • 序列长度: [512, 1024, 2048]
  • 批次大小: [8, 16, 32, 64]
  • 头维度(head dimension): 64
  • 注意力头数(number of heads): 8

实验结果

测试环境

  • GPU: 4090
  • CUDA Version: 12.1
  • PyTorch Version: 2.1.0

性能对比图

Attention Performance Comparison

主要发现

原始的flash库,比torch集成的flash attention快。 flash attention可以极大的加速attention的计算,减少显存占用。

注意事项

  1. 测试前请确保 GPU 资源充足且无其他重要任务运行
  2. 首次运行时会进行预热,可能需要一些时间
  3. 如果显存不足,可以调整序列长度和批次大小的测试范围

About

测试flash attention

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages