Skip to content

Commit

Permalink
[Improvement] Better AUTO_SPLIT and model split for InternVL2
Browse files Browse the repository at this point in the history
  • Loading branch information
kennymckormick committed Dec 30, 2024
1 parent 14385c5 commit b13a37d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
13 changes: 12 additions & 1 deletion vlmeval/smp/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,4 +277,15 @@ def get_gpu_memory():

def auto_split_flag():
flag = os.environ.get('AUTO_SPLIT', '0')
return flag == '1'
if flag == '1':
return True
_, world_size = get_rank_and_world_size()
try:
import torch
device_count = torch.cuda.device_count()
if device_count > world_size and device_count % world_size == 0:
return True
else:
return False
except:
return False
5 changes: 3 additions & 2 deletions vlmeval/vlm/internvl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def get_local_rank_and_local_world_size():


def split_model(model_path):
num_gpus_per_node = 8
num_gpus_per_node = torch.cuda.device_count()
rank, world_size = get_rank_and_world_size()
try:
local_rank, local_world_size = get_local_rank_and_local_world_size()
Expand All @@ -130,7 +130,7 @@ def split_model(model_path):
gpus_per_process = int(os.environ['GPUS_PER_PROCESS'])
else:
gpus_per_process = 8 # default to use 8 GPUs for one model

gpus_per_process = min(gpus_per_process, num_gpus_per_node // local_world_size)
start_gpu = local_rank * gpus_per_process
end_gpu = start_gpu + gpus_per_process

Expand Down Expand Up @@ -159,6 +159,7 @@ def split_model(model_path):
device_map['language_model.model.embed_tokens'] = visible_devices[0]
device_map['language_model.output'] = visible_devices[0]
device_map['language_model.model.norm'] = visible_devices[0]
device_map['language_model.model.rotary_emb'] = visible_devices[0]
device_map['language_model.lm_head'] = visible_devices[0]
device_map[f'language_model.model.layers.{num_layers - 1}'] = visible_devices[0]

Expand Down

0 comments on commit b13a37d

Please sign in to comment.