Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions chatlearn/algorithm/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,17 @@
RewardConfig,
PolicyConfig,
RuntimeConfig,
RuntimeEnvConfig
RuntimeEnvConfig,
RolloutManagerConfig
)
from chatlearn.configs.fsdp_config import FSDPPolicyTrainerConfig, FSDPRefPolicyConfig

from chatlearn.algorithm.grpo_utils.advantage_compute import compute_grpo_adv
from chatlearn.algorithm.grpo_utils.advantage_compute import AdvantageComputer
from chatlearn.algorithm.grpo_utils.policy_trainer import PolicyTrainer
from chatlearn.algorithm.grpo_utils.vllm_policy_inference import \
VLLMPolicyInference
from chatlearn.algorithm.grpo_utils.sglang_policy_inference import SGLangPolicyInference, AsyncSGLangPolicyInference
from chatlearn.algorithm.grpo_utils.rollout_manager import RolloutManager
from chatlearn.data.data import read_data_path_list
from chatlearn.models.reward.rule_reward import RuleReward
from chatlearn.runtime.environment import Environment
Expand Down Expand Up @@ -65,6 +67,9 @@ class GrpoModelConfig(BaseConfig):
reward: RewardConfig = field(
default_factory=RewardConfig, metadata={"help": "Reward config."}
)
rollout_manager: RolloutManagerConfig = field(
default=RolloutManagerConfig, metadata={"help": "Rollout manager config. Only useful when partial_rollout is enabled"}
)
ref_policy: Any = field(
default=None,
metadata={
Expand Down Expand Up @@ -186,6 +191,7 @@ def __init__(
reward: RuleReward,
ref_policy: PolicyTrainer,
policy_trainer: PolicyTrainer,
rollout_manager: RolloutManager = None
):
def env_compute_flow(batch):
policy_out = policy.forward_step(batch)
Expand All @@ -194,6 +200,15 @@ def env_compute_flow(batch):
reward_out = reward.forward_step(ref_logprobs_out)
return reward_out

def env_compute_flow_with_partial(batch):
batch = rollout_manager.get_sample_for_rollout(batch)
batch = policy.forward_step(batch)
batch = rollout_manager.post_process_rollout_results(batch)
old_logprobs_out = policy_trainer.forward_step(batch)
ref_logprobs_out = ref_policy.forward_step(old_logprobs_out)
reward_out = reward.forward_step(ref_logprobs_out)
return reward_out

def trainer_compute_flow(batch):
policy_trainer.train_step(batch)

Expand All @@ -202,15 +217,17 @@ def evaluator_flow(batch):
reward_out = reward.eval_forward(policy_out)
return reward_out

env = Environment(env_compute_flow)
if rollout_manager is None:
env = Environment(env_compute_flow)
else:
env = Environment(env_compute_flow_with_partial)
trainer = Trainer(trainer_compute_flow)
evaluator = GRPOEvaluator(evaluator_flow)
super().__init__(
environment=env, trainer=trainer, evaluator=evaluator, name="grpo"
)
self.set_parameter_sync(policy_trainer, policy)


class GrpoAlgorithm(BaseAlgorithm):
"""GrpoAlgorithm"""

Expand All @@ -235,7 +252,13 @@ def run(self) -> None:
RolloutModule_cls = SGLangPolicyInference if self.cfg.models.policy.is_sync_mode else AsyncSGLangPolicyInference
policy = RolloutModule_cls("policy")
reward = RuleReward("reward")
engine = GRPOEngine(policy, reward, ref_policy, policy_trainer)

if self.cfg.runtime_args.partial_rollout:
rollout_manager = RolloutManager("rollout_manager")
else:
rollout_manager = None

engine = GRPOEngine(policy, reward, ref_policy, policy_trainer, rollout_manager)

# get train and evaluation data
train_data_path_list = [
Expand All @@ -250,7 +273,7 @@ def run(self) -> None:
# put data in engine._all_datasets
engine.set_dataset(train_data)
engine.evaluator.set_dataset(eval_data)
engine.set_replay_sample_manager(compute_grpo_adv)
engine.set_replay_sample_manager(AdvantageComputer(self.cfg.runtime_args.num_inference_per_prompt))
engine.learn()

def validate(self):
Expand Down
46 changes: 23 additions & 23 deletions chatlearn/algorithm/grpo_utils/advantage_compute.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
"""compute advantage for grpo"""
from collections import defaultdict
from typing import List, Dict, Any
from collections import defaultdict

import numpy as np

def compute_grpo_adv(episode_replay_buffers: List[Dict[str, Any]]):
buffers = episode_replay_buffers[-1].buffer
queryids2samples = defaultdict(list)
sample_id = 0
for s in buffers:
s['sample_id'] = sample_id
queryids2samples[hash(",".join(map(str, s["prompt_token_ids"])))].append(s)
sample_id += 1

res_buffers = []
# TODO: torch and numpy have difference result, not knowing consequence
for _, l in queryids2samples.items():
rewards = np.array([each["rule_reward"] for each in l])
mean = np.mean(rewards)
std = np.std(rewards)
class AdvantageComputer:
"""advantage computer"""
def __init__(self, num_inference_per_prompt):
self.rule_reward_buffer = defaultdict(list)
self.num_inference_per_prompt = num_inference_per_prompt

def __call__(self, episode_replay_buffers: List[Dict[str, Any]]):
buffers = episode_replay_buffers[-1].buffer
# Update buffer first
for s in buffers:
sample_id = s['prompt_uid']
self.rule_reward_buffer[sample_id].append(s["rule_reward"])

for li in l:
li["advantages"] = (li["rule_reward"] - mean) / (std + 1e-5)
res_buffers.extend(l)
# Calculate advantage for all samples
for s in buffers:
sample_id = s['prompt_uid']
avg = np.mean(self.rule_reward_buffer[sample_id])
std = np.std(self.rule_reward_buffer[sample_id])
s['advantages'] = (s["rule_reward"] - avg) / (std + 1e-5)

# Sort samples by original order in buffer
res_buffers.sort(key=lambda x: x["sample_id"])
for data in res_buffers:
data.pop("sample_id")
return res_buffers
# clean buffer
self.rule_reward_buffer = {k: v for k, v in self.rule_reward_buffer.items() if len(v) < self.num_inference_per_prompt}
self.rule_reward_buffer = defaultdict(list, self.rule_reward_buffer)
return buffers
8 changes: 4 additions & 4 deletions chatlearn/algorithm/grpo_utils/megatron_policy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
from contextlib import nullcontext
from functools import partial
import itertools
from typing import List, Union, Dict, Any, Sequence
from collections import defaultdict
import numpy as np
Expand Down Expand Up @@ -47,6 +48,7 @@

import chatlearn
from chatlearn import MegatronModule
from chatlearn.utils.utils import even_slice
from chatlearn.runtime.decorator import timeit, compute_decorator, monitor_error
from chatlearn.algorithm.grpo_utils.megatron_utils import (
PolicyModel,
Expand Down Expand Up @@ -393,10 +395,8 @@ def forward_step(self, data: List[Dict[str, Any]], **kwargs) -> List[Dict[str, A
]
# Split by num_train_global_batch first
microbatch_list = []
train_global_batch_size = len(data) // self.num_train_global_batch
for train_batch_id in range(self.num_train_global_batch):
start_idx = train_batch_id * train_global_batch_size
end_idx = (train_batch_id + 1) * train_global_batch_size
slice_index = even_slice(len(data), self.num_train_global_batch)
for start_idx, end_idx in itertools.pairwise(slice_index):
microbatch_list.extend(split_microbatch(data_list=data[start_idx: end_idx], max_train_token=self.module_args.max_token_in_packing, process_group_list=process_group_list, offset=start_idx, packing=self.module_args.packing))
else:
microbatch_list = split_microbatch(data_list=data, micro_batch_size=args.micro_batch_size, packing=self.module_args.packing)
Expand Down
145 changes: 145 additions & 0 deletions chatlearn/algorithm/grpo_utils/rollout_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""Rollout Manager"""
import random
from typing import Dict, List, Any
from collections import defaultdict

import numpy as np
from transformers import AutoTokenizer

from chatlearn.data.prompt_dataset import PromptPipeline
from chatlearn.runtime.decorator import timeit, compute_decorator, monitor_error
from chatlearn import BaseModule

class RolloutManager(BaseModule):
"""Rollout Manager"""
def setup(self):
self._metric_prefix = "rollout_manager"
self.rollout_finished_no_train = defaultdict(list)
self.num_response_track = defaultdict(int)
self.rollout_not_finished = []
self.max_rollout_round = self.module_args.max_rollout_round
self.max_gen_len = self.module_args.max_gen_len
self.ratio = self.module_args.rollout_ratio
self.max_token_per_round = [int(self.max_gen_len * ratio) for ratio in self.ratio]
self.num_inference_per_prompt = self.module_args.num_inference_per_prompt
self.mini_response_per_prompt = self.module_args.mini_response_per_prompt
# Logging metric dict for this module
# It will be append to self._metric_list after logging all metrics
self.metric_dict = {}

def build_dataset(self, prompts: List[Dict], is_eval=False):
# prompts seems like the total data set by engine.set_dataset(dataset)
# TODO: move dataset to seperate node
self.tokenizer = AutoTokenizer.from_pretrained(self.module_args.load, trust_remote_code=True)
prompts_dataset = PromptPipeline(
prompts,
sum(self.max_token_per_round),
self.tokenizer,
enable_thinking=self.module_args.get("enable_thinking", False),
)
return prompts_dataset

def initialze_data(self, data: List[Dict[str, Any]]):
for sample in data:
sample["rollout_round"] = 0
sample["max_generate_token_length"] = self.max_token_per_round[sample["rollout_round"]]
return data

@monitor_error()
@compute_decorator(trainable=False, rollout=False)
@timeit()
def get_sample_for_rollout(self, data: List[Dict[str, Any]], **kwargs): # pylint: disable=unused-argument
# Get sample_per_episode samples from prompts_dataset
# Add these samples into self.rollout_not_finished for future rollout
# Send all samples to rollout engine
data = self.initialze_data(data)
self.rollout_not_finished.extend(data)
train_batch = self.rollout_not_finished
# Record start episode id
for single_data in train_batch:
if "start_episode" not in single_data:
single_data["start_episode"] = self._episode_id
random.shuffle(train_batch)
round_track = {f"round_{i}_samples": 0 for i in range(self.max_rollout_round)}
for d in train_batch:
round_track[f"round_{d['rollout_round']}_samples"] += 1
self.metric_dict.update(round_track)
return train_batch

def is_finished(self, data_b: Dict[str, Any]):
# determine whether the rollout is finished
# rollout finished if
# 1. response_token_lenght is less than this round's max_generate_token_length
# 2. reach max rollout round

return (data_b["response_token_length"] < data_b["max_generate_token_length"]) or \
(data_b["rollout_round"] == self.max_rollout_round)

def find_index_by_uid(self, uid):
idx = next(i for i,d in enumerate(self.rollout_not_finished) if d['uid'] == uid)
return idx

def update_data(self, data: Dict[str, Any], rollout_result: Dict[str, Any]):
# Merge data in self.rollout_not_finished buffer and rollout_result with same uid
assert data["uid"] == rollout_result["uid"]
data.update({
"str_outputs": data.get("str_outputs", "") + rollout_result["str_outputs"],
"rollout_round": rollout_result["rollout_round"],
"response_token_length": data.get("response_token_length", 0) + rollout_result["response_token_length"],
"input_ids": rollout_result["all_tokens"].tolist(),
"all_tokens": rollout_result["all_tokens"],
"max_generate_token_length": self.max_token_per_round[rollout_result["rollout_round"]] \
if rollout_result["rollout_round"] < self.max_rollout_round else 0
})
return data

def logging_generate_by_round(self, rollout_result_list: List[Dict[str, Any]]):
# Logging generate metrics
logging_generate = {f"round_{i}_response": [] for i in range(self.max_rollout_round)}
for data in rollout_result_list:
logging_generate[f"round_{data['rollout_round'] - 1}_response"].append(data["response_token_length"])
update_dict = {}
for key in logging_generate:
arr = np.array(logging_generate[key] or [0])
update_dict[f"{key}_mean"] = np.mean(arr)
update_dict[f"{key}_max"] = np.max(arr)
update_dict[f"{key}_min"] = np.min(arr)
self.metric_dict.update(update_dict)

@monitor_error()
@compute_decorator(trainable=False, rollout=False)
@timeit()
def post_process_rollout_results(self, rollout_result_list: List[Dict[str, Any]], **kwargs): # pylint: disable=unused-argument
self.logging_generate_by_round(rollout_result_list)
unfinished_data = []
for sample in rollout_result_list:
uid = sample["uid"]
prompt_uid = sample["prompt_uid"]
finished = self.is_finished(sample)
data_idx = self.find_index_by_uid(uid)
# Merge data from buffer and data from rollout
data_b = self.update_data(self.rollout_not_finished[data_idx], sample)
if finished:
# Finished, add data to self.rollout_finished_no_train[prompt_uid]
self.rollout_finished_no_train[prompt_uid].append(data_b)
self.num_response_track[prompt_uid] += 1
else:
# If not finished, update data in rollout_not_finished
unfinished_data.append(data_b)
# update remaining data
self.rollout_not_finished = unfinished_data
train_data = []
pop_keys = []
for key, data_list in self.rollout_finished_no_train.items():
if self.num_response_track[key] > self.mini_response_per_prompt:
train_data.extend(data_list)
pop_keys.append(key)
for key in pop_keys:
self.rollout_finished_no_train.pop(key)
if self.num_response_track[key] == self.num_inference_per_prompt:
self.num_response_track.pop(key)
random.shuffle(train_data)
total_train_token = sum(d['response_token_length'] + d['prompt_token_length'] for d in train_data)
self.metric_dict.update({'total_valid_tokens': total_train_token, 'num_train_samples': len(train_data)})
self._metric_list.append(self.metric_dict)
return train_data
20 changes: 13 additions & 7 deletions chatlearn/algorithm/grpo_utils/sglang_policy_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, Dict, List

import torch
import numpy as np
from transformers import AutoTokenizer

from chatlearn.configs import BaseConfig
Expand Down Expand Up @@ -50,18 +51,17 @@ def sglang_postprocess_func(
prompt_token_ids = input_data["input_ids"]
output_tokens = output["output_ids"]
response_token_length = output["meta_info"]["completion_tokens"]
prompt_token_length = output["meta_info"]["prompt_tokens"]
str_outputs = tokenizer.decode(output_tokens, skip_special_tokens=True)
all_tokens = torch.tensor(prompt_token_ids + output_tokens)
input_data.update(
{
"prompt_token_ids": prompt_token_ids,
"all_tokens": all_tokens,
"response_token_length": response_token_length,
"prompt_token_length": prompt_token_length,
"str_outputs": str_outputs,
}
)
if "rollout_round" in input_data:
input_data["rollout_round"] += 1
data_output.append(input_data)

print("str_outputs", data_output[0]["str_outputs"])
Expand All @@ -74,15 +74,20 @@ def metric_collect(rets, seq_length):
# collect metric
response_token_length = [ret["response_token_length"] for ret in rets]
prompt_token_length = [ret["prompt_token_length"] for ret in rets]
seq_len = [
ret["response_token_length"] + ret["prompt_token_length"] for ret in rets
]
clip_ratio = sum(l >= seq_length for l in seq_len) / len(seq_len)
seq_len = seq_length
clip_ratio = sum(
ret["response_token_length"] >= ret.get("max_generate_token_length", seq_len) for ret in rets
) / len(rets)
response_token_length.sort()
inference_stats = {
"response_token_length": sum(response_token_length)
/ len(response_token_length),
"prompt_token_length": sum(prompt_token_length) / len(prompt_token_length),
"response_clip_ratio": clip_ratio,
"response_max": max(response_token_length),
"response_25_percentile": np.percentile(response_token_length, 25),
"response_50_percentile": np.percentile(response_token_length, 50),
"response_75_percentile": np.percentile(response_token_length, 75),
}
return inference_stats

Expand All @@ -91,6 +96,7 @@ class SGLangPolicyInference(SGLangModule):
"""sglang rollout"""

def build_dataset(self, prompts: List[Dict], is_eval=False):
# TODO: move dataset to seperate node
return build_dataset_func(self.module_args, self.tokenizer, prompts, is_eval)

@compute_decorator(trainable=False, rollout=True)
Expand Down
Loading