Skip to content

Commit cc214e2

Browse files
author
lilinchuan
committed
[RFC] Recipe for a agent-lightning like RL training pipeline #3434
1 parent 634bd93 commit cc214e2

29 files changed

+2297
-0
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2025 Individual Contributor: linxxx3 ([email protected])
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2025 Individual Contributor: linxxx3 ([email protected])
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from abc import ABC, abstractmethod
16+
from typing import Any
17+
18+
19+
class AgentClientBase(ABC):
20+
"""Agent client base class."""
21+
22+
def __init__(self, server_address: str, **kwargs):
23+
if server_address.startswith("http"):
24+
self.server_address_full = server_address
25+
else:
26+
self.server_address_full = f"http://{server_address}"
27+
28+
@abstractmethod
29+
async def chat(self, trace_id: str, sampling_params: dict[str, Any], **kwargs) -> Any:
30+
"""Custom chat function.
31+
Note: use async http client like aiohttp in this function, to avoid blocking the event loop.
32+
Args:
33+
trace_id: trace id for collecting the trajectory
34+
sampling_params: sampling parameters, e.g., temperature, top_p, max_tokens, etc.
35+
**kwargs: non-tensor fields of a data sample from RLHFDataset
36+
"""
37+
raise NotImplementedError
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Copyright 2025 Individual Contributor: linxxx3 ([email protected])
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import asyncio
16+
import logging
17+
import os
18+
import random
19+
from typing import Any, cast
20+
from uuid import uuid4
21+
22+
import hydra
23+
import ray
24+
from omegaconf import DictConfig, OmegaConf
25+
26+
from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput
27+
28+
from .trajectory import Trajectory
29+
30+
logger = logging.getLogger(__file__)
31+
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
32+
33+
34+
class LightningAgentLoop(AgentLoopBase):
35+
@classmethod
36+
def init_class(cls, config, tokenizer, processor, **kwargs):
37+
if cls._class_initialized:
38+
return
39+
cls._validate_config(config)
40+
cls.agent_client_config = OmegaConf.load(config.lightning_trainer.agent_client_config_path)
41+
logger.info(f"LightningAgentLoop using agent_server_addr: {config.lightning_trainer.agent_server_addr}")
42+
cls.max_turns = config.actor_rollout_ref.rollout.multi_turn.max_assistant_turns
43+
cls._class_initialized = True
44+
45+
@classmethod
46+
def _validate_config(cls, config: DictConfig):
47+
assert config.get("lightning_trainer") is not None, "config.lightning_trainer is required"
48+
assert config.lightning_trainer.model_name
49+
assert config.lightning_trainer.agent_server_addr
50+
assert config.lightning_trainer.agent_client_config_path
51+
52+
async def run(self, sampling_params: dict, **kwargs) -> AgentLoopOutput:
53+
model_name = self.config.lightning_trainer.model_name
54+
client = hydra.utils.instantiate(
55+
self.agent_client_config,
56+
server_address=self.config.lightning_trainer.agent_server_addr,
57+
)
58+
59+
async def _wait_random(min_seconds: int = 0, max_seconds: int = 3):
60+
wait_time = random.uniform(min_seconds, max_seconds)
61+
await asyncio.sleep(wait_time)
62+
63+
trace_id = str(uuid4())
64+
resp = None
65+
try:
66+
await _wait_random() # avoid large amount of simultaneous requests
67+
logger.debug(f"AgentClient sending request {trace_id=}, {sampling_params=}")
68+
resp = await client.chat(
69+
trace_id=trace_id, sampling_params=sampling_params, max_turns=self.max_turns, **kwargs
70+
)
71+
logger.debug(f"AgentClient final response {trace_id=}: {resp}")
72+
except Exception as e:
73+
import traceback
74+
75+
# client.chat should not raise exception
76+
logger.error(f"Error in client.chat, should not happen: {e}")
77+
traceback.print_exc()
78+
79+
llm_router = ray.get_actor("LLMRouter") # get LLMRouter handler by name
80+
assert llm_router is not None, "LLMRouter actor not found"
81+
trajactory = await llm_router.retrieve_trajectory.remote(model_name=model_name, trace_id=trace_id)
82+
logger.debug(f"Retrieved trajectory for {trace_id=}: {trajactory}")
83+
84+
output = None
85+
if trajactory is None:
86+
logger.error(f"Trajectory not found for model: {model_name}, trace_id: {trace_id}")
87+
try:
88+
trajactory = cast(Trajectory, trajactory)
89+
output = _trajectory_to_agent_loop_output(trajactory, resp)
90+
except Exception as e:
91+
logger.error(f"Invalid trajectory for model: {model_name}, trace_id: {trace_id}, error: {e}")
92+
if output is None:
93+
output = _create_empty_agent_loop_output(
94+
trace_id=trace_id,
95+
model_name=model_name,
96+
prompt_length=self.config.actor_rollout_ref.rollout.prompt_length,
97+
response_length=self.config.actor_rollout_ref.rollout.response_length,
98+
pad_token_id=self.tokenizer.pad_token_id,
99+
final_response=resp,
100+
)
101+
102+
## maybe compute score here
103+
## fill in output.reward_score and output.extra_fields["reward_extra_info"]
104+
return self._postprocess(output)
105+
106+
def _postprocess(self, output: AgentLoopOutput) -> AgentLoopOutput:
107+
max_response_length = self.config.actor_rollout_ref.rollout.response_length
108+
109+
output.response_ids = output.response_ids[:max_response_length]
110+
output.response_mask = output.response_mask[:max_response_length]
111+
assert len(output.response_ids) == len(output.response_mask)
112+
113+
if output.response_logprobs:
114+
output.response_logprobs = output.response_logprobs[:max_response_length]
115+
assert len(output.response_ids) == len(output.response_logprobs)
116+
117+
return output
118+
119+
120+
def _trajectory_to_agent_loop_output(trajectory: Trajectory, final_response: Any) -> AgentLoopOutput:
121+
last_item = trajectory.get_last_item()
122+
if last_item is None:
123+
raise ValueError(f"Trajectory is empty, model: {trajectory.model_name}, trace_id: {trajectory.trace_id}")
124+
125+
## TODO: metrics
126+
output = AgentLoopOutput(
127+
prompt_ids=last_item.prompt_ids,
128+
response_ids=last_item.response_ids,
129+
response_mask=last_item.response_mask,
130+
response_logprobs=None,
131+
reward_score=None,
132+
num_turns=len(trajectory.items),
133+
metrics={},
134+
extra_fields={
135+
"model_name": trajectory.model_name,
136+
"trace_id": trajectory.trace_id,
137+
"final_response": final_response,
138+
},
139+
)
140+
return output
141+
142+
143+
def _create_empty_agent_loop_output(
144+
trace_id: str, model_name: str, prompt_length: int, response_length: int, pad_token_id: int, final_response: Any
145+
) -> AgentLoopOutput:
146+
"""Create an empty AgentLoopOutput, with padding response_ids and response_mask."""
147+
return AgentLoopOutput(
148+
prompt_ids=[pad_token_id] * prompt_length,
149+
response_ids=[pad_token_id] * response_length,
150+
response_mask=[0] * response_length,
151+
response_logprobs=None,
152+
reward_score=None,
153+
num_turns=0,
154+
metrics={},
155+
extra_fields={
156+
"model_name": model_name,
157+
"trace_id": trace_id,
158+
"final_response": final_response,
159+
},
160+
)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
hydra:
2+
searchpath:
3+
- file://verl/trainer/config
4+
5+
defaults:
6+
- ppo_trainer
7+
- _self_
8+
9+
# config for the rollout
10+
rollout:
11+
agent:
12+
# custom agent loop class, should be a subclass of AgentLoopBase
13+
agent_loop_config_path: null
14+
15+
# custom config for the agent-lightning-like trainer
16+
lightning_trainer:
17+
18+
# model name used in the agent server
19+
model_name: Default
20+
21+
# custom agent client class, should be a subclass of AgentClientBase
22+
agent_client_config_path: null
23+
24+
# standalone custom agent server address, with format of "ip:port" or "https://ip:port"
25+
agent_server_addr: null
26+
27+
# health-check url path of agent server
28+
health_check_url: /health
29+
30+
health_check_timeout: 60
31+
32+
# request header name for trace_id, used by llm_router to manage trajectories.
33+
# the header should be included in the request chain from agent client to agent server to llm_router
34+
request_header_trace_id: trace_id
35+
36+
# tool call parser, used by llm_router to extract tool calls from model responses
37+
# inherit from rollout.multi_turn.format by default
38+
tool_call_parser: ${oc.select:actor_rollout_ref.rollout.multi_turn.format,hermes}
39+
40+
# reasoning parser, used by lm_router to extract reasoning_content from model responses
41+
reasoning_parser: null
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# Agent-Lightning-like RL training Example
2+
3+
Agent-Lightning-like is a RL training recipe inspire by Agent Lightning (https://arxiv.org/abs/2508.03680). You can train almost **ANY** agent by writing a few lines of codes. More important, the agent can run in an independent Python environment or even on a separate machine, as a service. That makes the training simpler, especially when you have a complex agent system.
4+
5+
Here is a tiny example to demonstrate how to use this recipe. The example uses OpenAI agent-sdk, but this recipe does not restrict on which framework you use to write the agent.
6+
7+
## Prepare agent server
8+
9+
Wrap the agent as a http service, if you don't have one. As an example, `agent_server.py` demonstrates how to set up a `/chat` API endpoint, which features an integrated `calc_gsm8k_reward` tool.
10+
11+
We need to inject two elements into the Agent, the LLM service url and additional request headers.
12+
13+
The LLM service url is provided after veRL training started. The agent gets it by calling `get_llm_server_address` defined in `recipe/agent_lightning_like/notify.py`:
14+
15+
```python
16+
# model_provider.py
17+
DEFAULT_MODEL_NAME = "Default"
18+
19+
class CustomModelProvider(ModelProvider):
20+
def get_model(self, model_name: str | None) -> Model:
21+
model_configs = get_model_configs()
22+
model_name = model_name or DEFAULT_MODEL_NAME
23+
if model_name not in model_configs:
24+
raise ValueError(f"Model {model_name} not found in model configs: {model_configs.keys()}")
25+
config = model_configs[model_name]
26+
base_url = config["base_url"]
27+
api_key = config.get("api_key", "")
28+
client = AsyncOpenAI(base_url=base_url, api_key=api_key)
29+
return OpenAIChatCompletionsModel(model=model_name, openai_client=client)
30+
31+
32+
def get_model_configs():
33+
"""Demo: get model configurations from LLM_SERVER_NOTIFY_FILE."""
34+
server_address = get_llm_server_address()
35+
base_url = f"http://{server_address}/v1"
36+
model_configs = {
37+
DEFAULT_MODEL_NAME: {
38+
"base_url": base_url,
39+
"api_key": "",
40+
},
41+
}
42+
return model_configs
43+
```
44+
45+
The `LLM_SERVER_NOTIFY_FILE` env set the file that pass the llm endpoint from the trainer to the agent server.
46+
47+
An additional request header with a name "trace_id" is included in the request context, and we need to pass it to the LLM server. We do it by setting `extra_headers` in `model_settings`.
48+
49+
```python
50+
# agent_server.py
51+
52+
@app.post("/chat")
53+
async def chat(request: Annotated[ChatRequest, fastapi.Body()]):
54+
"""A demo chat function."""
55+
context = request.context
56+
model_provider = CustomModelProvider()
57+
extra_headers = request.extra_headers or {}
58+
extra_headers.update({HEADER_TRACE_ID: context.trace_id})
59+
model_settings = ModelSettings(
60+
temperature=request.temperature,
61+
top_p=request.top_p,
62+
max_tokens=request.max_tokens,
63+
extra_headers=extra_headers, # inject trace_id here
64+
extra_body=request.extra_body or {},
65+
)
66+
agent = Agent[UserContext](
67+
name="Assistant",
68+
instructions=request.system_prompt or "You are a helpful assistant.",
69+
tools=[calc_gsm8k_reward],
70+
)
71+
# ......
72+
```
73+
74+
## Write agent client
75+
76+
Trainer uses a client to send prompts to the agent server, that starts the rollout. The client shall implement an async `chat` method, like the demo `agent_client.py`. The `chat` method is expected to throw no exceptions.
77+
78+
```python
79+
# agent_client.py
80+
81+
class AgentClient(AgentClientBase):
82+
83+
async def chat(self, trace_id: str, sampling_params: dict[str, Any], **kwargs) -> Any:
84+
# kwargs include "max_turns" and non-tensor fields of a data sample from RLHFDataset
85+
# ...
86+
# async send request to agent server
87+
```
88+
89+
## Prepare dataset
90+
Let's prepare two small datasets for training and evaluation:
91+
```bash
92+
python examples/data_preprocess/gsm8k_tool_agent_loop.py
93+
```
94+
95+
We use a simple `CustomDataset` class defined in `dataset.py` to adapt the "agent_name" field in the generated dataset with the one we define in `agent_loop.yaml`.
96+
97+
```python
98+
# dataset.py
99+
from verl.utils.dataset import RLHFDataset
100+
101+
class CustomDataset(RLHFDataset):
102+
"""A custom dataset for the agent-lightning-like example."""
103+
def __getitem__(self, item):
104+
row_dict = super().__getitem__(item)
105+
row_dict["agent_name"] = "lightning_demo" # must match the name in agent_loop.yaml
106+
row_dict.pop("tools_kwargs", None) # remove tools_kwargs if exists, tools defined in agent server side
107+
return row_dict
108+
```
109+
110+
## Training
111+
112+
Prepare these yaml config file if you train your own agent: `agent_loop.yaml`, `agent_client.yaml`, `recipe/agent_lightning_like/config/lightning_ppo_trainer.yaml`, and write a start script.
113+
114+
Run this demo example:
115+
116+
```bash
117+
bash recipe/agent_lightning_like/example/run_qwen2.5_7b.sh 2>&1 | tee run.log
118+
```
119+
120+
You probably need a 8-GPU node for this example, or choose a smaller model.
121+
122+
The validation score is expected to reach about 93.6/100 after training one epoch.
123+
124+
## Testing
125+
126+
There are some CI tests in `recipe/agent_lightning_like/test`.
127+
128+
Run a test:
129+
130+
```bash
131+
PYTHONPATH=$(pwd) pytest -s recipe/agent_lightning_like/test/test_xxx.py
132+
```
133+

0 commit comments

Comments
 (0)