From 0b9c81bc1d5428ad8c11d1dadd598992a259329b Mon Sep 17 00:00:00 2001 From: Huiqiang Jiang Date: Tue, 16 Jul 2024 15:40:15 +0800 Subject: [PATCH] Hotfix(MInference): fix vllm>=0.4.1 (#44) Co-authored-by: Yucheng Li Co-authored-by: Chengruidong Zhang --- README.md | 2 +- minference/modules/minference_forward.py | 17 +++++++----- minference/patch.py | 34 +++++++----------------- minference/version.py | 2 +- 4 files changed, 22 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index 52bb906..8eb3b77 100644 --- a/README.md +++ b/README.md @@ -83,7 +83,7 @@ pipe(prompt, max_length=10) ``` for vLLM, -> For now, please use vllm==0.4.x +> For now, please use vllm>=0.4.1 ```diff from vllm import LLM, SamplingParams diff --git a/minference/modules/minference_forward.py b/minference/modules/minference_forward.py index 6db50cb..b7ea866 100644 --- a/minference/modules/minference_forward.py +++ b/minference/modules/minference_forward.py @@ -12,9 +12,14 @@ if _is_package_available("vllm"): try: + from vllm import _custom_ops as vllm_ops from vllm.attention.ops.paged_attn import PagedAttention + from vllm_flash_attn import flash_attn_with_kvcache except: - warnings.warn("Only support 'vllm>=0.4.0'. Please update your vllm version.") + import vllm + vllm_version = vllm.__version__ + if vllm_version < "0.4.1": + warnings.warn("Only support 'vllm>=0.4.1'. Please update your vllm version.") from ..ops.block_sparse_flash_attention import block_sparse_attention from ..ops.pit_sparse_flash_attention_v2 import vertical_slash_sparse_attention @@ -1186,7 +1191,7 @@ def minference_prefill_func( # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - cache_ops.reshape_and_cache_flash( + vllm_ops.reshape_and_cache_flash( key, value, key_cache, @@ -1268,12 +1273,10 @@ def minference_prefill_func( # Reshape the output tensor. return output.view(num_tokens, hidden_size) - - if vllm_version == "0.4.1": + if vllm_version in "0.4.1": return forward elif vllm_version == "0.4.2": return forward_vllm_042 - elif vllm_version == "0.4.3": + elif vllm_version >= "0.4.3": return forward_vllm_043 - else: - return forward_vllm_042 + assert False, "Only support 'vllm>=0.4.1'. Please update your vllm version." diff --git a/minference/patch.py b/minference/patch.py index 568a332..69873aa 100644 --- a/minference/patch.py +++ b/minference/patch.py @@ -1030,7 +1030,7 @@ def llama_layer_forward_vllm( def llama_attn_forward_vllm( vllm_version: str = "0.4.2", ): - def llama_attn_forward_vllm_042( + def llama_attn_forward_vllm( self, positions: torch.Tensor, hidden_states: torch.Tensor, @@ -1041,33 +1041,19 @@ def llama_attn_forward_vllm_042( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn( - q, k, v, kv_cache, attn_metadata, self.kv_scale, layer_idx - ) - output, _ = self.o_proj(attn_output) - return output + if "0.4.1" <= vllm_version <= "0.4.2": + attn_output = self.attn( + q, k, v, kv_cache, attn_metadata, self.kv_scale, layer_idx + ) + elif vllm_version >= "0.4.3": + attn_output = self.attn(q, k, v, kv_cache, attn_metadata, layer_idx) + else: + assert False, "Only support 'vllm>=0.4.1'. Please update your vllm version." - def llama_attn_forward_vllm_043( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata, - layer_idx: int, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata, layer_idx) output, _ = self.o_proj(attn_output) return output - if vllm_version == "0.4.2": - return llama_attn_forward_vllm_042 - elif vllm_version == "0.4.3": - return llama_attn_forward_vllm_043 - else: - return llama_attn_forward_vllm_042 + return llama_attn_forward_vllm def vllm_attn_forward( diff --git a/minference/version.py b/minference/version.py index b96d89f..c295776 100644 --- a/minference/version.py +++ b/minference/version.py @@ -8,7 +8,7 @@ _PATCH = "4" # This is mainly for nightly builds which have the suffix ".dev$DATE". See # https://semver.org/#is-v123-a-semantic-version for the semantics. -_SUFFIX = ".post3" +_SUFFIX = ".post4" VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR) VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX)