Skip to content

Commit

Permalink
[Model] Update TeleMM (#672)
Browse files Browse the repository at this point in the history
* [Model] Update TeleMM

* Restore the deleted geminiflash2.0 configuration.

* update lint

* update retry

* update .gitignore

---------

Co-authored-by: Arno.Wei <[email protected]>
Co-authored-by: kennymckormick <[email protected]>
  • Loading branch information
3 people authored Dec 17, 2024
1 parent 5316c5f commit 547b36f
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 102 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,4 @@ outputs/*
demo.ipynb
*json
.vscode
*.swp
286 changes: 185 additions & 101 deletions vlmeval/api/siliconflow.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,99 @@
import math
from vlmeval.smp import *
from vlmeval.api.base import BaseAPI
from vlmeval.dataset import img_root_map
from vlmeval.dataset import DATASET_TYPE

API_BASE = 'https://api.siliconflow.cn/v1/chat/completions'
API_BASE = "https://api.siliconflow.cn/v1/chat/completions"


def resize_image(image: Image.Image, max_height: int, max_width: int) -> Image.Image:
width, height = image.size
if min(width, height) < 50:
scale = 50 / min(width, height)
image = image.resize((int(width * scale), int(height * scale)))
current_pixels = width * height

if current_pixels <= max_height * max_width:
return image

scale = math.sqrt(max_height * max_width / current_pixels)
new_width = int(width * scale)
new_height = int(height * scale)

return image.resize((new_width, new_height), Image.Resampling.LANCZOS)


def encode_image(path: str, max_height: int = 1024, max_width: int = 1024) -> str:
image = Image.open(path).convert("RGB")
image = resize_image(image, max_height, max_width)
height, width = image.size
if min(height, width) < 50:
scale = 50 / min(width, height)
image = image.resize((int(width * scale), int(height * scale)))
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_bytes = buffered.getvalue()
img_base64 = base64.b64encode(img_bytes).decode("utf-8")
return img_base64


class SiliconFlowAPI(BaseAPI):

is_api: bool = True

def __init__(self,
model: str = 'deepseek-ai/DeepSeek-V2.5',
retry: int = 5,
wait: int = 5,
key: str = None,
api_base: str = API_BASE,
verbose: bool = True,
system_prompt: str = None,
timeout: int = 60,
**kwargs):
def __init__(
self,
model: str = "deepseek-ai/DeepSeek-V2.5",
retry: int = 5,
wait: int = 5,
key: str = None,
api_base: str = API_BASE,
verbose: bool = True,
system_prompt: str = None,
timeout: int = 60,
**kwargs,
):

self.model = model
self.api_base = api_base

default_kwargs = {
'stream': False,
'temperature': 0,
'frequency_penalty': 0,
'n': 1,
'max_tokens': 1024,
"stream": False,
"temperature": 0,
"n": 1,
"max_tokens": 1280,
}
for k, v in default_kwargs.items():
if k not in kwargs:
kwargs[k] = default_kwargs[k]
if key is not None:
self.key = key
else:
self.key = os.environ.get('SiliconFlow_API_KEY', '')
headers = {
"Authorization": 'Bearer {}',
"Content-Type": "application/json"
}
headers['Authorization'] = headers['Authorization'].format(self.key)
self.key = os.environ.get("SiliconFlow_API_KEY", "")
headers = {"Authorization": "Bearer {}", "Content-Type": "application/json"}
headers["Authorization"] = headers["Authorization"].format(self.key)
self.headers = headers
super().__init__(wait=wait, retry=retry, system_prompt=system_prompt, verbose=verbose, **kwargs)
super().__init__(
wait=wait,
retry=retry,
system_prompt=system_prompt,
verbose=verbose,
**kwargs,
)

@staticmethod
def build_msgs(msgs_raw):
messages = []
message = {'role': 'user', 'content': []}

def encode_image_to_base64_PNG(image_dir):
image = Image.open(image_dir)
from io import BytesIO
byte_stream = BytesIO()
image.save(byte_stream, format="PNG")
byte_data = byte_stream.getvalue()
base64_encoded_data = base64.b64encode(byte_data)
base64_string = base64_encoded_data.decode("utf-8")

return base64_string
message = {"role": "user", "content": []}
image_b64 = None
for msg in msgs_raw:
if msg['type'] == 'image' and not image_b64:
image_b64 = encode_image_to_base64_PNG(msg['value'])
message['content'].append({
'image_url': {'url': image_b64},
'type': 'image_url'
})
elif msg['type'] == 'text':
message['content'].append({
'text': msg['value'],
'type': 'text'
})
if msg["type"] == "image" and not image_b64:
image_b64 = encode_image(msg["value"])
message["content"].append(
{"image_url": {"url": image_b64}, "type": "image_url"}
)
elif msg["type"] == "text":
message["content"].append({"text": msg["value"], "type": "text"})

messages.append(message)
return messages
Expand All @@ -85,16 +105,19 @@ def generate_inner(self, inputs, **kwargs) -> str:
payload = dict(
model=self.model,
messages=self.build_msgs(msgs_raw=inputs),
**default_kwargs)
**default_kwargs,
)

response = requests.post(self.api_base, headers=self.headers, data=json.dumps(payload))
response = requests.post(
self.api_base, headers=self.headers, data=json.dumps(payload)
)
ret_code = response.status_code
ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code

answer = self.fail_msg
try:
resp_struct = json.loads(response.text)
answer = resp_struct['choices'][0]['message']['content'].strip()
answer = resp_struct["choices"][0]["message"]["content"].strip()
except:
pass
return ret_code, answer, response
Expand All @@ -104,11 +127,17 @@ class TeleMMAPI(SiliconFlowAPI):

is_api: bool = True

def __init__(self,
model: str = 'TeleAI/TeleMM',
key: str = None,
**kwargs):
def __init__(
self,
model: str = "TeleAI/TeleMM",
key: str = None,
max_height: int = 1280,
max_width: int = 784,
**kwargs,
):
super().__init__(model=model, key=key, **kwargs)
self.max_height = max_height
self.max_width = max_width

def dump_image(self, line, dataset):
"""Dump the image(s) of the input line to the corresponding dataset folder.
Expand All @@ -123,63 +152,118 @@ def dump_image(self, line, dataset):
ROOT = LMUDataRoot()
assert isinstance(dataset, str)
# img_root = osp.join(ROOT, 'images', img_root_map[dataset] if dataset in img_root_map else dataset)
img_root = osp.join(ROOT, 'images', img_root_map(dataset))
img_root = osp.join(ROOT, "images", img_root_map(dataset))
os.makedirs(img_root, exist_ok=True)
if 'image' in line:
if isinstance(line['image'], list):
if "image" in line:
if isinstance(line["image"], list):
tgt_path = []
assert 'image_path' in line
for img, im_name in zip(line['image'], line['image_path']):
assert "image_path" in line
for img, im_name in zip(line["image"], line["image_path"]):
path = osp.join(img_root, im_name)
if not read_ok(path):
decode_base64_to_image_file(img, path)
tgt_path.append(path)
else:
tgt_path = osp.join(img_root, f"{line['index']}.jpg")
if not read_ok(tgt_path):
decode_base64_to_image_file(line['image'], tgt_path)
decode_base64_to_image_file(line["image"], tgt_path)
tgt_path = [tgt_path]
else:
assert 'image_path' in line
tgt_path = toliststr(line['image_path'])
assert "image_path" in line
tgt_path = toliststr(line["image_path"])
return tgt_path

def use_custom_prompt(self, dataset):
assert dataset is not None
if listinstr(['MMDU', 'MME-RealWorld', 'MME-RealWorld-CN'], dataset):
# For Multi-Turn we don't have custom prompt
return False
if 'mmmu' in dataset.lower():
return True
return False

def build_mmmu(self, line):
question = line['question']
options = {
cand: line[cand]
for cand in string.ascii_uppercase
if cand in line and not pd.isna(line[cand])
}
options_prompt = 'Options:\n'
for key, item in options.items():
options_prompt += f'{key}. {item}\n'
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
prompt = ''
if hint is not None:
prompt += f'Hint: {hint}\n'
prompt += f'Question: {question}\n'
if len(options):
prompt += options_prompt
prompt += 'Please select the correct answer from the options above. \n'
return prompt

def build_prompt(self, line, dataset=None):
assert dataset is None or isinstance(dataset, str)
assert self.use_custom_prompt(dataset)
tgt_path = self.dump_image(line, dataset)
if 'mmmu' in dataset.lower():
prompt = self.build_mmmu(line)

ret = [dict(type='text', value=prompt)]
ret.extend([dict(type='image', value=s) for s in tgt_path])
return ret
def _prepare_content(
self, inputs: list[dict[str, str]], dataset: str | None = None
) -> list[dict[str, str]]:
"""
inputs list[dict[str, str]], each dict has keys: ['type', 'value']
"""
content = []
has_image = False
for s in inputs:
if s["type"] == "image":
if not has_image:
item = {
"type": "image_url",
"image_url": {
"url": encode_image(
s["value"],
max_height=self.max_height,
max_width=self.max_width,
)
},
}
has_image = True
else:
continue
elif s["type"] == "text":
prompt = s["value"]
if len(prompt) == 0:
continue
if dataset == "HallusionBench":
prompt += " Please answer yes or no directly, without any unnecessary explanation."
elif dataset == "OCRBench":
prompt = (
prompt + "\nExtract the text from the image intactly and "
+ "answer the question concisely and clearly if possible."
)

elif (
dataset == "AI2D_TEST"
or dataset == "MMStar"
or dataset == "MMBench_TEST_EN_V11"
or dataset == "MMVet"
):
prompt = prompt.replace(
"Please select the correct answer from the options above. \n",
"Please select the correct option from the above choices based on the "
+ "input image and question. The final output should only be one option, such as 'A'",
)
elif dataset == "MMBench_TEST_CN_V11":
prompt = prompt.replace(
"Please select the correct answer from the options above. \n",
"请根据输入图像和问题从上述选项中选择正确选项,最终的输出只有一个选项,例如'A'",
)
item = {"type": "text", "text": prompt}
else:
raise ValueError(f"Invalid message type: {s['type']}, {s}")
content.append(item)

return content

def generate_inner(self, inputs, **kwargs) -> str:
default_kwargs = self.default_kwargs
default_kwargs.update(kwargs)

messages = []
messages.append(
{
"role": "user",
"content": self._prepare_content(
inputs, dataset=kwargs.get("dataset", None)
),
}
)

payload = dict(model=self.model, messages=messages, **default_kwargs)

response = requests.post(
self.api_base, headers=self.headers, data=json.dumps(payload)
)
ret_code = response.status_code
ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code

answer = self.fail_msg
try:
resp_struct = json.loads(response.text)
answer = resp_struct["choices"][0]["message"]["content"].strip()
return ret_code, answer, response
except Exception as err:
import traceback

traceback.print_exc()
if self.verbose:
self.logger.error(f"{type(err)}: {err}")
self.logger.error(f"The input messages are {inputs}.")
return -1, "", ""
2 changes: 1 addition & 1 deletion vlmeval/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
"JTVL": partial(JTVLChatAPI, model='jt-vl-chat', temperature=0, retry=10),
"Taiyi": partial(TaiyiAPI, model='taiyi', temperature=0, retry=10),
# TeleMM
'TeleMM': partial(TeleMMAPI, model='TeleAI/TeleMM', stream=False, temperature=0.7, top_p=0.95, top_k=50, frequency_penalty=0, n=1, max_tokens=300, retry=10)
'TeleMM': partial(TeleMMAPI, model='TeleAI/TeleMM', temperature=0, retry=10)
}

mmalaya_series = {
Expand Down

0 comments on commit 547b36f

Please sign in to comment.