Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/agentscope_frozenlake/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(self, model: OpenAIChatModel, max_steps: int = 20):
formatter=OpenAIChatFormatter(),
max_iters=2,
)
self.agent.set_console_output_enabled(False)
self.response_structure = FrozenLakeAction
self.current_step = 0
self.last_action = None
Expand Down
92 changes: 92 additions & 0 deletions examples/agentscope_frozenlake/multi_step_padding.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
project: "FrozenLake"
name: "Qwen25-3B-padding"
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
algorithm:
algorithm_type: multi_step_grpo
repeat_times: 8
kl_loss_fn: "low_var_kl"
kl_loss_fn_args:
kl_coef: 0
advantage_fn_args:
epsilon: 1e-6
optimizer:
lr: 1e-6
policy_loss_fn_args:
clip_range_low: 0.2
clip_range_high: 0.28
data_processor:
experience_pipeline:
operators: # NOTE
- name: multi_step_padding
args:
max_steps: 10
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-3B-Instruct}
max_response_tokens: 2048
min_response_tokens: 0
max_model_len: 25600
temperature: 0.7
cluster:
node_num: 1
gpu_per_node: 8
buffer:
total_epochs: 5
batch_size: 64
train_batch_size: 5120 # NOTE: 64 * 8 * 10 = batch_size * repeat_times * max_steps
explorer_input:
taskset:
name: frozenlake
storage_type: file
path: ${oc.env:TRINITY_TASKSET_PATH}
split: train
workflow_args:
env_max_steps: 8
agent_max_steps: 10
is_slippery: false
eval_tasksets:
- name: frozenlake
storage_type: file
path: ${oc.env:TRINITY_TASKSET_PATH}
split: test
workflow_args:
env_max_steps: 8
agent_max_steps: 10
is_slippery: false
repeat_times: 4
rollout_args:
top_p: 0.8
top_k: 20
default_workflow_type: 'examples.agentscope_frozenlake.workflow.FrozenLakeWorkflow'
trainer_input:
experience_buffer:
name: frozenlake_experience_buffer
storage_type: queue
max_read_timeout: 7200
explorer:
eval_on_startup: true
eval_interval: 20
runner_per_model: 6
max_repeat_times_per_runner: 4
rollout_model:
engine_num: 4
tensor_parallel_size: 1
enable_chunked_prefill: true
enforce_eager: false
enable_openai_api: true
enable_log_requests: false
enable_history: true
enable_auto_tool_choice: true
tool_call_parser: hermes
enable_thinking: true
dtype: bfloat16
seed: 42
trainer:
save_interval: 50
use_dynamic_bsz: true
grad_clip: 1.0
ulysses_sequence_parallel_size: 2
synchronizer:
sync_method: nccl
sync_style: fixed
sync_interval: 1
sync_timeout: 1200
1 change: 1 addition & 0 deletions trinity/buffer/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"pass_rate_calculator": "trinity.buffer.operators.mappers.pass_rate_calculator.PassRateCalculator",
"data_juicer": "trinity.buffer.operators.data_juicer_operator.DataJuicerOperator",
"invalid_reward_filter": "trinity.buffer.operators.filters.reward_filter.InvalidRewardFilter",
"multi_step_padding": "trinity.buffer.operators.multi_step_operator.MultiStepPadding",
},
)

Expand Down
105 changes: 105 additions & 0 deletions trinity/buffer/operators/multi_step_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import List, Tuple

import torch

from trinity.buffer.operators import ExperienceOperator
from trinity.common.experience import EID, Experience, group_by
from trinity.utils.log import get_logger

logger = get_logger(__name__)


class MultiStepPadding(ExperienceOperator):
"""
Padding experiences of one run to the max step.

Note: This operator assumes that the reward is already calculated and stored in the Experience object.
"""

def __init__(self, max_steps: int = 0):
self.max_steps = max_steps

def process(self, exps: List[Experience]) -> Tuple[List[Experience], dict]:
"""Padding each rollout to the max step."""
logger.debug(f"Processing {len(exps)} experiences")
total_num_placeholder_exps = 0
all_exps = []

task_exps = group_by(exps, "task")
for _, task_exp in task_exps.items():
run_exps = group_by(task_exp, "run")
for _, exps_same_run in run_exps.items():
if len(exps_same_run) == 0:
continue
num_placeholder_exps = 0
if len(exps_same_run) < self.max_steps:
num_placeholder_exps = self.max_steps - len(exps_same_run)
# Calculate average response length to keep metrics unchanged
assert all(
exp.tokens is not None for exp in exps_same_run
), "Tokens are not provided"
response_lengths = [
len(exp.tokens) - exp.prompt_length for exp in exps_same_run # type: ignore
]
avg_response_length = int(sum(response_lengths) / len(response_lengths))
# Ensure at least 1 to avoid zero-length response
avg_response_length = max(avg_response_length, 1)

# Use the first experience as a template
template_exp = exps_same_run[0]
prompt_length = template_exp.prompt_length

