From 0a54bb8fa20941c2b6cc036fd79011c86b277ee1 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 28 Dec 2024 10:14:17 +0800 Subject: [PATCH] update --- swift/llm/model/__init__.py | 2 +- swift/llm/model/constant.py | 4 +- swift/llm/model/model/__init__.py | 2 +- swift/llm/model/model/reward_model.py | 33 ++++++++++++++++ swift/llm/model/register.py | 56 --------------------------- swift/llm/train/rlhf.py | 11 ------ 6 files changed, 38 insertions(+), 70 deletions(-) create mode 100644 swift/llm/model/model/reward_model.py diff --git a/swift/llm/model/__init__.py b/swift/llm/model/__init__.py index 754d71520..939db750a 100644 --- a/swift/llm/model/__init__.py +++ b/swift/llm/model/__init__.py @@ -4,6 +4,6 @@ from .model_arch import MODEL_ARCH_MAPPING, ModelArch, ModelKeys, MultiModelKeys, get_model_arch, register_model_arch from .register import (MODEL_MAPPING, Model, ModelGroup, ModelMeta, fix_do_sample_warning, get_default_device_map, get_default_torch_dtype, get_model_info_meta, get_model_name, get_model_tokenizer, - get_model_tokenizer_multimodal, get_model_tokenizer_with_flash_attn, get_model_with_value_head, + get_model_tokenizer_multimodal, get_model_tokenizer_with_flash_attn, load_by_unsloth, register_model) from .utils import HfConfigFactory, ModelInfo, git_clone_github, safe_snapshot_download diff --git a/swift/llm/model/constant.py b/swift/llm/model/constant.py index a87f901c7..82a89bd03 100644 --- a/swift/llm/model/constant.py +++ b/swift/llm/model/constant.py @@ -93,9 +93,11 @@ class LLMModelType: mamba = 'mamba' polylm = 'polylm' aya = 'aya' - + # bert modern_bert = 'modern_bert' bert = 'bert' + # reward model + reward_model = 'reward_model' class MLLMModelType: diff --git a/swift/llm/model/model/__init__.py b/swift/llm/model/model/__init__.py index a972ec64e..82ebf432f 100644 --- a/swift/llm/model/model/__init__.py +++ b/swift/llm/model/model/__init__.py @@ -1,2 +1,2 @@ from . import (baai, baichuan, bert, codefuse, deepseek, gemma, glm, internlm, llama, llava, llm, mamba, microsoft, - minicpm, mistral, mllm, mplug, openbuddy, qwen, telechat, yi) + minicpm, mistral, mllm, mplug, openbuddy, qwen, telechat, yi, reward_model) diff --git a/swift/llm/model/model/reward_model.py b/swift/llm/model/model/reward_model.py new file mode 100644 index 000000000..63b0bb0c9 --- /dev/null +++ b/swift/llm/model/model/reward_model.py @@ -0,0 +1,33 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from transformers import AutoConfig +from transformers import AutoModel +from swift.utils import get_logger +from ..constant import LLMModelType +from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_from_local, register_model + +logger = get_logger() + + +def get_model_tokenizer_reward_model(model_dir, *args, **kwargs): + model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) + if 'AutoModel' in (getattr(model_config, 'auto_map', None) or {}): + kwargs['automodel_class'] = AutoModel + return get_model_tokenizer_from_local(model_dir, *args, **kwargs) + + +register_model( + ModelMeta( + LLMModelType.reward_model, [ + ModelGroup([ + Model('Qwen/Qwen2.5-Math-RM-72B', 'Qwen/Qwen2.5-Math-RM-72B'), + Model('Qwen/Qwen2-Math-RM-72B', 'Qwen/Qwen2-Math-RM-72B'), + ]), + ModelGroup([ + Model('Shanghai_AI_Laboratory/internlm2-1_8b-reward', 'internlm/internlm2-1_8b-reward'), + Model('Shanghai_AI_Laboratory/internlm2-7b-reward', 'internlm/internlm2-7b-reward'), + Model('Shanghai_AI_Laboratory/internlm2-20b-reward', 'internlm/internlm2-20b-reward'), + ]), + ], + None, + get_model_tokenizer_reward_model, + tags=['reward_model'])) diff --git a/swift/llm/model/register.py b/swift/llm/model/register.py index a98406eb8..8cf83a554 100644 --- a/swift/llm/model/register.py +++ b/swift/llm/model/register.py @@ -196,62 +196,6 @@ def get_model_tokenizer_from_local(model_dir: str, return model, tokenizer -def get_model_with_value_head(model) -> 'AutoModelForCausalLMWithValueHead': - from trl import AutoModelForCausalLMWithValueHead - lm_head_namings = ['lm_head', 'embed_out'] - if not any(hasattr(model, attribute) for attribute in lm_head_namings): - setattr(model, 'lm_head', None) # avoid ValueError - - model = AutoModelForCausalLMWithValueHead.from_pretrained(model) - - def patch_valuehead_model(model): - attr_list = [ - 'get_input_embeddings', 'vis_processor', 'extract_feature', 'get_rope_index', 'model', 'vision_tower', - 'img2emb', '_encode_image', '_merge_input_ids_with_image_features', 'prepare_inputs_embeds', - 'build_conversation_input_ids', 'config', 'get_slice_image_placeholder', 'transform', 'get_vllm_embedding', - 'forward_image', 'dtype', 'base_model_prefix', 'device', 'visual' - ] - for attr in attr_list: - if hasattr(model.pretrained_model, attr) and not hasattr(model, attr): - setattr(model, attr, getattr(model.pretrained_model, attr)) - - # PPO compatible - if not hasattr(model, 'score'): - setattr(model, 'score', model.v_head) - if model.base_model_prefix == '' and hasattr(model.pretrained_model, 'language_model'): - model.base_model_prefix = model.pretrained_model.language_model.base_model_prefix - - base_model_prefix = model.pretrained_model.base_model_prefix - if hasattr(model.pretrained_model, base_model_prefix): - setattr(model, base_model_prefix, getattr(model.pretrained_model, base_model_prefix)) - - patch_valuehead_model(model) - - # try to load local vhead weights - vhead_params = None - try: - from safetensors import safe_open - vhead_file = os.path.join(model.pretrained_model.model_dir, 'value_head.safetensors') - with safe_open(vhead_file, framework='pt', device='cpu') as f: - vhead_params = {key: f.get_tensor(key) for key in f.keys()} - except Exception: - pass - - try: - vhead_file = os.path.join(model.pretrained_model.model_dir, 'value_head.bin') - vhead_params = torch.load(vhead_file, map_location='cpu') - except Exception: - pass - - if vhead_params is not None: - model.load_state_dict(vhead_params, strict=False) - logger.info(f'Loading value head weights from {vhead_file}') - else: - logger.info('The local value head weight file was not detected.' - 'Ignore it if this is during the reward modeling phase,') - return model - - def get_model_tokenizer_with_flash_attn(model_dir: str, model_info: ModelInfo, model_kwargs: Dict[str, Any], diff --git a/swift/llm/train/rlhf.py b/swift/llm/train/rlhf.py index 906a3e116..3ec858b1b 100644 --- a/swift/llm/train/rlhf.py +++ b/swift/llm/train/rlhf.py @@ -32,17 +32,6 @@ def _prepare_template(self) -> None: # Avoid padding labels during the model's forward pass in multimodal models. self.template.loss_scale = 'last_round' - @classmethod - def prepare_model(cls, args, model, *_args, **kwargs): - model = super().prepare_model(args, model, *_args, **kwargs) - if args.rlhf_type == 'rm': - from trl import AutoModelForCausalLMWithValueHead - lm_head_namings = ['lm_head', 'embed_out'] - if not any(hasattr(model, attribute) for attribute in lm_head_namings): - model.lm_head = None # avoid error - model = AutoModelForCausalLMWithValueHead.from_pretrained(model) - patch_getattr(AutoModelForCausalLMWithValueHead, 'pretrained_model') - return model def _get_dataset(self): args = self.args