Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add WarpSpec version for Flash Attention (#45)
Summary: Added variant triton_tutorial_flash_v2_ws, updated triton_tutorial_flash_v2_tma to contain warpspec Needs compiler support to run triton_tutorial_flash_v2_ws and triton_tutorial_flash_v2_tma Use hasattr(tl, 'async_task') so the benchmark runs on compilers without warpspec support CUDA_VISIBLE_DEVICES=5 TORCH_CUDA_ARCH_LIST=9.0a python run.py --op flash_attention --only triton_tutorial_flash_v2_ws,triton_tutorial_flash_v2_tma,triton_tutorial_flash_v2 --num-inputs 1 --seq-len 13 --metrics accuracy --batch 8 --n-heads 16 --d-head 128 --baseline triton_tutorial_flash_v2 Pull Request resolved: #45 Reviewed By: xuzhao9 Differential Revision: D65692351 Pulled By: manman-ren fbshipit-source-id: 4ab8e0b3424a1124524e1fba5894522330b3c4ce
- Loading branch information