From b13a37db5e3718779b8e8b40b182e96c9093f3a9 Mon Sep 17 00:00:00 2001 From: kennymckormick Date: Mon, 30 Dec 2024 04:05:53 +0000 Subject: [PATCH] [Improvement] Better `AUTO_SPLIT` and model split for InternVL2 --- vlmeval/smp/misc.py | 13 ++++++++++++- vlmeval/vlm/internvl/utils.py | 5 +++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/vlmeval/smp/misc.py b/vlmeval/smp/misc.py index 81b1ae5c4..d0d9e909f 100644 --- a/vlmeval/smp/misc.py +++ b/vlmeval/smp/misc.py @@ -277,4 +277,15 @@ def get_gpu_memory(): def auto_split_flag(): flag = os.environ.get('AUTO_SPLIT', '0') - return flag == '1' + if flag == '1': + return True + _, world_size = get_rank_and_world_size() + try: + import torch + device_count = torch.cuda.device_count() + if device_count > world_size and device_count % world_size == 0: + return True + else: + return False + except: + return False diff --git a/vlmeval/vlm/internvl/utils.py b/vlmeval/vlm/internvl/utils.py index a4e16c8d8..7e3a917fd 100644 --- a/vlmeval/vlm/internvl/utils.py +++ b/vlmeval/vlm/internvl/utils.py @@ -119,7 +119,7 @@ def get_local_rank_and_local_world_size(): def split_model(model_path): - num_gpus_per_node = 8 + num_gpus_per_node = torch.cuda.device_count() rank, world_size = get_rank_and_world_size() try: local_rank, local_world_size = get_local_rank_and_local_world_size() @@ -130,7 +130,7 @@ def split_model(model_path): gpus_per_process = int(os.environ['GPUS_PER_PROCESS']) else: gpus_per_process = 8 # default to use 8 GPUs for one model - + gpus_per_process = min(gpus_per_process, num_gpus_per_node // local_world_size) start_gpu = local_rank * gpus_per_process end_gpu = start_gpu + gpus_per_process @@ -159,6 +159,7 @@ def split_model(model_path): device_map['language_model.model.embed_tokens'] = visible_devices[0] device_map['language_model.output'] = visible_devices[0] device_map['language_model.model.norm'] = visible_devices[0] + device_map['language_model.model.rotary_emb'] = visible_devices[0] device_map['language_model.lm_head'] = visible_devices[0] device_map[f'language_model.model.layers.{num_layers - 1}'] = visible_devices[0]