Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 42 additions & 20 deletions ovis/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,24 +91,47 @@ def args2dict(args):
elif len(module_name_lr) == 1:
module_name = module_name_lr[0]
else:
raise ValueError
match module_name:
case 'all':
module = model
case 'llm':
module = model.llm
case 'visual_tokenizer':
module = model.visual_tokenizer
case 'visual_tokenizer.head':
module = model.visual_tokenizer.head
case 'visual_tokenizer.vit':
module = model.visual_tokenizer.vit
case 'visual_tokenizer.vit.last_block':
module = model.visual_tokenizer._get_last_block()
case 'vte':
module = model.vte
case _:
raise ValueError(f'Invalid train module name: {module_name}')
raise ValueError(f"Invalid module format: {module_name_lr}")

# 解析模块名称,支持指定LLM的特定层
module_parts = module_name.split('.')
current_module = model

try:
# 逐层查找模块
for part in module_parts:
# 处理数字索引(如layers.31)
if part.isdigit():
part = int(part)
current_module = current_module[part]
else:
current_module = getattr(current_module, part)

module = current_module
rank0_print(f"Selected module: {module_name} (found)")

except AttributeError as e:
# 保留原有匹配模式作为备份
match module_name:
case 'all':
module = model
case 'llm':
module = model.llm
print(module)
case 'visual_tokenizer':
module = model.visual_tokenizer
case 'visual_tokenizer.head':
module = model.visual_tokenizer.head
case 'visual_tokenizer.vit':
module = model.visual_tokenizer.vit
case 'visual_tokenizer.vit.last_block':
module = model.visual_tokenizer._get_last_block()
case 'vte':
module = model.vte
case _:
raise ValueError(f'Invalid train module name: {module_name}, error: {str(e)}')

# 启用该模块的梯度计算
module.requires_grad_(True)
parameters.append({'params': module.parameters(), 'lr': module_lr})

Expand Down Expand Up @@ -152,8 +175,7 @@ def args2dict(args):
trainer = Trainer(
model=model,
args=training_args,
callbacks=[MonitorCallback],
**data_module
callbacks=[MonitorCallback],** data_module
)
rankN_print(BEGIN_LINE)
rankN_print(f'model_accepts_loss_kwargs: {trainer.model_accepts_loss_kwargs}')
Expand Down