forked from agentscope-ai/agentscope
-
Notifications
You must be signed in to change notification settings - Fork 0
Example/data_augment #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
lingzhq
wants to merge
9
commits into
feature/tuner_enhance
Choose a base branch
from
example/data
base: feature/tuner_enhance
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 7 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
8f6c1a1
V0: add basic readme and yaml
lingzhq 6859906
V1: basic implementation of data-augmentation
lingzhq a97208c
minor fix of readme
lingzhq ec6f380
add prepare_data.py
lingzhq cd098c8
update readme
lingzhq 2b5e67e
update config yaml for more notations
lingzhq edbb705
minor fix of py file
lingzhq 13cff49
Update README, refactor config process
lingzhq 476bb38
minor fix of README
lingzhq File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,172 @@ | ||
| # Training Math Agent with Data-Augment Strategies | ||
|
|
||
| This example demonstrates how to use **AgentScope-Tuner** to enhance a math problem-solving agent. We will focus on leveraging **Data-Centric** features, such as the `difficulty_based` task selector, to improve data utility and training efficiency. | ||
|
|
||
| ## Task Setting | ||
|
|
||
| We use the foundational [math-agent example](../react_agent/main.py) as our baseline to demonstrate the data enhancement capabilities. Notably, these data-centric techniques are generic and customizable, making them adaptable to other agent workflows. | ||
|
|
||
| ### Agent Goal and Type | ||
| The agent's objective is to solve mathematical reasoning problems, learning to produce a correct final answer through a step-by-step thought process. The agent is implemented as a **`ReActAgent`**, which follows a reasoning-acting loop to solve tasks iteratively. | ||
|
|
||
| ### Environment | ||
| Each task is a question-answer pair from a math dataset. The agent's performance is evaluated based on its final answer. | ||
|
|
||
| ### Objective of the Data-Centric Approach | ||
|
|
||
| Training can be inefficient if tasks are too easy or too hard. This example addresses this by providing **selectors** to dynamically select tasks using **data feedback**. This empowers users to explore and implement their own data-centric strategies, such as focusing on "productively challenging" samples, to maximize training efficiency. | ||
|
|
||
| ## Dataset Preparation | ||
|
|
||
| To enable difficulty-based sampling, our training data needs to include features that represent the "difficulty" of each task. | ||
|
|
||
| 1. **Base Dataset**: You can use any standard math problem dataset. A good example is the math data in [LLM360/guru-RL-92k](https://huggingface.co/datasets/LLM360/guru-RL-92k), which comes pre-annotated with pass rates from different LLMs, serving as direct difficulty features. | ||
| 2. **Build Your Own Features**: If you use your own dataset, you can generate these features by pre-running several models of varying capabilities and recording their pass rates. This can be done within the [Trinity](https://github.com/modelscope/Trinity-RFT/pull/440) framework. | ||
lingzhq marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| 3. **Data Format**: The final dataset should be in HuggingFace format. In this example, data will be transferred to *GSM8K format* according to the [workflow](../react_agent/main.py). Besides the task content, it must include the difficulty feature columns you've defined (e.g., `qwen_7b_pass_rate`, `qwen_30b_pass_rate`). | ||
| 4. **Example Data Preparation**: We provide a script for this example. Simply execute `python prepare_data.py` to generate the required dataset. | ||
|
|
||
| ## Code Implementation | ||
|
|
||
| ### Agent Workflow & Judge Function | ||
|
|
||
| This example follows the foundational [math-agent example](../react_agent/main.py), adopting its `run_react_agent` and `gsm8k_judge` as the `workflow_func` and `judge_func`, respectively. This highlights a key benefit: you can apply training strategies without altering your core agent logic. | ||
|
|
||
| ### Design of Data-Centric Features | ||
|
|
||
| Leveraging the powerful data processing capabilities of **`Trinity`**, **AgentScope-Tuner** provides interfaces for advanced operations like task selection and experience processing. | ||
lingzhq marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| #### Task Selector | ||
|
|
||
| The `Task Selector` determines how samples are selected from a dataset. It is configured within the `Dataset` object in your Python script. | ||
|
|
||
| - **Built-in Selectors**: | ||
| - `sequential`: Samples are selected in a fixed order. | ||
| - `shuffle`: The dataset is shuffled at the beginning of each epoch. | ||
| - `random`: Samples are randomly chosen with replacement for each batch. | ||
| - `offline_easy2hard`: Samples are sorted by a predefined feature for curriculum learning. | ||
| - `difficulty_based` (Customized): An adaptive sampler based on task difficulty. | ||
|
|
||
| > For more details on `Task Selector`, including how to implement a custom selector based on feedback signals, please refer to Trinity's **[Selector Development Guide](https://github.com/modelscope/Trinity-RFT/blob/main/docs/sphinx_doc/source/tutorial/develop_selector.md)**. | ||
|
|
||
| #### Data Processor | ||
|
|
||
| The `Data Processor` allows for real-time processing of **Task** and **Experience** during training, enabling operations like calculating feedback metrics, data augmentation, or filtering. | ||
|
|
||
| For example, the `difficulty_based` selector requires a `pass_rate_calculator` operator to compute the agent's success rate for each task. This feedback is then used to adjust the sampling strategy. | ||
|
|
||
| > For more details on `Data Processor`, please refer to Trinity's **[Operator Development Guide](https://github.com/modelscope/Trinity-RFT/blob/main/docs/sphinx_doc/source/tutorial/develop_operator.md)**. | ||
|
|
||
|
|
||
| ### Configuring the Experiments | ||
|
|
||
| We demonstrate how to set up two experiments to compare the baseline `random` selector against the `difficulty_based` selector. | ||
|
|
||
| **Experiment 1: Baseline with Random Selector** | ||
|
|
||
| In `main_random.py`, we configure the `task_selector` for random sampling. | ||
|
|
||
| ```python | ||
| # In main_random.py | ||
| train_dataset = Dataset( | ||
| path="path/to/your/augmented/math_data", | ||
| split="train", | ||
| task_selector={'selector_type': 'random'}, | ||
| ) | ||
|
|
||
| tune( | ||
| workflow_func=run_react_agent, | ||
| judge_func=gsm8k_judge, | ||
| config_path="config_random.yaml", | ||
| train_dataset=train_dataset, | ||
| ... | ||
| ) | ||
| ``` | ||
|
|
||
| **Experiment 2: Advanced Training with Difficulty-Based Selector** | ||
|
|
||
| In `main_difficulty.py`, we switch the selector to `difficulty_based` and provide initial feature keys. | ||
|
|
||
| ```python | ||
| # In main_difficulty.py | ||
| train_dataset = Dataset( | ||
| path="path/to/your/augmented/math_data", | ||
| split="train", | ||
| task_selector={ | ||
lingzhq marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| 'selector_type': 'difficulty_based', | ||
| 'feature_keys': ["qwen_7b_pass_rate", "qwen_30b_pass_rate"], | ||
| 'kwargs': {...}, # Hyper-parameters for the selection algorithm | ||
| }, | ||
| ) | ||
|
|
||
| tune( | ||
| workflow_func=run_react_agent, | ||
| judge_func=gsm8k_judge, | ||
| config_path="config_difficulty.yaml", | ||
| train_dataset=train_dataset, | ||
| ... | ||
| ) | ||
| ``` | ||
|
|
||
| > The `difficulty_based` selector in this example is an implementation of the ***BOTS*** algorithm. For details on its inner workings, please refer to the [***BOTS paper***](https://arxiv.org/abs/2510.26374) and its [***tutorials***](https://github.com/modelscope/Trinity-RFT/blob/main/examples/bots/README.md). | ||
|
|
||
| The `config_difficulty.yaml` must enable the `pass_rate_calculator` to provide real-time feedback. | ||
|
|
||
| ```yaml | ||
| # Enable the calculator to provide feedback for the selector | ||
| data_processor: | ||
| experience_pipeline: | ||
| operators: | ||
| - name: pass_rate_calculator | ||
| ``` | ||
|
|
||
| ## How to Run | ||
|
|
||
| ### Step 1: Prerequisites | ||
|
|
||
| Ensure you have installed AgentScope and Trinity-RFT with [the guidance](../react_agent/README.md). | ||
|
|
||
| ### Step 2: Prepare the Dataset | ||
|
|
||
| Run the data preparation script. Make sure to update the dataset paths in `main_random.py` and `main_difficulty.py` afterward. | ||
|
|
||
| ```bash | ||
| python prepare_data.py | ||
| ``` | ||
|
|
||
| ### Step 3: Start Ray Cluster | ||
|
|
||
| For distributed training, start a Ray cluster. | ||
|
|
||
| ```bash | ||
| # For single node | ||
| ray start --head | ||
| ``` | ||
|
|
||
| ### Step 4: Run Training | ||
|
|
||
| You can now run either the baseline or the difficulty-based training experiment. | ||
|
|
||
| - **To run the baseline experiment with a random selector:** | ||
|
|
||
| ```bash | ||
| python main_random.py | ||
| ``` | ||
|
|
||
| - **To run the experiment with the difficulty-based selector:** | ||
| ```bash | ||
| python main_difficulty.py | ||
| ``` | ||
|
|
||
| ## Experimental Results | ||
|
|
||
| The following results compare the performance of the `difficulty-based` selection strategy (red line, bots) against a standard `random` selection strategy (black line, random). | ||
|
|
||
|  | ||
|
|
||
| ### Training Reward Curve | ||
|
|
||
| The chart on the left shows the rollout accuracy during training. As can be seen, the tasks sampled by the random strategy appear to be difficult for the model, with the accuracy remaining below 0.2. In contrast, using the difficulty selector results in a higher mean accuracy, indicating that the agent is engaging with more tasks that it can successfully solve. | ||
|
|
||
| ### Evaluation on AIME-24 | ||
|
|
||
| For comparison, we evaluated both selection strategies on the AIME-24 benchmark. The chart on the right shows that the difficulty-based method demonstrates a better upward trend in performance over time. | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| project: "Data-Augmentation" # Project name | ||
| name: "Difficulty-Based-Selector" # Experiment name | ||
| checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} # Directory to save model checkpoints | ||
|
|
||
| data_processor: | ||
| experience_pipeline: | ||
| operators: | ||
| - name: pass_rate_calculator # Calculate average reward and pass it back to selector | ||
|
|
||
| synchronizer: | ||
| sync_style: dynamic_by_explorer # Sync triggered dynamically by explorer | ||
| sync_method: 'nccl' | ||
| sync_interval: 4 # Sync every N steps | ||
| sync_timeout: 7200 # Timeout for synchronization (seconds) | ||
|
|
||
| monitor: | ||
| monitor_type: tensorboard # Can also use wandb, mlflow or swanlab | ||
|
|
||
| # The config below has been set in python file | ||
|
|
||
| algorithm: | ||
| algorithm_type: multi_step_grpo # GRPO series for multi-step scenario | ||
| repeat_times: 8 # Number of rollouts per prompt for advantage estimation | ||
| optimizer: | ||
| lr: 1e-6 # Learning rate | ||
|
|
||
| model: | ||
| model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-0.6B} # Base model path | ||
| max_model_len: 24576 # Max context length | ||
| max_response_tokens: 16384 # Max tokens per response | ||
| temperature: 1.0 # Temperature of model's generation | ||
|
|
||
| cluster: | ||
| node_num: 1 # Number of used nodes | ||
| gpu_per_node: 8 # Number of GPUs every node | ||
|
|
||
| buffer: | ||
| total_epochs: 1 # Total training epochs | ||
| batch_size: 16 # Batch size per explore step | ||
| explorer_input: | ||
| taskset: | ||
| path: "path/to/your/augmented/math_data" # Training data path | ||
| split: "train" # Training data split | ||
| task_selector: | ||
| selector_type: difficulty_based # Strategy of task selection | ||
| feature_keys: [ "qwen2.5_7b_pass_rate", "qwen3_30b_pass_rate" ] # Utilized pass_rate key | ||
| kwargs: # Hyperparameter from [BOTS](https://github.com/modelscope/Trinity-RFT/blob/main/examples/bots/README.md) | ||
| m: 8 | ||
| lamb: 0.1 | ||
| rho: 0.1 | ||
| target_reward: 0.8 | ||
| tau: 0 | ||
| do_sample: true | ||
| eval_tasksets: | ||
| path: "path/to/aime24_data" # Evaluation data path | ||
| split: "test" # Evaluation data split | ||
|
|
||
| explorer: | ||
| eval_interval: 20 # Evaluation every N steps | ||
| runner_per_model: 16 # Runners per infer engine | ||
| max_timeout: 1200 # Max timeout for each rollout (seconds) | ||
| rollout_model: | ||
| engine_num: 4 # Number of vLLM engines for rollout model | ||
| tensor_parallel_size: 1 # TP size per engine for rollout model | ||
| enable_openai_api: true # Enable OpenAI-compatible API | ||
| enable_history: true # Enable conversation history | ||
| enable_auto_tool_choice: true # Enable automatic tool selection | ||
| tool_call_parser: hermes # Parser for tool calls | ||
| reasoning_parser: deepseek_r1 # Parser for reasoning type | ||
|
|
||
| trainer: | ||
| save_interval: 100 # Save checkpoint every N steps | ||
| use_dynamic_bsz: true # Use dynamic batch size | ||
| ulysses_sequence_parallel_size: 1 # Sequence parallel size for Ulysses |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| project: "Data-Augmentation" # Project name | ||
| name: "Random-Selector" # Experiment name | ||
| checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} # Directory to save model checkpoints | ||
|
|
||
| synchronizer: | ||
| sync_style: dynamic_by_explorer # Sync triggered dynamically by explorer | ||
| sync_method: 'nccl' | ||
| sync_interval: 4 # Sync every N steps | ||
| sync_timeout: 7200 # Timeout for synchronization (seconds) | ||
|
|
||
| monitor: | ||
| monitor_type: tensorboard # Can also use wandb, mlflow or swanlab | ||
|
|
||
| # The config below has been set in python file | ||
|
|
||
| algorithm: | ||
| algorithm_type: multi_step_grpo # GRPO series for multi-step scenario | ||
| repeat_times: 8 # Number of rollouts per prompt for advantage estimation | ||
| optimizer: | ||
| lr: 1e-6 # Learning rate | ||
|
|
||
| model: | ||
| model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-0.6B} # Base model path | ||
| max_model_len: 24576 # Max context length | ||
| max_response_tokens: 16384 # Max tokens per response | ||
| temperature: 1.0 # Temperature of model's generation | ||
|
|
||
| cluster: | ||
| node_num: 1 # Number of used nodes | ||
| gpu_per_node: 8 # Number of GPUs every node | ||
|
|
||
| buffer: | ||
| total_epochs: 1 # Total training epochs | ||
| batch_size: 16 # Batch size per explore step | ||
| explorer_input: | ||
| taskset: | ||
| path: "path/to/your/augmented/math_data" # Training data path | ||
| split: "train" # Training data split | ||
| task_selector: | ||
| selector_type: random # Strategy of task selection | ||
| eval_tasksets: | ||
| path: "path/to/aime24_data" # Evaluation data path | ||
| split: "test" # Evaluation data split | ||
|
|
||
| explorer: | ||
| eval_interval: 20 # Evaluation every N steps | ||
| runner_per_model: 16 # Runners per infer engine | ||
| max_timeout: 1200 # Max timeout for each rollout (seconds) | ||
| rollout_model: | ||
| engine_num: 4 # Number of vLLM engines for rollout model | ||
| tensor_parallel_size: 1 # TP size per engine for rollout model | ||
| enable_openai_api: true # Enable OpenAI-compatible API | ||
| enable_history: true # Enable conversation history | ||
| enable_auto_tool_choice: true # Enable automatic tool selection | ||
| tool_call_parser: hermes # Parser for tool calls | ||
| reasoning_parser: deepseek_r1 # Parser for reasoning type | ||
|
|
||
| trainer: | ||
| save_interval: 100 # Save checkpoint every N steps | ||
| use_dynamic_bsz: true # Use dynamic batch size | ||
| ulysses_sequence_parallel_size: 1 # Sequence parallel size for Ulysses |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.