# Create tokens with average response length
# Keep the prompt part, pad the response part to average length
prompt_tokens = template_exp.tokens[:prompt_length] # type: ignore
# Use the last token of prompt as padding token for response part
pad_token = prompt_tokens[-1] if len(prompt_tokens) > 0 else 0
response_tokens = torch.full(
(avg_response_length,),
pad_token,
dtype=template_exp.tokens.dtype, # type: ignore
)
avg_tokens = torch.cat([prompt_tokens, response_tokens])
avg_logprobs = (
torch.zeros(avg_response_length, dtype=torch.float32)
if template_exp.logprobs is not None
else None
)
assert all(
exp.reward is not None for exp in exps_same_run
), "Rewards are not provided"
rewards = [exp.reward for exp in exps_same_run if exp.reward is not None]
avg_reward = sum(rewards) / len(rewards)
Comment thread
hiyuchang marked this conversation as resolved.

template_eid = template_exp.eid

empty_experiences = [
Experience(
eid=EID(
batch=template_eid.batch,
task=template_eid.task,
run=template_eid.run,
step=-1,
), # -1 means placeholder
tokens=avg_tokens,
logprobs=avg_logprobs,
prompt_length=prompt_length,
action_mask=torch.zeros(avg_response_length, dtype=torch.bool),
truncate_status="prompt_truncated", # TODO: merge with the following
info={"status": "placeholder"}, # TODO: use another field
reward=avg_reward,
)
for _ in range(num_placeholder_exps)
]
logger.debug(
f"Adding {num_placeholder_exps} placeholder experiences"
)
# Put empty at the beginning, as the adv is computed using the last exp
exps_same_run = empty_experiences + exps_same_run
all_exps.extend(exps_same_run)
else:
all_exps.extend(exps_same_run)
total_num_placeholder_exps += num_placeholder_exps
metrics = {"total_num_placeholder_exps": total_num_placeholder_exps}
logger.debug(f"After padding: {len(all_exps)}")
return all_exps, metrics
15 changes: 15 additions & 0 deletions trinity/common/experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,9 @@ def gather(
else:
teacher_logprobs = None

# gather statuses
statuses = gather_statuses(experiences)

exps = Experiences(
eids=eids,
tokens=tokens,
Expand All @@ -379,6 +382,7 @@ def gather(
logprobs=logprobs,
multi_modal_inputs=multi_modal_inputs,
teacher_logprobs=teacher_logprobs,
statuses=statuses,
)
if custom_fields is not None:
for custom_field in custom_fields:
Expand Down Expand Up @@ -465,6 +469,7 @@ class Experiences:
prompt_length: int
logprobs: Optional[Tensor] # [batch_size, response_length]
multi_modal_inputs: Optional[Any]
statuses: Optional[Tensor] = None # [batch_size] # 1 for effective, 0 for placeholder
custom_fields: List[str] = field(
default_factory=list
) # Custom fields to include in the gathered experiences
Expand Down Expand Up @@ -605,6 +610,16 @@ def gather_multi_modal_inputs(experiences) -> Dict[str, Tensor]:
return {key: [exp.multi_modal_inputs[key] for exp in experiences] for key in keys}


def gather_statuses(experiences) -> Tensor:
statuses = []
for exp in experiences:
if exp.info.get("status", None) == "placeholder":
statuses.append(0)
else:
statuses.append(1)
return torch.tensor(statuses, dtype=torch.bool)
Comment thread
hiyuchang marked this conversation as resolved.


def group_by(
experiences: List[Experience], id_type: Literal["task", "run", "step"]
) -> Dict[str, List[Experience]]:
Expand Down
21 changes: 20 additions & 1 deletion trinity/trainer/verl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,35 @@
gather_action_masks,
gather_attention_masks,
gather_response_attrs,
gather_statuses,
gather_token_ids,
split_dpo_experience_to_single_turn,
)


def print_effective_experience_stats(experiences: List[Experience], logger: Logger) -> None:
"""Gather effective experience count and the corresponding reweight factor."""
statuses = gather_statuses(experiences)
effective_experiences = torch.sum(statuses).item()
batch_size = len(experiences)
if effective_experiences == 0:
effective_weight = 1.0
logger.info("No effective experiences found, using default weight 1.0")
else:
effective_weight = float(batch_size / effective_experiences)
logger.info(
f"Effective experiences: {effective_experiences}, batch size: {batch_size}, effective_weight: {effective_weight}"
)
return None


def to_data_proto(
experiences: List[Experience], pad_token_id: int, processor: ProcessorMixin, logger: Logger
) -> DataProto: # noqa: C901
"""Convert List[Experience] to verl DataProto."""

print_effective_experience_stats(experiences, logger)

assert len(experiences) > 0, "No experiences provided."
if experiences[0].experience_type == "dpo":
experiences = split_dpo_experience_to_single_turn(experiences)
Expand All @@ -49,7 +69,6 @@ def to_data_proto(
"attention_mask": attention_mask,
"response_mask": gather_action_masks(experiences, max_response_length),
}

have_reward = all(exp.reward is not None for exp in experiences)
have_token_level_reward = all(exp.token_level_reward is not None for exp in experiences)
if have_reward or have_token_level_reward:
Expand Down
Loading