Skip to content

Commit

Permalink
FEAT: support qwen2-vl-instruct (#2205)
Browse files Browse the repository at this point in the history
  • Loading branch information
Minamiyama authored Sep 6, 2024
1 parent bcfedf8 commit c60e8fd
Show file tree
Hide file tree
Showing 11 changed files with 313 additions and 7 deletions.
1 change: 1 addition & 0 deletions .github/workflows/python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ jobs:
pip install tensorizer
pip install eva-decord
pip install jj-pytorchvideo
pip install qwen-vl-utils
working-directory: .

- name: Test with pytest
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ all =
loguru # For Fish Speech
natsort # For Fish Speech
loralib # For Fish Speech
qwen-vl-utils # For qwen2-vl
intel =
torch==2.1.0a0
intel_extension_for_pytorch==2.1.10+xpu
Expand All @@ -151,6 +152,7 @@ transformers =
peft
eva-decord # For video in VL
jj-pytorchvideo # For CogVLM2-video
qwen-vl-utils # For qwen2-vl
vllm =
vllm>=0.2.6
sglang =
Expand Down
12 changes: 6 additions & 6 deletions xinference/core/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,15 @@ def __init__(
self._supervisor_ref: Optional[xo.ActorRefType] = None
self._main_pool = main_pool
self._main_pool.recover_sub_pool = self.recover_sub_pool
self._status_guard_ref: xo.ActorRefType["StatusGuardActor"] = ( # type: ignore
None
)
self._status_guard_ref: xo.ActorRefType[
"StatusGuardActor"
] = None # type: ignore
self._event_collector_ref: xo.ActorRefType[ # type: ignore
EventCollectorActor
] = None
self._cache_tracker_ref: xo.ActorRefType[CacheTrackerActor] = ( # type: ignore
None
)
self._cache_tracker_ref: xo.ActorRefType[
CacheTrackerActor
] = None # type: ignore

# internal states.
# temporary placeholder during model launch process:
Expand Down
1 change: 1 addition & 0 deletions xinference/deploy/docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ jj-pytorchvideo # For CogVLM2-video
loguru # For Fish Speech
natsort # For Fish Speech
loralib # For Fish Speech
qwen-vl-utils # For qwen2-vl

# sglang
outlines>=0.0.44
Expand Down
1 change: 1 addition & 0 deletions xinference/deploy/docker/requirements_cpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,4 @@ jj-pytorchvideo # For CogVLM2-video
loguru # For Fish Speech
natsort # For Fish Speech
loralib # For Fish Speech
qwen-vl-utils # For qwen2-vl
2 changes: 2 additions & 0 deletions xinference/model/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def _install():
from .transformers.internlm2 import Internlm2PytorchChatModel
from .transformers.minicpmv25 import MiniCPMV25Model
from .transformers.minicpmv26 import MiniCPMV26Model
from .transformers.qwen2_vl import Qwen2VLChatModel
from .transformers.qwen_vl import QwenVLChatModel
from .transformers.yi_vl import YiVLChatModel
from .vllm.core import VLLMChatModel, VLLMModel, VLLMVisionModel
Expand Down Expand Up @@ -171,6 +172,7 @@ def _install():
PytorchChatModel,
Internlm2PytorchChatModel,
QwenVLChatModel,
Qwen2VLChatModel,
YiVLChatModel,
DeepSeekVLChatModel,
InternVLChatModel,
Expand Down
46 changes: 46 additions & 0 deletions xinference/model/llm/llm_family.json
Original file line number Diff line number Diff line change
Expand Up @@ -6805,5 +6805,51 @@
"stop": [
"</s>"
]
},
{
"version":1,
"context_length":32768,
"model_name":"qwen2-vl-instruct",
"model_lang":[
"en",
"zh"
],
"model_ability":[
"chat",
"vision"
],
"model_description":"Qwen2-VL: To See the World More Clearly.Qwen2-VL is the latest version of the vision language models in the Qwen model familities.",
"model_specs":[
{
"model_format":"pytorch",
"model_size_in_billions":2,
"quantizations":[
"none"
],
"model_id":"Qwen/Qwen2-VL-2B-Instruct",
"model_revision":"096da3b96240e3d66d35be0e5ccbe282eea8d6b1"
},
{
"model_format":"pytorch",
"model_size_in_billions":7,
"quantizations":[
"none"
],
"model_id":"Qwen/Qwen2-VL-7B-Instruct",
"model_revision":"6010982c1010c3b222fa98afc81575f124aa9bd6"
}
],
"prompt_style":{
"style_name":"QWEN",
"system_prompt":"You are a helpful assistant",
"roles":[
"user",
"assistant"
],
"stop": [
"<|im_end|>",
"<|endoftext|>"
]
}
}
]
44 changes: 44 additions & 0 deletions xinference/model/llm/llm_family_modelscope.json
Original file line number Diff line number Diff line change
Expand Up @@ -4508,5 +4508,49 @@
160133,
160132
]
},
{
"version": 1,
"context_length": 32768,
"model_name": "qwen2-vl-instruct",
"model_lang": [
"en",
"zh"
],
"model_ability": [
"chat",
"vision"
],
"model_description": "Qwen2-VL: To See the World More Clearly.Qwen2-VL is the latest version of the vision language models in the Qwen model familities.",
"model_specs": [
{
"model_format": "pytorch",
"model_size_in_billions": 2,
"quantizations": [
"none"
],
"model_hub": "modelscope",
"model_id": "qwen/Qwen2-VL-2B-Instruct",
"model_revision": "master"
},
{
"model_format": "pytorch",
"model_size_in_billions": 7,
"quantizations": [
"none"
],
"model_hub": "modelscope",
"model_id": "qwen/Qwen2-VL-7B-Instruct",
"model_revision": "master"
}
],
"prompt_style": {
"style_name": "QWEN",
"system_prompt": "You are a helpful assistant",
"roles": [
"user",
"assistant"
]
}
}
]
1 change: 1 addition & 0 deletions xinference/model/llm/transformers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"MiniCPM-Llama3-V-2_5",
"MiniCPM-V-2.6",
"glm-4v",
"qwen2-vl-instruct",
]


