Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch InterleaveDataset to use weights (e.g., 2.0, 0.5, etc) #140

Merged
merged 18 commits into from
Oct 26, 2024
74 changes: 22 additions & 52 deletions ultravox/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import dataclasses
import enum
import io
import itertools
import logging
import os
import tempfile
Expand Down Expand Up @@ -550,77 +549,48 @@ def __len__(self):
return self._length


class StopStrategy(str, enum.Enum):
FIRST_EXHAUSTED = "FIRST_EXHAUSTED"
LAST_EXHAUSTED = "LAST_EXHAUSTED"
NEVER_STOP = "NEVER_STOP"


class InterleaveDataset(SizedIterableDataset):
"""Interleaves multiple IterableDataset objects based on normalized weights."""
"""Interleaves multiple SizedIterableDataset objects based on normalized weights."""

def __init__(
self,
datasets: Sequence[SizedIterableDataset],
weights: Optional[Sequence[float]] = None,
stop_strategy: StopStrategy = StopStrategy.LAST_EXHAUSTED,
seed: Optional[int] = 42,
static: bool = False,
) -> None:
"""
Args:
datasets: A list of SizedIterableDataset objects.
weights: A list of weights for each dataset.
stop_strategy: Strategy for stopping iteration.
weights: An optional list of dataset weights, i.e., the number of times it should be repeated.
seed: Optional seed for reproducibility.
static: If true, the datasets are interleaved in a static order with equal weights.
"""
self._datasets = datasets
self._rng = np.random.default_rng(seed)
self._static = static
self._stop_strategy = stop_strategy

if weights is None:
if weights is not None:
assert len(weights) == len(datasets)
else:
weights = [1.0] * len(datasets)
total_weight = sum(weights)
self._normalized_probs = [w / total_weight for w in weights]
self._weighted_samples = [int(w * len(d)) for w, d in zip(weights, datasets)]
self._total_samples = sum(self._weighted_samples)

def __iter__(self):
# If no datasets are provided, return an empty iterator
if not self._datasets:
return

iters = [iter(ds) for ds in self._datasets]
exhausted = [False] * len(iters)

if self._static:
static_iter = itertools.cycle(range(len(self._datasets)))

while True:
if self._static:
iter_index = next(static_iter)
else:
iter_index = self._rng.choice(len(iters), p=self._normalized_probs)

ds_iters = [iter(ds) for ds in self._datasets]
ds_pos = [0] * len(ds_iters)
juberti marked this conversation as resolved.
Show resolved Hide resolved
# Find the iterator that is least far along and vend from it.
for i in range(self._total_samples):
juberti marked this conversation as resolved.
Show resolved Hide resolved
min_fraction = 1.0
for j in range(len(ds_iters)):
iter_fraction = ds_pos[j] / self._weighted_samples[j]
if iter_fraction < min_fraction:
min_fraction = iter_fraction
iter_index = j
juberti marked this conversation as resolved.
Show resolved Hide resolved
try:
yield next(iters[iter_index])
yield next(ds_iters[iter_index])
except StopIteration:
exhausted[iter_index] = True

# Check if stopping condition is met
if self._stop_strategy == StopStrategy.FIRST_EXHAUSTED or (
self._stop_strategy == StopStrategy.LAST_EXHAUSTED
and all(exhausted)
):
break

# Recreate the iterator if stopping condition is not met and yield the next sample
iters[iter_index] = iter(self._datasets[iter_index])
yield next(iters[iter_index])
ds_iters[iter_index] = iter(self._datasets[iter_index])
yield next(ds_iters[iter_index])
ds_pos[iter_index] += 1

def __len__(self):
# TODO: Implement the length method for different stop strategies
return sum(len(ds) for ds in self._datasets)
return self._total_samples


class Dataproc(SizedIterableDataset):
Expand Down
86 changes: 19 additions & 67 deletions ultravox/data/datasets_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import itertools
from typing import Optional, Union

import datasets as hf_datasets
Expand All @@ -16,7 +15,7 @@ class FakeSizedIterableDataset(datasets.SizedIterableDataset):

def __init__(self, n, start=0, length=0):
self.data = range(start, start + n)
self._length = length
self._length = length or n

def __iter__(self):
for sample in self.data:
Expand Down Expand Up @@ -95,90 +94,43 @@ def test_dataproc():
assert list(s) == [0, -1, -2, -3, -4]


def test_interleaved_first_exhausted():
ds1 = FakeSizedIterableDataset(5)
s = datasets.InterleaveDataset([ds1])
assert list(s) == [0, 1, 2, 3, 4]
ds2 = FakeSizedIterableDataset(9)
ds3 = FakeSizedIterableDataset(3)
s = datasets.InterleaveDataset(
[ds1, ds2, ds3],
stop_strategy=datasets.StopStrategy.FIRST_EXHAUSTED,
static=True,
)
# static=True disables random sampling of datasets, so the order is deterministic
# stop_strategy=first_exhausted will stop interleave when the first dataset is exhausted
assert list(s) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3]
def test_interleaved_empty():
s = datasets.InterleaveDataset([])
assert list(s) == []


