diff --git a/MANIFEST.in b/MANIFEST.in index eb5ecfc..05461ae 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,4 @@ recursive-include csrc *.cu recursive-include csrc *.cpp + +recursive-include minference *.json diff --git a/README.md b/README.md index bf6c99d..cf01a0b 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,7 @@ pipe(prompt, max_length=10) ``` for vLLM, +> For now, please use vllm==0.4.1 ```diff from vllm import LLM, SamplingParams diff --git a/minference/modules/minference_forward.py b/minference/modules/minference_forward.py index eb0d612..4abda6c 100644 --- a/minference/modules/minference_forward.py +++ b/minference/modules/minference_forward.py @@ -4,13 +4,17 @@ import inspect import json import os +import warnings from importlib import import_module from transformers.models.llama.modeling_llama import * from transformers.utils.import_utils import _is_package_available if _is_package_available("vllm"): - from vllm.attention.backends.flash_attn import * + try: + from vllm.attention.backends.flash_attn import * + except: + warnings.warn("Only support 'vllm==0.4.1'. Please update your vllm version.") from ..ops.block_sparse_flash_attention import block_sparse_attention from ..ops.pit_sparse_flash_attention_v2 import vertical_slash_sparse_attention @@ -768,7 +772,7 @@ def forward( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata[FlashAttentionMetadata], + attn_metadata, kv_scale: float, layer_idx: int, ) -> torch.Tensor: