diff --git a/modeling/llm_post_training/README_VERL.md b/modeling/llm_post_training/README_VERL.md
new file mode 100644
index 0000000..a0123d8
--- /dev/null
+++ b/modeling/llm_post_training/README_VERL.md
@@ -0,0 +1,284 @@
+# VERL-based GRPO Training for Reasoning Tasks
+
+This directory contains a VERL (Versatile Reinforcement Learning) implementation of GRPO (Group Relative Policy Optimization) for training language models on reasoning tasks.
+
+## Overview
+
+The VERL implementation provides several advantages over the traditional TRL-based approach:
+
+- **Better Scalability**: VERL is designed for distributed training and can handle larger models more efficiently
+- **Improved Performance**: Optimized for multi-GPU and multi-node training scenarios
+- **Flexible Architecture**: Supports different backends (FSDP, Megatron-LM) and worker configurations
+- **Advanced Features**: Built-in support for fault tolerance, checkpointing, and resource management
+
+## Files
+
+- `reasoning_grpo_verl_clean.py`: Main VERL GRPO training script
+- `reasoning_grpo.py`: Original TRL-based implementation for comparison
+- `requirements_verl.txt`: VERL-specific dependencies
+- `README_VERL_GRPO.md`: This documentation
+
+## Key Differences from TRL Implementation
+
+### 1. Reward Function Architecture
+
+**TRL Version:**
+
+```python
+# Multiple separate reward functions
+reward_funcs = [
+ self.match_format_func,
+ self.penalize_short_think_func,
+ self.check_answer_func,
+]
+```
+
+**VERL Version:**
+
+```python
+# Single unified reward manager
+class ReasoningRewardManager:
+ def compute_reward(self, completions, ground_truth, **kwargs):
+ # Combines all reward components internally
+ format_score = self._compute_format_reward(completion)
+ thinking_score = self._compute_thinking_reward(completion)
+ answer_score = self._compute_answer_reward(completion, gt)
+ return format_score + thinking_score + answer_score
+```
+
+### 2. Configuration Management
+
+**TRL Version:**
+
+```python
+# Simple configuration object
+config = GRPOConfig(
+ output_dir=output_dir,
+ learning_rate=self.learning_rate,
+ # ... other parameters
+)
+```
+
+**VERL Version:**
+
+```python
+# Comprehensive configuration with nested structure
+config = {
+ "trainer": {"type": "grpo", "n_gpus_per_node": 1, ...},
+ "actor_rollout_ref": {
+ "actor": {"strategy": "fsdp", "model": {...}, ...},
+ "rollout": {"temperature": 1.0, "num_generations": 8, ...}
+ },
+ "critic": {"strategy": "fsdp", "model": {...}, ...},
+ "data": {"train_path": dataset_path, "batch_size": 4, ...},
+ "grpo": {"cliprange": 0.2, "gamma": 0.99, ...}
+}
+```
+
+### 3. Worker and Resource Management
+
+**TRL Version:**
+
+```python
+# Simple trainer initialization
+trainer = GRPOTrainer(
+ model=self.model_name,
+ reward_funcs=reward_funcs,
+ args=training_args,
+ train_dataset=self.dataset,
+ peft_config=lora_config,
+)
+```
+
+**VERL Version:**
+
+```python
+# Complex worker setup with resource pools
+role_worker_mapping = {
+ Role.ActorRollout: ActorRolloutRefWorker,
+ Role.Critic: CriticWorker,
+ Role.RefPolicy: ActorRolloutRefWorker
+}
+
+resource_pool_spec = {
+ 'global_pool': [config.trainer.n_gpus_per_node] * config.trainer.nnodes
+}
+
+trainer = RayGRPOTrainer(
+ config=config,
+ tokenizer=tokenizer,
+ role_worker_mapping=role_worker_mapping,
+ resource_pool_manager=resource_pool_manager,
+ ray_worker_group_cls=ray_worker_group_cls,
+ reward_fn=reward_fn,
+ val_reward_fn=val_reward_fn,
+)
+```
+
+## Installation
+
+1. Install VERL and dependencies:
+
+```bash
+pip install -r requirements_verl.txt
+```
+
+2. Install VERL framework:
+
+```bash
+pip install verl
+```
+
+## Usage
+
+### Basic Training
+
+```bash
+python reasoning_grpo_verl_clean.py
+```
+
+### With LoRA
+
+```bash
+python reasoning_grpo_verl_clean.py --use-lora
+```
+
+### Custom Configuration
+
+```bash
+python reasoning_grpo_verl_clean.py \
+ --model-size 1.5B \
+ --max-steps 1000 \
+ --batch-size 8 \
+ --learning-rate 2e-5
+```
+
+### With Hugging Face Token
+
+```bash
+python reasoning_grpo_verl_clean.py --hf-token your_token_here
+```
+
+## Configuration Options
+
+| Parameter | Description | Default |
+| ------------------------------- | --------------------------------------- | ------- |
+| `--model-size` | Model size ("0.5B", "1.5B", "3B", "4B") | "4B" |
+| `--use-lora` | Enable LoRA fine-tuning | False |
+| `--disable-wandb` | Disable wandb logging | False |
+| `--max-steps` | Maximum training steps | 500 |
+| `--batch-size` | Training batch size | 4 |
+| `--learning-rate` | Learning rate | 1e-5 |
+| `--gradient-accumulation-steps` | Gradient accumulation steps | 16 |
+| `--hf-token` | Hugging Face token | None |
+
+## Architecture Components
+
+### 1. ReasoningRewardManager
+
+- **Purpose**: Computes reward scores for reasoning completions
+- **Components**:
+ - Format compliance checking
+ - Thinking quality assessment
+ - Answer correctness evaluation
+- **Integration**: Compatible with VERL's reward system
+
+### 2. ReasoningGRPOVERLTrainer
+
+- **Purpose**: Main trainer class for VERL GRPO training
+- **Features**:
+ - Dataset preparation and preprocessing
+ - VERL configuration management
+ - Worker and resource setup
+ - Training orchestration
+
+### 3. VERL Configuration
+
+- **Trainer**: Defines training parameters and resource allocation
+- **Actor/Rollout**: Model configuration and generation parameters
+- **Critic**: Value function model setup
+- **Data**: Dataset paths and batch configuration
+- **GRPO**: Algorithm-specific hyperparameters
+
+## Performance Considerations
+
+### Memory Optimization
+
+- **FSDP Strategy**: Enables efficient memory usage across GPUs
+- **LoRA Support**: Reduces memory footprint for fine-tuning
+- **Gradient Accumulation**: Allows larger effective batch sizes
+
+### Scalability
+
+- **Multi-GPU Support**: Built-in support for distributed training
+- **Resource Pools**: Flexible GPU allocation and management
+- **Ray Integration**: Enables multi-node training scenarios
+
+## Debugging and Monitoring
+
+### Debug Logging
+
+- **Location**: `debug_logs/verl_grpo_debug_YYYYMMDD_HHMMSS.txt`
+- **Content**: Detailed reward breakdowns and sample completions
+- **Frequency**: Configurable via `num_examine` parameter
+
+### Wandb Integration
+
+- **Project**: `verl-reasoning-grpo`
+- **Metrics**: Training loss, reward scores, and model performance
+- **Tags**: `["reasoning", "grpo", "verl"]`
+
+## Troubleshooting
+
+### Common Issues
+
+1. **Import Errors**: Ensure VERL is properly installed
+
+ ```bash
+ pip install verl
+ ```
+
+2. **CUDA Memory Issues**: Reduce batch size or enable LoRA
+
+ ```bash
+ python reasoning_grpo_verl_clean.py --batch-size 2 --use-lora
+ ```
+
+3. **Dataset Loading Issues**: Check internet connection and HF token
+ ```bash
+ python reasoning_grpo_verl_clean.py --hf-token your_token_here
+ ```
+
+### Performance Tuning
+
+1. **Increase Batch Size**: For better GPU utilization
+2. **Enable LoRA**: For memory-efficient fine-tuning
+3. **Adjust Learning Rate**: Based on model size and dataset
+4. **Optimize Gradient Accumulation**: Balance memory and training speed
+
+## Comparison with TRL Implementation
+
+| Aspect | TRL Version | VERL Version |
+| ----------------------- | ------------------ | ---------------------- |
+| **Scalability** | Single GPU focused | Multi-GPU/Multi-node |
+| **Memory Usage** | Higher | Optimized with FSDP |
+| **Configuration** | Simple | Comprehensive |
+| **Worker Management** | Automatic | Explicit control |
+| **Resource Allocation** | Basic | Advanced pooling |
+| **Fault Tolerance** | Limited | Built-in support |
+| **Performance** | Good | Better for large scale |
+
+## Future Enhancements
+
+1. **Multi-Modal Support**: Extend to image reasoning tasks
+2. **Advanced Algorithms**: Implement other RL algorithms (PPO, DPO)
+3. **Custom Reward Models**: Support for learned reward functions
+4. **Hyperparameter Optimization**: Automated tuning capabilities
+5. **Model Serving**: Integration with inference servers
+
+## References
+
+- [VERL Documentation](https://verl.readthedocs.io/)
+- [GRPO Paper](https://arxiv.org/abs/2406.05930)
+- [VERL GitHub Repository](https://github.com/volcengine/verl)
+- [Original TRL Implementation](./reasoning_grpo.py)
diff --git a/modeling/llm_post_training/reasoning_grpo.py b/modeling/llm_post_training/reasoning_grpo.py
index bceb497..91311a2 100644
--- a/modeling/llm_post_training/reasoning_grpo.py
+++ b/modeling/llm_post_training/reasoning_grpo.py
@@ -85,7 +85,12 @@ def __init__(
self.step_counter = 0
# Setup workspace directories
- self.workspace_dir = os.environ.get("WORKSPACE_DIR", "/workspace")
+ self.workspace_dir = os.environ.get(
+ "WORKSPACE_DIR", os.path.expanduser("~/workspace")
+ )
+ # Ensure we use a writable directory
+ if not os.access(self.workspace_dir, os.W_OK):
+ self.workspace_dir = os.path.expanduser("~/workspace")
self.models_dir = os.path.join(self.workspace_dir, "models")
self.data_dir = os.path.join(self.workspace_dir, "data")
self.cache_dir = os.path.join(self.workspace_dir, "cache")
diff --git a/modeling/llm_post_training/reasoning_grpo_verl.py b/modeling/llm_post_training/reasoning_grpo_verl.py
new file mode 100644
index 0000000..e2f7a84
--- /dev/null
+++ b/modeling/llm_post_training/reasoning_grpo_verl.py
@@ -0,0 +1,766 @@
+"""
+VERL-based GRPO Training Script for Reasoning Tasks
+
+This script implements GRPO (Group Relative Policy Optimization) training using VERL
+for improving reasoning capabilities on the mini-reasoning-dataset.
+
+VERL (Versatile Reinforcement Learning) is a framework for RLHF training that
+provides better scalability and performance compared to traditional TRL
+implementations.
+
+Usage:
+ # Single GPU training
+ python reasoning_grpo_verl.py --model-size 3B --use-lora
+ python reasoning_grpo_verl.py --model-size 4B --gradient-accumulation-steps 16
+ python reasoning_grpo_verl.py --disable-wandb
+
+Arguments:
+ --model-size: Model size to use ("0.5B", "1.5B", "3B", "4B") [default: 4B]
+ --use-lora: Enable LoRA for efficient fine-tuning [default: False]
+ --disable-wandb: Disable wandb logging [default: False]
+ --max-steps: Maximum training steps [default: 500]
+ --batch-size: Training batch size [default: 4]
+ --learning-rate: Learning rate [default: 1e-5]
+ --hf-token: Hugging Face token for accessing gated repositories
+ [default: None]
+
+Examples:
+ # Basic training
+ python reasoning_grpo_verl.py
+ python reasoning_grpo_verl.py --use-lora
+
+ # Custom configuration
+ python reasoning_grpo_verl.py --model-size 1.5B --max-steps 1000 --batch-size 8
+
+ # With Hugging Face token
+ python reasoning_grpo_verl.py --hf-token your_token_here
+"""
+
+import re
+import random
+import os
+import argparse
+from datetime import datetime
+from typing import List, Optional
+
+import pandas as pd
+from datasets import load_dataset
+from transformers import AutoTokenizer
+
+# VERL imports
+from verl.config import Config
+from verl.trainer.grpo.ray_trainer import RayGRPOTrainer
+from verl.trainer.grpo.ray_trainer import ResourcePoolManager, Role
+
+
+class ReasoningRewardManager:
+ """VERL-compatible reward manager for reasoning tasks."""
+
+ def __init__(self, tokenizer: AutoTokenizer, num_examine: int = 0):
+ """
+ Initialize the reward manager.
+
+ Args:
+ tokenizer: Tokenizer for the model
+ num_examine: Number of examples to examine for debugging
+ """
+ self.tokenizer = tokenizer
+ self.num_examine = num_examine
+ self.step_counter = 0
+
+ # Tag constants
+ self.reasoning_start = ""
+ self.reasoning_end = ""
+ self.answer_start = ""
+ self.answer_end = ""
+
+ # Setup logging
+ self.log_dir = "debug_logs"
+ os.makedirs(self.log_dir, exist_ok=True)
+ self.log_file = os.path.join(
+ self.log_dir,
+ f"verl_grpo_debug_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt",
+ )
+
+ def compute_reward(
+ self, completions: List[str], ground_truth: List[str], **kwargs
+ ) -> List[float]:
+ """
+ Compute reward scores for reasoning completions.
+
+ Args:
+ completions: List of model completions
+ ground_truth: List of ground truth answers
+ **kwargs: Additional arguments
+
+ Returns:
+ List of reward scores
+ """
+ self.step_counter += 1
+ scores = []
+
+ # Debug logging for first completion
+ if completions and self.num_examine > 0:
+ self._log_debug_info(completions[0], ground_truth[0])
+
+ # Score each completion
+ for completion, gt in zip(completions, ground_truth):
+ # Calculate individual reward components
+ format_score = self._compute_format_reward(completion)
+ thinking_score = self._compute_thinking_reward(completion)
+ answer_score = self._compute_answer_reward(completion, gt)
+
+ # Combine scores
+ total_score = format_score + thinking_score + answer_score
+ scores.append(total_score)
+
+ return scores
+
+ def _compute_format_reward(self, completion: str) -> float:
+ """Compute format compliance reward."""
+ # Regex for matching the format: contentcontent
+ match_format = re.compile(
+ rf"^[\s]{{0,}}"
+ rf"{self.reasoning_start}.*?{self.reasoning_end}"
+ rf"{self.answer_start}.*?{self.answer_end}"
+ rf"[\s]{{0,}}$",
+ flags=re.MULTILINE | re.DOTALL,
+ )
+
+ penalty = 0
+
+ # Format compliance checking
+ if match_format.search(completion) is not None:
+ # Format is perfect - no penalty
+ return penalty
+
+ # Missing or incorrect tags
+ penalty -= 1.0 if completion.count(self.reasoning_start) != 1 else 0
+ penalty -= 1.0 if completion.count(self.reasoning_end) != 1 else 0
+ penalty -= 1.0 if completion.count(self.answer_start) != 1 else 0
+ penalty -= 1.0 if completion.count(self.answer_end) != 1 else 0
+
+ # Content structure penalties
+ penalty += self._check_content_structure(completion)
+
+ return penalty
+
+ def _check_content_structure(self, completion: str) -> float:
+ """Check content structure and return penalty score."""
+ penalty = 0
+
+ # Unwrapped content (content not in tags)
+ content_without_tags = re.sub(
+ rf"{self.reasoning_start}.*?{self.reasoning_end}",
+ "",
+ completion,
+ flags=re.DOTALL,
+ )
+ content_without_tags = re.sub(
+ rf"{self.answer_start}.*?{self.answer_end}",
+ "",
+ content_without_tags,
+ flags=re.DOTALL,
+ )
+ content_without_tags = content_without_tags.strip()
+
+ if content_without_tags:
+ penalty -= 5.0 # Penalty for unwrapped content
+
+ # Wrong order (answer before thinking)
+ think_pos = completion.find(self.reasoning_start)
+ answer_pos = completion.find(self.answer_start)
+
+ if think_pos != -1 and answer_pos != -1:
+ if answer_pos < think_pos: # Answer comes before thinking
+ penalty -= 1.0
+
+ # Multiple sections (should be exactly one of each)
+ think_count = completion.count(self.reasoning_start)
+ answer_count = completion.count(self.answer_start)
+
+ if think_count > 1:
+ penalty -= 2.0
+ if answer_count > 1:
+ penalty -= 2.0
+
+ return penalty
+
+ def _compute_thinking_reward(self, completion: str) -> float:
+ """Compute thinking quality reward."""
+ # Extract thinking content
+ think_match = re.search(
+ rf"{self.reasoning_start}(.+?){self.reasoning_end}",
+ completion,
+ flags=re.DOTALL,
+ )
+
+ if think_match:
+ think_content = think_match.group(1).strip()
+ else:
+ think_content = completion
+
+ content_length = len(think_content)
+
+ # Gradual penalty for short thinking (under 200 chars)
+ if content_length < 200:
+ penalty_ratio = (200 - content_length) / 200
+ # Gradual penalty from 0 to -10.0
+ return -10.0 * penalty_ratio
+
+ return 0
+
+ def _compute_answer_reward(self, completion: str, ground_truth: str) -> float:
+ """Compute answer correctness reward."""
+ # Extract answer from completion
+ answer_match = re.search(
+ rf"{self.answer_start}\s*(.+?)\s*{self.answer_end}",
+ completion,
+ flags=re.DOTALL,
+ )
+
+ if answer_match is None:
+ # No answer tags found - treat as wrong answer
+ return -1.0
+
+ answer = answer_match.group(1).strip()
+
+ # Exact match gets full score
+ if answer.lower() == ground_truth.lower():
+ return 8.0
+ # Partial match if answer contains ground truth
+ elif ground_truth.lower() in answer.lower():
+ return 3.0
+ else:
+ return -1.0 # Penalty for wrong answers
+
+ def _log_debug_info(self, completion: str, ground_truth: str) -> None:
+ """Log debug information for monitoring training progress."""
+ # Always print when there's a full score, occasionally print other cases
+ should_print = False
+ print_reason = ""
+
+ answer_match = re.search(
+ rf"{self.answer_start}\s*(.+?)\s*{self.answer_end}",
+ completion,
+ flags=re.DOTALL,
+ )
+
+ if answer_match:
+ extracted_answer = answer_match.group(1).strip()
+ if extracted_answer.lower() == ground_truth.lower():
+ should_print = True
+ print_reason = "🎯 FULL SCORE (8.0) - Exact match!"
+ elif random.random() < 0.1: # 10% chance for other cases
+ should_print = True
+ if ground_truth.lower() in extracted_answer.lower():
+ print_reason = "✅ PARTIAL SCORE (3.0) - Contains ground truth"
+ else:
+ print_reason = "❌ WRONG ANSWER (-1.0) - No match"
+ elif random.random() < 0.1: # 10% chance for no tags case
+ should_print = True
+ print_reason = "❌ No answer tags found (-1.0 penalty)"
+
+ if should_print:
+ self._write_debug_output(completion, ground_truth, print_reason)
+
+ def _write_debug_output(
+ self, completion: str, ground_truth: str, print_reason: str
+ ) -> None:
+ """Write debug output to console and file."""
+ # Calculate individual function scores for debugging
+ format_reward = self._compute_format_reward(completion)
+ think_reward = self._compute_thinking_reward(completion)
+ answer_reward = self._compute_answer_reward(completion, ground_truth)
+ total_reward = format_reward + think_reward + answer_reward
+
+ # Prepare debug output
+ debug_output = self._format_debug_output(
+ completion,
+ ground_truth,
+ print_reason,
+ format_reward,
+ think_reward,
+ answer_reward,
+ total_reward,
+ )
+
+ # Print to console
+ for line in debug_output:
+ print(line)
+
+ # Write to file
+ with open(self.log_file, "a", encoding="utf-8") as f:
+ f.write("\n".join(debug_output))
+ f.write("\n")
+
+ def _format_debug_output(
+ self,
+ completion: str,
+ ground_truth: str,
+ print_reason: str,
+ format_reward: float,
+ think_reward: float,
+ answer_reward: float,
+ total_reward: float,
+ ) -> List[str]:
+ """Format debug output for logging."""
+ debug_output = []
+ debug_output.append("\n" + "=" * 60)
+ debug_output.append(
+ f"VERL GRPO SPOT CHECK: PROMPT AND COMPLETIONS "
+ f"(Step: {self.step_counter})"
+ )
+ debug_output.append("=" * 60)
+ debug_output.append(f"==Completion:==\n {completion}\n")
+ debug_output.append(f"==Ground Truth:==\n {ground_truth}")
+
+ # Extract answer for display
+ answer_match = re.search(
+ rf"{self.answer_start}\s*(.+?)\s*{self.answer_end}",
+ completion,
+ flags=re.DOTALL,
+ )
+ if answer_match:
+ extracted_answer = answer_match.group(1).strip()
+ debug_output.append(f"==Extracted Answer: '{extracted_answer}'")
+
+ debug_output.append(print_reason)
+ debug_output.append("==SCORE BREAKDOWN==")
+ debug_output.append(f" Format reward: {format_reward}")
+ debug_output.append(f" Think reward: {think_reward}")
+ debug_output.append(f" Answer reward: {answer_reward}")
+ debug_output.append(f" TOTAL REWARD: {total_reward}")
+ debug_output.append("=" * 60)
+
+ return debug_output
+
+
+class ReasoningGRPOVERLTrainer:
+ """VERL-based GRPO trainer for reasoning tasks."""
+
+ def __init__(
+ self,
+ model_size: str = "4B",
+ use_lora: bool = False,
+ wandb_enabled: bool = True,
+ max_steps: int = 500,
+ batch_size: int = 4,
+ learning_rate: float = 1e-5,
+ gradient_accumulation_steps: int = 16,
+ hf_token: Optional[str] = None,
+ ):
+ """
+ Initialize the VERL GRPO trainer.
+
+ Args:
+ model_size: Size of the model ("0.5B", "1.5B", "3B", "4B")
+ use_lora: Whether to use LoRA for efficient fine-tuning
+ wandb_enabled: Whether to enable wandb logging
+ max_steps: Maximum training steps
+ batch_size: Training batch size
+ learning_rate: Learning rate for training
+ gradient_accumulation_steps: Number of gradient accumulation steps
+ hf_token: Hugging Face token for accessing gated repositories
+ """
+ self.model_size = model_size
+ self.use_lora = use_lora
+ self.wandb_enabled = wandb_enabled
+ self.max_steps = max_steps
+ self.batch_size = batch_size
+ self.learning_rate = learning_rate
+ self.gradient_accumulation_steps = gradient_accumulation_steps
+ self.hf_token = hf_token
+ self.model_name = self._get_model_name()
+
+ # Setup workspace directories
+ self.workspace_dir = os.environ.get(
+ "WORKSPACE_DIR", os.path.expanduser("~/workspace")
+ )
+ if not os.access(self.workspace_dir, os.W_OK):
+ self.workspace_dir = os.path.expanduser("~/workspace")
+
+ self.models_dir = os.path.join(self.workspace_dir, "models")
+ self.data_dir = os.path.join(self.workspace_dir, "data")
+ self.cache_dir = os.path.join(self.workspace_dir, "cache")
+
+ # Create directories if they don't exist
+ os.makedirs(self.models_dir, exist_ok=True)
+ os.makedirs(self.data_dir, exist_ok=True)
+ os.makedirs(self.cache_dir, exist_ok=True)
+
+ # Set environment variables for HuggingFace
+ os.environ["HF_HOME"] = self.cache_dir
+ os.environ["HF_HUB_CACHE"] = self.models_dir
+ os.environ["HF_DATASETS_CACHE"] = self.data_dir
+
+ # Set Hugging Face token if provided
+ if self.hf_token:
+ os.environ["HUGGINGFACE_HUB_TOKEN"] = self.hf_token
+
+ def _get_model_name(self) -> str:
+ """Get the model name based on size."""
+ model_mapping = {
+ "4B": "Qwen/Qwen3-4B-Instruct-2507",
+ "3B": "meta-llama/Llama-3.2-3B-Instruct",
+ "1.5B": "Qwen/Qwen2-1.5B-Instruct",
+ "0.5B": "Qwen/Qwen2-0.5B-Instruct",
+ }
+
+ if self.model_size not in model_mapping:
+ raise ValueError(f"Invalid model size: {self.model_size}")
+
+ return model_mapping[self.model_size]
+
+ def _create_reasoning_prompt(self, question: str) -> str:
+ """Create the reasoning prompt template."""
+ return f"""
+ The following question requires reasoning.
+ In addition to provide your answer, you should also provide your
+ DETAILED thought process about how you arrive at your answer.
+ Put your thought process between tags and then put
+ your answer between tags.
+
+ The question is:
+ {question}
+ """
+
+ def prepare_dataset_for_verl(self) -> str:
+ """
+ Prepare dataset in VERL-compatible format and save as parquet.
+
+ Returns:
+ Path to the prepared parquet file
+ """
+ # Load dataset from HuggingFace
+ dataset = load_dataset("tech-tao/mini-reasoning-dataset", split="train")
+
+ # Transform dataset with reasoning prompt template
+ processed_data = []
+ for item in dataset:
+ processed_item = {
+ "prompt": self._create_reasoning_prompt(item["prompt"]),
+ "ground_truth": item["completion"],
+ }
+ processed_data.append(processed_item)
+
+ # Convert to DataFrame and save as parquet
+ df = pd.DataFrame(processed_data)
+ parquet_path = os.path.join(self.data_dir, "reasoning_dataset.parquet")
+ df.to_parquet(parquet_path, index=False)
+
+ print(f"📊 Dataset prepared and saved to: {parquet_path}")
+ print(f" Total samples: {len(processed_data)}")
+
+ return parquet_path
+
+ def create_verl_config(self) -> Config:
+ """Create VERL configuration for GRPO training."""
+ # Create output directory
+ model_name_short = self.model_name.split("/")[-1]
+ lora_suffix = "LoRA" if self.use_lora else "Full"
+ model_output_name = f"{model_name_short}-{lora_suffix}-VERL-GRPO"
+ output_dir = os.path.join(self.models_dir, model_output_name)
+
+ # Prepare dataset
+ dataset_path = self.prepare_dataset_for_verl()
+
+ # Create VERL configuration
+ config = {
+ "trainer": {
+ "type": "grpo",
+ "total_epochs": 1,
+ "save_interval": 100,
+ "logging_interval": 1,
+ "eval_interval": 50,
+ "n_gpus_per_node": 1,
+ "nnodes": 1,
+ },
+ "actor_rollout_ref": {
+ "actor": {
+ "strategy": "fsdp",
+ "model": {
+ "type": "causal_lm",
+ "model_name": self.model_name,
+ "use_lora": self.use_lora,
+ "lora_config": (
+ {
+ "r": 16,
+ "lora_alpha": 32,
+ "target_modules": [
+ "q_proj",
+ "v_proj",
+ "k_proj",
+ "o_proj",
+ ],
+ "lora_dropout": 0.01,
+ "bias": "none",
+ }
+ if self.use_lora
+ else None
+ ),
+ },
+ "optimizer": {
+ "lr": self.learning_rate,
+ "eps": 1e-5,
+ "weight_decay": 0.01,
+ },
+ "lr_scheduler": {
+ "type": "cosine",
+ "warmup_ratio": 0.1,
+ },
+ },
+ "rollout": {
+ "temperature": 1.0,
+ "top_p": 0.9,
+ "max_new_tokens": 512,
+ "num_generations": 8,
+ },
+ },
+ "critic": {
+ "strategy": "fsdp",
+ "model": {
+ "type": "causal_lm",
+ "model_name": self.model_name,
+ "use_lora": self.use_lora,
+ "lora_config": (
+ {
+ "r": 16,
+ "lora_alpha": 32,
+ "target_modules": ["q_proj", "v_proj", "k_proj", "o_proj"],
+ "lora_dropout": 0.01,
+ "bias": "none",
+ }
+ if self.use_lora
+ else None
+ ),
+ },
+ "optimizer": {
+ "lr": self.learning_rate,
+ "eps": 1e-5,
+ "weight_decay": 0.01,
+ },
+ "lr_scheduler": {
+ "type": "cosine",
+ "warmup_ratio": 0.1,
+ },
+ },
+ "data": {
+ "train_path": dataset_path,
+ "val_path": dataset_path, # Use same dataset for validation
+ "max_prompt_length": 768,
+ "max_response_length": 512,
+ "batch_size": self.batch_size,
+ "gradient_accumulation_steps": self.gradient_accumulation_steps,
+ },
+ "grpo": {
+ "cliprange": 0.2,
+ "cliprange_value": 0.2,
+ "gamma": 0.99,
+ "lam": 0.95,
+ "vf_coef": 0.1,
+ "ent_coef": 0.01,
+ "max_grad_norm": 1.0,
+ },
+ "output_dir": output_dir,
+ "seed": 42,
+ "fp16": True,
+ "bf16": False,
+ }
+
+ # Add wandb configuration if enabled
+ if self.wandb_enabled:
+ config["wandb"] = {
+ "project": "verl-reasoning-grpo",
+ "name": f"{model_name_short}-{lora_suffix}-VERL-GRPO",
+ "tags": ["reasoning", "grpo", "verl"],
+ }
+
+ return Config(config)
+
+ def setup_workers_and_resources(self, config: Config):
+ """Setup VERL workers and resource management."""
+ # Import worker classes based on strategy
+ if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
+ assert config.critic.strategy in {"fsdp", "fsdp2"}
+ from verl.workers.fsdp_workers import ActorRolloutRefWorker, CriticWorker
+ from verl.single_controller.ray import RayWorkerGroup
+
+ ray_worker_group_cls = RayWorkerGroup
+ else:
+ raise NotImplementedError(
+ f"Strategy {config.actor_rollout_ref.actor.strategy} not supported"
+ )
+
+ # Define role-worker mapping
+ role_worker_mapping = {
+ Role.ActorRollout: ActorRolloutRefWorker,
+ Role.Critic: CriticWorker,
+ Role.RefPolicy: ActorRolloutRefWorker,
+ }
+
+ # Define resource pools
+ global_pool_id = "global_pool"
+ resource_pool_spec = {
+ global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
+ }
+
+ mapping = {
+ Role.ActorRollout: global_pool_id,
+ Role.Critic: global_pool_id,
+ Role.RefPolicy: global_pool_id,
+ }
+
+ resource_pool_manager = ResourcePoolManager(
+ resource_pool_spec=resource_pool_spec, mapping=mapping
+ )
+
+ return role_worker_mapping, resource_pool_manager, ray_worker_group_cls
+
+ def print_directory_info(self) -> None:
+ """Print information about workspace directories."""
+ print("📁 VERL Workspace Configuration:")
+ print(f" Workspace Directory: {self.workspace_dir}")
+ print(f" Models Directory: {self.models_dir}")
+ print(f" Data Directory: {self.data_dir}")
+ print(f" Cache Directory: {self.cache_dir}")
+ print(f" Model: {self.model_name}")
+ print("-" * 50)
+
+ def train(self) -> None:
+ """Execute the VERL GRPO training process."""
+ # Print directory information
+ self.print_directory_info()
+
+ # Create VERL configuration
+ config = self.create_verl_config()
+
+ # Setup workers and resources
+ (role_worker_mapping, resource_pool_manager, ray_worker_group_cls) = (
+ self.setup_workers_and_resources(config)
+ )
+
+ # Initialize tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(self.model_name)
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+
+ # Create reward manager
+ reward_fn = ReasoningRewardManager(tokenizer=tokenizer, num_examine=1)
+ val_reward_fn = ReasoningRewardManager(tokenizer=tokenizer, num_examine=0)
+
+ # Initialize VERL GRPO trainer
+ trainer = RayGRPOTrainer(
+ config=config,
+ tokenizer=tokenizer,
+ role_worker_mapping=role_worker_mapping,
+ resource_pool_manager=resource_pool_manager,
+ ray_worker_group_cls=ray_worker_group_cls,
+ reward_fn=reward_fn,
+ val_reward_fn=val_reward_fn,
+ )
+
+ # Initialize workers and start training
+ print("🚀 Initializing VERL workers...")
+ trainer.init_workers()
+
+ print("🎯 Starting VERL GRPO training...")
+ trainer.fit()
+
+
+def parse_arguments() -> argparse.Namespace:
+ """Parse command-line arguments."""
+ parser = argparse.ArgumentParser(
+ description="VERL GRPO Training Script for Reasoning Tasks",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+
+ parser.add_argument(
+ "--model-size",
+ type=str,
+ choices=["0.5B", "1.5B", "3B", "4B"],
+ default="4B",
+ help="Model size to use for training",
+ )
+
+ parser.add_argument(
+ "--use-lora",
+ action="store_true",
+ help="Enable LoRA for efficient fine-tuning",
+ )
+
+ parser.add_argument(
+ "--disable-wandb",
+ action="store_true",
+ help="Disable wandb logging",
+ )
+
+ parser.add_argument(
+ "--max-steps",
+ type=int,
+ default=500,
+ help="Maximum training steps",
+ )
+
+ parser.add_argument("--batch-size", type=int, default=4, help="Training batch size")
+
+ parser.add_argument(
+ "--learning-rate",
+ type=float,
+ default=1e-5,
+ help="Learning rate for training",
+ )
+
+ parser.add_argument(
+ "--gradient-accumulation-steps",
+ type=int,
+ default=16,
+ help="Number of gradient accumulation steps",
+ )
+
+ parser.add_argument(
+ "--hf-token",
+ type=str,
+ default=None,
+ help="Hugging Face token for accessing gated repositories",
+ )
+
+ return parser.parse_args()
+
+
+def main():
+ """Main entry point for the VERL GRPO training script."""
+ # Parse command-line arguments
+ args = parse_arguments()
+
+ print("🚀 Starting VERL GRPO training with:")
+ print(f" Model: {args.model_size}")
+ lora_status = "Enabled" if args.use_lora else "Disabled"
+ wandb_status = "Disabled" if args.disable_wandb else "Enabled"
+ hf_token_status = "Provided" if args.hf_token else "Not provided"
+ print(f" LoRA: {lora_status}")
+ print(f" Wandb: {wandb_status}")
+ print(f" HF Token: {hf_token_status}")
+ print(f" Max Steps: {args.max_steps}")
+ print(f" Batch Size: {args.batch_size}")
+ print(f" Learning Rate: {args.learning_rate}")
+ print(f" Gradient Accumulation Steps: {args.gradient_accumulation_steps}")
+ print("-" * 50)
+
+ # Create and run trainer
+ trainer = ReasoningGRPOVERLTrainer(
+ model_size=args.model_size,
+ use_lora=args.use_lora,
+ wandb_enabled=not args.disable_wandb,
+ max_steps=args.max_steps,
+ batch_size=args.batch_size,
+ learning_rate=args.learning_rate,
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ hf_token=args.hf_token,
+ )
+
+ trainer.train()
+
+
+if __name__ == "__main__":
+ main()