Skip to content

Conversation

qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Sep 26, 2025

This PR belongs to a sequence of PR that aims to refactor the generation part of GRPO/RLOO to allow for easier customization and ultimately tool calling

Previous:

Next:

Instead of getting prompt_ids for all generation methods like this

prompt_inputs = super()._prepare_inputs(prompt_inputs)
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]

we rely on each method to provide the prompt_ids.

Benchmark

I ran a quick benchmark with two configs, results are here:

I can't remember which run correspond to before, and which one correspond to after 😅

from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig
import os

os.environ["TRACKIO_PROJECT"] = "4152"
os.environ["TRACKIO_SPACE_ID"] = "qgallouedec/trackio"

# Dummy reward function: count the number of unique characters in the completions
def reward_num_unique_chars(completions, **kwargs):
    return [len(set(c)) for c in completions]

dataset = load_dataset("trl-lib/tldr", split="train[:500]")
dataset.select_columns(["prompt"])

training_args = GRPOConfig(
    output_dir="tmp",
    per_device_train_batch_size=4,  # reduce the batch size to reduce memory usage
    gradient_accumulation_steps=8,
    num_generations=4,  # reduce the number of generations to reduce memory usage
    max_completion_length=256,  # reduce the completion length to reduce memory usage
    steps_per_generation=4,
    logging_steps=1,
    num_train_epochs=1,
    vllm_mode="colocate",
    use_vllm=True,
    report_to="trackio"
)
trainer = GRPOTrainer(
    model="Qwen/Qwen3-0.6B-Base",
    reward_funcs=[reward_num_unique_chars],
    args=training_args,
    train_dataset=dataset,
)
trainer.train()

https://huggingface.co/spaces/qgallouedec/trackio?project=4152

qgallouedec and others added 18 commits October 1, 2025 08:49
Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: sergiopaniego <[email protected]>
Co-authored-by: lewtun <[email protected]>
Co-authored-by: Kashif Rasul <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clean refactor!

Overall LGTM with a question about whether the difference across benchmark runs is within the variance of a repeated run (i.e. if you run one of these again on main does it also fluctuate from the original run on main)?

Screenshot 2025-10-07 at 07 46 10

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants