Skip to content

Commit

Permalink
Support LLaVANext
Browse files Browse the repository at this point in the history
  • Loading branch information
kennymckormick committed Mar 21, 2024
1 parent bbf8a3e commit 1cafa6a
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 0 deletions.
7 changes: 7 additions & 0 deletions vlmeval/smp/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,10 @@ def pip_install_robust(package):
subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])
retry -= 1
return False


def version_cmp(v1, v2, op='eq'):
from packaging import version
import operator
op_func = getattr(operator, op)
return op_func(version.parse(v1), version.parse(v2))
79 changes: 79 additions & 0 deletions vlmeval/vlm/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,82 @@ def generate(self, image_path, prompt, dataset=None):

output = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
return output


class LLaVA_Next(CustomPrompt):

def __init__(self, model_pth='llava-hf/llava-v1.6-vicuna-7b-hf', **kwargs):
import transformers
assert version_cmp(transformers.__version__, '4.39.0', 'ge')
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
self.model_pth = model_pth
self.processor = LlavaNextProcessor.from_pretrained(self.model_pth)
model = LlavaNextForConditionalGeneration.from_pretrained(
self.model_pth, torch_dtype=torch.float16, low_cpu_mem_usage=True)
model = model.eval()
self.model = model.cuda()
kwargs_default = dict(do_sample=False, temperature=0, max_new_tokens=512, top_p=None, num_beams=1)
kwargs_default.update(kwargs)
self.kwargs = kwargs_default
warnings.warn(f'Following kwargs received: {self.kwargs}, will use as generation config. ')

def apply_prompt_template(self, prompt):
model_pth = self.model_pth.lower()
if 'mistral' in model_pth:
s = f'[INST] <image>\n {prompt} [/INST]'
elif 'vicuna' in model_pth:
s = (
'A chat between a curious human and an artificial intelligence assistant. '
"The assistant gives helpful, detailed, and polite answers to the human's questions. "
f'USER: <image>\n{prompt} ASSISTANT:'
)
elif '34b' in model_pth:
s = (
f'<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\n{prompt}<|im_end|>'
'<|im_start|>assistant\n'
)
else:
raise NotImplementedError(f'Prompt template for {model_pth} not implemented.')
return s

def use_custom_prompt(self, dataset):
assert dataset is not None
if DATASET_TYPE(dataset) == 'multi-choice':
return True
return False

def build_prompt(self, line, dataset=None):
assert self.use_custom_prompt(dataset)
assert dataset is None or isinstance(dataset, str)
tgt_path = self.dump_image(line, dataset)

question = line['question']
hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None
if hint is not None:
question = hint + '\n' + question

options = {
cand: line[cand]
for cand in string.ascii_uppercase
if cand in line and not pd.isna(line[cand])
}
for key, item in options.items():
question += f'\n{key}. {item}'
prompt = question

if len(options):
prompt += (
'\n请直接回答选项字母。' if cn_string(prompt) else
"\nAnswer with the option's letter from the given choices directly."
)
else:
prompt += '\n请直接回答问题。' if cn_string(prompt) else '\nAnswer the question directly.'
return {'image': tgt_path, 'text': prompt}

def generate(self, image_path, prompt, dataset=None):
image = Image.open(image_path)
prompt_wtmpl = self.apply_prompt_template(prompt)
inputs = self.processor(prompt_wtmpl, image, return_tensors='pt').to('cuda')
output = self.model.generate(**inputs, **self.kwargs)
answer = self.processor.decode(output[0], skip_special_token=True)
return answer

0 comments on commit 1cafa6a

Please sign in to comment.