From dbe9029f0773e3a2161ef54f38e3a29063dbad78 Mon Sep 17 00:00:00 2001
From: liyucheng09
We found traditional methods perform poorly in retrieval tasks, with difficulty levels as follows: KV retrieval > Needle in a Haystack > Retrieval.Number > Retrieval PassKey. The main challenge is the semantic difference between needles and the haystack. Traditional methods excel when this difference is larger, as in passkey tasks. KV retrieval requires higher retrieval capabilities since any key can be a target, and multi-needle tasks are even more complex.
We will continue to update our results with more models and datasets in future versions.
-## Does this dynamic sparse attention pattern only exist in long-context LLMs that are not fully trained?
+### 2. Does this dynamic sparse attention pattern only exist in long-context LLMs that are not fully trained?
Firstly, attention is dynamically sparse, a characteristic inherent to the mechanism. We selected state-of-the-art long-context LLMs, GLM-4-9B-1M and LLaMA-3-8B-Instruct-1M, with effective context windows of 64K and 16K. With MInference, these can be extended to 64K and 32K, respectively. We will continue to adapt our method to other advanced long-context LLMs and update our results, as well as explore the theoretical basis for this dynamic sparse attention pattern.
-## Does this dynamic sparse attention pattern only exist in Auto-regressive LMs or RoPE based LLMs?
+### 3. Does this dynamic sparse attention pattern only exist in Auto-regressive LMs or RoPE based LLMs?
Similar vertical and slash line sparse patterns have been discovered in BERT[1] and multi-modal LLMs[2]. Our analysis of T5's attention patterns, shown in the figure, reveals these patterns persist across different heads, even in bidirectional attention.
[1] SparseBERT: Rethinking the Importance Analysis in Self-Attention, ICML 2021.
@@ -46,6 +48,20 @@ Similar vertical and slash line sparse patterns have been discovered in BERT[1]
Figure 1. The sparse pattern in T5 Encoder.
-## What is the relationship between MInference, SSM, Linear Attention, and Sparse Attention? +### 4. What is the relationship between MInference, SSM, Linear Attention, and Sparse Attention? All four approaches (MInference, SSM, Linear Attention, and Sparse Attention) efficiently optimize attention complexity in Transformers, each introducing inductive bias differently. The latter three require training from scratch. Recent works like Mamba-2 and Unified Implicit Attention Representation unify SSM and Linear Attention as static sparse attention, with Mamba-2 itself being a block-wise sparse method. While these approaches show potential due to sparse redundancy in attention, static sparse attention may struggle with dynamic semantic associations in complex tasks. In contrast, dynamic sparse attention is better suited for managing these relationships. + +## š The questions about Usage + +### 1. Error "RuntimeError: tensor does not have a device" + +This issue is due to the current version of MInference being incompatible with torch and CUDA. Please reinstall MInference. +```bash +pip uninstall minference +pip install minference +``` + +## 2. How to use MInference with torch 2.3.x? + +MInference supports various torch versions. However, due to certain issues with flash-attn in torch 2.3.x, please use flash-attn version <= 2.4.x. diff --git a/minference/modules/minference_forward.py b/minference/modules/minference_forward.py index 4abda6c..db4b668 100644 --- a/minference/modules/minference_forward.py +++ b/minference/modules/minference_forward.py @@ -764,7 +764,8 @@ def dense(q, k, v, vertical_size=None, slash_size=None): return fc(q, k, v, vertical_size, slash_size) def minference_vllm_forward( - pattern_config + pattern_config, + vllm_version = "0.4.1" ): def forward( self, @@ -919,4 +920,321 @@ def minference_prefill_func( # Reshape the output tensor. return output.view(num_tokens, hidden_size) - return forward + def forward_vllm_042( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata, + kv_scale: float, + layer_idx: int, + ) -> torch.Tensor: + """Forward pass with FlashAttention and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + self.best_pattern = {int(ii): jj for ii, jj in pattern_config[layer_idx].items()} + def repeat_kv(hidden_states, n_rep): + sqlen, num_head, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :].expand(sqlen, num_head, n_rep, head_dim) + return hidden_states.reshape(sqlen, num_head * n_rep, head_dim) + + def minference_prefill_func( + q, k, v, + ): + # (seq_len, num_heads, head_size) + if q.size(-2) != k.size(-2): + k = repeat_kv(k, q.size(-2) // k.size(-2)) + v = repeat_kv(v, q.size(-2) // v.size(-2)) + + output = torch.empty_like(q) + for head in range(q.size(-2)): + q_head = q[:, head, :].unsqueeze(1) + k_head = k[:, head, :].unsqueeze(1) + v_head = v[:, head, :].unsqueeze(1) + + # (1, seq_len, num_heads, head_size) + q_head = q_head[None, ...] + k_head = k_head[None, ...] + v_head = v_head[None, ...] + + q_head = q_head.transpose(1, 2) + k_head = k_head.transpose(1, 2) + v_head = v_head.transpose(1, 2) + + out = self.gather_last_q_vertical_slash_topk_vllm(q_head, k_head, v_head, head) + + out = out.transpose(1, 2).squeeze(0).contiguous() + output[:, head:head+1, :] = out + return output + + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache is not None: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + PagedAttention.write_to_paged_cache(key, value, key_cache, + value_cache, + attn_metadata.slot_mapping, + attn_metadata.kv_cache_dtype, + kv_scale) + + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + if kv_cache is None or prefill_meta.block_tables.numel() == 0: + # normal attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + # (seq_len, num_heads, head_size) + # out = flash_attn_varlen_func( + # q=query, + # k=key, + # v=value, + # cu_seqlens_q=prefill_meta.seq_start_loc, + # cu_seqlens_k=prefill_meta.seq_start_loc, + # max_seqlen_q=prefill_meta.max_prompt_len, + # max_seqlen_k=prefill_meta.max_prompt_len, + # softmax_scale=self.scale, + # causal=True, + # window_size=self.sliding_window, + # alibi_slopes=self.alibi_slopes, + # ) + out = minference_prefill_func(query, key, value) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out + else: + # prefix-enabled attention + # TODO(Hai) this triton kernel has regression issue (broke) to + # deal with different data types between KV and FP8 KV cache, + # to be addressed separately. + output[:num_prefill_tokens] = PagedAttention.forward_prefix( + query, + key, + value, + key_cache, + value_cache, + prefill_meta.block_tables, + prefill_meta.subquery_start_loc, + prefill_meta.prompt_lens_tensor, + prefill_meta.context_lens, + prefill_meta.max_subquery_len, + self.alibi_slopes, + ) + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, + key_cache, + value_cache, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, + decode_meta.max_seq_len, + attn_metadata.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + kv_scale, + ) + + # Reshape the output tensor. + return output.view(num_tokens, hidden_size) + + def forward_vllm_043( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + kv_scale: float, + layer_idx: int, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + # NOTE(woosuk): FlashAttention does not support FP8 KV cache. + self.best_pattern = {int(ii): jj for ii, jj in pattern_config[layer_idx].items()} + assert kv_scale == 1.0, "kv_scale is not supported in FlashAttention." + + def repeat_kv(hidden_states, n_rep): + sqlen, num_head, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :].expand(sqlen, num_head, n_rep, head_dim) + return hidden_states.reshape(sqlen, num_head * n_rep, head_dim) + + def minference_prefill_func( + q, k, v, + ): + # (seq_len, num_heads, head_size) + if q.size(-2) != k.size(-2): + k = repeat_kv(k, q.size(-2) // k.size(-2)) + v = repeat_kv(v, q.size(-2) // v.size(-2)) + + output = torch.empty_like(q) + for head in range(q.size(-2)): + q_head = q[:, head, :].unsqueeze(1) + k_head = k[:, head, :].unsqueeze(1) + v_head = v[:, head, :].unsqueeze(1) + + # (1, seq_len, num_heads, head_size) + q_head = q_head[None, ...] + k_head = k_head[None, ...] + v_head = v_head[None, ...] + + q_head = q_head.transpose(1, 2) + k_head = k_head.transpose(1, 2) + v_head = v_head.transpose(1, 2) + + out = self.gather_last_q_vertical_slash_topk_vllm(q_head, k_head, v_head, head) + + out = out.transpose(1, 2).squeeze(0).contiguous() + output[:, head:head+1, :] = out + return output + + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache is not None: + key_cache = kv_cache[0] + value_cache = kv_cache[1] + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + cache_ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping.flatten(), + self.kv_cache_dtype, + ) + + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + if (kv_cache is None or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): + # normal attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + # out = flash_attn_varlen_func( + # q=query, + # k=key, + # v=value, + # cu_seqlens_q=prefill_meta.seq_start_loc, + # cu_seqlens_k=prefill_meta.seq_start_loc, + # max_seqlen_q=prefill_meta.max_prefill_seq_len, + # max_seqlen_k=prefill_meta.max_prefill_seq_len, + # softmax_scale=self.scale, + # causal=True, + # window_size=self.sliding_window, + # alibi_slopes=self.alibi_slopes, + # ) + out = minference_prefill_func(query, key, value) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out + else: + # prefix-enabled attention + assert prefill_meta.seq_lens is not None + max_seq_len = max(prefill_meta.seq_lens) + output[:num_prefill_tokens] = flash_attn_varlen_func( + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_k=max_seq_len, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + block_table=prefill_meta.block_tables, + ) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + output[num_prefill_tokens:] = flash_attn_with_kvcache( + decode_query.unsqueeze(1), + key_cache, + value_cache, + block_table=decode_meta.block_tables, + cache_seqlens=decode_meta.seq_lens_tensor, + softmax_scale=self.scale, + causal=True, + alibi_slopes=self.alibi_slopes, + ).squeeze(1) + + # Reshape the output tensor. + return output.view(num_tokens, hidden_size) + + + if vllm_version == "0.4.1": + return forward + elif vllm_version == "0.4.2": + return forward_vllm_042 + elif vllm_version == "0.4.3": + return forward_vllm_043 + else: + return forward_vllm_042 diff --git a/minference/patch.py b/minference/patch.py index 211b166..b0d25ee 100644 --- a/minference/patch.py +++ b/minference/patch.py @@ -1026,19 +1026,46 @@ def llama_layer_forward_vllm( def llama_attn_forward_vllm( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata, - layer_idx: int, -) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata, self.kv_scale, layer_idx) - output, _ = self.o_proj(attn_output) - return output + vllm_version: str = "0.4.2", +): + def llama_attn_forward_vllm_042( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata, + layer_idx: int, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn( + q, k, v, kv_cache, attn_metadata, self.kv_scale, layer_idx + ) + output, _ = self.o_proj(attn_output) + return output + + def llama_attn_forward_vllm_043( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata, + layer_idx: int, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata, layer_idx) + output, _ = self.o_proj(attn_output) + return output + + if vllm_version == "0.4.2": + return llama_attn_forward_vllm_042 + elif vllm_version == "0.4.3": + return llama_attn_forward_vllm_043 + else: + return llama_attn_forward_vllm_042 def vllm_attn_forward( @@ -1051,6 +1078,8 @@ def vllm_attn_forward( kv_scale: float = 1.0, layer_idx: int = 0, ) -> torch.Tensor: + # check self._kv_scale + kv_scale = getattr(self, "_kv_scale", kv_scale) return self.impl.forward( query, key, value, kv_cache, attn_metadata, kv_scale, layer_idx ) @@ -1060,6 +1089,7 @@ def minference_patch_vllm( llm, config_file, ): + import vllm from vllm.attention import Attention from vllm.model_executor.models.llama import ( LlamaAttention, @@ -1068,8 +1098,10 @@ def minference_patch_vllm( LlamaModel, ) + vllm_version = vllm.__version__ + config = json.load(open(config_file)) - attn_forward = minference_vllm_forward(config) + attn_forward = minference_vllm_forward(config, vllm_version=vllm_version) def update_module(m): if isinstance(m, Attention): @@ -1086,7 +1118,7 @@ def update_module(m): if isinstance(m, LlamaModel): m.forward = llama_model_forward_vllm.__get__(m, LlamaModel) if isinstance(m, LlamaAttention): - m.forward = llama_attn_forward_vllm.__get__(m, LlamaAttention) + m.forward = llama_attn_forward_vllm(vllm_version).__get__(m, LlamaAttention) llm.llm_engine.model_executor.driver_worker.model_runner.model.apply(update_module) diff --git a/minference/version.py b/minference/version.py index 9bd5554..24807fc 100644 --- a/minference/version.py +++ b/minference/version.py @@ -8,7 +8,7 @@ _PATCH = "4" # This is mainly for nightly builds which have the suffix ".dev$DATE". See # https://semver.org/#is-v123-a-semantic-version for the semantics. -_SUFFIX = "" +_SUFFIX = ".post1" VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR) VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX)