Skip to content

Commit

Permalink
Merge pull request hiyouga#4706 from T-Atlas/main
Browse files Browse the repository at this point in the history
chore: Update vllm_engine.py to support vllm version >= 0.5.1
  • Loading branch information
hiyouga committed Jul 7, 2024
2 parents a15782c + f84b007 commit 563a27d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/llamafactory/chat/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ..data import get_template_and_fix_tokenizer
from ..extras.logging import get_logger
from ..extras.misc import get_device_count
from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5
from ..extras.packages import is_vllm_available, is_vllm_version_greater_than_0_5, is_vllm_version_greater_than_0_5_1
from ..model import load_config, load_tokenizer
from ..model.model_utils.quantization import QuantizationMethod
from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM
Expand All @@ -29,7 +29,9 @@
from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest

if is_vllm_version_greater_than_0_5():
if is_vllm_version_greater_than_0_5_1():
pass
elif is_vllm_version_greater_than_0_5():
from vllm.multimodal.image import ImagePixelData
else:
from vllm.sequence import MultiModalData
Expand Down Expand Up @@ -130,8 +132,10 @@ async def _generate(
if self.processor is not None and image is not None: # add image features
image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor")
pixel_values = image_processor(image, return_tensors="pt")["pixel_values"]
if is_vllm_version_greater_than_0_5():
multi_modal_data = ImagePixelData(image=pixel_values)
if is_vllm_version_greater_than_0_5_1():
multi_modal_data = {"image": pixel_values}
elif is_vllm_version_greater_than_0_5():
multi_modal_data = ImagePixelData(image=pixel_values)
else: # TODO: remove vllm 0.4.3 support
multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values)
else:
Expand Down
5 changes: 5 additions & 0 deletions src/llamafactory/extras/packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,8 @@ def is_vllm_available():
@lru_cache
def is_vllm_version_greater_than_0_5():
return _get_package_version("vllm") >= version.parse("0.5.0")


@lru_cache
def is_vllm_version_greater_than_0_5_1():
return _get_package_version("vllm") >= version.parse("0.5.1")

0 comments on commit 563a27d

Please sign in to comment.