Skip to content

Commit 86dd59c

Browse files
Jeff-HuangpoyencDDEle
authored
[CK_TILE] Add sequence padding and variable length support in fmha (a… (#2851)
* [CK_TILE] Add sequence padding and variable length support in fmha (and v3) - Group Mode Padding: Introduces the `-s_qpad` argument to support physically padded layouts. Kernels now use padded start pointers (`seqstart_padded_*_ptr`) for memory addressing. - Batch Mode Variable Length: Adds `-q_eff_lens` and `-kv_eff_lens` arguments for efficient processing of variable-length sequences by passing cumulative effective lengths (`cu_seqlen_*_ptr`) to the kernel. - FMHA examples: Support padding and variable length both in group and batch mode. Dispatcher is updated as well (dispatch to kPadSeqLenK enabled pipeline). - New padding test cases: Add padding test cases to `smoke_test_fwd.sh`, and add benchmarks to `benchmark_fwd.sh` and `benchmark_fwd_v3.sh` as well. These test cases and benchmarks that specifically validate/benchmark the new padding and variable-length functionalities in both group and batch modes. * [CK_TILE] Fix build error in fmha unit tests --------- Co-authored-by: Po Yen Chen <[email protected]> Co-authored-by: Yi DING <[email protected]>
1 parent 2aec38f commit 86dd59c

File tree

13 files changed

+1034
-62
lines changed

13 files changed

+1034
-62
lines changed

example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,11 +259,11 @@ def seqtune(self) -> str:
259259
def skcheck(self) -> str:
260260
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
261261
if self.pipeline_tag == 'qr_async':
262-
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
263-
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
262+
if self.skpad == 't' : return f'(a.cu_seqlen_kv_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)'
263+
else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)'
264264
elif self.pipeline_tag in ['qr', 'qs']:
265265
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
266-
else : return f'a.seqlen_k % {self.bn0} == 0'
266+
else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)'
267267
elif self.pipeline_tag == 'qr_async_trload':
268268
if self.skpad == 't' : return 'true'
269269
else: return 'true'

example/ck_tile/01_fmha/example_fmha_fwd.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ auto create_args(int argc, char* argv[])
3333
"0",
3434
"seqlen_k for new key/value, 0 means not to use this at all; "
3535
"-1 to choose s_knew in [1, s] randomly.")
36+
.insert("s_qpad",
37+
"-1",
38+
"seqlen_q stride between 2 batches (group-mode optional).\n"
39+
"Provide positive strides per-batch to simulate physical padding on Q.")
3640
.insert("s_kpad",
3741
"-1",
3842
"seqlen_k stride between 2 batches, currently used in group-mode only\n"
@@ -107,7 +111,15 @@ auto create_args(int argc, char* argv[])
107111
.insert("warmup", "5", "number of iterations before benchmark the kernel")
108112
.insert("repeat", "20", "number of iterations to benchmark the kernel")
109113
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
110-
.insert("jsonfile", "fmha_fwd.json", "json file name to dump results");
114+
.insert("jsonfile", "fmha_fwd.json", "json file name to dump results")
115+
.insert("q_eff_lens",
116+
"",
117+
"Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n"
118+
"Comma-separated list of length 'b'. If empty, no override.")
119+
.insert("kv_eff_lens",
120+
"",
121+
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
122+
"Comma-separated list of length 'b'. If empty, no override.");
111123

112124
bool result = arg_parser.parse(argc, argv);
113125
return std::make_tuple(result, arg_parser);
@@ -127,6 +139,9 @@ auto run(const ck_tile::ArgParser& arg_parser)
127139
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
128140
ck_tile::index_t seqlen_knew = arg_parser.get_int("s_knew");
129141
auto seqlen_kpads = arg_parser.get_int_vec("s_kpad");
142+
auto seqlen_qpads = arg_parser.get_int_vec("s_qpad");
143+
auto q_eff_lens_per_batch = arg_parser.get_int_vec("q_eff_lens");
144+
auto kv_eff_lens_per_batch = arg_parser.get_int_vec("kv_eff_lens");
130145
ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim");
131146
bool i_perm = arg_parser.get_bool("iperm");
132147
bool o_perm = arg_parser.get_bool("operm");
@@ -174,7 +189,10 @@ auto run(const ck_tile::ArgParser& arg_parser)
174189
hdim_q,
175190
hdim_v,
176191
seqlen_knew,
192+
seqlen_qpads,
177193
seqlen_kpads,
194+
q_eff_lens_per_batch,
195+
kv_eff_lens_per_batch,
178196
rotary_dim,
179197
i_perm,
180198
o_perm,

example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp

Lines changed: 137 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,16 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
5252
"random seed used for initializing input tensors. 0 for "
5353
"non-deterministic seed")
5454
.insert("warmup", "5", "number of iterations before benchmark the kernel")
55-
.insert("repeat", "30", "number of iterations to benchmark the kernel");
55+
.insert("repeat", "30", "number of iterations to benchmark the kernel")
56+
// Optional effective seqlen override (exclude PAD) for batch mode
57+
.insert("q_eff_lens",
58+
"",
59+
"Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n"
60+
"Comma-separated list of length 'b'. If empty, no override.")
61+
.insert("kv_eff_lens",
62+
"",
63+
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
64+
"Comma-separated list of length 'b'. If empty, no override.");
5665

