Skip to content

Commit

Permalink
Update llama_vision.py (#705)
Browse files Browse the repository at this point in the history
  • Loading branch information
XuGW-Kevin authored Dec 31, 2024
1 parent ea28715 commit 6e1a59a
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion vlmeval/vlm/llama_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(self, model_path='meta-llama/Llama-3.2-11B-Vision-Instruct', **kwar

self.device = 'cuda'
self.processor = AutoProcessor.from_pretrained(model_path)
if 'Instruct' in model_path:
if 'Instruct' in model_path or 'cot' in model_path or 'CoT' in model_path:
kwargs_default = dict(do_sample=True, temperature=0.6, top_p=0.9)
else:
kwargs_default = dict(do_sample=False, max_new_tokens=512, temperature=0.0, top_p=None, num_beams=1)
Expand Down Expand Up @@ -200,5 +200,7 @@ def generate_inner(self, message, dataset=None):
self.kwargs['max_new_tokens'] = 128
else:
self.kwargs['max_new_tokens'] = 512
if "cot" in self.model_name or "CoT" in self.model_name:
self.kwargs['max_new_tokens'] = 2048
output = self.model.generate(**inputs, **self.kwargs)
return self.processor.decode(output[0][inputs['input_ids'].shape[1]:]).replace('<|eot_id|>', '')

0 comments on commit 6e1a59a

Please sign in to comment.