Skip to content

Conversation

@juuso-oskari
Copy link

@juuso-oskari juuso-oskari commented Oct 30, 2025

Authors: @Chi-Chu319 @juuso-oskari

This PR implements a unified attention kernel written in CK Tile. It builds on top of the fmha_v3 (composable_kernel/example/ck_tile/01_fmha) with the pipeline largely remaining the same. This PR implements the following features introduced in Triton unified attention kernel:

reduced launch grid size at composable_kernel/example/ck_tile/01_unified_attention/unified_attention_impl.hpp

// args.num_tokens is the cumulative amount of tokens from all sequences
index_t total_num_q_blocks = args.num_tokens / BLOCK_Q + args.num_seqs;
dim3 grids            = Kernel::GridSize2D(args.num_kv_heads, total_num_q_blocks);
return launch_kernel(config, make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));

This is significantly less amount of programs launched compared to before grid=(num_seqs, max_seqlen // BLOCK_M, num_q_heads), which contained lots of empty programs (not all sequences are of length max_seqlen).

But since now the current sequence index cannot be taken from the program id, we need to do a binary search at the beginning of the kernel to find our sequence index (used to index sequence length; needed for determining innerloop length).

This is implemented at composable_kernel/include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp:

// Binary search to find the sequence index for a given global index
CK_TILE_DEVICE static constexpr ck_tile::index_t
find_seq_idx(const int32_t* query_start_len_ptr,
                ck_tile::index_t target_idx,
                ck_tile::index_t num_seqs,
                ck_tile::index_t block_q,
                bool use_q_block_mode)
{
    ck_tile::index_t left = 0;
    ck_tile::index_t right = num_seqs;
    while (left < right)
    {
        ck_tile::index_t mid = (left + right) / 2;
        ck_tile::index_t val = query_start_len_ptr[mid];
        ck_tile::index_t mid_val = use_q_block_mode ? (val / block_q + mid) : val;
        
        if (mid_val <= target_idx)
        {
            left = mid + 1;
        }
        else
        {
            right = mid;
        }
    }
    return left - 1;
}
// usage inside the kernel
const auto [kv_head_idx, q_block_global_idx] = GetTileIndex(pid, kargs);
// grid size is (num_kv_heads, total_num_q_blocks)
// total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
// q.shape[0] is total number of query tokens across all batches
const index_t seq_idx = find_seq_idx(
    kargs.query_start_len_ptr, q_block_global_idx, kargs.num_seqs, BLOCK_Q, true
); // which seq am I

In order to process more query tokens per load in decode settings (where sequence length is small, often only 1), we group query tokens in the head dim. Up to num_queries_per_kv query tokens share the same key/value token (CQA-setting). The total number of grouped tokens for a tile load is BLOCK_M = BLOCK_Q * num_queries_per_kv.

We do this in the kernel implementation by transforming the tensor view for Q in dram:

const auto q_dram = [&]() {
    const auto q_dram_base = make_naive_tensor_view<address_space_enum::global>(
        q_ptr,
        make_tuple(cur_batch_query_len, num_queries_per_kv, HEAD_SIZE),
        make_tuple(kargs.query_stride_0, kargs.query_stride_1, 1),
        number<UnifiedAttentionPipeline::kAlignmentQ>{},
        number<1>{});

    const auto q_dram_pad = pad_tensor_view( // aling seqlen with BLOCK_Q and head dim with HEAD_SIZE_PADDED
        q_dram_base,
        // block sizes
        make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED),
        sequence<true, false, kPadHeadDimQ>{}
    ); // pads to (seq_len_padded, num_head_q, HEAD_SIZE_PADDED)

    const auto q_dram_merged = transform_tensor_view(
                q_dram_pad,
                make_tuple(
                    make_merge_transform(
                        make_tuple(query_len_padded, num_queries_per_kv)
                    ),
                    make_pass_through_transform(HEAD_SIZE_PADDED)
                ),
                make_tuple(sequence<0, 1>{}, sequence<2>{}),
                make_tuple(sequence<0>{}, sequence<1>{})
    ); // flattens the first two dims, head idx is the fastest changing dim in the merged dim
    return q_dram_merged;
}();

This way, pipeline can remain untouched and use the BLOCK_M as its tile size.

build

# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
../script/cmake-ck-dev.sh  ../ <arch>
make tile_example_unified_attention -j1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants