Skip to content

Commit

Permalink
Initial commit of replay
Browse files Browse the repository at this point in the history
  • Loading branch information
AIproj committed Apr 12, 2024
1 parent 01657aa commit defa0a4
Show file tree
Hide file tree
Showing 6 changed files with 524 additions and 23 deletions.
109 changes: 93 additions & 16 deletions megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def build_the_dataset(
skip_warmup,
build_index_mappings=True,
label_prefix=None,
index_mapping_paths=None,
index_offset=0,
reshuffle_when_loading=True,
):
"""Build train/valid/test datasets."""

Expand All @@ -85,6 +88,9 @@ def build_the_dataset(
seed,
build_index_mappings=build_index_mappings,
label_dataset=label_dataset,
index_mapping_paths=index_mapping_paths,
index_offset=index_offset,
reshuffle_when_loading=reshuffle_when_loading,
)
return dataset

Expand Down Expand Up @@ -191,6 +197,32 @@ def get_normalized_weights_and_num_samples(
weighted_num_samples.append(int(math.ceil(num_samples * weight * 1.005)))
return weights, weighted_num_samples

def get_normalized_weights_and_num_samples_with_replay(
weights: List[float], replay_weights: List[float], replay_fraction, num_samples: int
) -> Tuple[List[float], List[int]]:
# Normalize weights. weights correspond to the weights from the training data and replay_weights correspond
# to weights from the replay data. The idea is that we will be merge the weights provided for training data
# and replay data into the same array. We know that replay_weights should contribute replay_fraction of all
# weights, so we also need to normalise replay weights by replay_fraction and the rest by (1-replay_fraction).
weight_sum = sum(weights)
assert weight_sum > 0.0
weights = [(weight / weight_sum) * (1-replay_fraction) for weight in weights]

replay_weights_sum = sum(replay_weights)
assert replay_weights_sum > 0.0
replay_weights = [(replay_weight / replay_weights_sum) * replay_fraction for replay_weight in replay_weights]

# merge weights with the replay weights given the replay_fraction
weights = weights + replay_weights

# Add 0.5% (the 1.005 factor) so in case the blending dataset does
# not uniformly distribute the number of samples, we still have
# samples left to feed to the network.
weighted_num_samples = []
for weight in weights:
weighted_num_samples.append(int(math.ceil(num_samples * weight * 1.005)))
return weights, weighted_num_samples


def build_weighted_datasets(
neox_args,
Expand All @@ -201,7 +233,21 @@ def build_weighted_datasets(
valid_weights,
test_weights,
build_index_mappings=True,
concatenate_train_replay_paths=False,
):

# The concatenate_train_replay_paths bool is necessary to avoid issues when this function gets called a second time.
if neox_args.is_replay_enabled and concatenate_train_replay_paths:
# Merge replay data paths into train data paths logic, but need to keep track of
# what paths in train_data_paths came from replay
num_replay_data_paths = len(neox_args.replay_data_paths)
num_non_replay_data_paths = len(neox_args.train_data_paths)
neox_args.train_data_paths += neox_args.replay_data_paths
else:
num_replay_data_paths = 0

assert not (neox_args.label_data_paths and neox_args.is_replay_enabled), "Simultaneous use of label data and replay is untested.\
Remove assert at your own risk. You might want to add a replay_label_data_paths arg too if relevant."
# build individual datasets
train_datasets, valid_datasets, test_datasets = [], [], []
for i, (train_path, label_path, valid_path, test_path) in enumerate(
Expand All @@ -213,19 +259,39 @@ def build_weighted_datasets(
)
):
if train_path:
train_datasets.append(
build_the_dataset(
data_prefix=train_path,
name=f"train_{i}",
data_impl=neox_args.data_impl,
num_samples=train_num_samples[i],
seq_length=neox_args.seq_length,
seed=neox_args.seed,
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=build_index_mappings,
label_prefix=label_path,
if i < len(neox_args.train_data_paths) - num_replay_data_paths:
train_datasets.append(
build_the_dataset(
data_prefix=train_path,
name=f"train_{i}",
data_impl=neox_args.data_impl,
num_samples=train_num_samples[i],
seq_length=neox_args.seq_length,
seed=neox_args.seed,
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=build_index_mappings,
label_prefix=label_path,
)
)
)

# when dealing with replay dataset, will need to pass neox_args to load idx files instead of building them.
else:
i_replay = i - (len(neox_args.train_data_paths) - num_replay_data_paths)
train_datasets.append(
build_the_dataset(
data_prefix=train_path,
name=f"replay_{i_replay}",
data_impl=neox_args.data_impl,
num_samples=train_num_samples[i],
seq_length=neox_args.seq_length,
seed=neox_args.replay_seed,
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=False,
index_mapping_paths=neox_args.replay_data_to_idx_paths[train_path],
index_offset=neox_args.replay_idx_offsets[i_replay],
reshuffle_when_loading=neox_args.replay_reshuffle_idx,
)
)

if valid_path:
valid_datasets.append(
Expand Down Expand Up @@ -326,9 +392,15 @@ def build_train_valid_test_data_iterators(neox_args):
if neox_args.train_data_paths:
# when individual train / valid / test data paths are provided
# normalize weight values and get num samples for each dataset
train_weights, train_num_samples = get_normalized_weights_and_num_samples(
neox_args.train_data_weights, train_val_test_num_samples[0]
)
if neox_args.is_replay_enabled:
train_weights, train_num_samples = get_normalized_weights_and_num_samples_with_replay(
neox_args.train_data_weights, neox_args.replay_data_weights,
neox_args.replay_fraction, train_val_test_num_samples[0]
)
else:
train_weights, train_num_samples = get_normalized_weights_and_num_samples(
neox_args.train_data_weights, train_val_test_num_samples[0]
)
valid_weights, valid_num_samples = get_normalized_weights_and_num_samples(
neox_args.valid_data_weights, train_val_test_num_samples[1]
)
Expand All @@ -346,10 +418,13 @@ def build_train_valid_test_data_iterators(neox_args):
valid_weights,
test_weights,
build_index_mappings=not neox_args.weight_by_num_documents,
concatenate_train_replay_paths=True,
)

if neox_args.weight_by_num_documents:

assert not neox_args.is_replay_enabled, "Replay not tested in the case of autoweighting, remove assert at your own risk.\
I suspect that something might break with the concatenation of the train and replay happening twice due to a second call\
of build_weighted_datasets, so setting it to False with concatenate_train_replay_paths=False."
# gets the number of documents in each datapath
get_num_docs_list = lambda datasets: [
dataset.indexed_dataset.sizes.shape[0] for dataset in datasets
Expand Down Expand Up @@ -394,6 +469,7 @@ def build_train_valid_test_data_iterators(neox_args):
train_weights,
valid_weights,
test_weights,
concatenate_train_replay_paths=False,
)

if train_datasets:
Expand All @@ -403,6 +479,7 @@ def build_train_valid_test_data_iterators(neox_args):
if test_datasets:
test_ds = BlendableDataset(test_datasets, test_weights)
else:
assert not neox_args.is_replay_enabled, "Replay not implemented in the case of autosplitting into train/val/test datasets."
# when just data_path is provided
# split dataset into train, valid and test from data_path
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
Expand Down
164 changes: 158 additions & 6 deletions megatron/data/gpt2_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def __init__(
build_index_mappings=True,
use_shared_fs=True,
label_dataset=None,
index_mapping_paths=None,
index_offset=0,
reshuffle_when_loading=True,
):

self.name = name
Expand All @@ -48,6 +51,8 @@ def __init__(
# Checks
assert np.min(documents) >= 0
assert np.max(documents) < indexed_dataset.sizes.shape[0]
if not build_index_mappings:
assert index_mapping_paths, "If not building index mappings, the path to existing ones must be provided."

if build_index_mappings:
# Build index mappings.
Expand All @@ -61,13 +66,32 @@ def __init__(
seed,
use_shared_fs=use_shared_fs,
)
self.shuffle_idx_len = self.shuffle_idx.shape[0] - 1
self.sample_idx_len = self.sample_idx.shape[0] - 1

else:
# If not building the index mappings, we need to load them
self.doc_idx, self.sample_idx, self.shuffle_idx = _load_index_mappings(
self.name,
data_prefix,
documents,
self.indexed_dataset.sizes,
num_samples,
seq_length,
seed,
use_shared_fs=use_shared_fs,
index_mapping_paths=index_mapping_paths,
index_offset=index_offset,
reshuffle_when_loading=reshuffle_when_loading,
)


self.shuffle_idx_len = self.shuffle_idx.shape[0] - 1
self.sample_idx_len = self.sample_idx.shape[0] - 1

if self.shuffle_idx_len != self.sample_idx_len:
print(
f"WARNING: shuffle index length ({self.shuffle_idx_len}) is not equal to sample index length ({self.sample_idx_len})"
)

if self.shuffle_idx_len != self.sample_idx_len - 1:
print(
f"WARNING: shuffle index length ({self.shuffle_idx_len}) is not equal to sample index length ({self.sample_idx_len})"
)

def __len__(self):
return min(self.shuffle_idx_len, self.sample_idx_len)
Expand Down Expand Up @@ -242,6 +266,134 @@ def _build_index_mappings(
return doc_idx, sample_idx, shuffle_idx


# Warning: only implemented with replay in mind, some issues may arise when dealing with more than 1 epoch over the dataset
def _load_index_mappings(
name,
data_prefix,
documents,
sizes,
num_samples,
seq_length,
seed,
index_mapping_paths,
use_shared_fs=True,
index_offset=0,
reshuffle_when_loading=True,
):
"""Build doc-idx, sample-idx, and shuffle-idx from ones loaded.
doc-idx: is an array (ordered) of documents to be used in training.
sample-idx: is the start document index and document offset for each
training sample.
shuffle-idx: maps the sample index into a random index into sample-idx.
"""
# Number of tokens in each epoch and number of required epochs.
tokens_per_epoch = _num_tokens(documents, sizes)
num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples)
# rng state
np_rng = np.random.RandomState(seed=seed)
is_replay = name.split("_")[0] == "replay"


# Filename of the index mappings.
_filename = data_prefix
_filename += "_{}_indexmap".format(name)
_filename += "_{}ns".format(num_samples)
_filename += "_{}sl".format(seq_length)
_filename += "_{}s".format(seed)
doc_idx_filename = _filename + "_doc_idx.npy"
sample_idx_filename = _filename + "_sample_idx.npy"
shuffle_idx_filename = _filename + "_shuffle_idx.npy"


if not use_shared_fs:
should_process_dataset = int(os.environ["LOCAL_RANK"]) == 0
else:
should_process_dataset = torch.distributed.get_rank() == 0

# Build the indexed mapping if not exist.
if should_process_dataset:
start_time = time.time()
print_rank_0(" > loading shuffle-idx mapping from {}".format(index_mapping_paths["shuffle_idx_path"]))
shuffle_idx = np.load(index_mapping_paths["shuffle_idx_path"], allow_pickle=True, mmap_mode="r")
print_rank_0(
" loaded indexed file in {:3.3f} seconds".format(time.time() - start_time)
)

idx_prefix = index_mapping_paths["shuffle_idx_path"].split("_")[:-2]

## restrict to samples seen during original pretraining if this is replay
## careful, this is hardcoded based on the number on idx filenames looking like dataset_train_4_indexmap_5781ns_2048sl_1234s_doc_idx.npy

# remove 1.005 buffer that was added when estimating numbers of samples in get_normalized_weights_and_num_samples(). Note that of course
# this may be missing a few of the seen samples.
num_samples_originally_seen = np.int64(np.int64(idx_prefix[-3][:-2]) / 1.005)
seq_length_originally_seen = int(idx_prefix[-2][:-2])
assert seq_length_originally_seen == seq_length, "Current seq len {} does not match seq len the indices were built for ({}).".format(
seq_length,
seq_length_originally_seen,
)
# Useful if we want to support extending the idx if more than 1 epoch is necessary, but for now will assert 1 epoch
num_epochs_in_replay = _num_epochs(tokens_per_epoch, seq_length, num_samples_originally_seen)
assert num_epochs_in_replay == 1, "Enough samples from replay dataset {} for more than one epoch; this is currently\
untested.".format(data_prefix)
if is_replay:
shuffle_idx = shuffle_idx[:num_samples_originally_seen]
# get a sufficient length if needed

# apply offset, but add back the removed elements.
# TODO Do we want to shuffle again the shuffle_idx[:index_offset] term ?
index_offset = index_offset % len(shuffle_idx)
if reshuffle_when_loading:
# Numpy can throw errors if the array isn't writeable, which it is not apparently when it is loaded
if not shuffle_idx.flags.writeable:
try:
shuffle_idx.setflags(write=True)
except:
# copy trick if we couldn't set it to writeable
print("Loaded shuffle_idx at {} is not writeable, need to copy it as workaround...".format("_".join(idx_prefix)))
temp = shuffle_idx.copy()
shuffle_idx = temp
assert shuffle_idx.flags.writeable, "Failed to make shuffle_idx writeable, the shuffling will not work."
print("succesfully copied !")
# For some reason this is faster than shuffling the copied array directly. I'm sure there is a good reason for it.
temp_random_idx = np.array(range(len(shuffle_idx)))
np_rng.shuffle(temp_random_idx)
shuffle_idx = shuffle_idx[temp_random_idx]
del temp_random_idx

shuffle_idx = np.concatenate([shuffle_idx[index_offset:], shuffle_idx[:index_offset]])
np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True)
print_rank_0(
" > elapsed time to build and save shuffle-idx mapping"
" (seconds): {:4f}".format(time.time() - start_time)
)

# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_io_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_io_parallel_group()
)

# Load mappings.
start_time = time.time()
print_rank_0(" > loading doc-idx mapping from {}".format(index_mapping_paths["doc_idx_path"]))
doc_idx = np.load(index_mapping_paths["doc_idx_path"], allow_pickle=True, mmap_mode="r")
print_rank_0(" > loading sample-idx mapping from {}".format(index_mapping_paths["sample_idx_path"]))
sample_idx = np.load(index_mapping_paths["sample_idx_path"], allow_pickle=True, mmap_mode="r")
print_rank_0(" > loading shuffle-idx mapping from {}".format(shuffle_idx_filename))
shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode="r")
print_rank_0(
" loaded indexed file in {:3.3f} seconds".format(time.time() - start_time)
)
print_rank_0(" total number of samples: {}".format(shuffle_idx.shape[0] + 1))
print_rank_0(" total number of epochs: {}".format(num_epochs))

return doc_idx, sample_idx, shuffle_idx


def _num_tokens(documents, sizes):
"""Total number of tokens in the dataset."""
return np.sum(sizes[documents])
Expand Down
Loading

0 comments on commit defa0a4

Please sign in to comment.