-
Notifications
You must be signed in to change notification settings - Fork 246
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add deepseek-vl2 * Fix Lint * Make distributed timeout configurable --------- Co-authored-by: Xingchao Liu <[email protected]> Co-authored-by: kennymckormick <[email protected]>
- Loading branch information
1 parent
f055c2e
commit cf4d61b
Showing
4 changed files
with
174 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
import sys | ||
import torch | ||
from transformers import AutoModelForCausalLM | ||
import warnings | ||
from .base import BaseModel | ||
from ..smp import * | ||
from PIL import Image | ||
|
||
|
||
class DeepSeekVL2(BaseModel): | ||
|
||
INSTALL_REQ = True | ||
INTERLEAVE = True | ||
|
||
def check_install(self): | ||
try: | ||
import deepseek_vl2 | ||
except Exception as e: | ||
logging.critical( | ||
'Please first install deepseek_vl2 from source codes in: https://github.com/deepseek-ai/DeepSeek-VL2') | ||
raise e | ||
|
||
def __init__(self, model_path='deepseek-ai/deepseek-vl2-tiny', **kwargs): | ||
self.check_install() | ||
assert model_path is not None | ||
self.model_path = model_path | ||
from deepseek_vl2.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM | ||
|
||
self.vl_chat_processor = DeepseekVLV2Processor.from_pretrained(model_path) | ||
self.tokenizer = self.vl_chat_processor.tokenizer | ||
|
||
model: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained(model_path, | ||
trust_remote_code=True, | ||
torch_dtype=torch.bfloat16) | ||
self.model = model.cuda().eval() | ||
|
||
torch.cuda.empty_cache() | ||
default_kwargs = dict(max_new_tokens=512, do_sample=False, use_cache=True) | ||
default_kwargs.update(kwargs) | ||
self.kwargs = default_kwargs | ||
warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ') | ||
|
||
def prepare_inputs(self, message, dataset=None): | ||
|
||
if dataset == 'MMMU_DEV_VAL': | ||
|
||
def prepare_itlist(msgs): | ||
content, images = '', [] | ||
image_idx = 1 | ||
for s in msgs: | ||
if s['type'] == 'image': | ||
images.append(s['value']) | ||
content += f'<image {image_idx}>' | ||
image_idx += 1 | ||
elif s['type'] == 'text': | ||
content += s['value'] | ||
# content = '<image>' * (image_idx-1) + '\n' + content | ||
content = '<image>' * (image_idx - 1) + '\n' + content | ||
return content, images | ||
|
||
conversation = [] | ||
if 'role' not in message[0]: | ||
content, images = prepare_itlist(message) | ||
content = content.replace( | ||
'Please select the correct answer from the options above.', | ||
"Answer with the option's letter from the given choices directly. Answer the question using a single word or phrase.\n" # noqa | ||
) | ||
content = content.replace('Question:', "") | ||
content = content.replace('Options:\n', "") | ||
conversation.append(dict(role='<|User|>', content=content, images=images)) | ||
else: | ||
role_map = {'user': '<|User|>', 'assistant': '<|Assistant|>'} | ||
for msgs in message: | ||
role = role_map[msgs['role']] | ||
content, images = prepare_itlist(msgs['content']) | ||
content = content.replace( | ||
'Please select the correct answer from the options above.', | ||
"Answer with the option's letter from the given choices directly. Answer the question using a single word or phrase.\n" # noqa | ||
) | ||
content = content.replace('Question:', "") | ||
content = content.replace('Options:\n', "") | ||
conversation.append(dict(role=role, content=content, images=images)) | ||
conversation.append(dict(role='<|Assistant|>', content='')) | ||
|
||
else: | ||
|
||
def prepare_itlist(msgs): | ||
content, images = '', [] | ||
for s in msgs: | ||
if s['type'] == 'image': | ||
images.append(s['value']) | ||
content += '<image>\n' | ||
elif s['type'] == 'text': | ||
content += s['value'] | ||
return content, images | ||
|
||
conversation = [] | ||
if 'role' not in message[0]: | ||
content, images = prepare_itlist(message) | ||
conversation.append(dict(role='<|User|>', content=content, images=images)) | ||
else: | ||
role_map = {'user': '<|User|>', 'assistant': '<|Assistant|>'} | ||
for msgs in message: | ||
role = role_map[msgs['role']] | ||
content, images = prepare_itlist(msgs['content']) | ||
conversation.append(dict(role=role, content=content, images=images)) | ||
conversation.append(dict(role='<|Assistant|>', content='')) | ||
|
||
return conversation | ||
|
||
def generate_inner(self, message, dataset=None): | ||
conversation = self.prepare_inputs(message, dataset) | ||
from deepseek_vl2.utils.io import load_pil_images | ||
pil_images = load_pil_images(conversation) | ||
|
||
if dataset == 'MMMU_DEV_VAL': | ||
if len(pil_images): | ||
h, w = pil_images[0].size | ||
pil_images[0] = pil_images[0].resize((2 * h, 2 * w), Image.BILINEAR) | ||
|
||
prepare_inputs = self.vl_chat_processor( | ||
conversations=conversation, | ||
images=pil_images, | ||
force_batchify=True, | ||
system_prompt="" | ||
) | ||
prepare_inputs = prepare_inputs.to(self.model.device) | ||
inputs_embeds = self.model.prepare_inputs_embeds(**prepare_inputs) | ||
|
||
inputs_embeds, past_key_values = self.model.incremental_prefilling( | ||
input_ids=prepare_inputs.input_ids, | ||
images=prepare_inputs.images, | ||
images_seq_mask=prepare_inputs.images_seq_mask, | ||
images_spatial_crop=prepare_inputs.images_spatial_crop, | ||
attention_mask=prepare_inputs.attention_mask, | ||
chunk_size=512 | ||
) | ||
|
||
# run the model to get the response | ||
outputs = self.model.generate( | ||
inputs_embeds=inputs_embeds, | ||
input_ids=prepare_inputs.input_ids, | ||
images=prepare_inputs.images, | ||
images_seq_mask=prepare_inputs.images_seq_mask, | ||
images_spatial_crop=prepare_inputs.images_spatial_crop, | ||
attention_mask=prepare_inputs.attention_mask, | ||
past_key_values=past_key_values, | ||
pad_token_id=self.tokenizer.eos_token_id, | ||
bos_token_id=self.tokenizer.bos_token_id, | ||
eos_token_id=self.tokenizer.eos_token_id, | ||
**self.kwargs | ||
) | ||
|
||
answer = self.tokenizer.decode( | ||
outputs[0][len(prepare_inputs.input_ids[0]):].cpu().tolist(), | ||
skip_special_tokens=True | ||
) | ||
answer = answer.rstrip('.') | ||
|
||
return answer | ||
|
||
def chat_inner(self, message, dataset=None): | ||
return self.generate_inner(message, dataset=dataset) |