Skip to content

Commit 5f7c105

Browse files
committed
[data, trainer] feat: add support for limiting samples from dataset
e.g.: For RLHFDataset, `filter_overlong_prompts` can be very expensive and it will be good to add support to limit the sample size before we do this when the dataset is very large. Also add support for other kinds of datasets for unification.
1 parent 4da0d3d commit 5f7c105

23 files changed

+235
-24
lines changed

docs/examples/config.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ Data
1717
tokenizer: null
1818
train_files: ~/data/rlhf/gsm8k/train.parquet
1919
val_files: ~/data/rlhf/gsm8k/test.parquet
20+
train_max_samples: -1 # set to -1 to use full dataset
21+
val_max_samples: -1 # set to -1 to use full dataset
2022
prompt_key: prompt
2123
max_prompt_length: 512
2224
max_response_length: 512
@@ -41,6 +43,10 @@ Data
4143
HDFS path to local path.
4244
- ``data.val_files``: Validation parquet. Can be a list or a single
4345
file.
46+
- ``data.train_max_samples``: Maximum number of samples to use from the
47+
training dataset. Set to -1 to use the full dataset.
48+
- ``data.val_max_samples``: Maximum number of samples to use from the
49+
validation dataset. Set to -1 to use the full dataset.
4450
- ``data.prompt_key``: The field in the dataset where the prompt is
4551
located. Default is 'prompt'.
4652
- ``data.max_prompt_length``: Maximum prompt length. All prompts will be

examples/split_placement/config/ppo_trainer_split.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ data:
1212
tokenizer: null
1313
train_files: ~/data/rlhf/gsm8k/train.parquet
1414
val_files: ~/data/rlhf/gsm8k/test.parquet
15+
train_max_samples: -1 # set to -1 to use full dataset
16+
val_max_samples: -1 # set to -1 to use full dataset
1517
prompt_key: prompt
1618
max_prompt_length: 512
1719
max_response_length: 512

recipe/entropy/main_entropy.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,16 @@ def run(self, config):
162162

163163
from verl.utils.dataset.rl_dataset import collate_fn
164164

