Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] CUDAGraph Compatibility in AppendPagedKVCacheKernel for Variable-Length Inputs #919

Open
SungBalance opened this issue Mar 7, 2025 · 1 comment

Comments

@SungBalance
Copy link

SungBalance commented Mar 7, 2025

Hello flashinfer team,

The support for CUDAGraph compatibility with variable-length inputs has been updated, but one issue remains:

When running chunked prefill or speculative decoding with variable-length tokens under a fixed batch size, key_states and value_states must be appended to the KV cache tensor. However, the current AppendPagedKVCacheKernel does not fully support variable-length tokens.

To run append_paged_kv_cache within a CUDA Graph, I tried the following workaround:
• Among the arguments of append_paged_kv_cache, key_states, value_states, batch_indices, and positions depend on the token length.
• I filled unused token slots in key_states and value_states with zeros, as the model's input size remains fixed.
batch_indices and positions were filled with arbitrary values.

However, for this approach to work correctly, the kernel should refer to kv_indptr and avoid processing any indices beyond the last valid token. The current kernel implementation lacks this check.

Therefore, I propose the following modification to the kernel:

Original Code:

// ...(omitted)
paged_kv.page_size.divmod(
    paged_kv.indptr[batch_indices[i]] * paged_kv.page_size + positions[i],
    page_iter, entry_idx);
// ...(omitted)

Proposed Fix:

// ...(omitted)
uint32_t start_page = paged_kv.indptr[batch_indices[i]];
uint32_t end_page   = paged_kv.indptr[batch_indices[i] + 1];
uint32_t length = (end_page - start_page - 1) * paged_kv.page_size
                  + paged_kv.last_page_len[batch_indices[i]];

if (positions[i] >= length) {
  continue;
}

paged_kv.page_size.divmod(
    (start_page * paged_kv.page_size) + positions[i],
    page_iter, entry_idx);
// ...(omitted)
@SungBalance
Copy link
Author

If there is a better way to handle variable-length inputs with append_paged_kv_cache using CUDA Graphs, please let me know! I’m open to alternative suggestions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant