Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Feb 3, 2024
1 parent 38e63bf commit b988ce0
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions src/llmtuner/train/ppo/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
from contextlib import nullcontext
from typing import TYPE_CHECKING, Dict, List, Literal, Optional

import torch
from transformers.integrations import is_deepspeed_zero3_enabled

from ...extras.packages import is_requests_available

Expand All @@ -23,18 +25,22 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.


def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
if target == "reward": # save default head temporarily
valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict()
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone())
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone())

model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
model.v_head.load_state_dict(
{
"summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(),
"summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone(),
}
)
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore

params = [model.v_head.summary.weight, model.v_head.summary.bias]
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
else:
context_maybe_zero3 = nullcontext()

with context_maybe_zero3:
if target == "reward": # save default head temporarily
setattr(model, "default_head_weight", model.v_head.summary.weight.data.detach().clone())
setattr(model, "default_head_bias", model.v_head.summary.bias.data.detach().clone())

model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
model.v_head.summary.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone()
model.v_head.summary.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone()


def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
Expand Down

0 comments on commit b988ce0

Please sign in to comment.