Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/unittest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ jobs:
fi
fi

- name: Clean checkpoint dir
working-directory: trinity-${{ github.run_id }}/.github/workflows/docker
if: always()
run: |
docker compose exec trinity-node-1 rm -rf /mnt/checkpoints/*
continue-on-error: true

- name: Upload test results
if: env.tests_run == 'true' || failure()
uses: actions/upload-artifact@v4
Expand Down
7 changes: 4 additions & 3 deletions benchmark/config/countdown-template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ buffer:
experience_buffer:
name: experience_buffer
storage_type: queue
use_priority_queue: true
replay_buffer_kwargs:
replay_buffer:
enable: true
priority_fn: linear_decay
decay: 0.1
priority_fn_args:
decay: 0.1
explorer:
runner_per_model: 8
max_timeout: 900
Expand Down
7 changes: 4 additions & 3 deletions benchmark/config/gsm8k-template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ buffer:
experience_buffer:
name: experience_buffer
storage_type: queue
use_priority_queue: true
replay_buffer_kwargs:
replay_buffer:
enable: true
priority_fn: linear_decay
decay: 0.1
priority_fn_args:
decay: 0.1
explorer:
runner_per_model: 8
max_timeout: 900
Expand Down
15 changes: 7 additions & 8 deletions docs/sphinx_doc/source/tutorial/example_mix_algo.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,22 +85,21 @@ class MixSampleStrategy(SampleStrategy):
expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size)

# experience buffer
usual_buffer_config = copy.deepcopy(buffer_config)
usual_buffer_config.train_batch_size = tot_batch_size - expert_batch_size
self.usual_exp_buffer = get_buffer_reader(
buffer_config.trainer_input.experience_buffer, usual_buffer_config # type: ignore
)
usual_buffer_config = copy.deepcopy(buffer_config.trainer_input.experience_buffer)
usual_buffer_config.batch_size = tot_batch_size - expert_batch_size
self.usual_exp_buffer = get_buffer_reader(usual_buffer_config)

if buffer_config.trainer_input.auxiliary_buffers is None:
raise ValueError(
"`buffer_config.trainer_input.auxiliary_buffers` is required in MIX algorithm"
)

# expert experience buffer
expert_buffer_config = copy.deepcopy(buffer_config)
expert_buffer_config.train_batch_size = expert_batch_size
expert_buffer_config = copy.deepcopy(
buffer_config.trainer_input.auxiliary_buffers[self.sft_dataset_name]
)
expert_buffer_config.batch_size = expert_batch_size
self.expert_exp_buffer = get_buffer_reader(
buffer_config.trainer_input.auxiliary_buffers[self.sft_dataset_name],
expert_buffer_config,
)

Expand Down
8 changes: 6 additions & 2 deletions docs/sphinx_doc/source/tutorial/example_step_wise.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ In general multi-step scenarios, each run may generate various number of experie

- `buffer.train_batch_size`: The number of experiences to be sampled from the buffer for training, which can be different from the number of generated experiences in each explore step.

- `buffer.trainer_input.use_priority_queue = true`: Using `PriorityQueue` allows the model to use the experiences with higher priority, which prefers newly-generated experiences by default.
- `buffer.trainer_input.experience_buffer.replay_buffer`: Using `PriorityQueue` allows the model to use the experiences with higher priority, which prefers newly-generated experiences by default.

- `synchronizer.sync_style = dynamic_by_explorer`: The explorer determines when to synchronize the model weights with the trainer.

Expand Down Expand Up @@ -126,7 +126,11 @@ buffer:
experience_buffer:
name: alfworld_buffer
storage_type: queue
use_priority_queue: true
replay_buffer:
enable: true
priority_fn: linear_decay
priority_fn_args:
decay: 0.1
explorer:
max_repeat_times_per_runner: 1
runner_per_model: 32
Expand Down
15 changes: 8 additions & 7 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -273,14 +273,12 @@ The configuration for each task dataset is defined as follows:
- `name`: Name of the dataset. This name will be used as the Ray actor's name, so it must be unique.
- `storage_type`: How the dataset is stored. Options: `file`, `queue`, `sql`.
- `file`: The dataset is stored in `jsonl`/`parquet` files. The data file organization is required to meet the huggingface standard. *We recommand using this storage type for most cases.*
- `queue`: The dataset is stored in a queue. The queue is a simple FIFO queue that stores the task dataset. *Do not use this storage type for task dataset unless you know what you are doing.*
- `sql`: The dataset is stored in a SQL database. *This type is unstable and will be optimized in the future versions.*
- `path`: The path to the task dataset.
- For `file` storage type, the path points to the directory that contains the task dataset files.
- For `queue` storage type, the path is optional. You can back up the data in the queue by specifying a sqlite database path here.
- For `sql` storage type, the path points to the sqlite database file.
- `subset_name`: The subset name of the task dataset. Default is `None`.
- `split`: The split of the task dataset. Default is `train`.
- `subset_name`: The subset name of the task dataset, according to the `name` parameter in huggingface datasets `load_dataset` function. Default is `None`.
- `split`: The split of the task dataset, according to the `split` parameter in huggingface datasets `load_dataset` function. Default is `train`.
- `repeat_times`: The number of rollouts generated for a task. If not set, it will be automatically set to `algorithm.repeat_times` for `taskset`, and `1` for `eval_tasksets`.
- `rollout_args`: The parameters for rollout.
- `temperature`: The temperature for sampling.
Expand Down Expand Up @@ -324,7 +322,7 @@ buffer:
- For `queue` storage type, this field is optional. You can specify a SQLite database or JSON file path here to back up the queue data.
- For `file` storage type, the path points to the directory containing the dataset files.
- For `sql` storage type, the path points to the SQLite database file.
- `format`: Defines keys for prompts and responses in the dataset.
- `format`: Mainly for SFT and DPO algorithm datasets, used to format the extracted data.
- `prompt_type`: Specifies the type of prompts in the dataset. We support `plaintext`, `messages` for now.
- `plaintext`: The prompt is in string format.
- `messages`: The prompt is organized as a message list.
Expand All @@ -339,8 +337,11 @@ buffer:
- `enable_concatenated_multi_turn`: Enable concatenated multi-turn SFT data preprocess. Only for `messages` and only take effect with SFT algorithm.
- `chat_template`: Specifies the chat template in string format. If not provided, use `model.custom_chat_template`.
- `max_read_timeout`: The maximum waiting time (in seconds) to read new experience data. If exceeded, an incomplete batch will be returned directly. Only take effect when `storage_type` is `queue`. Default is 1800 seconds (30 minutes).
- `use_priority_queue`: Only take effect when `storage_type` is `queue`. If set to `True`, the queue will be a priority queue, which allows for prioritizing certain experiences over others. Default is `False`.
- `reuse_cooldown_time`: Only take effect when `storage_type` is `queue` and `use_priority_queue` is `True`. If set, it specifies the cooldown time (in seconds) for reusing experiences. If not specified, the default value is `None`, meaning experiences can not be reused.
- `replay_buffer`: Only take effect when `storage_type` is `queue`. Used to configure the replay buffer for experience reuse.
- `enable`: Whether to enable the replay buffer. Default is `false`.
- `reuse_cooldown_time`: Cooldown time (in seconds) for reusing experiences. If not specified, the default value is `None`, meaning experiences can not be reused.
- `priority_fn`: Experience priority function used to determine the order of experience reuse. Currently supports `linear_decay` and `linear_decay_use_count_control_randomization`.
- `priority_fn_args`: A dictionary of arguments passed to the priority function, specific parameters depend on the selected priority function.
- `auxiliary_buffers`: Optional buffers used for trainer. It is a dictionary where each key is the buffer name and the value is the buffer configuration. Each buffer configuration is similar to the `experience_buffer`.

---
Expand Down
15 changes: 7 additions & 8 deletions docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,22 +77,21 @@ class MixSampleStrategy(SampleStrategy):
expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size)

# experience buffer
usual_buffer_config = copy.deepcopy(buffer_config)
usual_buffer_config.train_batch_size = tot_batch_size - expert_batch_size
self.usual_exp_buffer = get_buffer_reader(
buffer_config.trainer_input.experience_buffer, usual_buffer_config # type: ignore
)
usual_buffer_config = copy.deepcopy(buffer_config.trainer_input.experience_buffer)
usual_buffer_config.batch_size = tot_batch_size - expert_batch_size
self.usual_exp_buffer = get_buffer_reader(usual_buffer_config)

if buffer_config.trainer_input.auxiliary_buffers is None:
raise ValueError(
"`buffer_config.trainer_input.auxiliary_buffers` is required in MIX algorithm"
)

# expert experience buffer
expert_buffer_config = copy.deepcopy(buffer_config)
expert_buffer_config.train_batch_size = expert_batch_size
expert_buffer_config = copy.deepcopy(
buffer_config.trainer_input.auxiliary_buffers[self.sft_dataset_name]
)
expert_buffer_config.batch_size = expert_batch_size
self.expert_exp_buffer = get_buffer_reader(
buffer_config.trainer_input.auxiliary_buffers[self.sft_dataset_name],
expert_buffer_config,
)

Expand Down
8 changes: 5 additions & 3 deletions docs/sphinx_doc/source_zh/tutorial/example_step_wise.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class StepWiseAlfworldWorkflow(RewardPropagationWorkflow):

- `buffer.train_batch_size`:从 buffer 中采样用于训练的 experience 数量,可以与每次探索生成的 experience 数量不同。

- `buffer.trainer_input.use_priority_queue = true`:使用 `PriorityQueue` 可使模型优先使用高优先级的 experience (默认为使用更新产生的 experience)。
- `buffer.trainer_input.experience_buffer.replay_buffer`:使用 `PriorityQueue` 可使模型优先使用高优先级的 experience (默认为使用更新产生的 experience)。

- `synchronizer.sync_style = dynamic_by_explorer`:由 explorer 决定何时与 trainer 同步模型权重。

Expand Down Expand Up @@ -124,7 +124,8 @@ buffer:
experience_buffer:
name: alfworld_buffer
storage_type: queue
use_priority_queue: true
replay_buffer:
enable: true
Comment thread
pan-x-c marked this conversation as resolved.
explorer:
max_repeat_times_per_runner: 1
runner_per_model: 16
Expand Down Expand Up @@ -154,11 +155,12 @@ trainer:
ulysses_sequence_parallel_size: 1
```


下面,我们提供运行 ALFWorld 任务的命令。

## 示例:多步 ALFWorld

### 环境准备

要安装 ALFWorld 环境,可按照以下说明操作。

1. 使用 pip 安装:`pip install alfworld[full]`
Expand Down
15 changes: 8 additions & 7 deletions docs/sphinx_doc/source_zh/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -270,14 +270,12 @@ buffer:
- `name`: 数据集名称。该名称将用作 Ray actor 的名称,因此必须唯一。
- `storage_type`: 数据集的存储方式。选项:`file`、`queue`、`sql`。
- `file`: 数据集存储在 `jsonl`/`parquet` 文件中。数据文件组织需符合 HuggingFace 标准。*建议大多数情况下使用此存储类型。*
- `queue`: 数据集存储在队列中。队列是一个简单的 FIFO 队列,用于存储任务数据集。*除非你明确了解其用途,否则不要为此类数据集使用此类型。*
- `sql`: 数据集存储在 SQL 数据库中。*此类型尚不稳定,将在未来版本中优化。*
- `path`: 任务数据集的路径。
- 对于 `file` 类型,路径指向包含任务数据集文件的目录。
- 对于 `queue` 类型,路径为可选。可通过在此指定 sqlite 数据库路径来备份队列数据。
- 对于 `sql` 类型,路径指向 sqlite 数据库文件。
- `subset_name`: 任务数据集的子集名称。默认为 `None`。
- `split`: 任务数据集的划分。默认为 `train`。
- `subset_name`: 任务数据集的子集名称,对应 huggingface datasets `load_dataset` 函数中的 `name` 参数。默认为 `None`。
- `split`: 任务数据集的划分。对应 huggingface datasets `load_dataset` 函数中的 `split` 参数。默认为 `train`。
- `repeat_times`: 为一个任务生成的 rollout 数量。若未设置,则自动设为 `algorithm.repeat_times`(`taskset`)或 `1`(`eval_tasksets`)。
- `rollout_args`: rollout 参数。
- `temperature`: 采样温度。
Expand Down Expand Up @@ -321,7 +319,7 @@ buffer:
- 对于 `queue` 类型,此字段可选。可在此指定 SQLite 数据库或 JSON 文件路径以备份队列数据。
- 对于 `file` 类型,路径指向包含数据集文件的目录。
- 对于 `sql` 类型,路径指向 SQLite 数据库文件。
- `format`: 定义数据集中 promptresponse 的键
- `format`: 主要针对 SFTDPO 算法的数据集,用于规范化提取的数据
- `prompt_type`: 指定数据集中 prompt 的类型。目前支持 `plaintext`、`messages`。
- `plaintext`: prompt 为 string 格式。
- `messages`: prompt 为消息列表。
Expand All @@ -336,8 +334,11 @@ buffer:
- `enable_concatenated_multi_turn`: 启用拼接的多轮 SFT 数据预处理。仅适用于 `messages`,且仅在 SFT 算法中生效。
- `chat_template`: 以字符串形式指定 chat template。若未提供,则使用 `model.custom_chat_template`。
- `max_read_timeout`: 读取新 experience 数据的最大等待时间(秒)。若超时,则直接返回不完整批次。仅当 `storage_type` 为 `queue` 时生效。默认为 1800 秒(30 分钟)。
- `use_priority_queue`: 仅当 `storage_type` 为 `queue` 时生效。若设为 `True`,队列为优先级队列,允许优先处理某些 experience。默认为 `False`。
- `reuse_cooldown_time`: 仅当 `storage_type` 为 `queue` 且 `use_priority_queue` 为 `True` 时生效。若设置,指定 experience 重用的冷却时间(秒)。若未指定,默认为 `None`,表示 experience 不可被重复使用。
- `replay_buffer`: 仅当 `storage_type` 为 `queue` 时生效。用于配置 experience 重用的回放缓冲区。
- `enable`: 是否启用回放缓冲区。默认为 `false`。
Comment thread
pan-x-c marked this conversation as resolved.
Outdated
- `reuse_cooldown_time`: experience 重用的冷却时间(秒)。若未指定,默认为 `None`,表示 experience 不可被重复使用。
- `priority_fn`: experience 优先级函数,用于确定 experience 的重用顺序。目前支持 `linear_decay` 和 `linear_decay_use_count_control_randomization`。
- `priority_fn_args`: 传递给优先级函数的参数字典,具体参数取决于所选的优先级函数。
- `auxiliary_buffers`: trainer 使用的可选缓冲区。为字典结构,每个键为 buffer 名称,值为 buffer 配置。每个 buffer 配置与 `experience_buffer` 类似。

---
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ buffer:
experience_buffer:
name: experience_buffer
storage_type: queue
use_priority_queue: true
replay_buffer:
enable: true
Comment thread
pan-x-c marked this conversation as resolved.
explorer:
eval_interval: 10
max_repeat_times_per_runner: 1
Expand Down
3 changes: 2 additions & 1 deletion examples/grpo_alfworld_general_multi_step/alfworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ buffer:
experience_buffer:
name: alfworld_buffer
storage_type: queue
use_priority_queue: true
replay_buffer:
enable: true
Comment thread
pan-x-c marked this conversation as resolved.
explorer:
max_repeat_times_per_runner: 1
runner_per_model: 8
Expand Down
3 changes: 2 additions & 1 deletion examples/grpo_email_search/email_search.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ buffer:
experience_buffer:
name: experience_buffer
storage_type: queue
use_priority_queue: true
replay_buffer:
enable: true
Comment thread
pan-x-c marked this conversation as resolved.
explorer:
eval_interval: 10
max_repeat_times_per_runner: 1
Expand Down
3 changes: 2 additions & 1 deletion examples/grpo_rubric_as_reward/rubric.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ buffer:
experience_buffer:
name: experience_buffer
storage_type: queue
use_priority_queue: true
replay_buffer:
enable: true
Comment thread
pan-x-c marked this conversation as resolved.
explorer:
eval_interval: 10
max_timeout: 3600
Expand Down
14 changes: 10 additions & 4 deletions tests/buffer/experience_pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from tests.tools import RayUnittestBaseAysnc, get_template_config
from trinity.buffer import get_buffer_reader
from trinity.buffer.pipelines.experience_pipeline import ExperiencePipeline
from trinity.common.config import ExperiencePipelineConfig, OperatorConfig
from trinity.common.config import (
ExperienceBufferConfig,
ExperiencePipelineConfig,
OperatorConfig,
)
from trinity.common.constants import SELECTOR_METRIC
from trinity.common.experience import EID, Experience

Expand Down Expand Up @@ -52,9 +56,11 @@ async def test_experience_pipeline(self):
config.algorithm.advantage_fn = (
"grpo" # grpo will add an operator at the end of the pipeline
)
config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
name="pipeline_test_experience_buffer",
max_read_timeout=3,
)
config.check_and_update()
config.buffer.trainer_input.experience_buffer.name = "pipeline_test_experience_buffer"
config.buffer.trainer_input.experience_buffer.max_read_timeout = 3

pipeline = (
ray.remote(ExperiencePipeline)
Expand All @@ -71,7 +77,7 @@ async def test_experience_pipeline(self):
) # first experience of each task will be filtered out by the reward filter

# tests
reader = get_buffer_reader(config.buffer.trainer_input.experience_buffer, config.buffer)
reader = get_buffer_reader(config.buffer.trainer_input.experience_buffer)
exps = await reader.read_async(batch_size=task_num * (repeat_times - 1))
self.assertEqual(len(exps), task_num * (repeat_times - 1))
with self.assertRaises(TimeoutError):
Expand Down
22 changes: 11 additions & 11 deletions tests/buffer/experience_storage_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tests.tools import RayUnittestBaseAysnc
from trinity.buffer.reader.sql_reader import SQLReader
from trinity.buffer.writer.sql_writer import SQLWriter
from trinity.common.config import BufferConfig, StorageConfig
from trinity.common.config import ExperienceBufferConfig
from trinity.common.constants import StorageType
from trinity.common.experience import EID, Experience

Expand All @@ -23,24 +23,22 @@ def setUp(self):
self.put_batch_size = 2
self.train_batch_size = 4

self.config = BufferConfig(
train_batch_size=self.train_batch_size,
)
if os.path.exists(DB_PATH):
os.remove(DB_PATH)

@parameterized.expand([("sft",), ("dpo",)])
async def test_sql_storage(self, schema_type):
meta = StorageConfig(
config = ExperienceBufferConfig(
name="test_storage",
schema_type=schema_type,
storage_type=StorageType.SQL,
max_read_timeout=3,
path=f"sqlite:///{DB_PATH}",
batch_size=self.train_batch_size,
)

writer = SQLWriter(meta, self.config)
reader = SQLReader(meta, self.config)
config = config.to_storage_config()
writer = SQLWriter(config)
reader = SQLReader(config)
self.assertEqual(await writer.acquire(), 1)
exps = [
Experience(
Expand Down Expand Up @@ -90,15 +88,17 @@ def thread_read(reader, result_queue):
self.assertRaises(StopIteration, reader.read, batch_size=1)

async def test_sql_experience_buffer(self):
meta = StorageConfig(
config = ExperienceBufferConfig(
name="test_storage",
schema_type="experience",
storage_type=StorageType.SQL,
max_read_timeout=3,
path=f"sqlite:///{DB_PATH}",
batch_size=self.train_batch_size,
)
writer = SQLWriter(meta, self.config)
reader = SQLReader(meta, self.config)
config = config.to_storage_config()
writer = SQLWriter(config)
reader = SQLReader(config)
self.assertEqual(await writer.acquire(), 1)
for idx in range(self.total_num // self.put_batch_size):
exps = [
Expand Down
Loading