diff --git a/csrc/flash_attn_v3/epilogue_bwd.hpp b/csrc/flash_attn_v3/epilogue_bwd.hpp index 6d9b5f4f596..1ab95b8ca31 100644 --- a/csrc/flash_attn_v3/epilogue_bwd.hpp +++ b/csrc/flash_attn_v3/epilogue_bwd.hpp @@ -358,7 +358,7 @@ struct CollectiveEpilogueBwdGQA { }; using TensorStorage = std::conditional_t; - using ShapedKV = cute::Shape; // (seqlen_k_rounded * d, head, batch) + using ShapedKV = cute::Shape; // (seqlen_k_rounded * d, head, batch) using StridedKV = cute::Stride<_1, int64_t, int64_t>; // Host side kernel arguments @@ -429,9 +429,10 @@ struct CollectiveEpilogueBwdGQA { flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_dKaccum), params.cu_seqlens, params.seqused}; bool const is_varlen = Varlen && params.cu_seqlens; Tensor mdKaccum = make_tensor(make_gmem_ptr(params.ptr_dKaccum), params.shape_dKaccum, params.stride_dKaccum)(_, bidh_kv, !is_varlen ? bidb : 0); + Tensor mdVaccum = make_tensor(make_gmem_ptr(params.ptr_dVaccum), params.shape_dVaccum, params.stride_dVaccum)(_, bidh_kv, !is_varlen ? bidb : 0); - Tensor gdKaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdKaccum), Shape>{}, make_coord(n_block)); // (M * K) - Tensor gdVaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdVaccum), Shape>{}, make_coord(n_block)); // (M * K) + Tensor gdKaccum = local_tile(domain_offset(make_coord(int64_t{seqlen_info.offset_padded} * int64_t{kHeadDim}), mdKaccum), Shape>{}, make_coord(n_block)); // (M * K) + Tensor gdVaccum = local_tile(domain_offset(make_coord(int64_t{seqlen_info.offset_padded} * int64_t{kHeadDim}), mdVaccum), Shape>{}, make_coord(n_block)); // (M * K) R2STiledCopydKVaccum r2s_tiled_copy_dKVaccum; auto r2s_thr_copy_dKVaccum = r2s_tiled_copy_dKVaccum.get_thread_slice(thread_idx); diff --git a/csrc/flash_attn_v3/flash_bwd_launch_template.h b/csrc/flash_attn_v3/flash_bwd_launch_template.h index 830b9f4d0f2..2256e89a98a 100644 --- a/csrc/flash_attn_v3/flash_bwd_launch_template.h +++ b/csrc/flash_attn_v3/flash_bwd_launch_template.h @@ -61,8 +61,8 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { static_cast(params.softmax_lse_log2_ptr), {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2 static_cast(params.dq_accum_ptr), - {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum - {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * seqlen_q_rounded * params.h : 0}, // stride_dQaccum + {int64_t{seqlen_q_rounded} * int64_t{params.d_rounded}, params.h, batch_q}, // shape_dQaccum + {_1{}, int64_t{seqlen_q_rounded} * int64_t{params.d_rounded}, !is_varlen_q ? int64_t{params.d_rounded} * int64_t{seqlen_q_rounded} * int64_t{params.h} : 0}, // stride_dQaccum params.b, params.dq_semaphore, params.cu_seqlens_q, @@ -114,7 +114,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { {seqlen_q, params.dv, params.h, batch_q}, // shape_dO {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO static_cast(params.dq_accum_ptr), - {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum + {int64_t{seqlen_q_rounded} * int64_t{params.d_rounded}, params.h, batch_q}, // shape_dQaccum {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum static_cast(params.softmax_lse_log2_ptr), {seqlen_q_rounded, params.h, batch_q}, // shape_LSE @@ -136,14 +136,14 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { if constexpr (!GQA) { return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.d, params.h, batch_k}; // shape_dK } else { - return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}; // shape_dKaccum + return typename CollectiveEpilogue::ShapedKV {int64_t{seqlen_k_rounded} * int64_t{params.d_rounded}, params.h_k, batch_k}; // shape_dKaccum } }(), [&] { if constexpr (!GQA) { return typename CollectiveEpilogue::StridedKV {params.dk_row_stride, _1{}, params.dk_head_stride, !is_varlen_k ? params.dk_batch_stride : 0}; // stride_dK } else { - return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0}; // stride_dKaccum + return typename CollectiveEpilogue::StridedKV {_1{}, int64_t{params.d_rounded} * int64_t{seqlen_k_rounded}, !is_varlen_k ? int64_t{params.h_k} * int64_t{params.d_rounded} * int64_t{params.seqlen_k_rounded} : 0}; // stride_dKaccum } }(), static_cast(!GQA ? params.dv_ptr : params.dv_accum_ptr), @@ -151,14 +151,14 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { if constexpr (!GQA) { return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.dv, params.h, batch_k}; // shape_dV } else { - return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.dv_rounded, params.h_k, batch_k}; // shape_dVaccum + return typename CollectiveEpilogue::ShapedKV {int64_t{seqlen_k_rounded} * int64_t{params.dv_rounded}, params.h_k, batch_k}; // shape_dVaccum } }(), [&] { if constexpr (!GQA) { return typename CollectiveEpilogue::StridedKV {params.dv_row_stride, _1{}, params.dv_head_stride, !is_varlen_k ? params.dv_batch_stride : 0}; // stride_dV } else { - return typename CollectiveEpilogue::StridedKV {_1{}, params.dv_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.dv_rounded * params.seqlen_k_rounded : 0}; // stride_dVaccum + return typename CollectiveEpilogue::StridedKV {_1{}, int64_t{params.dv_rounded} * int64_t{seqlen_k_rounded}, !is_varlen_k ? int64_t{params.h_k} * int64_t{params.dv_rounded} * int64_t{params.seqlen_k_rounded} : 0}; // stride_dVaccum } }(), params.h, @@ -225,7 +225,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { >; typename PostprocessKernel::Arguments postprocess_args { static_cast(params.dq_accum_ptr), - {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum + {int64_t{seqlen_q_rounded} * int64_t{params.d_rounded}, params.h, batch_q}, // shape_dQaccum {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum static_cast(params.dq_ptr), {seqlen_q, params.d, params.h, batch_q}, // shape_dQ @@ -254,7 +254,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { typename PostprocessKerneldKV::Arguments postprocess_dK_args { static_cast(params.dk_accum_ptr), {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dKaccum - {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dKaccum + {_1{}, int64_t{seqlen_k_rounded} * int64_t{params.d_rounded}, !is_varlen_k ? int64_t{params.d_rounded} * int64_t{params.seqlen_k_rounded} * int64_t{params.h_k} : 0}, // stride_dKaccum static_cast(params.dk_ptr), {seqlen_k, params.d, params.h_k, batch_k}, // shape_dK {params.dk_row_stride, _1{}, params.dk_head_stride, params.dk_batch_stride}, // stride_dK @@ -265,8 +265,8 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { typename PostprocessKerneldKV::Params postprocess_dK_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dK_args); typename PostprocessKerneldKV::Arguments postprocess_dV_args { static_cast(params.dv_accum_ptr), - {seqlen_k_rounded * params.dv_rounded, params.h_k, batch_k}, // shape_dVaccum - {_1{}, seqlen_k_rounded * params.dv_rounded, !is_varlen_k ? params.dv_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dVaccum + {int64_t{seqlen_k_rounded} * int64_t{params.dv_rounded}, params.h_k, batch_k}, // shape_dVaccum + {_1{}, int64_t{seqlen_k_rounded} * int64_t{params.dv_rounded}, !is_varlen_k ? int64_t{params.dv_rounded} * int64_t{params.seqlen_k_rounded} * int64_t{params.h_k} : 0}, // stride_dVaccum static_cast(params.dv_ptr), {seqlen_k, params.dv, params.h_k, batch_k}, // shape_dV {params.dv_row_stride, _1{}, params.dv_head_stride, params.dv_batch_stride}, // stride_dV diff --git a/csrc/flash_attn_v3/flash_bwd_postprocess_kernel.h b/csrc/flash_attn_v3/flash_bwd_postprocess_kernel.h index c91e261507d..9fa19f04ed1 100644 --- a/csrc/flash_attn_v3/flash_bwd_postprocess_kernel.h +++ b/csrc/flash_attn_v3/flash_bwd_postprocess_kernel.h @@ -104,7 +104,7 @@ class FlashAttnBwdPostprocessConvertdQ { using ShapedQ = cute::Shape; // (seqlen_q, d, head, batch) using StridedQ = cute::Stride; - using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) + using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) using StridedQaccum = cute::Stride<_1, int64_t, int64_t>; // Device side arguments @@ -174,7 +174,7 @@ class FlashAttnBwdPostprocessConvertdQ { // Step 1: load dQaccum from gmem to smem Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQaccum)), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0); - Tensor gdQaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(m_block)); // (M * K) + Tensor gdQaccum = local_tile(domain_offset(make_coord(int64_t{seqlen_info.offset_padded} * int64_t{kHeadDim}), mdQaccum), Shape>{}, make_coord(m_block)); // (M * K) if constexpr (IsSm90) { // Use BulkCopy static constexpr uint32_t TmaTransactionBytesdQaccum = static_cast(size(SmemLayoutdQaccumFlat{}) * cute::sizeof_bits_v / 8); auto bulk_copy = Copy_Traits{}; diff --git a/csrc/flash_attn_v3/flash_bwd_preprocess_kernel.h b/csrc/flash_attn_v3/flash_bwd_preprocess_kernel.h index 85e877f9d4f..668357e7719 100644 --- a/csrc/flash_attn_v3/flash_bwd_preprocess_kernel.h +++ b/csrc/flash_attn_v3/flash_bwd_preprocess_kernel.h @@ -63,7 +63,7 @@ class FlashAttnBwdPreprocess { using StrideO = cute::Stride; using ShapedPsum = cute::Shape; // (seqlen_q, head, batch) using StridedPsum = cute::Stride<_1, int64_t, int64_t>; - using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) + using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) using StridedQaccum = cute::Stride<_1, int64_t, int64_t>; // Device side arguments @@ -230,7 +230,7 @@ class FlashAttnBwdPreprocess { if constexpr (Clear_dQaccum) { Tensor mdQaccum = make_tensor(make_gmem_ptr(params.ptr_dQaccum), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0); - Tensor gdQaccum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(m_block)); + Tensor gdQaccum = local_tile(cute::domain_offset(make_coord(int64_t{seqlen_info.offset_padded} * int64_t{kHeadDim}), mdQaccum), Shape>{}, make_coord(m_block)); GmemTiledCopyAccum gmem_tiled_copy_dQaccum; auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(thread_idx); Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum); diff --git a/csrc/flash_attn_v3/mainloop_bwd_sm90_tma_gmma_ws.hpp b/csrc/flash_attn_v3/mainloop_bwd_sm90_tma_gmma_ws.hpp index c8fc5671b03..5c6aea98abc 100644 --- a/csrc/flash_attn_v3/mainloop_bwd_sm90_tma_gmma_ws.hpp +++ b/csrc/flash_attn_v3/mainloop_bwd_sm90_tma_gmma_ws.hpp @@ -215,7 +215,7 @@ struct CollectiveMainloopBwdSm90 { using StrideQKV = cute::Stride; using ShapeLSE = cute::Shape; // (seqlen, head, batch) using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch) - using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) + using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) using StridedQaccum = cute::Stride<_1, int64_t, int64_t>; using TMA_QdO = decltype(make_tma_copy_A_sm90( @@ -613,7 +613,7 @@ struct CollectiveMainloopBwdSm90 { bool const is_varlen = Varlen && params.cu_seqlens_q; Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQaccum)), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0); - Tensor gdQaccum_ = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(_)); // (M * K, _) + Tensor gdQaccum_ = local_tile(domain_offset(make_coord(int64_t{seqlen_info.offset_q_padded} * int64_t{kHeadDim}), mdQaccum), Shape>{}, make_coord(_)); // (M * K, _) Tensor gdQaccum = cute::flat_divide(gdQaccum_, Int{}); // (M * K / WG, WG, _) int const num_batch = params.num_batch; @@ -790,7 +790,7 @@ struct CollectiveMainloopBwdSm90 { bool const is_varlen = Varlen && params.cu_seqlens_q; Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQaccum)), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0); - Tensor gdQaccum_ = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(_)); // (M * K, _) + Tensor gdQaccum_ = local_tile(domain_offset(make_coord(int64_t{seqlen_info.offset_q_padded} * int64_t{kHeadDim}), mdQaccum), Shape>{}, make_coord(_)); // (M * K, _) Tensor gdQaccum = cute::flat_divide(gdQaccum_, Int{}); // (M * K / WG, WG, _) // We can reuse r2s_thr_copy_dQaccum for this partitioning Tensor tdQgdQaccum = r2s_thr_copy_dQaccum.partition_D(gdQaccum);