Skip to content

Commit 4f7051f

Browse files
authored
Partial Rollout (#350)
* Support Partial Rollout for all backends * Support dynamic train batch size * Using OmegaConf to parse command line inputs
1 parent 5c8c1a6 commit 4f7051f

23 files changed

+492
-192
lines changed

chatlearn/algorithm/grpo.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,17 @@
2828
RewardConfig,
2929
PolicyConfig,
3030
RuntimeConfig,
31-
RuntimeEnvConfig
31+
RuntimeEnvConfig,
32+
RolloutManagerConfig
3233
)
3334
from chatlearn.configs.fsdp_config import FSDPPolicyTrainerConfig, FSDPRefPolicyConfig
3435

35-
from chatlearn.algorithm.grpo_utils.advantage_compute import compute_grpo_adv
36+
from chatlearn.algorithm.grpo_utils.advantage_compute import AdvantageComputer
3637
from chatlearn.algorithm.grpo_utils.policy_trainer import PolicyTrainer
3738
from chatlearn.algorithm.grpo_utils.vllm_policy_inference import \
3839
VLLMPolicyInference
3940
from chatlearn.algorithm.grpo_utils.sglang_policy_inference import SGLangPolicyInference, AsyncSGLangPolicyInference
41+
from chatlearn.algorithm.grpo_utils.rollout_manager import RolloutManager
4042
from chatlearn.data.data import read_data_path_list
4143
from chatlearn.models.reward.rule_reward import RuleReward
4244
from chatlearn.runtime.environment import Environment
@@ -65,6 +67,9 @@ class GrpoModelConfig(BaseConfig):
6567
reward: RewardConfig = field(
6668
default_factory=RewardConfig, metadata={"help": "Reward config."}
6769
)
70+
rollout_manager: RolloutManagerConfig = field(
71+
default=RolloutManagerConfig, metadata={"help": "Rollout manager config. Only useful when partial_rollout is enabled"}
72+
)
6873
ref_policy: Any = field(
6974
default=None,
7075
metadata={
@@ -186,6 +191,7 @@ def __init__(
186191
reward: RuleReward,
187192
ref_policy: PolicyTrainer,
188193
policy_trainer: PolicyTrainer,
194+
rollout_manager: RolloutManager = None
189195
):
190196
def env_compute_flow(batch):
191197
policy_out = policy.forward_step(batch)
@@ -194,6 +200,15 @@ def env_compute_flow(batch):
194200
reward_out = reward.forward_step(ref_logprobs_out)
195201
return reward_out
196202

203+
def env_compute_flow_with_partial(batch):
204+
batch = rollout_manager.get_sample_for_rollout(batch)
205+
batch = policy.forward_step(batch)
206+
batch = rollout_manager.post_process_rollout_results(batch)
207+
old_logprobs_out = policy_trainer.forward_step(batch)
208+
ref_logprobs_out = ref_policy.forward_step(old_logprobs_out)
209+
reward_out = reward.forward_step(ref_logprobs_out)
210+
return reward_out
211+
197212
def trainer_compute_flow(batch):
198213
policy_trainer.train_step(batch)
199214

@@ -202,15 +217,17 @@ def evaluator_flow(batch):
202217
reward_out = reward.eval_forward(policy_out)
203218
return reward_out
204219

205-
env = Environment(env_compute_flow)
220+
if rollout_manager is None:
221+
env = Environment(env_compute_flow)
222+
else:
223+
env = Environment(env_compute_flow_with_partial)
206224
trainer = Trainer(trainer_compute_flow)
207225
evaluator = GRPOEvaluator(evaluator_flow)
208226
super().__init__(
209227
environment=env, trainer=trainer, evaluator=evaluator, name="grpo"
210228
)
211229
self.set_parameter_sync(policy_trainer, policy)
212230

213-
214231
class GrpoAlgorithm(BaseAlgorithm):
215232
"""GrpoAlgorithm"""
216233

@@ -235,7 +252,13 @@ def run(self) -> None:
235252
RolloutModule_cls = SGLangPolicyInference if self.cfg.models.policy.is_sync_mode else AsyncSGLangPolicyInference
236253
policy = RolloutModule_cls("policy")
237254
reward = RuleReward("reward")
238-
engine = GRPOEngine(policy, reward, ref_policy, policy_trainer)
255+
256+
if self.cfg.runtime_args.partial_rollout:
257+
rollout_manager = RolloutManager("rollout_manager")
258+
else:
259+
rollout_manager = None
260+
261+
engine = GRPOEngine(policy, reward, ref_policy, policy_trainer, rollout_manager)
239262

240263
# get train and evaluation data
241264
train_data_path_list = [
@@ -250,7 +273,7 @@ def run(self) -> None:
250273
# put data in engine._all_datasets
251274
engine.set_dataset(train_data)
252275
engine.evaluator.set_dataset(eval_data)
253-
engine.set_replay_sample_manager(compute_grpo_adv)
276+
engine.set_replay_sample_manager(AdvantageComputer(self.cfg.runtime_args.num_inference_per_prompt))
254277
engine.learn()
255278

256279
def validate(self):
Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,31 @@
11
"""compute advantage for grpo"""
2-
from collections import defaultdict
32
from typing import List, Dict, Any
3+
from collections import defaultdict
44

55
import numpy as np
66

7-
def compute_grpo_adv(episode_replay_buffers: List[Dict[str, Any]]):
8-
buffers = episode_replay_buffers[-1].buffer
9-
queryids2samples = defaultdict(list)
10-
sample_id = 0
11-
for s in buffers:
12-
s['sample_id'] = sample_id
13-
queryids2samples[hash(",".join(map(str, s["prompt_token_ids"])))].append(s)
14-
sample_id += 1
157

16-
res_buffers = []
17-
# TODO: torch and numpy have difference result, not knowing consequence
18-
for _, l in queryids2samples.items():
19-
rewards = np.array([each["rule_reward"] for each in l])
20-
mean = np.mean(rewards)
21-
std = np.std(rewards)
8+
class AdvantageComputer:
9+
"""advantage computer"""
10+
def __init__(self, num_inference_per_prompt):
11+
self.rule_reward_buffer = defaultdict(list)
12+
self.num_inference_per_prompt = num_inference_per_prompt
13+
14+
def __call__(self, episode_replay_buffers: List[Dict[str, Any]]):
15+
buffers = episode_replay_buffers[-1].buffer
16+
# Update buffer first
17+
for s in buffers:
18+
sample_id = s['prompt_uid']
19+
self.rule_reward_buffer[sample_id].append(s["rule_reward"])
2220

23-
for li in l:
24-
li["advantages"] = (li["rule_reward"] - mean) / (std + 1e-5)
25-
res_buffers.extend(l)
21+
# Calculate advantage for all samples
22+
for s in buffers:
23+
sample_id = s['prompt_uid']
24+
avg = np.mean(self.rule_reward_buffer[sample_id])
25+
std = np.std(self.rule_reward_buffer[sample_id])
26+
s['advantages'] = (s["rule_reward"] - avg) / (std + 1e-5)
2627

27-
# Sort samples by original order in buffer
28-
res_buffers.sort(key=lambda x: x["sample_id"])
29-
for data in res_buffers:
30-
data.pop("sample_id")
31-
return res_buffers
28+
# clean buffer
29+
self.rule_reward_buffer = {k: v for k, v in self.rule_reward_buffer.items() if len(v) < self.num_inference_per_prompt}
30+
self.rule_reward_buffer = defaultdict(list, self.rule_reward_buffer)
31+
return buffers

chatlearn/algorithm/grpo_utils/megatron_policy_trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
from contextlib import nullcontext
1919
from functools import partial
20+
import itertools
2021
from typing import List, Union, Dict, Any, Sequence
2122
from collections import defaultdict
2223
import numpy as np
@@ -47,6 +48,7 @@
4748

4849
import chatlearn
4950
from chatlearn import MegatronModule
51+
from chatlearn.utils.utils import even_slice
5052
from chatlearn.runtime.decorator import timeit, compute_decorator, monitor_error
5153
from chatlearn.algorithm.grpo_utils.megatron_utils import (
5254
PolicyModel,
@@ -393,10 +395,8 @@ def forward_step(self, data: List[Dict[str, Any]], **kwargs) -> List[Dict[str, A
393395
]
394396
# Split by num_train_global_batch first
395397
microbatch_list = []
396-
train_global_batch_size = len(data) // self.num_train_global_batch
397-
for train_batch_id in range(self.num_train_global_batch):
398-
start_idx = train_batch_id * train_global_batch_size
399-
end_idx = (train_batch_id + 1) * train_global_batch_size
398+
slice_index = even_slice(len(data), self.num_train_global_batch)
399+
for start_idx, end_idx in itertools.pairwise(slice_index):
400400
microbatch_list.extend(split_microbatch(data_list=data[start_idx: end_idx], max_train_token=self.module_args.max_token_in_packing, process_group_list=process_group_list, offset=start_idx, packing=self.module_args.packing))
401401
else:
402402
microbatch_list = split_microbatch(data_list=data, micro_batch_size=args.micro_batch_size, packing=self.module_args.packing)
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""Rollout Manager"""
2+
import random
3+
from typing import Dict, List, Any
4+
from collections import defaultdict
5+
6+
import numpy as np
7+
from transformers import AutoTokenizer
8+
9+
from chatlearn.data.prompt_dataset import PromptPipeline
10+
from chatlearn.runtime.decorator import timeit, compute_decorator, monitor_error
11+
from chatlearn import BaseModule
12+
13+
class RolloutManager(BaseModule):
14+
"""Rollout Manager"""
15+
def setup(self):
16+
self._metric_prefix = "rollout_manager"
17+
self.rollout_finished_no_train = defaultdict(list)
18+
self.num_response_track = defaultdict(int)
19+
self.rollout_not_finished = []
20+
self.max_rollout_round = self.module_args.max_rollout_round
21+
self.max_gen_len = self.module_args.max_gen_len
22+
self.ratio = self.module_args.rollout_ratio
23+
self.max_token_per_round = [int(self.max_gen_len * ratio) for ratio in self.ratio]
24+
self.num_inference_per_prompt = self.module_args.num_inference_per_prompt
25+
self.mini_response_per_prompt = self.module_args.mini_response_per_prompt
26+
# Logging metric dict for this module
27+
# It will be append to self._metric_list after logging all metrics
28+
self.metric_dict = {}
29+
30+
def build_dataset(self, prompts: List[Dict], is_eval=False):
31+
# prompts seems like the total data set by engine.set_dataset(dataset)
32+
# TODO: move dataset to seperate node
33+
self.tokenizer = AutoTokenizer.from_pretrained(self.module_args.load, trust_remote_code=True)
34+
prompts_dataset = PromptPipeline(
35+
prompts,
36+
sum(self.max_token_per_round),
37+
self.tokenizer,
38+
enable_thinking=self.module_args.get("enable_thinking", False),
39+
)
40+
return prompts_dataset
41+
42+
def initialze_data(self, data: List[Dict[str, Any]]):
43+
for sample in data:
44+
sample["rollout_round"] = 0
45+
sample["max_generate_token_length"] = self.max_token_per_round[sample["rollout_round"]]
46+
return data
47+
48+
@monitor_error()
49+
@compute_decorator(trainable=False, rollout=False)
50+
@timeit()
51+
def get_sample_for_rollout(self, data: List[Dict[str, Any]], **kwargs): # pylint: disable=unused-argument
52+
# Get sample_per_episode samples from prompts_dataset
53+
# Add these samples into self.rollout_not_finished for future rollout
54+
# Send all samples to rollout engine
55+
data = self.initialze_data(data)
56+
self.rollout_not_finished.extend(data)
57+
train_batch = self.rollout_not_finished
58+
# Record start episode id
59+
for single_data in train_batch:
60+
if "start_episode" not in single_data:
61+
single_data["start_episode"] = self._episode_id
62+
random.shuffle(train_batch)
63+
round_track = {f"round_{i}_samples": 0 for i in range(self.max_rollout_round)}
64+
for d in train_batch:
65+
round_track[f"round_{d['rollout_round']}_samples"] += 1
66+
self.metric_dict.update(round_track)
67+
return train_batch
68+
69+
def is_finished(self, data_b: Dict[str, Any]):
70+
# determine whether the rollout is finished
71+
# rollout finished if
72+
# 1. response_token_lenght is less than this round's max_generate_token_length
73+
# 2. reach max rollout round
74+
75+
return (data_b["response_token_length"] < data_b["max_generate_token_length"]) or \
76+
(data_b["rollout_round"] == self.max_rollout_round)
77+
78+
def find_index_by_uid(self, uid):
79+
idx = next(i for i,d in enumerate(self.rollout_not_finished) if d['uid'] == uid)
80+
return idx
81+
82+
def update_data(self, data: Dict[str, Any], rollout_result: Dict[str, Any]):
83+
# Merge data in self.rollout_not_finished buffer and rollout_result with same uid
84+
assert data["uid"] == rollout_result["uid"]
85+
data.update({
86+
"str_outputs": data.get("str_outputs", "") + rollout_result["str_outputs"],
87+
"rollout_round": rollout_result["rollout_round"],
88+
"response_token_length": data.get("response_token_length", 0) + rollout_result["response_token_length"],
89+
"input_ids": rollout_result["all_tokens"].tolist(),
90+
"all_tokens": rollout_result["all_tokens"],
91+
"max_generate_token_length": self.max_token_per_round[rollout_result["rollout_round"]] \
92+
if rollout_result["rollout_round"] < self.max_rollout_round else 0
93+
})
94+
return data
95+
96+
def logging_generate_by_round(self, rollout_result_list: List[Dict[str, Any]]):
97+
# Logging generate metrics
98+
logging_generate = {f"round_{i}_response": [] for i in range(self.max_rollout_round)}
99+
for data in rollout_result_list:
100+
logging_generate[f"round_{data['rollout_round'] - 1}_response"].append(data["response_token_length"])
101+
update_dict = {}
102+
for key in logging_generate:
103+
arr = np.array(logging_generate[key] or [0])
104+
update_dict[f"{key}_mean"] = np.mean(arr)
105+
update_dict[f"{key}_max"] = np.max(arr)
106+
update_dict[f"{key}_min"] = np.min(arr)
107+
self.metric_dict.update(update_dict)
108+
109+
@monitor_error()
110+
@compute_decorator(trainable=False, rollout=False)
111+
@timeit()
112+
def post_process_rollout_results(self, rollout_result_list: List[Dict[str, Any]], **kwargs): # pylint: disable=unused-argument
113+
self.logging_generate_by_round(rollout_result_list)
114+
unfinished_data = []
115+
for sample in rollout_result_list:
116+
uid = sample["uid"]
117+
prompt_uid = sample["prompt_uid"]
118+
finished = self.is_finished(sample)
119+
data_idx = self.find_index_by_uid(uid)
120+
# Merge data from buffer and data from rollout
121+
data_b = self.update_data(self.rollout_not_finished[data_idx], sample)
122+
if finished:
123+
# Finished, add data to self.rollout_finished_no_train[prompt_uid]
124+
self.rollout_finished_no_train[prompt_uid].append(data_b)
125+
self.num_response_track[prompt_uid] += 1
126+
else:
127+
# If not finished, update data in rollout_not_finished
128+
unfinished_data.append(data_b)
129+
# update remaining data
130+
self.rollout_not_finished = unfinished_data
131+
train_data = []
132+
pop_keys = []
133+
for key, data_list in self.rollout_finished_no_train.items():
134+
if self.num_response_track[key] > self.mini_response_per_prompt:
135+
train_data.extend(data_list)
136+
pop_keys.append(key)
137+
for key in pop_keys:
138+
self.rollout_finished_no_train.pop(key)
139+
if self.num_response_track[key] == self.num_inference_per_prompt:
140+
self.num_response_track.pop(key)
141+
random.shuffle(train_data)
142+
total_train_token = sum(d['response_token_length'] + d['prompt_token_length'] for d in train_data)
143+
self.metric_dict.update({'total_valid_tokens': total_train_token, 'num_train_samples': len(train_data)})
144+
self._metric_list.append(self.metric_dict)
145+
return train_data

chatlearn/algorithm/grpo_utils/sglang_policy_inference.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Any, Dict, List
1818

1919
import torch
20+
import numpy as np
2021
from transformers import AutoTokenizer
2122

2223
from chatlearn.configs import BaseConfig
@@ -50,18 +51,17 @@ def sglang_postprocess_func(
5051
prompt_token_ids = input_data["input_ids"]
5152
output_tokens = output["output_ids"]
5253
response_token_length = output["meta_info"]["completion_tokens"]
53-
prompt_token_length = output["meta_info"]["prompt_tokens"]
5454
str_outputs = tokenizer.decode(output_tokens, skip_special_tokens=True)
5555
all_tokens = torch.tensor(prompt_token_ids + output_tokens)
5656
input_data.update(
5757
{
58-
"prompt_token_ids": prompt_token_ids,
5958
"all_tokens": all_tokens,
6059
"response_token_length": response_token_length,
61-
"prompt_token_length": prompt_token_length,
6260
"str_outputs": str_outputs,
6361
}
6462
)
63+
if "rollout_round" in input_data:
64+
input_data["rollout_round"] += 1
6565
data_output.append(input_data)
6666

6767
print("str_outputs", data_output[0]["str_outputs"])
@@ -74,15 +74,20 @@ def metric_collect(rets, seq_length):
7474
# collect metric
7575
response_token_length = [ret["response_token_length"] for ret in rets]
7676
prompt_token_length = [ret["prompt_token_length"] for ret in rets]
77-
seq_len = [
78-
ret["response_token_length"] + ret["prompt_token_length"] for ret in rets
79-
]
80-
clip_ratio = sum(l >= seq_length for l in seq_len) / len(seq_len)
77+
seq_len = seq_length
78+
clip_ratio = sum(
79+
ret["response_token_length"] >= ret.get("max_generate_token_length", seq_len) for ret in rets
80+
) / len(rets)
81+
response_token_length.sort()
8182
inference_stats = {
8283
"response_token_length": sum(response_token_length)
8384
/ len(response_token_length),
8485
"prompt_token_length": sum(prompt_token_length) / len(prompt_token_length),
8586
"response_clip_ratio": clip_ratio,
87+
"response_max": max(response_token_length),
88+
"response_25_percentile": np.percentile(response_token_length, 25),
89+
"response_50_percentile": np.percentile(response_token_length, 50),
90+
"response_75_percentile": np.percentile(response_token_length, 75),
8691
}
8792
return inference_stats
8893

@@ -91,6 +96,7 @@ class SGLangPolicyInference(SGLangModule):
9196
"""sglang rollout"""
9297

9398
def build_dataset(self, prompts: List[Dict], is_eval=False):
99+
# TODO: move dataset to seperate node
94100
return build_dataset_func(self.module_args, self.tokenizer, prompts, is_eval)
95101

96102
@compute_decorator(trainable=False, rollout=True)

0 commit comments

Comments
 (0)