Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/docs/tutorials/rl_multihop/index.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@
" num_threads=16,\n",
" use_train_as_val=False,\n",
" num_steps_for_val=10,\n",
" train_kwargs=config.to_dict(),#for now maintain backward compatability can change this in later changes\n",
" config=config,#now using GRPOConfig directly\n",
" report_train_scores=False,\n",
" gpu_config=MultiGPUConfig(num_inference_gpus=1, num_training_gpus=1),\n",
")\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/docs/tutorials/rl_papillon/index.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@
" num_threads=24,\n",
" use_train_as_val=False,\n",
" num_steps_for_val=10,\n",
" train_kwargs=config.to_dict(),#backwards compatability for now\n",
" config=config,#now using GRPOConfig directly\n",
" report_train_scores=False,\n",
" gpu_config=MultiGPUConfig(num_inference_gpus=2, num_training_gpus=2),\n",
")\n",
Expand Down
6 changes: 2 additions & 4 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,16 +227,14 @@ def thread_function_wrapper():

return job

def reinforce(
self, train_kwargs, gpu_config: MultiGPUConfig = MultiGPUConfig(num_inference_gpus=1, num_training_gpus=1)
) -> ReinforceJob:
def reinforce(self, config, gpu_config: MultiGPUConfig = MultiGPUConfig(num_inference_gpus=1, num_training_gpus=1)) -> ReinforceJob:
# TODO(GRPO Team): Should we return an initialized job here?
from dspy import settings as settings

err = f"Provider {self.provider} does not implement the reinforcement learning interface."
assert self.provider.reinforceable, err

job = self.provider.ReinforceJob(lm=self, train_kwargs=train_kwargs, gpu_config=gpu_config)
job = self.provider.ReinforceJob(lm=self, config=config, gpu_config=gpu_config)
job.initialize()
return job

Expand Down
10 changes: 5 additions & 5 deletions dspy/clients/lm_local_arbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ def status(self) -> TrainingStatus:


class ArborReinforceJob(ReinforceJob):
def __init__(self, lm: "LM", train_kwargs: GRPOConfig, gpu_config: MultiGPUConfig = MultiGPUConfig(num_inference_gpus=1, num_training_gpus=1)):
def __init__(self, lm: "LM", config: GRPOConfig, gpu_config: MultiGPUConfig = MultiGPUConfig(num_inference_gpus=1, num_training_gpus=1)):
# The teleprompter must ensure that this is set
if not isinstance(train_kwargs, GRPOConfig):
raise TypeError(f"Expected train_kwargs to be of type GRPOConfig, but got {type(train_kwargs)}")
if not isinstance(config, GRPOConfig):
raise TypeError(f"Expected config to be of type GRPOConfig, but got {type(config)}")

self.lm = lm
self.train_kwargs: GRPOConfig = train_kwargs
self.config: GRPOConfig = config
self.provider_job_id = None
self.checkpoints = {}
self.last_checkpoint = None
Expand All @@ -65,7 +65,7 @@ def initialize(self):
gpu_config_type = "multi"

# Create data payload from GRPOConfig
data = self.train_kwargs.to_dict()
data = self.config.to_dict()
data["model"] = finetune_model
data["gpu_config"] = {
"type": gpu_config_type,
Expand Down
4 changes: 2 additions & 2 deletions dspy/clients/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def status(self):


class ReinforceJob:
def __init__(self, lm: "LM", train_kwargs: dict[str, Any] | None = None, gpu_config: MultiGPUConfig = MultiGPUConfig(num_inference_gpus=1, num_training_gpus=1)):
def __init__(self, lm: "LM", config: dict[str, Any] | None = None, gpu_config: MultiGPUConfig = MultiGPUConfig(num_inference_gpus=1, num_training_gpus=1)):
self.lm = lm
self.train_kwargs = train_kwargs or {}
self.config = config or {}
self.gpu_config = gpu_config
self.checkpoints = {}
self.last_checkpoint = None
Expand Down
57 changes: 47 additions & 10 deletions dspy/teleprompt/grpo/grpo.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import logging
import random
from collections import Counter
from collections import Counter, defaultdict
from typing import Any, Callable, Literal

from dspy.adapters.base import Adapter
from dspy.adapters.chat_adapter import ChatAdapter
from dspy.clients.lm import LM
from dspy.clients.utils_finetune import GRPOGroup, MultiGPUConfig, TrainDataFormat
from dspy.teleprompt.grpo.grpo_config import GRPOConfig
from dspy.dsp.utils.settings import settings
from dspy.evaluate.evaluate import Evaluate
from dspy.primitives.example import Example
Expand All @@ -26,7 +27,7 @@ def __init__(
self,
metric: Callable | None = None,
multitask: bool = True,
train_kwargs: dict[str, Any] | dict[LM, dict[str, Any]] | None = None,
config: GRPOConfig | dict[LM, GRPOConfig] | None = None,
adapter: Adapter | dict[LM, Adapter] | None = None,
exclude_demos: bool = False,
num_threads: int = 6,
Expand All @@ -43,7 +44,8 @@ def __init__(
variably_invoked_predictor_fill_strategy: Literal["randint"] | Literal["max"] | None = None,
gpu_config: MultiGPUConfig = MultiGPUConfig(num_inference_gpus=1, num_training_gpus=1),
):
super().__init__(train_kwargs=train_kwargs)
# Store GRPOConfig objects for internal use
self.grpo_configs: dict[LM, GRPOConfig] = self._convert_to_grpo_config_dict(config)
self.metric = metric
self.multitask = multitask
self.adapter: dict[LM, Adapter] = self.convert_to_lm_dict(adapter)
Expand Down Expand Up @@ -79,6 +81,43 @@ def __init__(
self.shuffled_trainset_ids = []
self.epoch = -1
self.id_freqs = Counter()

def _convert_to_grpo_config_dict(self, config) -> dict[LM, GRPOConfig]:
"""
Convert config to GRPOConfig dict format for internal use.

For consistency purposes this returns a defaultdict such that it can be accessed for each LM,
"""
if config is None:
return defaultdict(lambda: GRPOConfig(num_generations=1)) # Default config

if isinstance(config, GRPOConfig):
return defaultdict(lambda: config)

if isinstance(config, dict):
# Check if it's a defaultdict first
if hasattr(config, 'default_factory'):
# It's a defaultdict, check the default factory
default_value = config.default_factory()
if isinstance(default_value, GRPOConfig):
return config
else:
raise ValueError(f"defaultdict default factory must return GRPOConfig, got {type(default_value)}")

# Regular dict - check if all keys are LMs
if config and all(isinstance(k, LM) for k in config.keys()):
# LM dict format
result = {}
for lm, config_obj in config.items():
if isinstance(config_obj, GRPOConfig):
result[lm] = config_obj
else:
raise ValueError(f"All values in LM dict must be GRPOConfig instances, got {type(config_obj)}")
return result
else:
raise ValueError(f"Dictionary config must have LM keys, got keys of type {type(next(iter(config.keys()))) if config else 'empty dict'}")

raise ValueError(f"config must be GRPOConfig, dict[LM, GRPOConfig], or None, got {type(config)}")

def validate_trace_data_and_log_issues(
self,
Expand Down Expand Up @@ -317,12 +356,10 @@ def compile(
for t in teachers:
disable_lm_cache(program=t, lm_cache_dict=lm_cache_dict)

# Update train_kwargs
# Update config - now using GRPOConfig objects
for pred in student.predictors():
train_kwargs = self.train_kwargs[pred.lm]
train_kwargs = {} if train_kwargs is None else train_kwargs
train_kwargs["num_generations"] = self.num_rollouts_per_grpo_step
self.train_kwargs[pred.lm] = train_kwargs
config = self.grpo_configs[pred.lm]
config.num_generations = self.num_rollouts_per_grpo_step

# We need to have a separate job for each unique LM x the data
# collection strategy. This properly handles all combinations of
Expand All @@ -333,8 +370,8 @@ def compile(
data_key = None if self.multitask else pred_ind
job_key = (pred.lm, data_key)
if job_key not in grpo_training_jobs:
train_kwargs = self.train_kwargs[pred.lm]
job = pred.lm.reinforce(train_kwargs=train_kwargs, gpu_config=self.gpu_config)
config = self.grpo_configs[pred.lm]
job = pred.lm.reinforce(config=config, gpu_config=self.gpu_config)
grpo_training_jobs[job_key] = job

self.report_validation_metrics(
Expand Down
16 changes: 13 additions & 3 deletions tests/clients/test_lm_local_arbor.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,15 @@ def test_init_with_grpo_config(self, mock_lm, grpo_config, gpu_config):
job = ArborReinforceJob(mock_lm, grpo_config, gpu_config)

assert job.lm == mock_lm
assert isinstance(job.train_kwargs, GRPOConfig)
assert isinstance(job.config, GRPOConfig)

assert job.provider_job_id is None
assert job.checkpoints == {}
assert job.last_checkpoint is None

def test_init_with_invalid_train_kwargs(self, mock_lm, gpu_config):
"""Test that init raises TypeError with non-GRPOConfig"""
with pytest.raises(TypeError, match="Expected train_kwargs to be of type GRPOConfig"):
with pytest.raises(TypeError, match="Expected config to be of type GRPOConfig"):
ArborReinforceJob(mock_lm, {"invalid": "dict"}, gpu_config)

@patch('requests.post')
Expand Down Expand Up @@ -307,8 +307,12 @@ def test_is_terminal_training_status(self):

@patch('dspy.clients.lm_local_arbor.openai.fine_tuning.jobs.retrieve')
@patch.object(ArborProvider, '_get_arbor_base_api', return_value="http://localhost:8000/v1/")
def test_get_training_status(self, mock_api, mock_retrieve):
@patch.object(ArborProvider, 'does_job_exist', return_value=True)
def test_get_training_status(self, mock_does_job_exist, mock_api, mock_retrieve):
"""Test getting training status."""
# Reset the mock to ensure clean state
mock_retrieve.reset_mock()

mock_job = Mock()
mock_job.status = "running"
mock_retrieve.return_value = mock_job
Expand All @@ -317,6 +321,7 @@ def test_get_training_status(self, mock_api, mock_retrieve):

assert status == TrainingStatus.running
mock_retrieve.assert_called_once_with("job-123")
mock_does_job_exist.assert_called_once_with("job-123", {})

def test_get_training_status_no_job(self):
"""Test getting status when no job exists."""
Expand All @@ -331,6 +336,11 @@ class TestArborWorkflows:
@patch.object(ArborProvider, '_get_arbor_base_api', return_value="http://localhost:8000/v1/")
def test_grpo_workflow(self, mock_api, mock_retrieve, mock_post, mock_lm, grpo_config, gpu_config):
"""Test a complete GRPO workflow."""
# Reset mocks to ensure clean state
mock_retrieve.reset_mock()
mock_post.reset_mock()

# Mock responses
init_response = Mock()
init_response.status_code = 200
init_response.json.return_value = {
Expand Down
Loading