From 03920805a5a3af3a695c3e6fc2b33478f4d88d75 Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Tue, 25 Jun 2024 12:08:27 -0500 Subject: [PATCH] - Bugfixes from upstreaming.... --- megatron/data/data_utils.py | 15 +++++++++------ megatron/neox_arguments/arguments.py | 18 ++++++++++++------ megatron/training.py | 18 ++++++++---------- 3 files changed, 29 insertions(+), 22 deletions(-) diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index 2c548077d..58e7953a1 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -100,7 +100,10 @@ def build_the_dataset( neg_label_dataset = make_indexed_dataset( neg_label_prefix, data_impl, skip_warmup ) - if pos_ref_prefix is not None: + if pos_ref_prefix is None: + pos_ref_dataset = None + neg_ref_dataset = None + else: pos_ref_dataset = make_indexed_dataset( pos_ref_prefix, data_impl, skip_warmup ) @@ -303,7 +306,7 @@ def build_weighted_datasets( neox_args.valid_label_data_paths if neox_args.valid_label_data_paths else [], - neox_args.test_data_paths if neox_args.pos_train_data_paths else [], + neox_args.test_data_paths if neox_args.test_data_paths else [], neox_args.test_label_data_paths if neox_args.test_label_data_paths else [], neox_args.pos_train_data_paths if neox_args.pos_train_data_paths else [], neox_args.neg_train_data_paths if neox_args.neg_train_data_paths else [], @@ -331,7 +334,7 @@ def build_weighted_datasets( else [], ) ): - if train_path: + if train_path or pos_train_path: train_datasets.append( build_the_dataset( data_prefix=train_path, @@ -353,7 +356,7 @@ def build_weighted_datasets( ) ) - if valid_path: + if valid_path or pos_valid_path: valid_datasets.append( build_the_dataset( data_prefix=valid_path, @@ -375,7 +378,7 @@ def build_weighted_datasets( ) ) - if test_path: + if test_path or pos_test_path: test_datasets.append( build_the_dataset( data_prefix=test_path, @@ -465,7 +468,7 @@ def build_train_valid_test_data_iterators(neox_args): test_iters * neox_args.train_batch_size, ] - if neox_args.train_data_paths: + if (neox_args.train_data_paths) or (neox_args.pos_train_data_paths): # when individual train / valid / test data paths are provided # normalize weight values and get num samples for each dataset train_weights, train_num_samples = get_normalized_weights_and_num_samples( diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index 770ec50b4..a89aa04a6 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -794,7 +794,9 @@ def calculate_batch_parameters( # either none of the three parameters are provided or just gradient_accumulation_step is provided else: - assert False, "Either train_batch_size or train_micro_batch_size_per_gpu needs to be provided" + assert ( + False + ), "Either train_batch_size or train_micro_batch_size_per_gpu needs to be provided" return int(train_batch), int(micro_batch), int(grad_acc) @staticmethod @@ -1098,8 +1100,8 @@ def calculate_derived(self): if "flash" in self.attention_config: _flash_version = packaging.version.Version(version("flash-attn")) if self.sliding_window_width is not None: - assert ( - _flash_version >= packaging.version.Version("2.3.0") + assert _flash_version >= packaging.version.Version( + "2.3.0" ), f"Flash-Attention version ({str(_flash_version)}) must be >= 2.3.0 to support sliding window attention." if self.pos_emb == "alibi": if not _flash_version >= packaging.version.Version("2.4.0.post1"): @@ -1110,15 +1112,19 @@ def calculate_derived(self): # Adding equal dataset weights if none are provided if self.train_data_paths and (self.train_data_weights is None): self.train_data_weights = [1.0] * len(self.train_data_paths) + elif self.pos_train_data_paths and (self.train_data_weights is None): + self.train_data_weights = [1.0] * len(self.pos_train_data_paths) if self.valid_data_paths and (self.valid_data_weights is None): self.valid_data_weights = [1.0] * len(self.valid_data_paths) + elif self.pos_valid_data_paths and (self.valid_data_weights is None): + self.valid_data_weights = [1.0] * len(self.pos_valid_data_paths) if self.test_data_paths and (self.test_data_weights is None): self.test_data_weights = [1.0] * len(self.test_data_paths) + elif self.pos_test_data_paths and (self.test_data_weights is None): + self.test_data_weights = [1.0] * len(self.pos_test_data_paths) if self.train_label_data_paths: - err_str = ( - "Must use `train_label_data_paths` with `train_data_paths`, not `data_path`" - ) + err_str = "Must use `train_label_data_paths` with `train_data_paths`, not `data_path`" assert self.train_data_paths and not self.data_path, err_str # if a sample input file is provided, default text_gen_type type to input-file diff --git a/megatron/training.py b/megatron/training.py index 3c6a6b506..b578c4ad9 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -285,12 +285,12 @@ def _get_batch(neox_args, tokenizer, keys, data, datatype): label_key = keys[1] if len(keys) > 1 else None # Unpack. tokens_ = data_b[token_key].long() - if "label" in data_b: + if label_key in data_b: label_mask = (data_b[label_key].long() >= 0)[:, 1:].contiguous() labels = torch.where( data_b[label_key].long() >= 0, data_b[label_key].long(), - torch.zeros_like(data_b["label"].long()), + torch.zeros_like(data_b[label_key].long()), )[:, 1:].contiguous() else: label_mask = (tokens_.long() >= 0)[:, 1:].contiguous() @@ -319,7 +319,7 @@ def get_batch(neox_args, data_iterator): elif neox_args.train_impl == "dpo": keys = ( [["pos", "pos_label"], ["neg", "neg_label"]] - if neox_args.pos_label_data_paths + if neox_args.pos_train_label_data_paths else [["pos"], ["neg"]] ) datatype = torch.int64 @@ -329,7 +329,7 @@ def get_batch(neox_args, data_iterator): data = next(data_iterator) else: data = None - if neox_args.train_type == "normal": + if neox_args.train_impl == "normal": return _get_batch( neox_args=neox_args, tokenizer=neox_args.tokenizer, @@ -337,7 +337,7 @@ def get_batch(neox_args, data_iterator): data=data, datatype=datatype, ) - elif neox_args.train_type == "dpo": + elif neox_args.train_impl == "dpo": pos_tup = _get_batch( neox_args=neox_args, tokenizer=neox_args.tokenizer, @@ -516,7 +516,7 @@ def forward_step( else: moe_loss = 0.0 loss = main_loss + moe_loss - elif neox_args.train_type == "dpo": + elif neox_args.train_impl == "dpo": # Based on https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90 with torch.no_grad(): # So we can gather token logps... @@ -853,7 +853,7 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): ) """Setup model and optimizer.""" - needs_reference_model = neox_args.train_type == "dpo" + needs_reference_model = neox_args.train_impl == "dpo" model = get_model(neox_args=neox_args, use_cache=use_cache) if needs_reference_model: reference_model = get_model(neox_args=neox_args, use_cache=use_cache) @@ -933,7 +933,6 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): neox_args.iteration = load_checkpoint( neox_args=neox_args, model=model, - reference_model=reference_model, optimizer=optimizer, lr_scheduler=lr_scheduler, iteration=iteration, @@ -1071,8 +1070,7 @@ def train_step( save_snapshot(neox_args) # reduces metrics across machines for logging reduce_metrics = { - key: reduce_losses([metric_dicts[key]]).mean() - for key in metric_dicts.keys() + key: reduce_losses(metric_dicts[key]).mean() for key in metric_dicts.keys() } reduce_metrics["lm_loss"] = reduce_losses(losses).mean()