@@ -52,7 +52,16 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
52
52
" random seed used for initializing input tensors. 0 for "
53
53
" non-deterministic seed" )
54
54
.insert (" warmup" , " 5" , " number of iterations before benchmark the kernel" )
55
- .insert (" repeat" , " 30" , " number of iterations to benchmark the kernel" );
55
+ .insert (" repeat" , " 30" , " number of iterations to benchmark the kernel" )
56
+ // Optional effective seqlen override (exclude PAD) for batch mode
57
+ .insert (" q_eff_lens" ,
58
+ " " ,
59
+ " Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n "
60
+ " Comma-separated list of length 'b'. If empty, no override." )
61
+ .insert (" kv_eff_lens" ,
62
+ " " ,
63
+ " Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n "
64
+ " Comma-separated list of length 'b'. If empty, no override." );
56
65
57
66
bool result = arg_parser.parse (argc, argv);
58
67
return std::make_pair (result, arg_parser);
@@ -111,6 +120,8 @@ struct Problem
111
120
112
121
input_layout = args.get_int (" iperm" ) == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
113
122
output_layout = args.get_int (" operm" ) == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
123
+ q_eff_lens = args.get_int_vec (" q_eff_lens" );
124
+ kv_eff_lens = args.get_int_vec (" kv_eff_lens" );
114
125
}
115
126
116
127
std::vector<ck_tile::index_t > get_query_shape () const
@@ -172,6 +183,8 @@ struct Problem
172
183
mask_info mask;
173
184
TensorLayout input_layout;
174
185
TensorLayout output_layout;
186
+ std::vector<int > q_eff_lens;
187
+ std::vector<int > kv_eff_lens;
175
188
};
176
189
177
190
struct RunConfig
@@ -326,8 +339,10 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
326
339
q_buf.ToDevice (q.data ());
327
340
k_buf.ToDevice (k.data ());
328
341
v_buf.ToDevice (v.data ());
342
+ // Ensure output buffer is zero-initialized so padded regions compare cleanly
343
+ o_buf.SetZero ();
329
344
330
- ck_tile::fmha_fwd_v3_args args;
345
+ ck_tile::fmha_fwd_v3_args args{} ;
331
346
332
347
args.data_type = problem.data_type ;
333
348
args.batch = problem.batch ;
@@ -380,6 +395,60 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
380
395
: problem.seqlen_q * problem.hdim ;
381
396
args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim ;
382
397
398
+ // Optional cumulative seqlen overrides (exclude PAD)
399
+ const bool has_varlen_q = !problem.q_eff_lens .empty () && problem.q_eff_lens [0 ] != -1 ;
400
+ const bool has_varlen_k = !problem.kv_eff_lens .empty () && problem.kv_eff_lens [0 ] != -1 ;
401
+
402
+ auto make_effective_vec = [&](const std::vector<int >& opt_vec, ck_tile::index_t fallback) {
403
+ std::vector<ck_tile::index_t > eff;
404
+ if (!opt_vec.empty () && opt_vec[0 ] != -1 )
405
+ {
406
+ eff.assign (opt_vec.begin (), opt_vec.end ());
407
+ if (eff.size () < static_cast <size_t >(problem.batch ))
408
+ {
409
+ eff.resize (problem.batch , eff.back ());
410
+ }
411
+ }
412
+ else
413
+ {
414
+ eff.assign (problem.batch , fallback);
415
+ }
416
+ return eff;
417
+ };
418
+
419
+ const auto eff_q_vec = make_effective_vec (problem.q_eff_lens , problem.seqlen_q );
420
+ const auto eff_kv_vec = make_effective_vec (problem.kv_eff_lens , problem.seqlen_k );
421
+
422
+ // Calculate cumulative sums for kernel arguments if varlen is used
423
+ std::vector<ck_tile::index_t > cuq_cum, cukv_cum;
424
+ auto calculate_cumulative = [&](const std::vector<ck_tile::index_t >& per_batch_vec,
425
+ std::vector<ck_tile::index_t >& cum_vec) {
426
+ cum_vec.resize (per_batch_vec.size () + 1 );
427
+ cum_vec[0 ] = 0 ;
428
+ for (std::size_t i = 0 ; i < per_batch_vec.size (); ++i)
429
+ cum_vec[i + 1 ] = cum_vec[i] + per_batch_vec[i];
430
+ };
431
+
432
+ if (has_varlen_q)
433
+ {
434
+ calculate_cumulative (eff_q_vec, cuq_cum);
435
+ }
436
+ if (has_varlen_k)
437
+ {
438
+ calculate_cumulative (eff_kv_vec, cukv_cum);
439
+ }
440
+
441
+ ck_tile::DeviceMem cuq_buf (!cuq_cum.empty () ? cuq_cum.size () * sizeof (ck_tile::index_t ) : 0 );
442
+ ck_tile::DeviceMem cukv_buf (!cukv_cum.empty () ? cukv_cum.size () * sizeof (ck_tile::index_t ) : 0 );
443
+ cuq_buf.ToDevice (!cuq_cum.empty () ? cuq_cum.data () : nullptr );
444
+ cukv_buf.ToDevice (!cukv_cum.empty () ? cukv_cum.data () : nullptr );
445
+ args.cu_seqlen_q_ptr =
446
+ !cuq_cum.empty () ? reinterpret_cast <const ck_tile::index_t *>(cuq_buf.GetDeviceBuffer ())
447
+ : nullptr ;
448
+ args.cu_seqlen_kv_ptr =
449
+ !cukv_cum.empty () ? reinterpret_cast <const ck_tile::index_t *>(cukv_buf.GetDeviceBuffer ())
450
+ : nullptr ;
451
+
383
452
ck_tile::stream_config stream_config{nullptr ,
384
453
true ,
385
454
/* log_level=*/ 0 ,
@@ -442,15 +511,72 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
442
511
o_ref = o_ref.transpose ({0 , 2 , 1 , 3 });
443
512
}
444
513
445
- host::fmha_fwd<float , DataType>(q,
446
- k,
447
- v,
448
- problem.mask ,
449
- o_ref,
450
- ck_tile::identity{},
451
- ck_tile::identity{},
452
- ck_tile::identity{},
453
- ck_tile::scales{problem.softmax_scale });
514
+ // If variable lengths are provided, compute per-batch references
515
+ // with the effective lengths; else compute a single full reference.
516
+ if (has_varlen_q || has_varlen_k)
517
+ {
518
+ // Variable-length aware verification: zero-fill padded region and only compute valid part.
519
+ o_ref.SetZero ();
520
+
521
+ for (int b = 0 ; b < problem.batch ; ++b)
522
+ {
523
+ const ck_tile::index_t seqlen_q_eff = eff_q_vec[b];
524
+ const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b];
525
+
526
+ if (seqlen_q_eff <= 0 || seqlen_kv_eff <= 0 )
527
+ continue ;
528
+
529
+ // Slice current batch from inputs (bshd) and build single-batch tensors
530
+ ck_tile::HostTensor<DataType> q_b ({1 , seqlen_q_eff, problem.nhead_q , problem.hdim });
531
+ ck_tile::HostTensor<DataType> k_b ({1 , seqlen_kv_eff, problem.nhead_kv , problem.hdim });
532
+ ck_tile::HostTensor<DataType> v_b ({1 , seqlen_kv_eff, problem.nhead_kv , problem.hdim });
533
+ ck_tile::HostTensor<DataType> o_b ({1 , seqlen_q_eff, problem.nhead_q , problem.hdim });
534
+
535
+ // Copy effective region
536
+ q_b.ForEach ([&](auto & self, auto idx) {
537
+ // idx: [0, s, h, d]
538
+ self (idx) = q (b, idx[1 ], idx[2 ], idx[3 ]);
539
+ });
540
+ k_b.ForEach ([&](auto & self, auto idx) { self (idx) = k (b, idx[1 ], idx[2 ], idx[3 ]); });
541
+ v_b.ForEach ([&](auto & self, auto idx) { self (idx) = v (b, idx[1 ], idx[2 ], idx[3 ]); });
542
+
543
+ // Compute reference for this batch segment (host::fmha_fwd expects bshd tensors)
544
+ host::fmha_fwd<float , DataType>(q_b,
545
+ k_b,
546
+ v_b,
547
+ problem.mask ,
548
+ o_b,
549
+ ck_tile::identity{},
550
+ ck_tile::identity{},
551
+ ck_tile::identity{},
552
+ ck_tile::scales{problem.softmax_scale });
553
+
554
+ // Scatter into o_ref's bshd descriptor memory
555
+ for (int s = 0 ; s < seqlen_q_eff; ++s)
556
+ {
557
+ for (int h = 0 ; h < problem.nhead_q ; ++h)
558
+ {
559
+ for (int d = 0 ; d < problem.hdim ; ++d)
560
+ {
561
+ o_ref (b, s, h, d) = o_b (0 , s, h, d);
562
+ }
563
+ }
564
+ }
565
+ }
566
+ }
567
+ else
568
+ {
569
+ // No varlen override: compute the full reference once
570
+ host::fmha_fwd<float , DataType>(q,
571
+ k,
572
+ v,
573
+ problem.mask ,
574
+ o_ref,
575
+ ck_tile::identity{},
576
+ ck_tile::identity{},
577
+ ck_tile::identity{},
578
+ ck_tile::scales{problem.softmax_scale });
579
+ }
454
580
455
581
ck_tile::HostTensor<DataType> o (problem.get_output_shape ());
456
582
o_buf.FromDevice (o.data ());
0 commit comments