diff --git a/vlmeval/smp/misc.py b/vlmeval/smp/misc.py index 81b1ae5c4..d0d9e909f 100644 --- a/vlmeval/smp/misc.py +++ b/vlmeval/smp/misc.py @@ -277,4 +277,15 @@ def get_gpu_memory(): def auto_split_flag(): flag = os.environ.get('AUTO_SPLIT', '0') - return flag == '1' + if flag == '1': + return True + _, world_size = get_rank_and_world_size() + try: + import torch + device_count = torch.cuda.device_count() + if device_count > world_size and device_count % world_size == 0: + return True + else: + return False + except: + return False diff --git a/vlmeval/vlm/internvl/utils.py b/vlmeval/vlm/internvl/utils.py index a4e16c8d8..7e3a917fd 100644 --- a/vlmeval/vlm/internvl/utils.py +++ b/vlmeval/vlm/internvl/utils.py @@ -119,7 +119,7 @@ def get_local_rank_and_local_world_size(): def split_model(model_path): - num_gpus_per_node = 8 + num_gpus_per_node = torch.cuda.device_count() rank, world_size = get_rank_and_world_size() try: local_rank, local_world_size = get_local_rank_and_local_world_size() @@ -130,7 +130,7 @@ def split_model(model_path): gpus_per_process = int(os.environ['GPUS_PER_PROCESS']) else: gpus_per_process = 8 # default to use 8 GPUs for one model - + gpus_per_process = min(gpus_per_process, num_gpus_per_node // local_world_size) start_gpu = local_rank * gpus_per_process end_gpu = start_gpu + gpus_per_process @@ -159,6 +159,7 @@ def split_model(model_path): device_map['language_model.model.embed_tokens'] = visible_devices[0] device_map['language_model.output'] = visible_devices[0] device_map['language_model.model.norm'] = visible_devices[0] + device_map['language_model.model.rotary_emb'] = visible_devices[0] device_map['language_model.lm_head'] = visible_devices[0] device_map[f'language_model.model.layers.{num_layers - 1}'] = visible_devices[0]