Skip to content

Commit ab10839

Browse files
committed
fix fill_oob_KV for v_transpose
1 parent c759027 commit ab10839

File tree

11 files changed

+1735
-23
lines changed

11 files changed

+1735
-23
lines changed

csrc/sm90/flash_api.cpp

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,19 @@ mha_fwd_kvcache_mla(
6868
const float softmax_scale,
6969
bool is_causal,
7070
const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
71-
const at::Tensor &num_splits // batch_size + 1
71+
const at::Tensor &num_splits, // batch_size + 1
72+
c10::optional<const at::Tensor> &descale_q, // batch_size
73+
c10::optional<const at::Tensor> &descale_k // batch_size
7274
) {
7375
// Check the architecture
7476
auto dprops = at::cuda::getCurrentDeviceProperties();
7577
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
7678
TORCH_CHECK(is_sm90);
7779

7880
// Check data types
79-
auto q_dtype = q.dtype();
80-
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf);
81+
auto q_dtype = q.scalar_type();
82+
TORCH_CHECK(q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf||
83+
q_dtype == torch::kFloat8_e4m3fn, "Unsupported dtype for query tensor");
8184
TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype");
8285
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
8386
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
@@ -106,7 +109,7 @@ mha_fwd_kvcache_mla(
106109
const int num_heads_q = sizes[2];
107110
const int head_size_k = sizes[3];
108111
TORCH_CHECK(head_size_k == 576, "Only head_size_k == 576 is supported");
109-
TORCH_CHECK(head_size_v == 512, "Only head_size_v == 576 is supported");
112+
TORCH_CHECK(head_size_v == 512, "Only head_size_v == 512 is supported");
110113

111114
const int max_num_blocks_per_seq = block_table.size(1);
112115
const int num_blocks = kcache.size(0);
@@ -115,6 +118,20 @@ mha_fwd_kvcache_mla(
115118
TORCH_CHECK(batch_size > 0, "batch size must be postive");
116119
TORCH_CHECK(num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
117120

121+
if (q_dtype == torch::kFloat8_e4m3fn) {
122+
TORCH_CHECK(descale_q.has_value() && descale_k.has_value(), "descale is required when input dtype is fp8");
123+
auto descale_q_value = descale_q.value();
124+
auto descale_k_value = descale_k.value();
125+
CHECK_DEVICE(descale_q_value);
126+
CHECK_DEVICE(descale_k_value);
127+
TORCH_CHECK(descale_q_value.stride(-1) == 1);
128+
TORCH_CHECK(descale_k_value.stride(-1) == 1);
129+
TORCH_CHECK(descale_q_value.dtype() == torch::kFloat);
130+
TORCH_CHECK(descale_k_value.dtype() == torch::kFloat);
131+
CHECK_SHAPE(descale_q_value, 1);
132+
CHECK_SHAPE(descale_k_value, 1);
133+
}
134+
118135
if (seqlen_q_ori == 1) { is_causal = false; }
119136

120137
const int num_q_heads_per_hk = num_heads_q / num_heads_k;
@@ -133,7 +150,8 @@ mha_fwd_kvcache_mla(
133150
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
134151

135152
auto opts = q.options();
136-
at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts);
153+
auto out_type = (q_dtype == torch::kFloat8_e4m3fn) ? torch::kBFloat16 : q_dtype; // Kernel already supports half, but need change python api for output dtype
154+
at::Tensor out = torch::empty({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts.dtype(out_type));
137155
at::Tensor softmax_lse = torch::empty({batch_size, num_heads, q_seq_per_hk}, opts.dtype(at::kFloat));
138156
CHECK_CONTIGUOUS(softmax_lse);
139157

@@ -152,6 +170,12 @@ mha_fwd_kvcache_mla(
152170
params.d_v = head_size_v;
153171
params.scale_softmax = softmax_scale;
154172
params.scale_softmax_log2 = float(softmax_scale * M_LOG2E);
173+
if (q_dtype == torch::kFloat8_e4m3fn) {
174+
// params.descale_q = get_scalar_f32_cpu_only(descale_q); // cpu scalar faster ,but need change sglang api used
175+
// params.descale_k = get_scalar_f32_cpu_only(descale_q); // cpu scalar faster ,but need change sglang api used
176+
params.descale_q_ptr = reinterpret_cast<float*>(descale_q.value().data_ptr());
177+
params.descale_k_ptr = reinterpret_cast<float*>(descale_k.value().data_ptr());
178+
}
155179
// Set the pointers and strides.
156180
params.q_ptr = q.data_ptr();
157181
params.k_ptr = kcache.data_ptr();
@@ -197,6 +221,9 @@ mha_fwd_kvcache_mla(
197221
run_flash_splitkv_mla_kernel<cutlass::half_t>(params, stream);
198222
run_flash_mla_combine_kernel<cutlass::half_t>(params, stream);
199223
#endif
224+
} else if (q_dtype == torch::kFloat8_e4m3fn) { // Output default dtype is bfloat16_t, can support half.
225+
run_flash_splitkv_mla_kernel<cutlass::float_e4m3_t, cutlass::bfloat16_t>(params, stream);
226+
run_flash_mla_combine_kernel<cutlass::bfloat16_t>(params, stream);
200227
} else {
201228
TORCH_CHECK(false, "Unsupported tensor dtype for query");
202229
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/**
2+
* ref to Fa3's SmemTranspose64x64:
3+
* https://github.com/Dao-AILab/flash-attention/blob/0823cf7b5d96499c1c79a4f64b1e256a035ba4b4/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp#L26
4+
*/
5+
6+
#pragma once
7+
using namespace cute;
8+
9+
template <int kBlockN, int kHeadDim>
10+
struct SmemTransposeFp8_64x64 {
11+
static_assert((kBlockN % 64 == 0) && (kHeadDim % 64 == 0));
12+
13+
using Element = cutlass::float_e4m3_t;
14+
using TransposeShapeAtomV = Shape<_64, _64>;
15+
using SmemLayoutAtomV = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom<Element>{}, TransposeShapeAtomV{}));
16+
using SmemLayoutV =
17+
decltype(tile_to_shape(SmemLayoutAtomV{},
18+
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
19+
20+
// for fp8 in-kernel transpose -- src layout
21+
using SmemLayoutDivideV = decltype(tiled_divide(SmemLayoutV{}, TransposeShapeAtomV{}));
22+
using SmemShapeLDSM = Shape<Shape<_8, _8>, Shape<_16, _4>>;
23+
using FactoringShapeV = decltype(make_shape(SmemShapeLDSM{}, shape<1>(SmemLayoutDivideV{}), shape<2>(SmemLayoutDivideV{})));
24+
using SmemLayoutTransposeV = decltype(composition(SmemLayoutDivideV{}, make_layout(FactoringShapeV{})));
25+
26+
// For fp8, this is the memory transpose.
27+
using SmemLayoutAtomVt = decltype(tile_to_shape(GMMA::Layout_K_SW64_Atom<Element>{}, TransposeShapeAtomV{}));
28+
using SmemLayoutVt =
29+
decltype(tile_to_shape(SmemLayoutAtomVt{},
30+
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
31+
32+
// for fp8 in-kernel transpose -- dst layout
33+
using SmemLayoutVtTrans = decltype(composition(
34+
SmemLayoutVt{}, make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1>{})));
35+
using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{}));
36+
using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_16, _4>>;
37+
using FactoringShapeVt = decltype(make_shape(SmemShapeSTSM{}, shape<1>(SmemLayoutDivideVt{}), shape<2>(SmemLayoutDivideVt{})));
38+
using SmemLayoutTransposeVt = decltype(composition(SmemLayoutDivideVt{}, make_layout(FactoringShapeVt{})));
39+
40+
41+
using ldsm_thread_shape = Shape<_4, _1, _8, _4>;
42+
using ldsm_value_shape = Shape<_2, _8, _2, _1>;
43+
using ldsm_value_stride = Stride<_2, _4, _1, _0>;
44+
using TiledCopyLDSM = decltype(make_tiled_copy(Copy_Atom<SM75_U16x8_LDSM_T, Element>{}, Layout<ldsm_thread_shape>{},
45+
Layout<ldsm_value_shape, ldsm_value_stride>{}));
46+
TiledCopyLDSM tiled_copy_ldsm;
47+
48+
using stsm_thread_shape = Shape<_4, _1, _8, _4>;
49+
// using stsm_thread_stride = Stride<_1, _0, _4, _32>;
50+
using stsm_value_shape = Shape<_4, _4, _2, _1>;
51+
using stsm_value_stride = Stride<_1, _8, _4, _0>;
52+
53+
using TiledCopySTSM = decltype(make_tiled_copy(Copy_Atom<SM90_U32x4_STSM_N, Element>{}, Layout<stsm_thread_shape>{},
54+
Layout<stsm_value_shape, stsm_value_stride>{}));
55+
TiledCopySTSM tiled_copy_stsm;
56+
57+
template <class SmemTensor, class SmemTensorOut>
58+
CUTLASS_DEVICE void transpose(SmemTensor &&s_in, SmemTensorOut &&s_out) {
59+
using namespace cute;
60+
61+
auto tid = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
62+
auto thr_copy_ldsm = tiled_copy_ldsm.get_thread_slice(tid);
63+
auto thr_copy_stsm = tiled_copy_stsm.get_thread_slice(tid);
64+
65+
auto tXsX = thr_copy_ldsm.partition_S(s_in);
66+
auto tXrX = make_tensor<Element>(shape(tXsX));
67+
auto tXsX_out = thr_copy_stsm.partition_D(s_out);
68+
69+
cute::copy(tiled_copy_ldsm, tXsX, tXrX);
70+
71+
auto data = tXrX.data();
72+
CUTLASS_PRAGMA_UNROLL
73+
for (int n = 0; n < size(tXrX); n += 8) {
74+
uint32_t *data_32bit = reinterpret_cast<uint32_t *>(&data[n]);
75+
auto upper = data_32bit[0];
76+
auto lower = data_32bit[1];
77+
data_32bit[0] = __byte_perm(upper, lower, 0x6420);
78+
data_32bit[1] = __byte_perm(upper, lower, 0x7531);
79+
}
80+
81+
cute::copy(tiled_copy_stsm, tXrX, tXsX_out);
82+
}
83+
};

csrc/sm90/kernels/params.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ struct Flash_fwd_mla_params {
1414
int q_head_per_hk; // The number of q_head(s) per KV head, = h_q / h_k
1515
bool is_causal;
1616
float scale_softmax, scale_softmax_log2;
17+
// float descale_q, descale_k; // cpu scalar faster ,but need change sglang api used
18+
float* __restrict__ descale_q_ptr = nullptr;
19+
float* __restrict__ descale_k_ptr = nullptr;
1720

1821
void *__restrict__ q_ptr;
1922
void *__restrict__ k_ptr;

csrc/sm90/kernels/splitkv_mla.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1270,7 +1270,7 @@ flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params
12701270
}
12711271

12721272

1273-
template<typename InputT>
1273+
template<typename InputT, typename OutputT = InputT>
12741274
void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params &params, cudaStream_t stream) {
12751275
using T = Traits<InputT>;
12761276
auto shape_Q = make_shape(params.q_seq_per_hk, params.d, params.h_k, params.b);
@@ -1347,8 +1347,8 @@ void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params &params, cudaStream_t str
13471347
CHECK_CUDA_KERNEL_LAUNCH();
13481348
}
13491349

1350-
template void run_flash_splitkv_mla_kernel<cutlass::bfloat16_t>(Flash_fwd_mla_params &params, cudaStream_t stream);
1350+
template void run_flash_splitkv_mla_kernel<cutlass::bfloat16_t, cutlass::bfloat16_t>(Flash_fwd_mla_params &params, cudaStream_t stream);
13511351

13521352
#ifndef FLASH_MLA_DISABLE_FP16
1353-
template void run_flash_splitkv_mla_kernel<cutlass::half_t>(Flash_fwd_mla_params &params, cudaStream_t stream);
1353+
template void run_flash_splitkv_mla_kernel<cutlass::half_t, cutlass::half_t>(Flash_fwd_mla_params &params, cudaStream_t stream);
13541354
#endif

csrc/sm90/kernels/splitkv_mla.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22

33
#include "params.h"
44

5-
template<typename InputT>
5+
template<typename InputT, typename OutputT = InputT>
66
void run_flash_splitkv_mla_kernel(Flash_fwd_mla_params &params, cudaStream_t stream);

0 commit comments

Comments
 (0)