def test_interleaved_last_exhausted():
def test_interleaved_single_set():
ds1 = FakeSizedIterableDataset(4)
ds2 = FakeSizedIterableDataset(2, start=10)
s = datasets.InterleaveDataset(
[ds1, ds2],
stop_strategy=datasets.StopStrategy.LAST_EXHAUSTED,
static=True,
)
# static=True disables random sampling of datasets, so the order is deterministic
# stop_strategy=last_exhausted will stop interleave when the last dataset is exhausted
assert list(s) == [0, 10, 1, 11, 2, 10, 3, 11]
s = datasets.InterleaveDataset([ds1])
assert list(s) == [0, 1, 2, 3]


def test_interleaved_never_stop():
def test_interleaved_normal_weights():
ds1 = FakeSizedIterableDataset(4)
ds2 = FakeSizedIterableDataset(8, start=10)
ds3 = FakeSizedIterableDataset(2, start=100)
s = datasets.InterleaveDataset([ds1, ds2, ds3])
assert list(s) == [0, 10, 100, 11, 1, 12, 13, 2, 14, 101, 15, 3, 16, 17]


def test_interleaved_specific_weights():
ds1 = FakeSizedIterableDataset(4)
ds2 = FakeSizedIterableDataset(2, start=10)
s = datasets.InterleaveDataset(
[ds1, ds2],
stop_strategy=datasets.StopStrategy.NEVER_STOP,
static=True,
)
# static=True disables random sampling of datasets, so the order is deterministic
# stop_strategy=never_stop will continue interleaving forever
assert list(itertools.islice(s, 12)) == [0, 10, 1, 11, 2, 10, 3, 11, 0, 10, 1, 11]
s = datasets.InterleaveDataset([ds1, ds2], [0.5, 2.0])
assert list(s) == [0, 10, 11, 1, 10, 11]


def test_interleaved_random():
def test_interleaved_zero_weights():
ds1 = FakeSizedIterableDataset(4)
ds2 = FakeSizedIterableDataset(2, start=10)
s = datasets.InterleaveDataset(
[ds1, ds2],
[10.0, 1.0],
)
# stop_strategy=last_exhausted will stop interleaving when the last dataset is exhausted (attempted after exhaustion)
assert list(s) == [
0,
1,
2,
3,
0,
10,
1,
2,
3,
0,
1,
11,
2,
3,
0,
1,
2,
3,
0,
1,
2,
3,
]
s = datasets.InterleaveDataset([ds1, ds2], [0.0, 0.0])
assert list(s) == []


def test_interleaved_with_multiprocessing():
ds = FakeSizedIterableDataset(5)
s = datasets.InterleaveDataset([ds])

dl = data.DataLoader(s, num_workers=1, batch_size=5)

batch = next(iter(dl))
assert torch.allclose(batch, torch.tensor([0, 1, 2, 3, 4]))

Expand Down
5 changes: 0 additions & 5 deletions ultravox/training/config_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,12 @@ def get_val_sets(self) -> List[DatasetOptions]:
do_train: bool = True
do_eval: bool = True

# In InterleaveDataset, when to stop interleave: choose from last_exhausted (default), first_exhausted, or never_stop
stop_strategy: datasets.StopStrategy = datasets.StopStrategy.LAST_EXHAUSTED
data_dir: Optional[str] = None
mds: bool = False
num_samples: Optional[int] = None
val_num_samples: int = 100
eval_num_samples: int = 100
eval_max_new_tokens: Optional[int] = None
eval_num_procs: int = 8
eval_text_only: bool = False
num_prompts: int = 1
# number of data loader workers
num_workers: int = 8 if torch.cuda.is_available() else 1
train_on_inputs: bool = False
Expand Down
1 change: 0 additions & 1 deletion ultravox/training/configs/meta_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ train_sets:
val_sets:
- name: gigaspeech
- weight: 0.01
stop_strategy: "LAST_EXHAUSTED"

train_on_inputs: False
shuffle_data: True
Expand Down
7 changes: 1 addition & 6 deletions ultravox/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def prepare_dataset(
data_args: datasets.VoiceDatasetArgs,
processor: ultravox_processing.UltravoxProcessor,
train_on_inputs: bool,
stop_strategy: datasets.StopStrategy,
num_samples: Optional[int] = None,
include_alt_fields: bool = False, # whether to generate tensors for text-only input (e.g., used for KD training)
) -> datasets.SizedIterableDataset:
Expand All @@ -50,9 +49,7 @@ def prepare_dataset(
len(ds) > 1
), f"Dataset {ds} has length {len(ds)} which is too short for epoch training"

interleave = datasets.InterleaveDataset(
data_sets, data_weights, stop_strategy=stop_strategy
)
interleave = datasets.InterleaveDataset(data_sets, data_weights)
ds_with_proc = data_processing.UltravoxDataproc(
interleave,
processor=processor,
Expand Down Expand Up @@ -207,7 +204,6 @@ def train(args: config_base.TrainConfig):
train_args=args,
data_opts=args.get_train_sets(),
train_on_inputs=args.train_on_inputs,
stop_strategy=args.stop_strategy,
processor=processor,
num_samples=args.num_samples,
data_args=datasets.VoiceDatasetArgs(
Expand All @@ -228,7 +224,6 @@ def train(args: config_base.TrainConfig):
train_args=args,
data_opts=[val_opt],
train_on_inputs=args.train_on_inputs,
stop_strategy=args.stop_strategy,
processor=processor,
num_samples=args.val_num_samples,
data_args=val_ds_args,
Expand Down
Loading