5766
bool result = arg_parser.parse(argc, argv);
5867
return std::make_pair(result, arg_parser);
@@ -111,6 +120,8 @@ struct Problem
111120

112121
input_layout = args.get_int("iperm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
113122
output_layout = args.get_int("operm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
123+
q_eff_lens = args.get_int_vec("q_eff_lens");
124+
kv_eff_lens = args.get_int_vec("kv_eff_lens");
114125
}
115126

116127
std::vector<ck_tile::index_t> get_query_shape() const
@@ -172,6 +183,8 @@ struct Problem
172183
mask_info mask;
173184
TensorLayout input_layout;
174185
TensorLayout output_layout;
186+
std::vector<int> q_eff_lens;
187+
std::vector<int> kv_eff_lens;
175188
};
176189

177190
struct RunConfig
@@ -326,8 +339,10 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
326339
q_buf.ToDevice(q.data());
327340
k_buf.ToDevice(k.data());
328341
v_buf.ToDevice(v.data());
342+
// Ensure output buffer is zero-initialized so padded regions compare cleanly
343+
o_buf.SetZero();
329344

330-
ck_tile::fmha_fwd_v3_args args;
345+
ck_tile::fmha_fwd_v3_args args{};
331346

332347
args.data_type = problem.data_type;
333348
args.batch = problem.batch;
@@ -380,6 +395,60 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
380395
: problem.seqlen_q * problem.hdim;
381396
args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim;
382397

398+
// Optional cumulative seqlen overrides (exclude PAD)
399+
const bool has_varlen_q = !problem.q_eff_lens.empty() && problem.q_eff_lens[0] != -1;
400+
const bool has_varlen_k = !problem.kv_eff_lens.empty() && problem.kv_eff_lens[0] != -1;
401+
402+
auto make_effective_vec = [&](const std::vector<int>& opt_vec, ck_tile::index_t fallback) {
403+
std::vector<ck_tile::index_t> eff;
404+
if(!opt_vec.empty() && opt_vec[0] != -1)
405+
{
406+
eff.assign(opt_vec.begin(), opt_vec.end());
407+
if(eff.size() < static_cast<size_t>(problem.batch))
408+
{
409+
eff.resize(problem.batch, eff.back());
410+
}
411+
}
412+
else
413+
{
414+
eff.assign(problem.batch, fallback);
415+
}
416+
return eff;
417+
};
418+
419+
const auto eff_q_vec = make_effective_vec(problem.q_eff_lens, problem.seqlen_q);
420+
const auto eff_kv_vec = make_effective_vec(problem.kv_eff_lens, problem.seqlen_k);
421+
422+
// Calculate cumulative sums for kernel arguments if varlen is used
423+
std::vector<ck_tile::index_t> cuq_cum, cukv_cum;
424+
auto calculate_cumulative = [&](const std::vector<ck_tile::index_t>& per_batch_vec,
425+
std::vector<ck_tile::index_t>& cum_vec) {
426+
cum_vec.resize(per_batch_vec.size() + 1);
427+
cum_vec[0] = 0;
428+
for(std::size_t i = 0; i < per_batch_vec.size(); ++i)
429+
cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i];
430+
};
431+
432+
if(has_varlen_q)
433+
{
434+
calculate_cumulative(eff_q_vec, cuq_cum);
435+
}
436+
if(has_varlen_k)
437+
{
438+
calculate_cumulative(eff_kv_vec, cukv_cum);
439+
}
440+
441+
ck_tile::DeviceMem cuq_buf(!cuq_cum.empty() ? cuq_cum.size() * sizeof(ck_tile::index_t) : 0);
442+
ck_tile::DeviceMem cukv_buf(!cukv_cum.empty() ? cukv_cum.size() * sizeof(ck_tile::index_t) : 0);
443+
cuq_buf.ToDevice(!cuq_cum.empty() ? cuq_cum.data() : nullptr);
444+
cukv_buf.ToDevice(!cukv_cum.empty() ? cukv_cum.data() : nullptr);
445+
args.cu_seqlen_q_ptr =
446+
!cuq_cum.empty() ? reinterpret_cast<const ck_tile::index_t*>(cuq_buf.GetDeviceBuffer())
447+
: nullptr;
448+
args.cu_seqlen_kv_ptr =
449+
!cukv_cum.empty() ? reinterpret_cast<const ck_tile::index_t*>(cukv_buf.GetDeviceBuffer())
450+
: nullptr;
451+
383452
ck_tile::stream_config stream_config{nullptr,
384453
true,
385454
/*log_level=*/0,
@@ -442,15 +511,72 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
442511
o_ref = o_ref.transpose({0, 2, 1, 3});
443512
}
444513

