Skip to content

Commit

Permalink
Hotfix(MInference): fix vllm>=0.4.1 (#44)
Browse files Browse the repository at this point in the history
Co-authored-by: Yucheng Li <[email protected]>
Co-authored-by: Chengruidong Zhang <[email protected]>
  • Loading branch information
3 people authored Jul 16, 2024
1 parent 50d17d9 commit 0b9c81b
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 33 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ pipe(prompt, max_length=10)
```

for vLLM,
> For now, please use vllm==0.4.x
> For now, please use vllm>=0.4.1
```diff
from vllm import LLM, SamplingParams
Expand Down
17 changes: 10 additions & 7 deletions minference/modules/minference_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@

if _is_package_available("vllm"):
try:
from vllm import _custom_ops as vllm_ops
from vllm.attention.ops.paged_attn import PagedAttention
from vllm_flash_attn import flash_attn_with_kvcache
except:
warnings.warn("Only support 'vllm>=0.4.0'. Please update your vllm version.")
import vllm
vllm_version = vllm.__version__
if vllm_version < "0.4.1":
warnings.warn("Only support 'vllm>=0.4.1'. Please update your vllm version.")

from ..ops.block_sparse_flash_attention import block_sparse_attention
from ..ops.pit_sparse_flash_attention_v2 import vertical_slash_sparse_attention
Expand Down Expand Up @@ -1186,7 +1191,7 @@ def minference_prefill_func(
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
cache_ops.reshape_and_cache_flash(
vllm_ops.reshape_and_cache_flash(
key,
value,
key_cache,
Expand Down Expand Up @@ -1268,12 +1273,10 @@ def minference_prefill_func(
# Reshape the output tensor.
return output.view(num_tokens, hidden_size)


if vllm_version == "0.4.1":
if vllm_version in "0.4.1":
return forward
elif vllm_version == "0.4.2":
return forward_vllm_042
elif vllm_version == "0.4.3":
elif vllm_version >= "0.4.3":
return forward_vllm_043
else:
return forward_vllm_042
assert False, "Only support 'vllm>=0.4.1'. Please update your vllm version."
34 changes: 10 additions & 24 deletions minference/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1030,7 +1030,7 @@ def llama_layer_forward_vllm(
def llama_attn_forward_vllm(
vllm_version: str = "0.4.2",
):
def llama_attn_forward_vllm_042(
def llama_attn_forward_vllm(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
Expand All @@ -1041,33 +1041,19 @@ def llama_attn_forward_vllm_042(
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(
q, k, v, kv_cache, attn_metadata, self.kv_scale, layer_idx
)
output, _ = self.o_proj(attn_output)
return output
if "0.4.1" <= vllm_version <= "0.4.2":
attn_output = self.attn(
q, k, v, kv_cache, attn_metadata, self.kv_scale, layer_idx
)
elif vllm_version >= "0.4.3":
attn_output = self.attn(q, k, v, kv_cache, attn_metadata, layer_idx)
else:
assert False, "Only support 'vllm>=0.4.1'. Please update your vllm version."

def llama_attn_forward_vllm_043(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata,
layer_idx: int,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata, layer_idx)
output, _ = self.o_proj(attn_output)
return output

if vllm_version == "0.4.2":
return llama_attn_forward_vllm_042
elif vllm_version == "0.4.3":
return llama_attn_forward_vllm_043
else:
return llama_attn_forward_vllm_042
return llama_attn_forward_vllm


def vllm_attn_forward(
Expand Down
2 changes: 1 addition & 1 deletion minference/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
_PATCH = "4"
# This is mainly for nightly builds which have the suffix ".dev$DATE". See
# https://semver.org/#is-v123-a-semantic-version for the semantics.
_SUFFIX = ".post3"
_SUFFIX = ".post4"

VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR)
VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX)

0 comments on commit 0b9c81b

Please sign in to comment.