Expand Down
208 changes: 208 additions & 0 deletions xinference/model/llm/transformers/qwen2_vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import uuid
from typing import Iterator, List, Optional, Union

from ....model.utils import select_device
from ....types import (
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessage,
CompletionChunk,
)
from ..llm_family import LLMFamilyV1, LLMSpecV1
from ..utils import generate_chat_completion, generate_completion_chunk
from .core import PytorchChatModel, PytorchGenerateConfig

logger = logging.getLogger(__name__)


class Qwen2VLChatModel(PytorchChatModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._tokenizer = None
self._model = None
self._device = None
self._processor = None

@classmethod
def match(
cls, model_family: "LLMFamilyV1", model_spec: "LLMSpecV1", quantization: str
) -> bool:
llm_family = model_family.model_family or model_family.model_name
if "qwen2-vl-instruct".lower() in llm_family.lower():
return True
return False

def load(self):
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration

device = self._pytorch_model_config.get("device", "auto")
device = select_device(device)
self._device = device
# for multiple GPU, set back to auto to make multiple devices work
device = "auto" if device == "cuda" else device

self._processor = AutoProcessor.from_pretrained(
self.model_path, trust_remote_code=True
)
self._tokenizer = self._processor.tokenizer
self._model = Qwen2VLForConditionalGeneration.from_pretrained(
self.model_path, device_map=device, trust_remote_code=True
).eval()

def _transform_messages(
self,
messages: List[ChatCompletionMessage],
):
transformed_messages = []
for msg in messages:
new_content = []
role = msg["role"]
content = msg["content"]
if isinstance(content, str):
new_content.append({"type": "text", "text": content})
elif isinstance(content, List):
for item in content: # type: ignore
if "text" in item:
new_content.append({"type": "text", "text": item["text"]})
elif "image_url" in item:
new_content.append(
{"type": "image", "image": item["image_url"]["url"]}
)
elif "video_url" in item:
new_content.append(
{"type": "video", "video": item["video_url"]["url"]}
)
new_message = {"role": role, "content": new_content}
transformed_messages.append(new_message)

return transformed_messages

def chat(
self,
messages: List[ChatCompletionMessage], # type: ignore
generate_config: Optional[PytorchGenerateConfig] = None,
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
messages = self._transform_messages(messages)

generate_config = generate_config if generate_config else {}

stream = generate_config.get("stream", False) if generate_config else False

if stream:
it = self._generate_stream(messages, generate_config)
return self._to_chat_completion_chunks(it)
else:
c = self._generate(messages, generate_config)
return c

def _generate(
self, messages: List, config: PytorchGenerateConfig = {}
) -> ChatCompletion:
from qwen_vl_utils import process_vision_info

# Preparation for inference
text = self._processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self._processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")

# Inference: Generation of the output
generated_ids = self._model.generate(
**inputs,
max_new_tokens=config.get("max_tokens", 512),
temperature=config.get("temperature", 1),
)
generated_ids_trimmed = [
out_ids[len(in_ids) :]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = self._processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
return generate_chat_completion(self.model_uid, output_text)

def _generate_stream(
self, messages: List, config: PytorchGenerateConfig = {}
) -> Iterator[CompletionChunk]:
from threading import Thread

from qwen_vl_utils import process_vision_info
from transformers import TextIteratorStreamer

text = self._processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = self._processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(self._model.device)

tokenizer = self._tokenizer
streamer = TextIteratorStreamer(
tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True
)

gen_kwargs = {
"max_new_tokens": config.get("max_tokens", 512),
"temperature": config.get("temperature", 1),
"streamer": streamer,
**inputs,
}

thread = Thread(target=self._model.generate, kwargs=gen_kwargs)
thread.start()

completion_id = str(uuid.uuid1())
for new_text in streamer:
yield generate_completion_chunk(
chunk_text=new_text,
finish_reason=None,
chunk_id=completion_id,
model_uid=self.model_uid,
prompt_tokens=-1,
completion_tokens=-1,
total_tokens=-1,
has_choice=True,
has_content=True,
)

yield generate_completion_chunk(
chunk_text=None,
finish_reason="stop",
chunk_id=completion_id,
model_uid=self.model_uid,
prompt_tokens=-1,
completion_tokens=-1,
total_tokens=-1,
has_choice=True,
has_content=False,
)
Loading

0 comments on commit c60e8fd

Please sign in to comment.