Skip to content

Commit

Permalink
- Bugfixes from upstreaming....
Browse files Browse the repository at this point in the history
  • Loading branch information
dmahan93 committed Jun 25, 2024
1 parent eed3643 commit 0392080
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 22 deletions.
15 changes: 9 additions & 6 deletions megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ def build_the_dataset(
neg_label_dataset = make_indexed_dataset(
neg_label_prefix, data_impl, skip_warmup
)
if pos_ref_prefix is not None:
if pos_ref_prefix is None:
pos_ref_dataset = None
neg_ref_dataset = None
else:
pos_ref_dataset = make_indexed_dataset(
pos_ref_prefix, data_impl, skip_warmup
)
Expand Down Expand Up @@ -303,7 +306,7 @@ def build_weighted_datasets(
neox_args.valid_label_data_paths
if neox_args.valid_label_data_paths
else [],
neox_args.test_data_paths if neox_args.pos_train_data_paths else [],
neox_args.test_data_paths if neox_args.test_data_paths else [],
neox_args.test_label_data_paths if neox_args.test_label_data_paths else [],
neox_args.pos_train_data_paths if neox_args.pos_train_data_paths else [],
neox_args.neg_train_data_paths if neox_args.neg_train_data_paths else [],
Expand Down Expand Up @@ -331,7 +334,7 @@ def build_weighted_datasets(
else [],
)
):
if train_path:
if train_path or pos_train_path:
train_datasets.append(
build_the_dataset(
data_prefix=train_path,
Expand All @@ -353,7 +356,7 @@ def build_weighted_datasets(
)
)

if valid_path:
if valid_path or pos_valid_path:
valid_datasets.append(
build_the_dataset(
data_prefix=valid_path,
Expand All @@ -375,7 +378,7 @@ def build_weighted_datasets(
)
)

if test_path:
if test_path or pos_test_path:
test_datasets.append(
build_the_dataset(
data_prefix=test_path,
Expand Down Expand Up @@ -465,7 +468,7 @@ def build_train_valid_test_data_iterators(neox_args):
test_iters * neox_args.train_batch_size,
]

if neox_args.train_data_paths:
if (neox_args.train_data_paths) or (neox_args.pos_train_data_paths):
# when individual train / valid / test data paths are provided
# normalize weight values and get num samples for each dataset
train_weights, train_num_samples = get_normalized_weights_and_num_samples(
Expand Down
18 changes: 12 additions & 6 deletions megatron/neox_arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,9 @@ def calculate_batch_parameters(

# either none of the three parameters are provided or just gradient_accumulation_step is provided
else:
assert False, "Either train_batch_size or train_micro_batch_size_per_gpu needs to be provided"
assert (
False
), "Either train_batch_size or train_micro_batch_size_per_gpu needs to be provided"
return int(train_batch), int(micro_batch), int(grad_acc)

@staticmethod
Expand Down Expand Up @@ -1098,8 +1100,8 @@ def calculate_derived(self):
if "flash" in self.attention_config:
_flash_version = packaging.version.Version(version("flash-attn"))
if self.sliding_window_width is not None:
assert (
_flash_version >= packaging.version.Version("2.3.0")
assert _flash_version >= packaging.version.Version(
"2.3.0"
), f"Flash-Attention version ({str(_flash_version)}) must be >= 2.3.0 to support sliding window attention."
if self.pos_emb == "alibi":
if not _flash_version >= packaging.version.Version("2.4.0.post1"):
Expand All @@ -1110,15 +1112,19 @@ def calculate_derived(self):
# Adding equal dataset weights if none are provided
if self.train_data_paths and (self.train_data_weights is None):
self.train_data_weights = [1.0] * len(self.train_data_paths)
elif self.pos_train_data_paths and (self.train_data_weights is None):
self.train_data_weights = [1.0] * len(self.pos_train_data_paths)
if self.valid_data_paths and (self.valid_data_weights is None):
self.valid_data_weights = [1.0] * len(self.valid_data_paths)
elif self.pos_valid_data_paths and (self.valid_data_weights is None):
self.valid_data_weights = [1.0] * len(self.pos_valid_data_paths)
if self.test_data_paths and (self.test_data_weights is None):
self.test_data_weights = [1.0] * len(self.test_data_paths)
elif self.pos_test_data_paths and (self.test_data_weights is None):
self.test_data_weights = [1.0] * len(self.pos_test_data_paths)

if self.train_label_data_paths:
err_str = (
"Must use `train_label_data_paths` with `train_data_paths`, not `data_path`"
)
err_str = "Must use `train_label_data_paths` with `train_data_paths`, not `data_path`"
assert self.train_data_paths and not self.data_path, err_str

# if a sample input file is provided, default text_gen_type type to input-file
Expand Down
18 changes: 8 additions & 10 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,12 +285,12 @@ def _get_batch(neox_args, tokenizer, keys, data, datatype):
label_key = keys[1] if len(keys) > 1 else None
# Unpack.
tokens_ = data_b[token_key].long()
if "label" in data_b:
if label_key in data_b:
label_mask = (data_b[label_key].long() >= 0)[:, 1:].contiguous()
labels = torch.where(
data_b[label_key].long() >= 0,
data_b[label_key].long(),
torch.zeros_like(data_b["label"].long()),
torch.zeros_like(data_b[label_key].long()),
)[:, 1:].contiguous()
else:
label_mask = (tokens_.long() >= 0)[:, 1:].contiguous()
Expand Down Expand Up @@ -319,7 +319,7 @@ def get_batch(neox_args, data_iterator):
elif neox_args.train_impl == "dpo":
keys = (
[["pos", "pos_label"], ["neg", "neg_label"]]
if neox_args.pos_label_data_paths
if neox_args.pos_train_label_data_paths
else [["pos"], ["neg"]]
)
datatype = torch.int64
Expand All @@ -329,15 +329,15 @@ def get_batch(neox_args, data_iterator):
data = next(data_iterator)
else:
data = None
if neox_args.train_type == "normal":
if neox_args.train_impl == "normal":
return _get_batch(
neox_args=neox_args,
tokenizer=neox_args.tokenizer,
keys=keys,
data=data,
datatype=datatype,
)
elif neox_args.train_type == "dpo":
elif neox_args.train_impl == "dpo":
pos_tup = _get_batch(
neox_args=neox_args,
tokenizer=neox_args.tokenizer,
Expand Down Expand Up @@ -516,7 +516,7 @@ def forward_step(
else:
moe_loss = 0.0
loss = main_loss + moe_loss
elif neox_args.train_type == "dpo":
elif neox_args.train_impl == "dpo":
# Based on https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90
with torch.no_grad():
# So we can gather token logps...
Expand Down Expand Up @@ -853,7 +853,7 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None):
)

"""Setup model and optimizer."""
needs_reference_model = neox_args.train_type == "dpo"
needs_reference_model = neox_args.train_impl == "dpo"
model = get_model(neox_args=neox_args, use_cache=use_cache)
if needs_reference_model:
reference_model = get_model(neox_args=neox_args, use_cache=use_cache)
Expand Down Expand Up @@ -933,7 +933,6 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None):
neox_args.iteration = load_checkpoint(
neox_args=neox_args,
model=model,
reference_model=reference_model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
iteration=iteration,
Expand Down Expand Up @@ -1071,8 +1070,7 @@ def train_step(
save_snapshot(neox_args)
# reduces metrics across machines for logging
reduce_metrics = {
key: reduce_losses([metric_dicts[key]]).mean()
for key in metric_dicts.keys()
key: reduce_losses(metric_dicts[key]).mean() for key in metric_dicts.keys()
}
reduce_metrics["lm_loss"] = reduce_losses(losses).mean()

Expand Down

0 comments on commit 0392080

Please sign in to comment.