diff --git a/csrc/flash_attn/CMakeLists.txt b/csrc/flash_attn/CMakeLists.txt index f6bb280cdc5..7267d923223 100644 --- a/csrc/flash_attn/CMakeLists.txt +++ b/csrc/flash_attn/CMakeLists.txt @@ -13,23 +13,8 @@ include_directories( ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ) -#file(GLOB SOURCES_CU "src/*.cu") -#file(GLOB SOURCES_CPP "src/*.cpp") -set(SOURCES_CU - src/fmha_fwd_hdim32.cu - src/fmha_fwd_hdim64.cu - src/fmha_fwd_hdim128.cu - src/fmha_bwd_hdim32.cu - src/fmha_bwd_hdim64.cu - src/fmha_bwd_hdim128.cu - src/fmha_fwd_with_mask_bias_hdim32.cu - src/fmha_fwd_with_mask_bias_hdim64.cu - src/fmha_fwd_with_mask_bias_hdim128.cu - src/fmha_bwd_with_mask_bias_hdim32.cu - src/fmha_bwd_with_mask_bias_hdim64.cu - src/fmha_bwd_with_mask_bias_hdim128.cu - src/utils.cu) -set(SOURCES_CPP src/cuda_utils.cpp) +file(GLOB SOURCES_CU "src/*.cu") +file(GLOB SOURCES_CPP "src/*.cpp") #add_library(flashattn OBJECT add_library(flashattn SHARED diff --git a/csrc/flash_attn/flash_attn.cpp b/csrc/flash_attn/flash_attn.cpp index 42f2644b41f..ebdd18900a8 100644 --- a/csrc/flash_attn/flash_attn.cpp +++ b/csrc/flash_attn/flash_attn.cpp @@ -497,6 +497,214 @@ bool flash_attn_bwd( FLASHATTNLIB_END_FUNC } +bool flash_attn_fwd_block( + const void *q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const void *k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + void *out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *cu_seqlens_q, // int32, batch_size+1, starting offset of each sequence + const void *cu_seqlens_k, // int32, batch_size+1, starting offset of each sequence + const void *blockmask, // int32, (seqlen_k / 256, seqlen_q / 16) + const int total_q, + const int total_k, + const int batch_size, + const int num_heads, + const int head_size, + const int max_seqlen_q_, + const int max_seqlen_k_, + const float p_dropout, + const float softmax_scale, + const bool is_causal, + const bool is_bf16, + void *softmax_lse_ptr, // softmax log_sum_exp + void *softmax_ptr, + void *workspace_ptr, + uint64_t *workspace_size, + cudaStream_t stream, + uint64_t seed, + uint64_t offset +) { + // printf("forward seed %jd offset %jd\b", seed, offset); + FLASHATTNLIB_BEGIN_FUNC + + auto dprops = GetDeviceProperties(-1); + ASSERT_CHECK(dprops->major == 8 && dprops->minor >= 0); + bool is_dropout = p_dropout > 0.0; + + const bool return_softmax = (softmax_ptr != nullptr); + Launch_params launch_params(dprops, stream, is_dropout, return_softmax); + + ASSERT_CHECK(batch_size > 0); + ASSERT_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128); + + int max_seqlen_k = ((max_seqlen_k_ + 256 - 1) / 256) * 256; + if( max_seqlen_k <= 256 ) { + max_seqlen_k = 256; + } + int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16; + bool loop = max_seqlen_k > 256; + + void* o_tmp_ptr = workspace_ptr; + // nullptr out to calculate workspace size + if (out == nullptr) { + if (loop) { + *workspace_size = uint64_t(total_q) * num_heads * head_size * sizeof(float); + } else { + *workspace_size = 0; + } + return true; + } + + if (return_softmax) { + SetZero(softmax_ptr, 2, {batch_size, num_heads, max_seqlen_q, max_seqlen_k}, stream); // float16 + } + + set_params_fprop(launch_params.params, + batch_size, + max_seqlen_q, + max_seqlen_k, + num_heads, + head_size, + const_cast(q), + const_cast(k), + const_cast(v), + const_cast(out), + const_cast(cu_seqlens_q), + const_cast(cu_seqlens_k), + loop ? o_tmp_ptr : nullptr, + return_softmax ? softmax_ptr : nullptr, + softmax_lse_ptr, + p_dropout, + softmax_scale, + is_causal, + is_bf16, + /*num_splits=*/1); + launch_params.params.blockmask = static_cast(const_cast(blockmask)); + + run_fmha_block_sm80(launch_params, /*configure=*/ true); + // number of times random will be generated per thread, to offset philox counter in thc random + // state + int64_t counter_offset = launch_params.elts_per_thread + offset; + + if( is_dropout ) { + launch_params.params.philox_args = PhiloxCudaState(seed, counter_offset); + } + + run_fmha_block_sm80(launch_params, /*configure=*/false); + + return true; + + FLASHATTNLIB_END_FUNC +} + +bool flash_attn_bwd_block( + const void *q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const void *k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + void *dq, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + void *dk, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + void *dv, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *dout, // total_q x num_heads, x head_size + const void *cu_seqlens_q, // int32, batch_size+1 + const void *cu_seqlens_k, // int32, batch_size+1 + const void *blockmask, // int32, (seqlen_k / 256, seqlen_q / 16) + const int total_q, + const int total_k, + const int batch_size, + const int num_heads, + const int head_size, + const int max_seqlen_q_, + const int max_seqlen_k_, + const float p_dropout, + const float softmax_scale, + const bool is_causal, + const bool is_bf16, + void *softmax_lse_ptr, + void *dsoftmax_ptr, + void *workspace_ptr, + uint64_t *workspace_size, + cudaStream_t stream, + uint64_t seed, + uint64_t offset +) { + // printf("backward seed %jd offset %jd\b", seed, offset); + + FLASHATTNLIB_BEGIN_FUNC + + auto dprops = GetDeviceProperties(-1); + bool is_sm80 = dprops->major == 8 && dprops->minor == 0; + bool is_sm8x = dprops->major == 8 && dprops->minor >= 0; + ASSERT_CHECK(dprops->major == 8 && dprops->minor >= 0); + auto launch = &run_fmha_block_dgrad_sm80; + + bool is_dropout = p_dropout > 0.0; + + ASSERT_CHECK(batch_size > 0); + ASSERT_CHECK(head_size == 16 || head_size == 32 || head_size == 64 || head_size == 128); + if (head_size == 128) { // TODO: eventually we should support SM86 and SM70 with d=128 as well + ASSERT_CHECK(is_sm80); + } + + int max_seqlen_k = ((max_seqlen_k_ + 256 - 1) / 256) * 256; + if( max_seqlen_k <= 256 ) { + max_seqlen_k = 256; + } + int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16; + bool loop = max_seqlen_k > 256; + + void *dq_tmp_ptr = workspace_ptr; + // nullptr out to calculate workspace size + if (out == nullptr) { + if (loop) { + *workspace_size = uint64_t(total_q) * num_heads * head_size * sizeof(float); + } else { + *workspace_size = 0; + } + return true; + } + + FMHA_dgrad_params params; + + set_params_dgrad(params, + batch_size, + max_seqlen_q, + max_seqlen_k, + num_heads, + head_size, + const_cast(q), + const_cast(k), + const_cast(v), + const_cast(out), + dq, dk, dv, + const_cast(cu_seqlens_q), + const_cast(cu_seqlens_k), + loop ? dq_tmp_ptr : nullptr, + const_cast(dout), + softmax_lse_ptr, + dsoftmax_ptr, + p_dropout, + softmax_scale, + is_causal, + is_bf16, + /*num_splits=*/1); + params.blockmask = static_cast(const_cast(blockmask)); + + // We're gonna reset the rng state in Python after this kernel, so the counter offset + // here doesn't matter at all. We just choose an arbitrary number; + int64_t counter_offset = 4 + offset; + + if( is_dropout ) { + params.philox_args = PhiloxCudaState(seed, counter_offset); + } + + launch(params, stream); + + return true; + + FLASHATTNLIB_END_FUNC +} + #ifdef __cplusplus } #endif diff --git a/csrc/flash_attn/flash_attn.h b/csrc/flash_attn/flash_attn.h index 48dfacd1911..6ad8014d9e1 100644 --- a/csrc/flash_attn/flash_attn.h +++ b/csrc/flash_attn/flash_attn.h @@ -138,6 +138,66 @@ bool flash_attn_bwd_with_bias_and_mask( const int64_t* bias_dims ); +bool flash_attn_fwd_block( + const void *q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const void *k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + void *out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *cu_seqlens_q, // int32, batch_size+1, starting offset of each sequence + const void *cu_seqlens_k, // int32, batch_size+1, starting offset of each sequence + const void *blockmask, // int32, (seqlen / 256, seqlen / 16) + const int total_q, + const int total_k, + const int batch_size, + const int num_heads, + const int head_size, + const int max_seqlen_q_, + const int max_seqlen_k_, + const float p_dropout, + const float softmax_scale, + const bool is_causal, + const bool is_bf16, + void *softmax_lse_ptr, // softmax log_sum_exp + void *softmax_ptr, + void *workspace_ptr, + uint64_t *workspace_size, + cudaStream_t stream, + uint64_t seed, + uint64_t offset +); + +bool flash_attn_bwd_block( + const void *q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + const void *k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + void *dq, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i + void *dk, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + void *dv, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *out, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i + const void *dout, // total_q x num_heads, x head_size + const void *cu_seqlens_q, // int32, batch_size+1 + const void *cu_seqlens_k, // int32, batch_size+1 + const void *blockmask, // int32, (seqlen / 256, seqlen / 16) + const int total_q, + const int total_k, + const int batch_size, + const int num_heads, + const int head_size, + const int max_seqlen_q_, + const int max_seqlen_k_, + const float p_dropout, + const float softmax_scale, + const bool is_causal, + const bool is_bf16, + void *softmax_lse_ptr, + void *dsoftmax_ptr, + void *workspace_ptr, + uint64_t *workspace_size, + cudaStream_t stream, + uint64_t seed, + uint64_t offset +); + void flash_attn_set_error(const char *msg); const char *flash_attn_error(); diff --git a/csrc/flash_attn/src/fmha.h b/csrc/flash_attn/src/fmha.h index ae381603284..0d4c5c36d1b 100644 --- a/csrc/flash_attn/src/fmha.h +++ b/csrc/flash_attn/src/fmha.h @@ -214,6 +214,6 @@ bool run_fmha_bwd_with_mask_bias_hdim32(FMHA_dgrad_params ¶ms, cudaStream_t bool run_fmha_bwd_with_mask_bias_hdim64(FMHA_dgrad_params ¶ms, cudaStream_t stream); bool run_fmha_bwd_with_mask_bias_hdim128(FMHA_dgrad_params ¶ms, cudaStream_t stream); -void run_fmha_block_fp16_sm80(Launch_params &launch_params, const bool configure); +void run_fmha_block_sm80(Launch_params &launch_params, const bool configure); -void run_fmha_block_dgrad_fp16_sm80(const FMHA_dgrad_params ¶ms, cudaStream_t stream); +void run_fmha_block_dgrad_sm80(const FMHA_dgrad_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/fmha/gemm.h b/csrc/flash_attn/src/fmha/gemm.h index 2fff2b219f0..72d68b0ac33 100644 --- a/csrc/flash_attn/src/fmha/gemm.h +++ b/csrc/flash_attn/src/fmha/gemm.h @@ -135,10 +135,11 @@ struct alignas(static_cast(Base_::ALIGNMENT)) Fragment : public Base_ { } // Multiply by another fragment. + template inline __device__ void hmul(const Fragment &other) { #pragma unroll for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) { - this->reg(ii) = fmha::hmul2(this->reg(ii), other.reg(ii)); + this->reg(ii) = fmha::hmul2(this->reg(ii), other.reg(ii)); } } diff --git a/csrc/flash_attn/src/fmha/utils.h b/csrc/flash_attn/src/fmha/utils.h index 110dda25f08..0494e4c0bc7 100644 --- a/csrc/flash_attn/src/fmha/utils.h +++ b/csrc/flash_attn/src/fmha/utils.h @@ -272,7 +272,11 @@ static inline __device__ uint32_t hmin2(uint32_t a, uint32_t b) { //////////////////////////////////////////////////////////////////////////////////////////////////// -static inline __device__ uint32_t hmul2(const uint32_t a, const uint32_t b) { +template +static inline __device__ uint32_t hmul2(const uint32_t a, const uint32_t b); + +template<> +inline __device__ uint32_t hmul2<__half>(const uint32_t a, const uint32_t b) { // uint32_t c; // asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); // return c; @@ -281,6 +285,18 @@ static inline __device__ uint32_t hmul2(const uint32_t a, const uint32_t b) { return reinterpret_cast(result); } +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template<> +inline __device__ uint32_t hmul2<__nv_bfloat16>(const uint32_t a, const uint32_t b) { + // uint32_t c; + // asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + // return c; + __nv_bfloat162 result = __hmul2(reinterpret_cast(a), + reinterpret_cast(b)); + return reinterpret_cast(result); +} +#endif + //////////////////////////////////////////////////////////////////////////////////////////////////// static inline __device__ uint2 hmul4(uint2 a, uint2 b) { @@ -292,23 +308,25 @@ static inline __device__ uint2 hmul4(uint2 a, uint2 b) { //////////////////////////////////////////////////////////////////////////////////////////////////// +template static inline __device__ uint4 hmul8(uint4 a, uint4 b) { uint4 c; - c.x = hmul2(a.x, b.x); - c.y = hmul2(a.y, b.y); - c.z = hmul2(a.z, b.z); - c.w = hmul2(a.w, b.w); + c.x = hmul2(a.x, b.x); + c.y = hmul2(a.y, b.y); + c.z = hmul2(a.z, b.z); + c.w = hmul2(a.w, b.w); return c; } //////////////////////////////////////////////////////////////////////////////////////////////////// +template static inline __device__ uint4 hmul8(uint32_t a, uint4 b) { uint4 c; - c.x = hmul2(a, b.x); - c.y = hmul2(a, b.y); - c.z = hmul2(a, b.z); - c.w = hmul2(a, b.w); + c.x = hmul2(a, b.x); + c.y = hmul2(a, b.y); + c.z = hmul2(a, b.z); + c.w = hmul2(a, b.w); return c; } diff --git a/csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu b/csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu deleted file mode 100644 index c6c45177e44..00000000000 --- a/csrc/flash_attn/src/fmha_block_dgrad_fp16_kernel_loop.sm80.cu +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright (c) 2022, Tri Dao. - */ -#include "fmha.h" -#include "fmha_block_dgrad_kernel_1xN_loop.h" - -template -__global__ void fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) { - fmha::compute_block_dq_dk_dv_1xN(params); -} - -template -void run_fmha_block_dgrad_fp16_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_t stream) { - constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); - constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; - constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; - constexpr int smem_size_dq = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; - constexpr int smem_size_dp_sum = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE; - - using Smem_tile_s = fmha::Smem_tile_mma_transposed; - constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; - static_assert(smem_size_s == 16 * Kernel_traits::Cta_tile_p::N * 2); - static_assert(smem_size_dq == 16 * Kernel_traits::Cta_tile_p::K * 4 * Kernel_traits::Cta_tile_p::WARPS_N); - static_assert(smem_size_dp_sum == 16 * 4 * 2); - - constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2 + smem_size_dp_sum; - - bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping" - bool is_causal = params.is_causal; - auto kernel = is_dropout - ? (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel) - : (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel); - constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; - if (params.seqlen_k == blocksize_c) { - kernel = is_dropout - ? (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel) - : (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel); - } else if (params.seqlen_k == blocksize_c * 2) { - kernel = is_dropout - ? (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel) - : (is_causal ? &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_fp16_sm80_dq_dk_dv_loop_kernel); - } - - if( smem_size_dq_dk_dv >= 48 * 1024 ) { - FMHA_CHECK_CUDA(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); - } - dim3 grid(params.b, params.h); - kernel<<>>(params); - FMHA_CHECK_CUDA(cudaPeekAtLastError()); -} - -void run_fmha_block_dgrad_fp16_sm80(const FMHA_dgrad_params ¶ms, cudaStream_t stream) { - if (params.d == 16) { - using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u>; - run_fmha_block_dgrad_fp16_sm80_loop_(params, stream); - } else if (params.d == 32) { - using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u>; - run_fmha_block_dgrad_fp16_sm80_loop_(params, stream); - } else if (params.d == 64) { - using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u>; - run_fmha_block_dgrad_fp16_sm80_loop_(params, stream); - } -} diff --git a/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h b/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h index 79e0a88c84c..2a7452d0dd9 100644 --- a/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h +++ b/csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h @@ -13,12 +13,12 @@ namespace fmha { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void dot_do_o(float (&sum)[M], const uint4 (&do_)[M], const uint4 (&o)[M], Smem_dp_sum smem, const int buffer_idx) { #pragma unroll for (int mi = 0; mi < M; ++mi) { - sum[mi] = smem.reduce_warp(fmha::hmulsum8<__half>(do_[mi], o[mi])); + sum[mi] = smem.reduce_warp(fmha::hmulsum8(do_[mi], o[mi])); } static_assert(M == 1); smem.store(sum[0], buffer_idx); @@ -30,6 +30,16 @@ template= 800 + constexpr bool is_fp16_type = std::is_same::value; + using elem_type = typename Kernel_traits::elem_type; +#else + constexpr bool is_fp16_type = std::is_same::value; + assert(is_fp16_type); + using elem_type = __half; +#endif + + // The description of the CTA tile for the 1st batched GEMM. using Cta_tile_p = typename Kernel_traits::Cta_tile_p; // The description of the CTA tile for the 2nd batched GEMM. @@ -39,7 +49,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, fmha::Cta_tile_extd; static_assert(Cta_tile_dkv::M == 512 || Cta_tile_dkv::M == 256 || Cta_tile_dkv::M == 128); - static_assert(Cta_tile_dkv::N == 16 || Cta_tile_dkv::N == 32 || Cta_tile_dkv::N == 64); + static_assert(Cta_tile_dkv::N == 16 || Cta_tile_dkv::N == 32 || Cta_tile_dkv::N == 64 || Cta_tile_dkv::N == 128); static_assert(Cta_tile_dkv::K == 16); // The MMA tile for the 1st GEMM. @@ -103,7 +113,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, using Smem_dp_sum = typename Kernel_traits::Smem_dp_sum; // using Gemm1 = Gemm_Q_K; - using Gemm1 = Gemm_Q_K; + using Gemm1 = Gemm_Q_K; using Softmax = fmha::Softmax; @@ -242,7 +252,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, // if (Is_first) { // if (true) { if (Is_first || mask_val % 2 == 1) { - dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum, 0); + dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum, 0); const int dp_sum_row = tidx / Smem_dp_sum::THREADS_PER_ROW; if ((dp_sum_row < Smem_dp_sum::ROWS) && (tidx % Smem_dp_sum::THREADS_PER_ROW == 0)) { gmem_softmax_d.store_row(reinterpret_cast(dp_sum_regs), dp_sum_row); @@ -254,7 +264,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, const uint32_t scale_dropout = params.scale_dropout; #pragma unroll for(int it=0; it < Gmem_tile_v::LDGS; it++){ - gmem_v.fetch_[it] = fmha::hmul8(scale_dropout, gmem_v.fetch_[it]); + gmem_v.fetch_[it] = fmha::hmul8(scale_dropout, gmem_v.fetch_[it]); } } @@ -365,7 +375,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, Frag_p frag_p[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M]; static_assert(Mma_tile_dq::MMAS_M == Mma_tile_p::MMAS_M); static_assert(Mma_tile_dq::MMAS_K == Mma_tile_p::MMAS_N); - softmax.template pack<__half>(frag_p); + softmax.template pack(frag_p); // Store s * dmask to smem for transpose smem_s.store(frag_p); @@ -414,9 +424,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, smem_do.load(frag_do[ki & 1], ki); if (!Kernel_traits::V_IN_REGS) { smem_v.load(frag_v[ki & 1], ki); - fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); + fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); } else { - fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]); + fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]); } // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) { // float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1])); @@ -430,9 +440,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, { int ki = Mma_tile_p::MMAS_K; if (!Kernel_traits::V_IN_REGS) { - fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); + fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]); } else { - fmha::gemm_cl<__half>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]); + fmha::gemm_cl(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]); } } @@ -470,17 +480,17 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, if (is_first_read) { softmax.subtract_dp_sum(dp_sum); } Frag_p frag_dp[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M]; - softmax.template pack<__half>(frag_dp); + softmax.template pack(frag_dp); if (!Is_dropout) { #pragma unroll for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { #pragma unroll for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { - frag_p[mi][ni].hmul(frag_dp[mi][ni]); + frag_p[mi][ni].template hmul(frag_dp[mi][ni]); } } - } else { + } else if (is_fp16_type) { __half2 dp_sum_half[Mma_tile_p::MMAS_M * 2]; for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { dp_sum_half[mi] = __float2half2_rn(dp_sum[mi]); @@ -503,6 +513,31 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, } } } + } else { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + __nv_bfloat162 dp_sum_half[Mma_tile_p::MMAS_M * 2]; + for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { + dp_sum_half[mi] = __float2bfloat162_rn(dp_sum[mi]); + } + const __nv_bfloat16 zero_h = __nv_bfloat16(0.f); + #pragma unroll + for( int mi = 0; mi < Mma_tile_p::MMAS_M; mi++ ) { + #pragma unroll + for( int ni = 0; ni < Mma_tile_p::MMAS_N; ni++ ) { + #pragma unroll + for (int ii = 0; ii < 4; ++ii) { + const __nv_bfloat162 p = frag_p[mi][ni].template elt_as<__nv_bfloat162>(ii); + const __nv_bfloat162 pdp = __hmul2(p, frag_dp[mi][ni].template elt_as<__nv_bfloat162>(ii)); + // If this element is dropped, then frag_p stores -p instead of p. + // So pd holds -p * dp_sum in that case. + const __nv_bfloat162 pd = __hmul2(p, dp_sum_half[mi * 2 + (ii % 2)]); + const __nv_bfloat16 low = __low2bfloat16(p) >= zero_h ? __low2bfloat16(pdp) : __low2bfloat16(pd); + const __nv_bfloat16 high = __low2bfloat16(p) >= zero_h ? __low2bfloat16(pdp) : __low2bfloat16(pd); + frag_p[mi][ni].template elt_as<__nv_bfloat162>(ii) = __halves2bfloat162(low, high); + } + } + } +#endif } // Store dp to smem for transpose @@ -521,13 +556,13 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, // Trigger the load from shared memory for the next series of Q values. smem_kt.load(frag_kt[ki & 1], ki); // Do the math for the values already in registers. - fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); + fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); // fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); } // Do the final stage of math. { int ki = Mma_tile_dq::MMAS_K; - fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); + fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]); // fmha::gemm_cl<__half>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]); } @@ -551,7 +586,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, for( int ki = 0; ki < Mma_tile_dkv::MMAS_K; ki++ ) { #pragma unroll for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) { - frag_s[ki][mi].template hrelu_<__half>(); + frag_s[ki][mi].template hrelu_(); } } } @@ -561,13 +596,13 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, // Trigger the load from shared memory for the next series of Q values. smem_dot.load(frag_dot[ki & 1], ki); // Do the math for the values already in registers. - fmha::gemm_cl<__half>(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); + fmha::gemm_cl(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); } // Do the final stage of math. { int ki = Mma_tile_dkv::MMAS_K; - fmha::gemm_cl<__half>(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); + fmha::gemm_cl(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]); } // __syncthreads(); @@ -590,7 +625,7 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, if (Is_first || mask_val_next % 2 == 1) { // dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum); // smem_dp_sum.move_to_next_write_buffer(); - dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum, (l + 1) % 2); + dot_do_o(dp_sum_regs, gmem_do.fetch_, gmem_o.fetch_, smem_dp_sum, (l + 1) % 2); const int dp_sum_row_1 = tidx / Smem_dp_sum::THREADS_PER_ROW; if ((dp_sum_row_1 < Smem_dp_sum::ROWS) && (tidx % Smem_dp_sum::THREADS_PER_ROW == 0)) { gmem_softmax_d.store_row(reinterpret_cast(dp_sum_regs), dp_sum_row_1); @@ -619,13 +654,13 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, // Trigger the load from shared memory for the next series of Q values. smem_qt.load(frag_qt[ki & 1], ki); // Do the math for the values already in registers. - fmha::gemm_cl<__half>(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); + fmha::gemm_cl(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); } // Do the final stage of math. { int ki = Mma_tile_dkv::MMAS_K; - fmha::gemm_cl<__half>(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); + fmha::gemm_cl(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]); } // Make sure dQ is in shared memory. @@ -643,9 +678,11 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, // if (Is_dropout) { // dq_out[0] = fmha::fmul4(dq_out[0], params.rp_dropout); // } - dq_out[0] = fmha::fmul4(dq_out[0], params.scale_bmm1f); + for (int jj = 0; jj < Gmem_tile_dq::STGS_PER_LOOP; ++jj) { + dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1f); + } // Output the values. - gmem_dq.template store<__half>(dq_out, 0); + gmem_dq.template store(dq_out, 0); } else { // Output the values. gmem_dq_tmp.store(dq_out, 0); @@ -700,11 +737,11 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms, // the total amount of shared mem? // Epilogue swizzle for dV Smem_tile_dv smem_dv(&smem_[0], tidx); - smem_dv.template store<__half>(acc_dv); + smem_dv.template store(acc_dv); // Epilogue swizzle for dK Smem_tile_dk smem_dk(&smem_[Smem_tile_dv::BYTES_PER_TILE], tidx); - smem_dk.template store<__half>(acc_dk); + smem_dk.template store(acc_dk); __syncthreads(); uint4 dv_out[Smem_tile_dv::NUM_LDS]; diff --git a/csrc/flash_attn/src/fmha_block_dgrad_kernel_loop.sm80.cu b/csrc/flash_attn/src/fmha_block_dgrad_kernel_loop.sm80.cu new file mode 100644 index 00000000000..410b944091f --- /dev/null +++ b/csrc/flash_attn/src/fmha_block_dgrad_kernel_loop.sm80.cu @@ -0,0 +1,69 @@ +/* Copyright (c) 2022, Tri Dao. + */ +#include "fmha.h" +#include "static_switch.h" +#include "fmha_block_dgrad_kernel_1xN_loop.h" + +template +__global__ void fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) { + fmha::compute_block_dq_dk_dv_1xN(params); +} + +template +void run_fmha_block_dgrad_sm80_loop_(const FMHA_dgrad_params ¶ms, cudaStream_t stream) { + constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float); + constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE; + constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE; + constexpr int smem_size_dq = Kernel_traits::Smem_tile_o::BYTES_PER_TILE; + constexpr int smem_size_dp_sum = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE; + + using Smem_tile_s = fmha::Smem_tile_mma_transposed; + constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE; + static_assert(smem_size_s == 16 * Kernel_traits::Cta_tile_p::N * 2); + static_assert(smem_size_dq == 16 * Kernel_traits::Cta_tile_p::K * 4 * Kernel_traits::Cta_tile_p::WARPS_N); + static_assert(smem_size_dp_sum == 16 * 4 * 2); + + constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2 + smem_size_dp_sum; + + bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping" + bool is_causal = params.is_causal; + auto kernel = is_dropout + ? (is_causal ? &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel) + : (is_causal ? &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel); + constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; + if (params.seqlen_k == blocksize_c) { + kernel = is_dropout + ? (is_causal ? &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel) + : (is_causal ? &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel); + } else if (params.seqlen_k == blocksize_c * 2) { + kernel = is_dropout + ? (is_causal ? &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel) + : (is_causal ? &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel : &fmha_block_dgrad_sm80_dq_dk_dv_loop_kernel); + } + + if( smem_size_dq_dk_dv >= 48 * 1024 ) { + FMHA_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + dim3 grid(params.b, params.h); + kernel<<>>(params); + FMHA_CHECK_CUDA(cudaPeekAtLastError()); +} + +void run_fmha_block_dgrad_sm80(const FMHA_dgrad_params ¶ms, cudaStream_t stream) { + FP16_SWITCH(params.is_bf16, ([&] { + if (params.d == 16) { + using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 8, 0x08u, elem_type>; + run_fmha_block_dgrad_sm80_loop_(params, stream); + } else if (params.d == 32) { + using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>; + run_fmha_block_dgrad_sm80_loop_(params, stream); + } else if (params.d == 64) { + using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>; + run_fmha_block_dgrad_sm80_loop_(params, stream); + } else if (params.d == 128) { + using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 8, 0x08u, elem_type>; + run_fmha_block_dgrad_sm80_loop_(params, stream); + } + })); +} diff --git a/csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu b/csrc/flash_attn/src/fmha_block_fprop_kernel.sm80.cu similarity index 64% rename from csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu rename to csrc/flash_attn/src/fmha_block_fprop_kernel.sm80.cu index d1a90633e4f..0adc617aad5 100644 --- a/csrc/flash_attn/src/fmha_block_fprop_fp16_kernel.sm80.cu +++ b/csrc/flash_attn/src/fmha_block_fprop_kernel.sm80.cu @@ -26,25 +26,26 @@ ******************************************************************************/ #include "fmha.h" +#include "static_switch.h" #include "fmha_block_fprop_kernel_1xN.h" template -__global__ void fmha_block_fprop_fp16_sm80_loop_kernel(FMHA_fprop_params params) { +__global__ void fmha_block_fprop_sm80_loop_kernel(FMHA_fprop_params params) { fmha::device_block_1xN_loop(params); } template -void run_fmha_block_fp16_sm80_loop_(Launch_params &launch_params, +void run_fmha_block_sm80_loop_(Launch_params &launch_params, const bool configure) { bool is_causal = launch_params.params.is_causal; // TD [2022-04-27]: This case work is pretty ugly, maybe there's a better way? auto kernel = launch_params.is_dropout ? (is_causal - ? (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel : &fmha_block_fprop_fp16_sm80_loop_kernel) - : (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel : &fmha_block_fprop_fp16_sm80_loop_kernel)) + ? (launch_params.return_softmax ? &fmha_block_fprop_sm80_loop_kernel : &fmha_block_fprop_sm80_loop_kernel) + : (launch_params.return_softmax ? &fmha_block_fprop_sm80_loop_kernel : &fmha_block_fprop_sm80_loop_kernel)) : (is_causal - ? (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel : &fmha_block_fprop_fp16_sm80_loop_kernel) - : (launch_params.return_softmax ? &fmha_block_fprop_fp16_sm80_loop_kernel : &fmha_block_fprop_fp16_sm80_loop_kernel)); + ? (launch_params.return_softmax ? &fmha_block_fprop_sm80_loop_kernel : &fmha_block_fprop_sm80_loop_kernel) + : (launch_params.return_softmax ? &fmha_block_fprop_sm80_loop_kernel : &fmha_block_fprop_sm80_loop_kernel)); constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N; const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c; @@ -75,16 +76,21 @@ void run_fmha_block_fp16_sm80_loop_(Launch_params &launch_par FMHA_CHECK_CUDA(cudaPeekAtLastError()); } -void run_fmha_block_fp16_sm80(Launch_params &launch_params, +void run_fmha_block_sm80(Launch_params &launch_params, const bool configure) { - if (launch_params.params.d == 16) { - using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u>; - run_fmha_block_fp16_sm80_loop_(launch_params, configure); - } else if (launch_params.params.d == 32) { - using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u>; - run_fmha_block_fp16_sm80_loop_(launch_params, configure); - } else if (launch_params.params.d == 64) { - using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u>; - run_fmha_block_fp16_sm80_loop_(launch_params, configure); - } -} \ No newline at end of file + FP16_SWITCH(launch_params.params.is_bf16, ([&] { + if (launch_params.params.d == 16) { + using Kernel_traits = FMHA_kernel_traits<256, 16, 16, 1, 4, 0x08u, elem_type>; + run_fmha_block_sm80_loop_(launch_params, configure); + } else if (launch_params.params.d == 32) { + using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>; + run_fmha_block_sm80_loop_(launch_params, configure); + } else if (launch_params.params.d == 64) { + using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>; + run_fmha_block_sm80_loop_(launch_params, configure); + } else if (launch_params.params.d == 128) { + using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x08u, elem_type>; + run_fmha_block_sm80_loop_(launch_params, configure); + } + })); +} diff --git a/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h b/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h index daa8c186e9f..50980db06c9 100644 --- a/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h +++ b/csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h @@ -39,6 +39,13 @@ namespace fmha { template inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, const int bidh, int steps, Prng &ph0, Prng &ph1, const int loop_step_idx) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using elem_type = typename Kernel_traits::elem_type; +#else + constexpr bool is_fp16_type = std::is_same::value; + assert(is_fp16_type); + using elem_type = __half; +#endif // The description of the CTA tile for the 1st batched GEMM. using Cta_tile_p = typename Kernel_traits::Cta_tile_p; @@ -73,7 +80,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c using Smem_softmax_sum = typename Kernel_traits::Smem_dp_sum; - using Gemm1 = Gemm_Q_K; + using Gemm1 = Gemm_Q_K; using Softmax = fmha::Softmax; @@ -340,7 +347,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]; static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M); static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N); - softmax.template pack<__half>(frag_p); + softmax.template pack(frag_p); if (Return_softmax) { gmem_s.store(frag_p, mask); if (not_last_iter) { @@ -358,7 +365,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) { #pragma unroll for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) { - frag_p[ki][mi].template hrelu_<__half>(); + frag_p[ki][mi].template hrelu_(); } } } @@ -370,7 +377,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c // Do this part of O = P^T * V^T. #pragma unroll for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) { - fmha::gemm_cl<__half>(acc_o, frag_p[ki], frag_v[ki]); + fmha::gemm_cl(acc_o, frag_p[ki], frag_v[ki]); } // The mapping from tidx to rows changes between the softmax and the O-reduction. @@ -471,7 +478,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c // Output the values. if (is_final_write) { - gmem_o.template store<__half>(out, 0); + gmem_o.template store(out, 0); } else { gmem_o_tmp.store(out, 0); }