445-
host::fmha_fwd<float, DataType>(q,
446-
k,
447-
v,
448-
problem.mask,
449-
o_ref,
450-
ck_tile::identity{},
451-
ck_tile::identity{},
452-
ck_tile::identity{},
453-
ck_tile::scales{problem.softmax_scale});
514+
// If variable lengths are provided, compute per-batch references
515+
// with the effective lengths; else compute a single full reference.
516+
if(has_varlen_q || has_varlen_k)
517+
{
518+
// Variable-length aware verification: zero-fill padded region and only compute valid part.
519+
o_ref.SetZero();
520+
521+
for(int b = 0; b < problem.batch; ++b)
522+
{
523+
const ck_tile::index_t seqlen_q_eff = eff_q_vec[b];
524+
const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b];
525+
526+
if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0)
527+
continue;
528+
529+
// Slice current batch from inputs (bshd) and build single-batch tensors
530+
ck_tile::HostTensor<DataType> q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
531+
ck_tile::HostTensor<DataType> k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
532+
ck_tile::HostTensor<DataType> v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
533+
ck_tile::HostTensor<DataType> o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
534+
535+
// Copy effective region
536+
q_b.ForEach([&](auto& self, auto idx) {
537+
// idx: [0, s, h, d]
538+
self(idx) = q(b, idx[1], idx[2], idx[3]);
539+
});
540+
k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); });
541+
v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); });
542+
543+
// Compute reference for this batch segment (host::fmha_fwd expects bshd tensors)
544+
host::fmha_fwd<float, DataType>(q_b,
545+
k_b,
546+
v_b,
547+
problem.mask,
548+
o_b,
549+
ck_tile::identity{},
550+
ck_tile::identity{},
551+
ck_tile::identity{},
552+
ck_tile::scales{problem.softmax_scale});
553+
554+
// Scatter into o_ref's bshd descriptor memory
555+
for(int s = 0; s < seqlen_q_eff; ++s)
556+
{
557+
for(int h = 0; h < problem.nhead_q; ++h)
558+
{
559+
for(int d = 0; d < problem.hdim; ++d)
560+
{
561+
o_ref(b, s, h, d) = o_b(0, s, h, d);
562+
}
563+
}
564+
}
565+
}
566+
}
567+
else
568+
{
569+
// No varlen override: compute the full reference once
570+
host::fmha_fwd<float, DataType>(q,
571+
k,
572+
v,
573+
problem.mask,
574+
o_ref,
575+
ck_tile::identity{},
576+
ck_tile::identity{},
577+
ck_tile::identity{},
578+
ck_tile::scales{problem.softmax_scale});
579+
}
454580

455581
ck_tile::HostTensor<DataType> o(problem.get_output_shape());
456582
o_buf.FromDevice(o.data());

example/ck_tile/01_fmha/fmha_fwd.hpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,20 @@ struct fmha_fwd_args
162162
void* lse_ptr;
163163
void* o_ptr;
164164

165+
// Optional cumulative sequence length arrays
166+
// Batch mode: cu_seqlen_* override effective per-batch lengths (exclude PAD)
167+
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
168+
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1]
169+
165170
const void* seqstart_q_ptr;
166171
const void* seqstart_k_ptr;
167172
const void*
168173
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
169174

175+
// Group mode: seqstart_padded_* provide physical starts including PAD (optional)
176+
const void* seqstart_padded_q_ptr = nullptr; // [batch+1]
177+
const void* seqstart_padded_k_ptr = nullptr; // [batch+1]
178+
170179
ck_tile::index_t seqlen_q;
171180
ck_tile::index_t seqlen_k;
172181
ck_tile::index_t batch;
@@ -554,7 +563,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
554563
args.min_seqlen_q,
555564
args.p_drop,
556565
args.s_randval,
557-
args.drop_seed_offset);
566+
args.drop_seed_offset,
567+
args.seqstart_padded_q_ptr,
568+
args.seqstart_padded_k_ptr);
558569
}
559570
else
560571
{ // create batch mode kernel arguments
@@ -600,7 +611,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
600611
args.mask_type,
601612
args.p_drop,
602613
args.s_randval,
603-
args.drop_seed_offset);
614+
args.drop_seed_offset,
615+
args.cu_seqlen_q_ptr,
616+
args.cu_seqlen_kv_ptr);
604617
}
605618
}();
606619

0 commit comments

Comments
 (0)