Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Add attention sinks #3515

Open
wants to merge 78 commits into
base: main
Choose a base branch
from

Conversation

felixzhu555
Copy link
Contributor

@felixzhu555 felixzhu555 commented Mar 19, 2024

Overview

This PR adds experimental support for attention sinks (#1304), based on this paper and repo. Support is currently limited to RoPE and ALiBi models (e.g. Llama, Mistral/Mixtral, Falcon, Bloom, MPT).

Usage

Set use_attention_sinks=True when instantiating LLM or LLMEngine, or set the --use-attention-sinks CLI argument. Also set enforce_eager=True (attention sinks currently does not work with CUDA graphs), and ensure the attention backend being used is either FlashAttention or XFormers.

Background

Experiments show that the attention mechanism heavily attends to the first few tokens of the sequence being completed, regardless of what the tokens are. Once sequence length exceeds the context length of a model, and we start tossing tokens from the beginning of the KV cache (in a sliding window fashion), the model will generate garbage (high perplexity).

This is where attention sinks come in. By always preserving the KVs for the first few tokens of the sequence while using a sliding window approach for the rest of the KV cache, the model can continue to generate sensible output (low perplexity). Theoretically, the model can stream indefinitely, as long as cache eviction is handled properly. Note the sliding window length is the model's context length.

Example

Suppose our model's context length is 2048, which equals 128 blocks of 16 tokens. Let's pass in a prompt of 2000 tokens. For the next 48 generated tokens, nothing changes; we end up filling 128 blocks so far.

Normally, vLLM forces generation to stop here since the model's context length has been reached. However, using attention sinks we bypass this stopping condition and keep generating.

At the next decode, we are writing the 2049th token to the cache and computing the 2050th token (1-based indexing). Here, we edit the block table to be [block_table[0]] + block_table[2:], where we effectively ignore the 2nd block while retaining the 1st block, which is our attention sink. Notice how the block table is still length 128 because the 129th block was just allocated for token 2049. This modified block table is then used in the attention kernel.

Every 16th decode that follows will ignore an additional block, but always retain the 1st block as the sink.

Modifications

This PR adds a StreamingAttentionSink layer that computes attention using modified block tables with the "sink" block concatenated with the remaining sliding window blocks. In the RoPE case, we always store unrotated keys into the cache, and extra work must be done at every decode to re-rotate all keys for a sequence based on their new positions in the cache. Note: due to this extra work, using attention sinks incurs a significant drop in tokens/s for RoPE models (around 50-70% for Llama).

use_attention_sinks is now an argument to LLMEngine, which passes it to ModelConfig and all the layers of the model class until the model's attention layer, where a StreamingAttentionSink is instantiated. On every forward call of the model's attention layer, normal attention logic is replaced by StreamingAttentionSink logic.

The scheduler evicts (frees) a block (the "ignored" block) whenever a new block is allocated past the model's context length, such that the total number of used blocks is capped at max_model_len // block_size.

Future Work

  • Beam search: currently does not work with attention sinks.
  • Other attention backends: ROCMFlashAttention, torch SDPA
  • Support LoRA: LoRA requests with attention sinks is currently untested.
  • Integrate with speculative decoding: StreamingAttentionSink assumes only 1 token is generated every decode.
  • Integrate with prefix caching: StreamingAttentionSink directly edits the block table for every decode (past the context length), so the hash table for prefix caching cannot be used currently.

@rkooo567
Copy link
Collaborator

Hi, @felixzhu555 . it is https://arxiv.org/abs/2309.17453 right?

@felixzhu555
Copy link
Contributor Author

Yep, trying to implement the logic from that paper. Their repo is https://github.com/mit-han-lab/streaming-llm.

@jqueguiner
Copy link

We need to @rlouf to the PR the guy in charge of outline, it seems that your PR is failing on the guided part.
I'll try to bring him in to help

@DarkLight1337
Copy link
Collaborator

To speed up the CI queue for #5905, I've cancelled the distributed tests for the latest CI run in this PR since they won't pass anyway until #5905 has been merged. Please merge main into your branch after that happens so that the CI can pass once again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants