Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding replay into GPT-NeoX #1200

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 117 additions & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 11a5537
Default = 621ab25

current git hash of repository

Expand Down Expand Up @@ -1378,6 +1378,122 @@ Training Arguments



- **replay_config**: dict

Default = None

Dictionary storing the replay config.



- **is_replay_enabled**: bool

Default = False

Triggers the logic for replay. It is important to deal with replay separately from the general "train_data_paths" logic, as replay
requires reusing the same idx files to know what data was seen the first time a dataset was originally trained on.
If one attempts to do replay by just putting the datasets to be replayed in the train_data_paths instead of the replay params:
- If the exact same dataset files are used as during the 1st time it was seen, and the number of iterations on the replay buffer
corresponds to as many epochs on a replay dataset as the non-replay training, the data will be seen in exactly the same order as
the first time if the seed and sequence length is the same.
- For similar reasons, replaying multiple times on the same dataset (e.g. across multiple tasks) with the same number of epochs
on the replay dataset will lead to seeing the same data in the same order.
- If a different dataset is used for replay (e.g. different shard of Pile), then the shuffling will lead to completely different
indices, which will lead to potentially significant proportions of data being unseen if the original training on the replay dataset
did not see all of it, e.g. when training on 300B tokens of the GPT2-tokenised Pile which contains a few dozen billion more tokens,
then sharding the full dataset into smaller ones.



- **replay_idx_paths_prefixes**: list

Default = None

List of paths prefixes to retrieve replay dataset idx files. Those idx files should have been generated when originally training on the dataset
being used for replay. They contain in the filename the number of samples potentially seen during pretraining, the sequence length and the
seed. The exact files (shuffle_idx, sample_idx and doc_idx) will be automatically derived from the prefix. Similarly, the data paths will
be generated from the prefixes.
The *_idx files are important as it allows one to know what data was seen in the dataset during training. If those files are missing, you can
regenerate them by relaunching the same training script (most importantly, config) used originally to pretrain on a given dataset. You
can add an exit(0) statement in training.py in pretrain() after the call to build_train_valid_test_data_iterators(neox_args=neox_args).
It is crucial to use the same dataset shard, sequence length, number of iterations, seed, and potentially batch size, or the indices
generated may not be the same.
For a single replay data source, the value passed looks like ["data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s"] and
the files at the following paths (the paths will be constructed during execution from the prefix), must exist:
"data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s_doc_idx.npy"
"data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s_sample_idx.npy"
"data/mydataset/train/mydataset_train_4_indexmap_456789ns_2048sl_1234s_shuffle_idx.npy"
"data/mydataset/train/mydataset"



- **replay_data_to_idx_paths**: dict

Default = None

As indicated above, gets automatically built from the replay_idx_paths_prefixes by appending to it "_doc_idx.npy", "_sample_idx.npy" and
"_shuffle_idx.npy". It generates a dict of dict, with the data paths as keys, and dictionaries mapping each data path to the relevant
doc_idx, sample_idx and shuffle_idx file paths. Note that these files must exist at the relevant paths.



- **replay_data_paths**: list

Default = None

As indicated above, gets automatically built from the replay_idx_paths_prefixes by removing the information about the idx files to retain
only the path to the dataset itself.



- **replay_data_weights**: list

Default = None

List of 'weights' that decide how often to sample from each replay dataset when building the replay buffer.



- **replay_idx_offsets**: list

Default = None

List of indices that decide where to start in the list of seen indices during pretraining on each replay dataset when building
the replay buffer. For example, when training originally on a dataset seeing 10000 samples, this allows to start looking at the
RESHUFFLED indices starting from idx replay_idx_offsets[i] for replay dataset i.
If not set, this will uniformly sample among all replay datasets.



- **replay_fraction**: float

Default = 0.05

Fraction of a batch dedicated to doing replay. For example, 0.1 means that in a batch of 100, 19 samples will come from the replay
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, 0.1 means that in a batch of 100, 19 samples will come from the replay buffer.

Is this a typo? Why wouldn't it be 10 samples?

buffer. Note that this means that if we train on 100B tokens, we will have only used 90B tokens from the datasets specified in
train_data_paths.



- **replay_reshuffle_idx**: bool

Default = True

When index files are loaded from those the dataset was originally pretrained on, they will follow the exact same sequence of samples
seen when training on that dataset the first time if this is set to False. If True, the indices are reshuffled to prevent that.



- **replay_seed**: int

Default = 1234
Copy link
Member

@StellaAthena StellaAthena Apr 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems important that the replay seed isn't the same as the general data seed from your other comments. If that's correct, let's use a different default.


Seed used to reshuffle indices accessed when originally training on a dataset, that are used to do replay. This is useful in the case
where replay is done twice on as many passes over the dataset, in which case if the same seed is used, the replay buffers in both case
will be exactly the same.



- **weight_by_num_documents**: bool

Default = False
Expand Down
109 changes: 93 additions & 16 deletions megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def build_the_dataset(
skip_warmup,
build_index_mappings=True,
label_prefix=None,
index_mapping_paths=None,
index_offset=0,
reshuffle_when_loading=True,
):
"""Build train/valid/test datasets."""

Expand All @@ -85,6 +88,9 @@ def build_the_dataset(
seed,
build_index_mappings=build_index_mappings,
label_dataset=label_dataset,
index_mapping_paths=index_mapping_paths,
index_offset=index_offset,
reshuffle_when_loading=reshuffle_when_loading,
)
return dataset

Expand Down Expand Up @@ -191,6 +197,32 @@ def get_normalized_weights_and_num_samples(
weighted_num_samples.append(int(math.ceil(num_samples * weight * 1.005)))
return weights, weighted_num_samples

def get_normalized_weights_and_num_samples_with_replay(
weights: List[float], replay_weights: List[float], replay_fraction, num_samples: int
) -> Tuple[List[float], List[int]]:
# Normalize weights. weights correspond to the weights from the training data and replay_weights correspond
# to weights from the replay data. The idea is that we will be merge the weights provided for training data
# and replay data into the same array. We know that replay_weights should contribute replay_fraction of all
# weights, so we also need to normalise replay weights by replay_fraction and the rest by (1-replay_fraction).
weight_sum = sum(weights)
assert weight_sum > 0.0
weights = [(weight / weight_sum) * (1-replay_fraction) for weight in weights]

replay_weights_sum = sum(replay_weights)
assert replay_weights_sum > 0.0
replay_weights = [(replay_weight / replay_weights_sum) * replay_fraction for replay_weight in replay_weights]

# merge weights with the replay weights given the replay_fraction
weights = weights + replay_weights

# Add 0.5% (the 1.005 factor) so in case the blending dataset does
# not uniformly distribute the number of samples, we still have
# samples left to feed to the network.
weighted_num_samples = []
for weight in weights:
weighted_num_samples.append(int(math.ceil(num_samples * weight * 1.005)))
return weights, weighted_num_samples


def build_weighted_datasets(
neox_args,
Expand All @@ -201,7 +233,21 @@ def build_weighted_datasets(
valid_weights,
test_weights,
build_index_mappings=True,
concatenate_train_replay_paths=False,
):

# The concatenate_train_replay_paths bool is necessary to avoid issues when this function gets called a second time.
if neox_args.is_replay_enabled and concatenate_train_replay_paths:
# Merge replay data paths into train data paths logic, but need to keep track of
# what paths in train_data_paths came from replay
num_replay_data_paths = len(neox_args.replay_data_paths)
num_non_replay_data_paths = len(neox_args.train_data_paths)
neox_args.train_data_paths += neox_args.replay_data_paths
else:
num_replay_data_paths = 0

assert not (neox_args.label_data_paths and neox_args.is_replay_enabled), "Simultaneous use of label data and replay is untested.\
Remove assert at your own risk. You might want to add a replay_label_data_paths arg too if relevant."
# build individual datasets
train_datasets, valid_datasets, test_datasets = [], [], []
for i, (train_path, label_path, valid_path, test_path) in enumerate(
Expand All @@ -213,19 +259,39 @@ def build_weighted_datasets(
)
):
if train_path:
train_datasets.append(
build_the_dataset(
data_prefix=train_path,
name=f"train_{i}",
data_impl=neox_args.data_impl,
num_samples=train_num_samples[i],
seq_length=neox_args.seq_length,
seed=neox_args.seed,
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=build_index_mappings,
label_prefix=label_path,
if i < len(neox_args.train_data_paths) - num_replay_data_paths:
train_datasets.append(
build_the_dataset(
data_prefix=train_path,
name=f"train_{i}",
data_impl=neox_args.data_impl,
num_samples=train_num_samples[i],
seq_length=neox_args.seq_length,
seed=neox_args.seed,
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=build_index_mappings,
label_prefix=label_path,
)
)
)

# when dealing with replay dataset, will need to pass neox_args to load idx files instead of building them.
else:
i_replay = i - (len(neox_args.train_data_paths) - num_replay_data_paths)
train_datasets.append(
build_the_dataset(
data_prefix=train_path,
name=f"replay_{i_replay}",
data_impl=neox_args.data_impl,
num_samples=train_num_samples[i],
seq_length=neox_args.seq_length,
seed=neox_args.replay_seed,
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=False,
index_mapping_paths=neox_args.replay_data_to_idx_paths[train_path],
index_offset=neox_args.replay_idx_offsets[i_replay],
reshuffle_when_loading=neox_args.replay_reshuffle_idx,
)
)

if valid_path:
valid_datasets.append(
Expand Down Expand Up @@ -326,9 +392,15 @@ def build_train_valid_test_data_iterators(neox_args):
if neox_args.train_data_paths:
# when individual train / valid / test data paths are provided
# normalize weight values and get num samples for each dataset
train_weights, train_num_samples = get_normalized_weights_and_num_samples(
neox_args.train_data_weights, train_val_test_num_samples[0]
)
if neox_args.is_replay_enabled:
train_weights, train_num_samples = get_normalized_weights_and_num_samples_with_replay(
neox_args.train_data_weights, neox_args.replay_data_weights,
neox_args.replay_fraction, train_val_test_num_samples[0]
)
else:
train_weights, train_num_samples = get_normalized_weights_and_num_samples(
neox_args.train_data_weights, train_val_test_num_samples[0]
)
valid_weights, valid_num_samples = get_normalized_weights_and_num_samples(
neox_args.valid_data_weights, train_val_test_num_samples[1]
)
Expand All @@ -346,10 +418,13 @@ def build_train_valid_test_data_iterators(neox_args):
valid_weights,
test_weights,
build_index_mappings=not neox_args.weight_by_num_documents,
concatenate_train_replay_paths=True,
)

if neox_args.weight_by_num_documents:

assert not neox_args.is_replay_enabled, "Replay not tested in the case of autoweighting, remove assert at your own risk.\
I suspect that something might break with the concatenation of the train and replay happening twice due to a second call\
of build_weighted_datasets, so setting it to False with concatenate_train_replay_paths=False."
# gets the number of documents in each datapath
get_num_docs_list = lambda datasets: [
dataset.indexed_dataset.sizes.shape[0] for dataset in datasets
Expand Down Expand Up @@ -394,6 +469,7 @@ def build_train_valid_test_data_iterators(neox_args):
train_weights,
valid_weights,
test_weights,
concatenate_train_replay_paths=False,
)

if train_datasets:
Expand All @@ -403,6 +479,7 @@ def build_train_valid_test_data_iterators(neox_args):
if test_datasets:
test_ds = BlendableDataset(test_datasets, test_weights)
else:
assert not neox_args.is_replay_enabled, "Replay not implemented in the case of autosplitting into train/val/test datasets."
# when just data_path is provided
# split dataset into train, valid and test from data_path
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
Expand Down
Loading