-
Notifications
You must be signed in to change notification settings - Fork 35
Partial Rollout #350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Partial Rollout #350
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
68590ba
[WIP] Partial Rollout
yytang0204 6869ff1
Support config for rollout
yytang0204 bc8506d
merge main
yytang0204 649bf58
support dynamic rule reward
yytang0204 5a436c0
support sglang for partial rollout; support megatron for dynamic trai…
yytang0204 61bd84f
remove redundant files
yytang0204 2d4fa53
cleanup code
yytang0204 247858f
fix pylint
yytang0204 2c041fb
fix bug
yytang0204 95b91ea
fix pylint; fix sglang
yytang0204 2e3ce5e
fix pylint
yytang0204 a33b8d7
fix adv calculation
yytang0204 d2ddb7e
fix pylint
yytang0204 3a2b03e
fix pylint
yytang0204 ecb1c7e
fix redundant import
yytang0204 caac8d0
remove bash
yytang0204 8b769a2
merge main for cp and vl
yytang0204 50391d9
fix comments
yytang0204 4856ab6
fix pylint
yytang0204 806a93c
fix comments; fix pylint
yytang0204 d739788
Add explain for data dict
yytang0204 ab5e9b5
Add explain for data dict
yytang0204 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,31 +1,31 @@ | ||
"""compute advantage for grpo""" | ||
from collections import defaultdict | ||
from typing import List, Dict, Any | ||
from collections import defaultdict | ||
|
||
import numpy as np | ||
|
||
def compute_grpo_adv(episode_replay_buffers: List[Dict[str, Any]]): | ||
buffers = episode_replay_buffers[-1].buffer | ||
queryids2samples = defaultdict(list) | ||
sample_id = 0 | ||
for s in buffers: | ||
s['sample_id'] = sample_id | ||
queryids2samples[hash(",".join(map(str, s["prompt_token_ids"])))].append(s) | ||
sample_id += 1 | ||
|
||
res_buffers = [] | ||
# TODO: torch and numpy have difference result, not knowing consequence | ||
for _, l in queryids2samples.items(): | ||
rewards = np.array([each["rule_reward"] for each in l]) | ||
mean = np.mean(rewards) | ||
std = np.std(rewards) | ||
class AdvantageComputer: | ||
"""advantage computer""" | ||
def __init__(self, num_inference_per_prompt): | ||
self.rule_reward_buffer = defaultdict(list) | ||
self.num_inference_per_prompt = num_inference_per_prompt | ||
|
||
def __call__(self, episode_replay_buffers: List[Dict[str, Any]]): | ||
buffers = episode_replay_buffers[-1].buffer | ||
# Update buffer first | ||
for s in buffers: | ||
sample_id = s['prompt_uid'] | ||
self.rule_reward_buffer[sample_id].append(s["rule_reward"]) | ||
|
||
for li in l: | ||
li["advantages"] = (li["rule_reward"] - mean) / (std + 1e-5) | ||
res_buffers.extend(l) | ||
# Calculate advantage for all samples | ||
for s in buffers: | ||
sample_id = s['prompt_uid'] | ||
avg = np.mean(self.rule_reward_buffer[sample_id]) | ||
std = np.std(self.rule_reward_buffer[sample_id]) | ||
s['advantages'] = (s["rule_reward"] - avg) / (std + 1e-5) | ||
yytang0204 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Sort samples by original order in buffer | ||
res_buffers.sort(key=lambda x: x["sample_id"]) | ||
for data in res_buffers: | ||
data.pop("sample_id") | ||
return res_buffers | ||
# clean buffer | ||
self.rule_reward_buffer = {k: v for k, v in self.rule_reward_buffer.items() if len(v) < self.num_inference_per_prompt} | ||
self.rule_reward_buffer = defaultdict(list, self.rule_reward_buffer) | ||
return buffers |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
"""Rollout Manager""" | ||
import random | ||
from typing import Dict, List, Any | ||
from collections import defaultdict | ||
|
||
import numpy as np | ||
from transformers import AutoTokenizer | ||
|
||
from chatlearn.data.prompt_dataset import PromptPipeline | ||
from chatlearn.runtime.decorator import timeit, compute_decorator, monitor_error | ||
from chatlearn import BaseModule | ||
|
||
class RolloutManager(BaseModule): | ||
"""Rollout Manager""" | ||
def setup(self): | ||
self._metric_prefix = "rollout_manager" | ||
self.rollout_finished_no_train = defaultdict(list) | ||
self.num_response_track = defaultdict(int) | ||
self.rollout_not_finished = [] | ||
yytang0204 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.max_rollout_round = self.module_args.max_rollout_round | ||
self.max_gen_len = self.module_args.max_gen_len | ||
self.ratio = self.module_args.rollout_ratio | ||
self.max_token_per_round = [int(self.max_gen_len * ratio) for ratio in self.ratio] | ||
self.num_inference_per_prompt = self.module_args.num_inference_per_prompt | ||
self.mini_response_per_prompt = self.module_args.mini_response_per_prompt | ||
# Logging metric dict for this module | ||
# It will be append to self._metric_list after logging all metrics | ||
self.metric_dict = {} | ||
yytang0204 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def build_dataset(self, prompts: List[Dict], is_eval=False): | ||
# prompts seems like the total data set by engine.set_dataset(dataset) | ||
# TODO: move dataset to seperate node | ||
self.tokenizer = AutoTokenizer.from_pretrained(self.module_args.load, trust_remote_code=True) | ||
prompts_dataset = PromptPipeline( | ||
prompts, | ||
sum(self.max_token_per_round), | ||
self.tokenizer, | ||
enable_thinking=self.module_args.get("enable_thinking", False), | ||
) | ||
return prompts_dataset | ||
lostkevin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def initialze_data(self, data: List[Dict[str, Any]]): | ||
for sample in data: | ||
sample["rollout_round"] = 0 | ||
sample["max_generate_token_length"] = self.max_token_per_round[sample["rollout_round"]] | ||
return data | ||
|
||
@monitor_error() | ||
@compute_decorator(trainable=False, rollout=False) | ||
@timeit() | ||
def get_sample_for_rollout(self, data: List[Dict[str, Any]], **kwargs): # pylint: disable=unused-argument | ||
# Get sample_per_episode samples from prompts_dataset | ||
# Add these samples into self.rollout_not_finished for future rollout | ||
# Send all samples to rollout engine | ||
data = self.initialze_data(data) | ||
self.rollout_not_finished.extend(data) | ||
train_batch = self.rollout_not_finished | ||
# Record start episode id | ||
for single_data in train_batch: | ||
if "start_episode" not in single_data: | ||
single_data["start_episode"] = self._episode_id | ||
random.shuffle(train_batch) | ||
round_track = {f"round_{i}_samples": 0 for i in range(self.max_rollout_round)} | ||
for d in train_batch: | ||
round_track[f"round_{d['rollout_round']}_samples"] += 1 | ||
self.metric_dict.update(round_track) | ||
return train_batch | ||
|
||
def is_finished(self, data_b: Dict[str, Any]): | ||
# determine whether the rollout is finished | ||
# rollout finished if | ||
# 1. response_token_lenght is less than this round's max_generate_token_length | ||
# 2. reach max rollout round | ||
|
||
return (data_b["response_token_length"] < data_b["max_generate_token_length"]) or \ | ||
(data_b["rollout_round"] == self.max_rollout_round) | ||
|
||
def find_index_by_uid(self, uid): | ||
idx = next(i for i,d in enumerate(self.rollout_not_finished) if d['uid'] == uid) | ||
return idx | ||
|
||
def update_data(self, data: Dict[str, Any], rollout_result: Dict[str, Any]): | ||
# Merge data in self.rollout_not_finished buffer and rollout_result with same uid | ||
assert data["uid"] == rollout_result["uid"] | ||
data.update({ | ||
"str_outputs": data.get("str_outputs", "") + rollout_result["str_outputs"], | ||
"rollout_round": rollout_result["rollout_round"], | ||
"response_token_length": data.get("response_token_length", 0) + rollout_result["response_token_length"], | ||
"input_ids": rollout_result["all_tokens"].tolist(), | ||
"all_tokens": rollout_result["all_tokens"], | ||
"max_generate_token_length": self.max_token_per_round[rollout_result["rollout_round"]] \ | ||
if rollout_result["rollout_round"] < self.max_rollout_round else 0 | ||
}) | ||
return data | ||
|
||
def logging_generate_by_round(self, rollout_result_list: List[Dict[str, Any]]): | ||
# Logging generate metrics | ||
logging_generate = {f"round_{i}_response": [] for i in range(self.max_rollout_round)} | ||
for data in rollout_result_list: | ||
logging_generate[f"round_{data['rollout_round'] - 1}_response"].append(data["response_token_length"]) | ||
update_dict = {} | ||
for key in logging_generate: | ||
arr = np.array(logging_generate[key] or [0]) | ||
update_dict[f"{key}_mean"] = np.mean(arr) | ||
update_dict[f"{key}_max"] = np.max(arr) | ||
update_dict[f"{key}_min"] = np.min(arr) | ||
self.metric_dict.update(update_dict) | ||
|
||
@monitor_error() | ||
@compute_decorator(trainable=False, rollout=False) | ||
@timeit() | ||
def post_process_rollout_results(self, rollout_result_list: List[Dict[str, Any]], **kwargs): # pylint: disable=unused-argument | ||
self.logging_generate_by_round(rollout_result_list) | ||
unfinished_data = [] | ||
for sample in rollout_result_list: | ||
uid = sample["uid"] | ||
prompt_uid = sample["prompt_uid"] | ||
finished = self.is_finished(sample) | ||
data_idx = self.find_index_by_uid(uid) | ||
# Merge data from buffer and data from rollout | ||
data_b = self.update_data(self.rollout_not_finished[data_idx], sample) | ||
if finished: | ||
# Finished, add data to self.rollout_finished_no_train[prompt_uid] | ||
self.rollout_finished_no_train[prompt_uid].append(data_b) | ||
self.num_response_track[prompt_uid] += 1 | ||
else: | ||
# If not finished, update data in rollout_not_finished | ||
unfinished_data.append(data_b) | ||
# update remaining data | ||
self.rollout_not_finished = unfinished_data | ||
train_data = [] | ||
pop_keys = [] | ||
for key, data_list in self.rollout_finished_no_train.items(): | ||
if self.num_response_track[key] > self.mini_response_per_prompt: | ||
train_data.extend(data_list) | ||
pop_keys.append(key) | ||
for key in pop_keys: | ||
self.rollout_finished_no_train.pop(key) | ||
if self.num_response_track[key] == self.num_inference_per_prompt: | ||
self.num_response_track.pop(key) | ||
random.shuffle(train_data) | ||
total_train_token = sum(d['response_token_length'] + d['prompt_token_length'] for d in train_data) | ||
self.metric_dict.update({'total_valid_tokens': total_train_token, 'num_train_samples': len(train_data)}) | ||
self._metric_list.append(self.metric_dict) | ||
return train_data |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.