Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Dec 28, 2024
1 parent b99eb6b commit 0a54bb8
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 70 deletions.
2 changes: 1 addition & 1 deletion swift/llm/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
from .model_arch import MODEL_ARCH_MAPPING, ModelArch, ModelKeys, MultiModelKeys, get_model_arch, register_model_arch
from .register import (MODEL_MAPPING, Model, ModelGroup, ModelMeta, fix_do_sample_warning, get_default_device_map,
get_default_torch_dtype, get_model_info_meta, get_model_name, get_model_tokenizer,
get_model_tokenizer_multimodal, get_model_tokenizer_with_flash_attn, get_model_with_value_head,
get_model_tokenizer_multimodal, get_model_tokenizer_with_flash_attn,
load_by_unsloth, register_model)
from .utils import HfConfigFactory, ModelInfo, git_clone_github, safe_snapshot_download
4 changes: 3 additions & 1 deletion swift/llm/model/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,11 @@ class LLMModelType:
mamba = 'mamba'
polylm = 'polylm'
aya = 'aya'

# bert
modern_bert = 'modern_bert'
bert = 'bert'
# reward model
reward_model = 'reward_model'


class MLLMModelType:
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/model/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from . import (baai, baichuan, bert, codefuse, deepseek, gemma, glm, internlm, llama, llava, llm, mamba, microsoft,
minicpm, mistral, mllm, mplug, openbuddy, qwen, telechat, yi)
minicpm, mistral, mllm, mplug, openbuddy, qwen, telechat, yi, reward_model)
33 changes: 33 additions & 0 deletions swift/llm/model/model/reward_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from transformers import AutoConfig
from transformers import AutoModel
from swift.utils import get_logger
from ..constant import LLMModelType
from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_from_local, register_model

logger = get_logger()


def get_model_tokenizer_reward_model(model_dir, *args, **kwargs):
model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
if 'AutoModel' in (getattr(model_config, 'auto_map', None) or {}):
kwargs['automodel_class'] = AutoModel
return get_model_tokenizer_from_local(model_dir, *args, **kwargs)


register_model(
ModelMeta(
LLMModelType.reward_model, [
ModelGroup([
Model('Qwen/Qwen2.5-Math-RM-72B', 'Qwen/Qwen2.5-Math-RM-72B'),
Model('Qwen/Qwen2-Math-RM-72B', 'Qwen/Qwen2-Math-RM-72B'),
]),
ModelGroup([
Model('Shanghai_AI_Laboratory/internlm2-1_8b-reward', 'internlm/internlm2-1_8b-reward'),
Model('Shanghai_AI_Laboratory/internlm2-7b-reward', 'internlm/internlm2-7b-reward'),
Model('Shanghai_AI_Laboratory/internlm2-20b-reward', 'internlm/internlm2-20b-reward'),
]),
],
None,
get_model_tokenizer_reward_model,
tags=['reward_model']))
56 changes: 0 additions & 56 deletions swift/llm/model/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,62 +196,6 @@ def get_model_tokenizer_from_local(model_dir: str,
return model, tokenizer


def get_model_with_value_head(model) -> 'AutoModelForCausalLMWithValueHead':
from trl import AutoModelForCausalLMWithValueHead
lm_head_namings = ['lm_head', 'embed_out']
if not any(hasattr(model, attribute) for attribute in lm_head_namings):
setattr(model, 'lm_head', None) # avoid ValueError

model = AutoModelForCausalLMWithValueHead.from_pretrained(model)

def patch_valuehead_model(model):
attr_list = [
'get_input_embeddings', 'vis_processor', 'extract_feature', 'get_rope_index', 'model', 'vision_tower',
'img2emb', '_encode_image', '_merge_input_ids_with_image_features', 'prepare_inputs_embeds',
'build_conversation_input_ids', 'config', 'get_slice_image_placeholder', 'transform', 'get_vllm_embedding',
'forward_image', 'dtype', 'base_model_prefix', 'device', 'visual'
]
for attr in attr_list:
if hasattr(model.pretrained_model, attr) and not hasattr(model, attr):
setattr(model, attr, getattr(model.pretrained_model, attr))

# PPO compatible
if not hasattr(model, 'score'):
setattr(model, 'score', model.v_head)
if model.base_model_prefix == '' and hasattr(model.pretrained_model, 'language_model'):
model.base_model_prefix = model.pretrained_model.language_model.base_model_prefix

base_model_prefix = model.pretrained_model.base_model_prefix
if hasattr(model.pretrained_model, base_model_prefix):
setattr(model, base_model_prefix, getattr(model.pretrained_model, base_model_prefix))

patch_valuehead_model(model)

# try to load local vhead weights
vhead_params = None
try:
from safetensors import safe_open
vhead_file = os.path.join(model.pretrained_model.model_dir, 'value_head.safetensors')
with safe_open(vhead_file, framework='pt', device='cpu') as f:
vhead_params = {key: f.get_tensor(key) for key in f.keys()}
except Exception:
pass

try:
vhead_file = os.path.join(model.pretrained_model.model_dir, 'value_head.bin')
vhead_params = torch.load(vhead_file, map_location='cpu')
except Exception:
pass

if vhead_params is not None:
model.load_state_dict(vhead_params, strict=False)
logger.info(f'Loading value head weights from {vhead_file}')
else:
logger.info('The local value head weight file was not detected.'
'Ignore it if this is during the reward modeling phase,')
return model


def get_model_tokenizer_with_flash_attn(model_dir: str,
model_info: ModelInfo,
model_kwargs: Dict[str, Any],
Expand Down
11 changes: 0 additions & 11 deletions swift/llm/train/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,6 @@ def _prepare_template(self) -> None:
# Avoid padding labels during the model's forward pass in multimodal models.
self.template.loss_scale = 'last_round'

@classmethod
def prepare_model(cls, args, model, *_args, **kwargs):
model = super().prepare_model(args, model, *_args, **kwargs)
if args.rlhf_type == 'rm':
from trl import AutoModelForCausalLMWithValueHead
lm_head_namings = ['lm_head', 'embed_out']
if not any(hasattr(model, attribute) for attribute in lm_head_namings):
model.lm_head = None # avoid error
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
patch_getattr(AutoModelForCausalLMWithValueHead, 'pretrained_model')
return model

def _get_dataset(self):
args = self.args
Expand Down

0 comments on commit 0a54bb8

Please sign in to comment.