diff --git a/vlmeval/vlm/llama_vision.py b/vlmeval/vlm/llama_vision.py index 9abbcfd9..649f963f 100644 --- a/vlmeval/vlm/llama_vision.py +++ b/vlmeval/vlm/llama_vision.py @@ -79,7 +79,7 @@ def __init__(self, model_path='meta-llama/Llama-3.2-11B-Vision-Instruct', **kwar self.device = 'cuda' self.processor = AutoProcessor.from_pretrained(model_path) - if 'Instruct' in model_path: + if 'Instruct' in model_path or 'cot' in model_path or 'CoT' in model_path: kwargs_default = dict(do_sample=True, temperature=0.6, top_p=0.9) else: kwargs_default = dict(do_sample=False, max_new_tokens=512, temperature=0.0, top_p=None, num_beams=1) @@ -200,5 +200,7 @@ def generate_inner(self, message, dataset=None): self.kwargs['max_new_tokens'] = 128 else: self.kwargs['max_new_tokens'] = 512 + if "cot" in self.model_name or "CoT" in self.model_name: + self.kwargs['max_new_tokens'] = 2048 output = self.model.generate(**inputs, **self.kwargs) return self.processor.decode(output[0][inputs['input_ids'].shape[1]:]).replace('<|eot_id|>', '')