From bd9b9758722495c5018e9495be637dc47a7665dc Mon Sep 17 00:00:00 2001 From: kim-youngjune Date: Mon, 23 Dec 2024 19:49:46 +0900 Subject: [PATCH] fix def generate_inner_image --- vlmeval/config.py | 1 - vlmeval/vlm/llava/llava.py | 3 +++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/vlmeval/config.py b/vlmeval/config.py index 6c346b57b..9efd1f3a2 100644 --- a/vlmeval/config.py +++ b/vlmeval/config.py @@ -396,4 +396,3 @@ for grp in model_groups: supported_VLM.update(grp) - diff --git a/vlmeval/vlm/llava/llava.py b/vlmeval/vlm/llava/llava.py index 9e37022cf..8137f4bbc 100644 --- a/vlmeval/vlm/llava/llava.py +++ b/vlmeval/vlm/llava/llava.py @@ -796,6 +796,7 @@ def __init__(self, model_path="llava-hf/llava-onevision-qwen2-0.5b-ov-hf", **kwa self.force_sample = self.video_kwargs.get("force_sample", False) self.nframe = kwargs.get("nframe", 8) self.fps = 1 + self.model_path = model_path def generate_inner_image(self, message, dataset=None): content, images = "", [] @@ -823,6 +824,8 @@ def generate_inner_image(self, message, dataset=None): inputs = self.processor(images=images, text=prompt, return_tensors="pt").to(0, torch.float16) output = self.model.generate(**inputs, max_new_tokens=100) + if self.model_path == "NCSOFT/VARCO-VISION-14B-HF": + return self.processor.decode(output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) return self.processor.decode(output[0], skip_special_tokens=True) def generate_inner_video(self, message, dataset=None):