|
| 1 | +"""Rollout Manager""" |
| 2 | +import random |
| 3 | +from typing import Dict, List, Any |
| 4 | +from collections import defaultdict |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +from transformers import AutoTokenizer |
| 8 | + |
| 9 | +from chatlearn.data.prompt_dataset import PromptPipeline |
| 10 | +from chatlearn.runtime.decorator import timeit, compute_decorator, monitor_error |
| 11 | +from chatlearn import BaseModule |
| 12 | + |
| 13 | +class RolloutManager(BaseModule): |
| 14 | + """Rollout Manager""" |
| 15 | + def setup(self): |
| 16 | + self._metric_prefix = "rollout_manager" |
| 17 | + self.rollout_finished_no_train = defaultdict(list) |
| 18 | + self.num_response_track = defaultdict(int) |
| 19 | + self.rollout_not_finished = [] |
| 20 | + self.max_rollout_round = self.module_args.max_rollout_round |
| 21 | + self.max_gen_len = self.module_args.max_gen_len |
| 22 | + self.ratio = self.module_args.rollout_ratio |
| 23 | + self.max_token_per_round = [int(self.max_gen_len * ratio) for ratio in self.ratio] |
| 24 | + self.num_inference_per_prompt = self.module_args.num_inference_per_prompt |
| 25 | + self.mini_response_per_prompt = self.module_args.mini_response_per_prompt |
| 26 | + # Logging metric dict for this module |
| 27 | + # It will be append to self._metric_list after logging all metrics |
| 28 | + self.metric_dict = {} |
| 29 | + |
| 30 | + def build_dataset(self, prompts: List[Dict], is_eval=False): |
| 31 | + # prompts seems like the total data set by engine.set_dataset(dataset) |
| 32 | + # TODO: move dataset to seperate node |
| 33 | + self.tokenizer = AutoTokenizer.from_pretrained(self.module_args.load, trust_remote_code=True) |
| 34 | + prompts_dataset = PromptPipeline( |
| 35 | + prompts, |
| 36 | + sum(self.max_token_per_round), |
| 37 | + self.tokenizer, |
| 38 | + enable_thinking=self.module_args.get("enable_thinking", False), |
| 39 | + ) |
| 40 | + return prompts_dataset |
| 41 | + |
| 42 | + def initialze_data(self, data: List[Dict[str, Any]]): |
| 43 | + for sample in data: |
| 44 | + sample["rollout_round"] = 0 |
| 45 | + sample["max_generate_token_length"] = self.max_token_per_round[sample["rollout_round"]] |
| 46 | + return data |
| 47 | + |
| 48 | + @monitor_error() |
| 49 | + @compute_decorator(trainable=False, rollout=False) |
| 50 | + @timeit() |
| 51 | + def get_sample_for_rollout(self, data: List[Dict[str, Any]], **kwargs): # pylint: disable=unused-argument |
| 52 | + # Get sample_per_episode samples from prompts_dataset |
| 53 | + # Add these samples into self.rollout_not_finished for future rollout |
| 54 | + # Send all samples to rollout engine |
| 55 | + data = self.initialze_data(data) |
| 56 | + self.rollout_not_finished.extend(data) |
| 57 | + train_batch = self.rollout_not_finished |
| 58 | + # Record start episode id |
| 59 | + for single_data in train_batch: |
| 60 | + if "start_episode" not in single_data: |
| 61 | + single_data["start_episode"] = self._episode_id |
| 62 | + random.shuffle(train_batch) |
| 63 | + round_track = {f"round_{i}_samples": 0 for i in range(self.max_rollout_round)} |
| 64 | + for d in train_batch: |
| 65 | + round_track[f"round_{d['rollout_round']}_samples"] += 1 |
| 66 | + self.metric_dict.update(round_track) |
| 67 | + return train_batch |
| 68 | + |
| 69 | + def is_finished(self, data_b: Dict[str, Any]): |
| 70 | + # determine whether the rollout is finished |
| 71 | + # rollout finished if |
| 72 | + # 1. response_token_lenght is less than this round's max_generate_token_length |
| 73 | + # 2. reach max rollout round |
| 74 | + |
| 75 | + return (data_b["response_token_length"] < data_b["max_generate_token_length"]) or \ |
| 76 | + (data_b["rollout_round"] == self.max_rollout_round) |
| 77 | + |
| 78 | + def find_index_by_uid(self, uid): |
| 79 | + idx = next(i for i,d in enumerate(self.rollout_not_finished) if d['uid'] == uid) |
| 80 | + return idx |
| 81 | + |
| 82 | + def update_data(self, data: Dict[str, Any], rollout_result: Dict[str, Any]): |
| 83 | + # Merge data in self.rollout_not_finished buffer and rollout_result with same uid |
| 84 | + assert data["uid"] == rollout_result["uid"] |
| 85 | + data.update({ |
| 86 | + "str_outputs": data.get("str_outputs", "") + rollout_result["str_outputs"], |
| 87 | + "rollout_round": rollout_result["rollout_round"], |
| 88 | + "response_token_length": data.get("response_token_length", 0) + rollout_result["response_token_length"], |
| 89 | + "input_ids": rollout_result["all_tokens"].tolist(), |
| 90 | + "all_tokens": rollout_result["all_tokens"], |
| 91 | + "max_generate_token_length": self.max_token_per_round[rollout_result["rollout_round"]] \ |
| 92 | + if rollout_result["rollout_round"] < self.max_rollout_round else 0 |
| 93 | + }) |
| 94 | + return data |
| 95 | + |
| 96 | + def logging_generate_by_round(self, rollout_result_list: List[Dict[str, Any]]): |
| 97 | + # Logging generate metrics |
| 98 | + logging_generate = {f"round_{i}_response": [] for i in range(self.max_rollout_round)} |
| 99 | + for data in rollout_result_list: |
| 100 | + logging_generate[f"round_{data['rollout_round'] - 1}_response"].append(data["response_token_length"]) |
| 101 | + update_dict = {} |
| 102 | + for key in logging_generate: |
| 103 | + arr = np.array(logging_generate[key] or [0]) |
| 104 | + update_dict[f"{key}_mean"] = np.mean(arr) |
| 105 | + update_dict[f"{key}_max"] = np.max(arr) |
| 106 | + update_dict[f"{key}_min"] = np.min(arr) |
| 107 | + self.metric_dict.update(update_dict) |
| 108 | + |
| 109 | + @monitor_error() |
| 110 | + @compute_decorator(trainable=False, rollout=False) |
| 111 | + @timeit() |
| 112 | + def post_process_rollout_results(self, rollout_result_list: List[Dict[str, Any]], **kwargs): # pylint: disable=unused-argument |
| 113 | + self.logging_generate_by_round(rollout_result_list) |
| 114 | + unfinished_data = [] |
| 115 | + for sample in rollout_result_list: |
| 116 | + uid = sample["uid"] |
| 117 | + prompt_uid = sample["prompt_uid"] |
| 118 | + finished = self.is_finished(sample) |
| 119 | + data_idx = self.find_index_by_uid(uid) |
| 120 | + # Merge data from buffer and data from rollout |
| 121 | + data_b = self.update_data(self.rollout_not_finished[data_idx], sample) |
| 122 | + if finished: |
| 123 | + # Finished, add data to self.rollout_finished_no_train[prompt_uid] |
| 124 | + self.rollout_finished_no_train[prompt_uid].append(data_b) |
| 125 | + self.num_response_track[prompt_uid] += 1 |
| 126 | + else: |
| 127 | + # If not finished, update data in rollout_not_finished |
| 128 | + unfinished_data.append(data_b) |
| 129 | + # update remaining data |
| 130 | + self.rollout_not_finished = unfinished_data |
| 131 | + train_data = [] |
| 132 | + pop_keys = [] |
| 133 | + for key, data_list in self.rollout_finished_no_train.items(): |
| 134 | + if self.num_response_track[key] > self.mini_response_per_prompt: |
| 135 | + train_data.extend(data_list) |
| 136 | + pop_keys.append(key) |
| 137 | + for key in pop_keys: |
| 138 | + self.rollout_finished_no_train.pop(key) |
| 139 | + if self.num_response_track[key] == self.num_inference_per_prompt: |
| 140 | + self.num_response_track.pop(key) |
| 141 | + random.shuffle(train_data) |
| 142 | + total_train_token = sum(d['response_token_length'] + d['prompt_token_length'] for d in train_data) |
| 143 | + self.metric_dict.update({'total_valid_tokens': total_train_token, 'num_train_samples': len(train_data)}) |
| 144 | + self._metric_list.append(self.metric_dict) |
| 145 | + return train_data |
0 commit comments