@@ -68,16 +68,19 @@ mha_fwd_kvcache_mla(
6868 const float softmax_scale,
6969 bool is_causal,
7070 const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize
71- const at::Tensor &num_splits // batch_size + 1
71+ const at::Tensor &num_splits, // batch_size + 1
72+ c10::optional<const at::Tensor> &descale_q, // batch_size
73+ c10::optional<const at::Tensor> &descale_k // batch_size
7274) {
7375 // Check the architecture
7476 auto dprops = at::cuda::getCurrentDeviceProperties ();
7577 bool is_sm90 = dprops->major == 9 && dprops->minor == 0 ;
7678 TORCH_CHECK (is_sm90);
7779
7880 // Check data types
79- auto q_dtype = q.dtype ();
80- TORCH_CHECK (q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf );
81+ auto q_dtype = q.scalar_type ();
82+ TORCH_CHECK (q_dtype == torch::kBFloat16 || q_dtype == torch::kHalf ||
83+ q_dtype == torch::kFloat8_e4m3fn , " Unsupported dtype for query tensor" );
8184 TORCH_CHECK (kcache.dtype () == q_dtype, " query and key must have the same dtype" );
8285 TORCH_CHECK (seqlens_k.dtype () == torch::kInt32 , " seqlens_k must have dtype int32" );
8386 TORCH_CHECK (block_table.dtype () == torch::kInt32 , " block_table must have dtype torch.int32" );
@@ -106,7 +109,7 @@ mha_fwd_kvcache_mla(
106109 const int num_heads_q = sizes[2 ];
107110 const int head_size_k = sizes[3 ];
108111 TORCH_CHECK (head_size_k == 576 , " Only head_size_k == 576 is supported" );
109- TORCH_CHECK (head_size_v == 512 , " Only head_size_v == 576 is supported" );
112+ TORCH_CHECK (head_size_v == 512 , " Only head_size_v == 512 is supported" );
110113
111114 const int max_num_blocks_per_seq = block_table.size (1 );
112115 const int num_blocks = kcache.size (0 );
@@ -115,6 +118,20 @@ mha_fwd_kvcache_mla(
115118 TORCH_CHECK (batch_size > 0 , " batch size must be postive" );
116119 TORCH_CHECK (num_heads_q % num_heads_k == 0 , " Number of heads in key/value must divide number of heads in query" );
117120
121+ if (q_dtype == torch::kFloat8_e4m3fn ) {
122+ TORCH_CHECK (descale_q.has_value () && descale_k.has_value (), " descale is required when input dtype is fp8" );
123+ auto descale_q_value = descale_q.value ();
124+ auto descale_k_value = descale_k.value ();
125+ CHECK_DEVICE (descale_q_value);
126+ CHECK_DEVICE (descale_k_value);
127+ TORCH_CHECK (descale_q_value.stride (-1 ) == 1 );
128+ TORCH_CHECK (descale_k_value.stride (-1 ) == 1 );
129+ TORCH_CHECK (descale_q_value.dtype () == torch::kFloat );
130+ TORCH_CHECK (descale_k_value.dtype () == torch::kFloat );
131+ CHECK_SHAPE (descale_q_value, 1 );
132+ CHECK_SHAPE (descale_k_value, 1 );
133+ }
134+
118135 if (seqlen_q_ori == 1 ) { is_causal = false ; }
119136
120137 const int num_q_heads_per_hk = num_heads_q / num_heads_k;
@@ -133,7 +150,8 @@ mha_fwd_kvcache_mla(
133150 at::cuda::CUDAGuard device_guard{(char )q.get_device ()};
134151
135152 auto opts = q.options ();
136- at::Tensor out = torch::empty ({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts);
153+ auto out_type = (q_dtype == torch::kFloat8_e4m3fn ) ? torch::kBFloat16 : q_dtype; // Kernel already supports half, but need change python api for output dtype
154+ at::Tensor out = torch::empty ({batch_size, q_seq_per_hk, num_heads, head_size_v}, opts.dtype (out_type));
137155 at::Tensor softmax_lse = torch::empty ({batch_size, num_heads, q_seq_per_hk}, opts.dtype (at::kFloat ));
138156 CHECK_CONTIGUOUS (softmax_lse);
139157
@@ -152,6 +170,12 @@ mha_fwd_kvcache_mla(
152170 params.d_v = head_size_v;
153171 params.scale_softmax = softmax_scale;
154172 params.scale_softmax_log2 = float (softmax_scale * M_LOG2E);
173+ if (q_dtype == torch::kFloat8_e4m3fn ) {
174+ // params.descale_q = get_scalar_f32_cpu_only(descale_q); // cpu scalar faster ,but need change sglang api used
175+ // params.descale_k = get_scalar_f32_cpu_only(descale_q); // cpu scalar faster ,but need change sglang api used
176+ params.descale_q_ptr = reinterpret_cast <float *>(descale_q.value ().data_ptr ());
177+ params.descale_k_ptr = reinterpret_cast <float *>(descale_k.value ().data_ptr ());
178+ }
155179 // Set the pointers and strides.
156180 params.q_ptr = q.data_ptr ();
157181 params.k_ptr = kcache.data_ptr ();
@@ -197,6 +221,9 @@ mha_fwd_kvcache_mla(
197221 run_flash_splitkv_mla_kernel<cutlass::half_t >(params, stream);
198222 run_flash_mla_combine_kernel<cutlass::half_t >(params, stream);
199223#endif
224+ } else if (q_dtype == torch::kFloat8_e4m3fn ) { // Output default dtype is bfloat16_t, can support half.
225+ run_flash_splitkv_mla_kernel<cutlass::float_e4m3_t , cutlass::bfloat16_t >(params, stream);
226+ run_flash_mla_combine_kernel<cutlass::bfloat16_t >(params, stream);
200227 } else {
201228 TORCH_CHECK (false , " Unsupported tensor dtype for query" );
202229 }
0 commit comments