Skip to content

Commit

Permalink
Add WarpSpec version for Flash Attention (#45)
Browse files Browse the repository at this point in the history
Summary:
Added variant triton_tutorial_flash_v2_ws, updated triton_tutorial_flash_v2_tma to contain warpspec
Needs compiler support to run triton_tutorial_flash_v2_ws and triton_tutorial_flash_v2_tma
Use hasattr(tl, 'async_task') so the benchmark runs on compilers without warpspec support

CUDA_VISIBLE_DEVICES=5 TORCH_CUDA_ARCH_LIST=9.0a python run.py --op flash_attention --only triton_tutorial_flash_v2_ws,triton_tutorial_flash_v2_tma,triton_tutorial_flash_v2 --num-inputs 1 --seq-len 13 --metrics accuracy --batch 8 --n-heads 16 --d-head 128 --baseline triton_tutorial_flash_v2

Pull Request resolved: #45

Reviewed By: xuzhao9

Differential Revision: D65692351

Pulled By: manman-ren

fbshipit-source-id: 4ab8e0b3424a1124524e1fba5894522330b3c4ce
  • Loading branch information
manman-ren authored and facebook-github-bot committed Nov 9, 2024
1 parent 7e19437 commit 7b4a0eb
Show file tree
Hide file tree
Showing 2 changed files with 514 additions and 88 deletions.
Loading

0 comments on commit 7b4a0eb

Please sign in to comment.