-
Notifications
You must be signed in to change notification settings - Fork 894
/
chat.py
293 lines (238 loc) · 11 KB
/
chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import os
import torch
import json
from PIL import Image
import base64
import io
from accelerate import load_checkpoint_and_dispatch, init_empty_weights
from transformers import AutoTokenizer, AutoModel
from omnilmm.utils import disable_torch_init
from omnilmm.model.omnilmm import OmniLMMForCausalLM
from omnilmm.model.utils import build_transform
from omnilmm.train.train_utils import omni_preprocess
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
def init_omni_lmm(model_path):
torch.backends.cuda.matmul.allow_tf32 = True
disable_torch_init()
model_name = os.path.expanduser(model_path)
print(f'Load omni_lmm model and tokenizer from {model_name}')
tokenizer = AutoTokenizer.from_pretrained(
model_name, model_max_length=2048)
if False:
# model on multiple devices for small size gpu memory (Nvidia 3090 24G x2)
with init_empty_weights():
model = OmniLMMForCausalLM.from_pretrained(model_name, tune_clip=True, torch_dtype=torch.bfloat16)
model = load_checkpoint_and_dispatch(model, model_name, dtype=torch.bfloat16,
device_map="auto", no_split_module_classes=['Eva','MistralDecoderLayer', 'ModuleList', 'Resampler']
)
else:
model = OmniLMMForCausalLM.from_pretrained(
model_name, tune_clip=True, torch_dtype=torch.bfloat16
).to(device='cuda', dtype=torch.bfloat16)
image_processor = build_transform(
is_train=False, input_size=model.model.config.image_size, std_mode='OPENAI_CLIP')
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
assert mm_use_im_start_end
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN], special_tokens=True)
vision_config = model.model.vision_config
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
[DEFAULT_IMAGE_PATCH_TOKEN])[0]
vision_config.use_im_start_end = mm_use_im_start_end
vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
image_token_len = model.model.config.num_query
return model, image_processor, image_token_len, tokenizer
def expand_question_into_multimodal(question_text, image_token_len, im_st_token, im_ed_token, im_patch_token):
if '<image>' in question_text[0]['content']:
question_text[0]['content'] = question_text[0]['content'].replace(
'<image>', im_st_token + im_patch_token * image_token_len + im_ed_token)
else:
question_text[0]['content'] = im_st_token + im_patch_token * \
image_token_len + im_ed_token + '\n' + question_text[0]['content']
return question_text
def wrap_question_for_omni_lmm(question, image_token_len, tokenizer):
question = expand_question_into_multimodal(
question, image_token_len, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN)
conversation = question
data_dict = omni_preprocess(sources=[conversation],
tokenizer=tokenizer,
generation=True)
data_dict = dict(input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0])
return data_dict
class OmniLMM12B:
def __init__(self, model_path) -> None:
model, img_processor, image_token_len, tokenizer = init_omni_lmm(model_path)
self.model = model
self.image_token_len = image_token_len
self.image_transform = img_processor
self.tokenizer = tokenizer
self.model.eval()
def decode(self, image, input_ids):
with torch.inference_mode():
output = self.model.generate_vllm(
input_ids=input_ids.unsqueeze(0).cuda(),
images=image.unsqueeze(0).half().cuda(),
temperature=0.6,
max_new_tokens=1024,
# num_beams=num_beams,
do_sample=True,
output_scores=True,
return_dict_in_generate=True,
repetition_penalty=1.1,
top_k=30,
top_p=0.9,
)
response = self.tokenizer.decode(
output.sequences[0], skip_special_tokens=True)
response = response.strip()
return response
def chat(self, input):
try:
image = Image.open(io.BytesIO(base64.b64decode(input['image']))).convert('RGB')
except Exception as e:
return "Image decode error"
msgs = json.loads(input['question'])
input_ids = wrap_question_for_omni_lmm(
msgs, self.image_token_len, self.tokenizer)['input_ids']
input_ids = torch.as_tensor(input_ids)
#print('input_ids', input_ids)
image = self.image_transform(image)
out = self.decode(image, input_ids)
return out
def img2base64(file_name):
with open(file_name, 'rb') as f:
encoded_string = base64.b64encode(f.read())
return encoded_string
class MiniCPMV:
def __init__(self, model_path) -> None:
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.bfloat16)
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.model.eval().cuda()
def chat(self, input):
try:
image = Image.open(io.BytesIO(base64.b64decode(input['image']))).convert('RGB')
except Exception as e:
return "Image decode error"
msgs = json.loads(input['question'])
answer, context, _ = self.model.chat(
image=image,
msgs=msgs,
context=None,
tokenizer=self.tokenizer,
sampling=True,
temperature=0.7
)
return answer
class MiniCPMV2_5:
def __init__(self, model_path) -> None:
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.float16)
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.model.eval().cuda()
def chat(self, input):
try:
image = Image.open(io.BytesIO(base64.b64decode(input['image']))).convert('RGB')
except Exception as e:
return "Image decode error"
msgs = json.loads(input['question'])
answer = self.model.chat(
image=image,
msgs=msgs,
tokenizer=self.tokenizer,
sampling=True,
temperature=0.7
)
return answer
class MiniCPMV2_6:
def __init__(self, model_path, multi_gpus=False) -> None:
print('torch_version:', torch.__version__)
if multi_gpus: # inference on multi-gpus
from accelerate import load_checkpoint_and_dispatch, init_empty_weights, infer_auto_device_map
with init_empty_weights():
model = AutoModel.from_pretrained(model_path, trust_remote_code=True,
attn_implementation='sdpa', torch_dtype=torch.bfloat16)
device_map = infer_auto_device_map(model, max_memory={0: "10GB", 1: "10GB"},
no_split_module_classes=['SiglipVisionTransformer', 'Qwen2DecoderLayer'])
device_id = device_map["llm.model.embed_tokens"]
device_map["llm.lm_head"] = device_id # first and last layer of llm should be in the same device
device_map["vpm"] = device_id
device_map["resampler"] = device_id
device_id2 = device_map["llm.model.layers.26"]
device_map["llm.model.layers.8"] = device_id2
device_map["llm.model.layers.9"] = device_id2
device_map["llm.model.layers.10"] = device_id2
device_map["llm.model.layers.11"] = device_id2
device_map["llm.model.layers.12"] = device_id2
device_map["llm.model.layers.13"] = device_id2
device_map["llm.model.layers.14"] = device_id2
device_map["llm.model.layers.15"] = device_id2
device_map["llm.model.layers.16"] = device_id2
print(device_map)
self.model = load_checkpoint_and_dispatch(model, model_path, dtype=torch.bfloat16, device_map=device_map)
self.model.eval()
else:
self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True,
attn_implementation='sdpa', torch_dtype=torch.bfloat16)
self.model.eval().cuda()
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
def chat(self, input):
image = None
if "image" in input and len(input["image"]) > 10: # legacy API
try:
image = Image.open(io.BytesIO(base64.b64decode(input['image']))).convert('RGB')
except Exception as e:
return "Image decode error"
msgs = json.loads(input["question"])
for msg in msgs:
contents = msg.pop('content') # support str or List[Dict]
if isinstance(contents, str):
contents = [contents]
new_cnts = []
for c in contents:
if isinstance(c, dict):
if c['type'] == 'text':
c = c['pairs']
elif c['type'] == 'image':
c = Image.open(io.BytesIO(base64.b64decode(c["pairs"]))).convert('RGB')
else:
raise ValueError("content type only support text and image.")
new_cnts.append(c)
msg['content'] = new_cnts
print(f'msgs: {str(msgs)}')
answer = self.model.chat(
image=image,
msgs=msgs,
tokenizer=self.tokenizer,
)
return answer
class MiniCPMVChat:
def __init__(self, model_path, multi_gpus=False) -> None:
if '12B' in model_path:
self.model = OmniLMM12B(model_path)
elif 'MiniCPM-Llama3-V' in model_path:
self.model = MiniCPMV2_5(model_path)
elif 'MiniCPM-V-2_6' in model_path:
self.model = MiniCPMV2_6(model_path, multi_gpus)
else:
self.model = MiniCPMV(model_path)
def chat(self, input):
return self.model.chat(input)
if __name__ == '__main__':
model_path = 'openbmb/OmniLMM-12B'
chat_model = MiniCPMVChat(model_path)
im_64 = img2base64('./assets/worldmap_ck.jpg')
# first round chat
msgs = [{"role": "user", "content": "What is interesting about this image?"}]
input = {"image": im_64, "question": json.dumps(msgs, ensure_ascii=True)}
answer = chat_model.chat(input)
print(msgs[-1]["content"]+'\n', answer)
# second round chat
msgs.append({"role": "assistant", "content": answer})
msgs.append({"role": "user", "content": "Where is China in the image"})
input = {"image": im_64,"question": json.dumps(msgs, ensure_ascii=True)}
answer = chat_model.chat(input)
print(msgs[-1]["content"]+'\n', answer)