Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 74 additions & 33 deletions csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ if(NOT SKIP_BUILD_FA)
option(DISABLE_PACKGQA "Disable PackGQA" OFF)
option(DISABLE_BACKWARD "Disable Backward" OFF)
option(DISABLE_SM8X "Disable SM8x" OFF)
option(DISABLE_HDIMDIFF64 "Disable HDIMDIFF64" OFF)
option(DISABLE_HDIMDIFF192 "Disable HDIMDIFF192" OFF)

if(DISABLE_FP16)
add_compile_definitions(FLASHATTENTION_DISABLE_FP16)
Expand Down Expand Up @@ -240,26 +242,29 @@ if(NOT SKIP_BUILD_FA)
list(APPEND DTYPE_BWD "fp16")
endif()

set(HEAD_DIMENSIONS_BWD)
if(NOT DISABLE_HDIM64)
list(APPEND HEAD_DIMENSIONS_BWD 64)
endif()
if(NOT DISABLE_HDIM96)
list(APPEND HEAD_DIMENSIONS_BWD 96)
endif()
if(NOT DISABLE_HDIM128)
list(APPEND HEAD_DIMENSIONS_BWD 128)
endif()
if(NOT DISABLE_HDIM192)
list(APPEND HEAD_DIMENSIONS_BWD 192)
endif()
if(NOT DISABLE_HDIM256)
list(APPEND HEAD_DIMENSIONS_BWD 256)
set(HALF_DTYPE_FWD_SM90 "bf16")
if(NOT DISABLE_FP16)
list(APPEND HALF_DTYPE_FWD_SM90 "fp16")
endif()

set(HEAD_DIMENSIONS_FWD "all" "diff")
set(SUPPORTED_HEAD_DIMENSIONS 64 96 128 192 256)
set(HEAD_DIMENSIONS_BWD "")
foreach(dim IN LISTS SUPPORTED_HEAD_DIMENSIONS)
if(NOT DISABLE_HDIM${dim})
list(APPEND HEAD_DIMENSIONS_BWD ${dim})
endif()
endforeach()

set(HEAD_DIMENSIONS_FWD ${HEAD_DIMENSIONS_BWD})
set(HEAD_DIMENSIONS_FWD_SM80 ${HEAD_DIMENSIONS_BWD})

set(HEAD_DIMENSIONS_DIFF64_FWD
64_512
)
set(HEAD_DIMENSIONS_DIFF192_FWD
192_128
)

set(SPLIT "__EMPTY__")
if(NOT DISABLE_SPLIT)
list(APPEND SPLIT "_split")
Expand Down Expand Up @@ -321,6 +326,48 @@ if(NOT SKIP_BUILD_FA)
endforeach()
endforeach()

# Enable HDIMDIFF64
if(NOT DISABLE_HDIMDIFF64)
foreach(hdim ${HEAD_DIMENSIONS_DIFF64_FWD})
foreach(dtype ${HALF_DTYPE_FWD_SM90})
foreach(split ${SPLIT})
foreach(paged ${PAGEDKV})
foreach(softcap ${SOFTCAP})
foreach(packgqa ${PACKGQA})
if(packgqa STREQUAL "__EMPTY__" OR (paged STREQUAL "__EMPTY__" AND split STREQUAL "__EMPTY__"))
set(name "flash_attn_v3/instantiations/flash_fwd_hdim${hdim}_${dtype}${paged}${split}${softcap}${packgqa}_sm90.cu")
string(REPLACE "__EMPTY__" "" refine_name "${name}")
list(APPEND sources_fwd_sm90 "${refine_name}")
endif()
endforeach()
endforeach()
endforeach()
endforeach()
endforeach()
endforeach()
endif()

# Enable HDIMDIFF192
if(NOT DISABLE_HDIMDIFF192)
foreach(hdim ${HEAD_DIMENSIONS_DIFF192_FWD})
foreach(dtype ${DTYPE_FWD_SM90})
foreach(split ${SPLIT})
foreach(paged ${PAGEDKV})
foreach(softcap ${SOFTCAP})
foreach(packgqa ${PACKGQA})
if(packgqa STREQUAL "__EMPTY__" OR (paged STREQUAL "__EMPTY__" AND split STREQUAL "__EMPTY__"))
set(name "flash_attn_v3/instantiations/flash_fwd_hdim${hdim}_${dtype}${paged}${split}${softcap}${packgqa}_sm90.cu")
string(REPLACE "__EMPTY__" "" refine_name "${name}")
list(APPEND sources_fwd_sm90 "${refine_name}")
endif()
endforeach()
endforeach()
endforeach()
endforeach()
endforeach()
endforeach()
endif()

set(sources_bwd_sm80)
foreach(hdim ${HEAD_DIMENSIONS_BWD})
foreach(dtype ${DTYPE_BWD})
Expand Down Expand Up @@ -364,7 +411,10 @@ if(NOT SKIP_BUILD_FA)

list(APPEND FA3_SOURCES_CU_SOURCES "flash_attn_v3/flash_prepare_scheduler.cu")

message(STATUS "Auto generated CUDA source files: ${FA3_SOURCES_CU_SOURCES}")
message(STATUS "sources_fwd_sm90: ${sources_fwd_sm90}")

# message(STATUS "Auto generated CUDA source files: ${FA3_SOURCES_CU_SOURCES}")

add_library(flashattnv3 SHARED
${FA3_SOURCES_CU_SOURCES}
)
Expand Down Expand Up @@ -493,22 +543,13 @@ if(NOT SKIP_BUILD_FA)
list(APPEND FLASHMASKV2_DTYPE_BWD "fp16")
endif()

set(FLASHMASKV2_HEAD_DIMENSIONS_BWD)
if(NOT DISABLE_FLASHMASK_V2_HDIM64)
list(APPEND FLASHMASKV2_HEAD_DIMENSIONS_BWD 64)
endif()
if(NOT DISABLE_FLASHMASK_V2_HDIM96)
list(APPEND FLASHMASKV2_HEAD_DIMENSIONS_BWD 96)
endif()
if(NOT DISABLE_FLASHMASK_V2_HDIM128)
list(APPEND FLASHMASKV2_HEAD_DIMENSIONS_BWD 128)
endif()
if(NOT DISABLE_FLASHMASK_V2_HDIM192)
list(APPEND FLASHMASKV2_HEAD_DIMENSIONS_BWD 192)
endif()
if(NOT DISABLE_FLASHMASK_V2_HDIM256)
list(APPEND FLASHMASKV2_HEAD_DIMENSIONS_BWD 256)
endif()
set(SUPPORTED_HEAD_DIMENSIONS 64 96 128 192 256)
set(FLASHMASKV2_HEAD_DIMENSIONS_BWD "")
foreach(dim IN LISTS SUPPORTED_HEAD_DIMENSIONS)
if(NOT DISABLE_FLASHMASK_V2_HDIM${dim})
list(APPEND FLASHMASKV2_HEAD_DIMENSIONS_BWD ${dim})
endif()
endforeach()

# Disable diff, not support headdim != headdim_v
set(FLASHMASKV2_HEAD_DIMENSIONS_FWD ${FLASHMASKV2_HEAD_DIMENSIONS_BWD})
Expand Down