Skip to content

Commit

Permalink
[Fix] Refine Qwen-VL2 device assignment
Browse files Browse the repository at this point in the history
  • Loading branch information
kennymckormick committed Dec 30, 2024
1 parent c08ab64 commit 9c85881
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions vlmeval/vlm/qwen2_vl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,22 +108,22 @@ def __init__(
assert max_gpu_mem > 0

# If only one process and GPU memory is less than 40GB
if auto_split_flag():
if '72b' in self.model_path.lower():
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path, torch_dtype='auto', device_map=split_model(), attn_implementation='flash_attention_2'
)
self.model.eval()
elif auto_split_flag():
assert world_size == 1, 'Only support world_size == 1 when AUTO_SPLIT is set for non-72B Qwen2-VL'
# Will Use All GPUs to run one model
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path, torch_dtype='auto', device_map='auto', attn_implementation='flash_attention_2'
)
elif '72b' not in self.model_path.lower():
else:
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path, torch_dtype='auto', device_map='cpu', attn_implementation='flash_attention_2'
)
self.model.cuda().eval()
else:
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path, torch_dtype='auto', device_map=split_model(), attn_implementation='flash_attention_2'
)
self.model.eval()

torch.cuda.empty_cache()

Expand Down

0 comments on commit 9c85881

Please sign in to comment.