-
Notifications
You must be signed in to change notification settings - Fork 2.3k
🧺 [1/N] Refactor _generate in GRPO/RLOO: list of ints instead of tensors
#4146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 55 commits
552e899
449ef07
c8933aa
229c554
3ca6ad5
dcf4b92
30ad7ca
86cc30b
088897b
d2adc63
f4c82bf
1257796
099a39b
529add6
fc6b11f
ae1f497
f998432
fa73876
52d8bd9
dfc0d38
fc52e68
4d12aeb
4fc2b5b
b628744
d3a769f
e17ec42
efbb03a
562c662
485781c
05270f8
1c53094
9b6652e
c500440
a6a8c44
d8665e1
365d501
cdb4c76
c83e710
ec6ad25
b4cadde
b0dceb9
ebe32c2
0213662
8b3a724
c1ae6aa
1a66b43
2dc69a6
9435a94
d3f1d3c
3d8ea27
27dc958
53772ef
8766fa5
236b78b
9da4830
b3bd0b0
8d34d54
55a2480
c5064d6
effb41b
e82bfb4
3a0ba92
c5fa2df
c570fb0
2f70440
80b7403
84f400c
c72f54a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1070,9 +1070,8 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): | |
| rewards_per_func = gather(rewards_per_func) | ||
| return rewards_per_func | ||
|
|
||
| def _generate(self, prompts: list[str], images: Optional[list]): | ||
| def _generate_single_turn(self, prompts: list[str], images: Optional[list]): | ||
| device = self.accelerator.device | ||
| mode = "train" if self.model.training else "eval" | ||
|
|
||
| # If the prompts are conversational and the inputs contain images, we need to convert the prompts from | ||
| # [{"role": "user", "content": "What color is the sky?"}] to | ||
|
|
@@ -1088,15 +1087,7 @@ def _generate(self, prompts: list[str], images: Optional[list]): | |
| maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts | ||
| ] | ||
|
|
||
| prompt_inputs = self.processing_class( | ||
| text=prompts_text, | ||
| return_tensors="pt", | ||
| padding=True, | ||
| padding_side="left", | ||
| add_special_tokens=False, | ||
| **kwargs, | ||
| ) | ||
| prompt_inputs = super()._prepare_inputs(prompt_inputs) | ||
| prompt_inputs = self.processing_class(text=prompts_text, add_special_tokens=False, **kwargs) | ||
| prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] | ||
| forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} | ||
|
|
||
|
|
@@ -1192,14 +1183,14 @@ def _generate(self, prompts: list[str], images: Optional[list]): | |
| # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice. | ||
| obj_list = [payload] | ||
| broadcast_object_list(obj_list, from_process=0) | ||
| completion_ids, all_logprobs = obj_list[0] | ||
| all_completion_ids, all_logprobs = obj_list[0] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. more consistent naming |
||
|
|
||
| process_slice = slice( | ||
| self.accelerator.process_index * len(prompts), | ||
| (self.accelerator.process_index + 1) * len(prompts), | ||
| ) | ||
| completion_ids = completion_ids[process_slice] | ||
| all_logprobs = all_logprobs[process_slice] | ||
| completion_ids = all_completion_ids[process_slice] | ||
| logprobs = all_logprobs[process_slice] | ||
|
Comment on lines
-1204
to
+1202
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same, better naming |
||
|
|
||
| # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts | ||
| elif self.vllm_mode == "colocate": | ||
|
|
@@ -1252,7 +1243,7 @@ def _generate(self, prompts: list[str], images: Optional[list]): | |
| with profiling_context(self, "vLLM.generate"): | ||
| all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False) | ||
|
|
||
| completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] | ||
| all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] | ||
|
Comment on lines
-1258
to
+1255
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same, better, naming |
||
| all_logprobs = [ | ||
| [next(iter(lp.values())).logprob for lp in output.logprobs] | ||
| for outputs in all_outputs | ||
|
|
@@ -1264,22 +1255,15 @@ def _generate(self, prompts: list[str], images: Optional[list]): | |
| # Each rank generates all outputs — we keep only our share. | ||
| local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) | ||
| tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) | ||
| completion_ids = completion_ids[tp_slice] | ||
| all_logprobs = all_logprobs[tp_slice] | ||
| completion_ids = all_completion_ids[tp_slice] | ||
| logprobs = all_logprobs[tp_slice] | ||
|
Comment on lines
-1270
to
+1268
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same, better naming |
||
| else: | ||
| completion_ids = all_completion_ids | ||
| logprobs = all_logprobs | ||
|
|
||
| if self.args.vllm_enable_sleep_mode: | ||
| self.llm.sleep(level=1) | ||
|
|
||
| # Pad the completions, and concatenate them with the prompts | ||
| completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] | ||
| completion_mask = [torch.ones(len(ids), device=device, dtype=torch.long) for ids in completion_ids] | ||
| completion_ids = pad(completion_ids, padding_value=self.pad_token_id) | ||
| completion_mask = pad(completion_mask, padding_value=0) | ||
| sampling_per_token_logps = [ | ||
| torch.tensor(logprobs, device=device, dtype=torch.float32) for logprobs in all_logprobs | ||
| ] | ||
| sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0) | ||
|
|
||
|
Comment on lines
-1276
to
-1285
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here we want generate to return list of ints instead of padded torch tensor; for the vllm case, it's pretty easy: just remove the tensor building part. |
||
| elif self.use_transformers_paged: | ||
| # Re-process inputs for paged generation if needed | ||
| # Note: images are already validated and preprocessed above | ||
|
|
@@ -1309,13 +1293,10 @@ def _generate(self, prompts: list[str], images: Optional[list]): | |
| ) | ||
| unwrapped_model.train() # restore training mode, as generate_batch forces eval mode | ||
| completion_ids = [output.generated_tokens for output in all_outputs.values()] | ||
| completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] | ||
| completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") | ||
| prompt_ids = [torch.tensor(ids, device=device) for ids in paged_prompt_inputs.input_ids] | ||
| prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") | ||
| prompt_ids = paged_prompt_inputs.input_ids | ||
|
Comment on lines
-1315
to
+1305
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as vllm, just remove the tensor building step |
||
| # Restore the original attention implementation, training mode | ||
| self.model_wrapped.config._attn_implementation = previous_attn | ||
| sampling_per_token_logps = None # not used in this case | ||
| logprobs = None # not used in this case | ||
|
|
||
| else: | ||
| # Regular generation path | ||
|
|
@@ -1338,29 +1319,36 @@ def _generate(self, prompts: list[str], images: Optional[list]): | |
| prompt_length = prompt_ids.size(1) | ||
| prompt_ids = prompt_completion_ids[:, :prompt_length] | ||
| completion_ids = prompt_completion_ids[:, prompt_length:] | ||
| sampling_per_token_logps = None # not used in this case | ||
|
|
||
| # Mask everything after the first EOS token | ||
| is_eos = completion_ids == self.eos_token_id | ||
| eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) | ||
| eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] | ||
| sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) | ||
| completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() | ||
| # Mask everything after the first EOS token | ||
| is_eos = completion_ids == self.eos_token_id | ||
| eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) | ||
| eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] | ||
| sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) | ||
| completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() | ||
| prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] | ||
| completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())] | ||
| logprobs = None # not used in this case | ||
|
|
||
| # Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging | ||
| completion_lengths = completion_mask.sum(1) | ||
| agg_completion_lengths = self.accelerator.gather(completion_lengths) | ||
| num_items_in_batch = agg_completion_lengths.sum() # this is required for the DAPO loss | ||
| return prompt_ids, completion_ids, logprobs, forward_kwargs | ||
|
|
||
| # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask | ||
| if self.mask_truncated_completions: | ||
| truncated_completions = ~is_eos.any(dim=1) | ||
| completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int() | ||
| def _generate(self, prompts: list[str], images: Optional[list]): | ||
| device = self.accelerator.device | ||
| mode = "train" if self.model.training else "eval" | ||
|
|
||
| prompt_ids, completion_ids, logprobs, forward_kwargs = self._generate_single_turn(prompts, images) | ||
|
|
||
| # Get completion length per sequence, used for logging | ||
| prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) | ||
| completion_lengths = torch.tensor([len(ids) for ids in completion_ids], device=device) | ||
| agg_prompt_lengths = self.accelerator.gather(prompt_lengths) | ||
| agg_completion_lengths = self.accelerator.gather(completion_lengths) | ||
| total_prompt_tokens = agg_prompt_lengths.sum() | ||
| total_completion_tokens = agg_completion_lengths.sum() # = num_items_in_batch, required for the DAPO loss | ||
|
|
||
| # Log the metrics | ||
| if mode == "train": | ||
| attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) | ||
| self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item() | ||
| self.state.num_input_tokens_seen += (total_prompt_tokens + total_completion_tokens).item() | ||
| self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] | ||
|
|
||
| # Log completion lengths, mean, min, max | ||
|
|
@@ -1369,25 +1357,18 @@ def _generate(self, prompts: list[str], images: Optional[list]): | |
| self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) | ||
|
|
||
| # Identify sequences that terminated with EOS and log their lengths | ||
| agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1)) | ||
| term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos] | ||
| clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths) | ||
| self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) | ||
| eos_and_pad = [self.processing_class.eos_token_id, self.processing_class.pad_token_id] | ||
| is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device) | ||
| agg_is_truncated = self.accelerator.gather(is_truncated) | ||
| self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) | ||
| term_completion_lengths = agg_completion_lengths[~agg_is_truncated] | ||
| if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found | ||
| term_completion_lengths = torch.zeros(1, device=device) | ||
| self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) | ||
| self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) | ||
| self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) | ||
|
|
||
| return ( | ||
| prompt_ids, | ||
| completion_ids, | ||
| prompt_mask, | ||
| completion_mask, | ||
|
Comment on lines
-1388
to
-1389
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. prompt and completion masks are later inferred from the sequence lengths |
||
| num_items_in_batch, | ||
| sampling_per_token_logps, | ||
| forward_kwargs, | ||
| ) | ||
| return prompt_ids, completion_ids, total_completion_tokens, logprobs, forward_kwargs | ||
|
|
||
| def _generate_and_score_completions( | ||
| self, inputs: list[dict[str, Union[torch.Tensor, Any]]] | ||
|
|
@@ -1405,18 +1386,33 @@ def _generate_and_score_completions( | |
| images = None | ||
|
|
||
| ( | ||
| prompt_ids, | ||
| completion_ids, | ||
| prompt_mask, | ||
| completion_mask, | ||
| prompt_ids_list, | ||
| completion_ids_list, | ||
| num_items_in_batch, | ||
| sampling_per_token_logps, | ||
| sampling_per_token_logps_list, | ||
| forward_kwargs, | ||
| ) = self._generate(prompts, images) | ||
|
|
||
| # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need | ||
| # to re-tokenize completions if the reward is computed from tokens. | ||
| completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] | ||
| # Convert lists of token IDs to padded tensors | ||
| prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] | ||
| prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] | ||
| prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") | ||
| prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") | ||
| completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] | ||
| completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] | ||
| completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") | ||
| completion_mask = pad(completion_mask, padding_value=0, padding_side="right") | ||
| if sampling_per_token_logps_list is not None: | ||
| sampling_per_token_logps = [torch.tensor(logps, device=device) for logps in sampling_per_token_logps_list] | ||
| sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right") | ||
| else: | ||
| sampling_per_token_logps = None | ||
|
|
||
| # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask | ||
| if self.mask_truncated_completions: | ||
| eos_and_pad = [self.processing_class.eos_token_id, self.processing_class.pad_token_id] | ||
| is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) | ||
| completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() | ||
|
|
||
| # Concatenate prompt_mask with completion_mask for logit computation | ||
| prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the function must now return a list of ints, so we must remove padding