165-
train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor)
166-
val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor)
165+
train_dataset = create_rl_dataset(
166+
config.data.train_files,
167+
config.data,
168+
tokenizer,
169+
processor,
170+
max_samples=config.data.get("train_max_samples", -1),
171+
)
172+
val_dataset = create_rl_dataset(
173+
config.data.val_files, config.data, tokenizer, processor, max_samples=config.data.get("val_max_samples", -1)
174+
)
167175
train_sampler = create_rl_sampler(config.data, train_dataset)
168176
trainer = RayEntropyTrainer(
169177
config=config,
@@ -183,7 +191,7 @@ def run(self, config):
183191
trainer.fit()
184192

185193

186-
def create_rl_dataset(data_paths, data_config, tokenizer, processor):
194+
def create_rl_dataset(data_paths, data_config, tokenizer, processor, max_samples: int = -1):
187195
"""Create a dataset.
188196
189197
Arguments:
@@ -216,6 +224,7 @@ def create_rl_dataset(data_paths, data_config, tokenizer, processor):
216224
tokenizer=tokenizer,
217225
processor=processor,
218226
config=data_config,
227+
max_samples=max_samples,
219228
)
220229

221230
return dataset

recipe/one_step_off_policy/main_ppo.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,16 @@ def run(self, config):
212212
from verl.utils.dataset.rl_dataset import collate_fn
213213

214214
# Create training and validation datasets.
215-
train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor)
216-
val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor)
215+
train_dataset = create_rl_dataset(
216+
config.data.train_files,
217+
config.data,
218+
tokenizer,
219+
processor,
220+
max_samples=config.data.get("train_max_samples", -1),
221+
)
222+
val_dataset = create_rl_dataset(
223+
config.data.val_files, config.data, tokenizer, processor, max_samples=config.data.get("val_max_samples", -1)
224+
)
217225
train_sampler = create_rl_sampler(config.data, train_dataset)
218226

219227
# Initialize the PPO trainer.

recipe/spin/spin_trainer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,11 +393,19 @@ def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampl
393393

394394
if train_dataset is None:
395395
train_dataset = create_rl_dataset(
396-
self.config.data.train_files, self.config.data, self.tokenizer, self.processor
396+
self.config.data.train_files,
397+
self.config.data,
398+
self.tokenizer,
399+
self.processor,
400+
max_samples=self.config.data.get("train_max_samples", -1),
397401
)
398402
if val_dataset is None:
399403
val_dataset = create_rl_dataset(
400-
self.config.data.val_files, self.config.data, self.tokenizer, self.processor
404+
self.config.data.val_files,
405+
self.config.data,
406+
self.tokenizer,
407+
self.processor,
408+
max_samples=self.config.data.get("val_max_samples", -1),
401409
)
402410
self.train_dataset, self.val_dataset = train_dataset, val_dataset
403411

tests/special_e2e/sft/test_sp_loss_match.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,12 @@ def create_trainer(config):
112112

113113
local_model_path = copy_to_local(src=config.model.partial_pretrain, verbose=True)
114114
tokenizer = hf_tokenizer(local_model_path, trust_remote_code=config.model.trust_remote_code)
115-
train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer)
116-
val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer)
115+
train_dataset = create_sft_dataset(
116+
config.data.train_files, config.data, tokenizer, max_samples=config.data.get("train_max_samples", -1)
117+
)
118+
val_dataset = create_sft_dataset(
119+
config.data.val_files, config.data, tokenizer, max_samples=config.data.get("val_max_samples", -1)
120+
)
117121

118122
return FSDPSFTTrainer(
119123
config=config,

tests/trainer/config/legacy_ppo_megatron_trainer.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ data:
22
tokenizer: null
33
train_files: ~/data/rlhf/gsm8k/train.parquet
44
val_files: ~/data/rlhf/gsm8k/test.parquet
5+
train_max_samples: -1 # set to -1 to use full dataset
6+
val_max_samples: -1 # set to -1 to use full dataset
57
prompt_key: prompt
68
reward_fn_key: data_source
79
max_prompt_length: 512

tests/trainer/config/legacy_ppo_trainer.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,16 @@ data:
2222
# Validation parquet. Can be a list or a single file.
2323
val_files: ~/data/rlhf/gsm8k/test.parquet
2424

25+
# Maximum sample length to be used.
26+
# Set to -1 to use full dataset, otherwise, randomly
27+
# select the specified number of samples from train dataset
28+
train_max_samples: -1
29+
30+
# Maximum sample length to be used.
31+
# Set to -1 to use full dataset, otherwise, randomly
32+
# select the specified number of samples from val dataset
33+
val_max_samples: -1
34+
2535
# The field in the dataset where the prompt is located. Default is 'prompt'.
2636
prompt_key: prompt
2737

tests/utils/dataset/test_rl_dataset_on_cpu.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,25 @@ def test_rl_dataset():
6666
print(f"\n\noutput: {output}")
6767

6868

69+
def test_rl_dataset_with_max_samples():
70+
from verl.utils import hf_tokenizer
71+
from verl.utils.dataset.rl_dataset import RLHFDataset
72+
73+
tokenizer = hf_tokenizer("deepseek-ai/deepseek-coder-1.3b-instruct")
74+
local_path = get_gsm8k_data()
75+
config = OmegaConf.create(
76+
{
77+
"prompt_key": "prompt",
78+
"max_prompt_length": 256,
79+
"filter_overlong_prompts": True,
80+
"filter_overlong_prompts_workers": 2,
81+
"max_samples": 5,
82+
}
83+
)
84+
dataset = RLHFDataset(data_files=local_path, tokenizer=tokenizer, config=config, max_samples=5)
85+
assert len(dataset) == 5
86+
87+
6988
def test_image_rl_data():
7089
from verl.utils import hf_processor, hf_tokenizer
7190
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn

tests/utils/dataset/test_sft_dataset_on_cpu.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,26 @@ def test_sft_dataset():
7272
output = tokenizer.batch_decode([data])[0]
7373
assert len(output) > 1
7474
assert isinstance(output, str)
75+
76+
77+
def test_sft_dataset_with_max_samples():
78+
tokenizer = hf_tokenizer("deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct")
79+
local_path = get_gsm8k_data()
80+
from omegaconf import OmegaConf
81+
82+
dataset = SFTDataset(
83+
parquet_files=local_path,
84+
tokenizer=tokenizer,
85+
config=OmegaConf.create(
86+
{
87+
"prompt_key": "extra_info",
88+
"prompt_dict_keys": ["question"],
89+
"response_key": "extra_info",
90+
"response_dict_keys": ["answer"],
91+
"max_length": 512,
92+
}
93+
),
94+
max_samples=5,
95+
)
96+
97+
assert len(dataset) == 5

0 commit comments

Comments
 (0)