Skip to content

Commit

Permalink
- add precompute logprobs...
Browse files Browse the repository at this point in the history
  • Loading branch information
dmahan93 committed Jun 25, 2024
1 parent 0392080 commit 361f459
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 13 deletions.
3 changes: 3 additions & 0 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
generate_samples_from_prompt,
generate_samples_unconditional,
generate_samples_interactive,
precompute_logits,
)


Expand Down Expand Up @@ -83,6 +84,8 @@ def main(input_args=None, overwrite_values=None):
top_p=neox_args.top_p,
)

elif neox_args.text_gen_type == "precompute":
precompute_logits(neox_args=neox_args, model=model)
else:
raise ValueError(
f"`text_gen_type` either not specified or not recognised: {neox_args.text_gen_type}"
Expand Down
7 changes: 6 additions & 1 deletion megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,7 +1281,12 @@ class NeoXArgsTextgen(NeoXArgsTemplate):
text_gen_type: str = None
"""
How to generate text/sample the model.
Options: `unconditional`, `input-file`, `interactive`
Options: `unconditional`, `input-file`, `interactive`, `precompute`
"""

precompute_model_name: str = None
"""
Model name to use for saving precomputed logprobs
"""

temperature: float = 0.0
Expand Down
192 changes: 189 additions & 3 deletions megatron/text_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@
import time
from typing import List, Union

import numpy as np
import torch
import torch.nn.functional as F

from megatron import print_rank_0
from megatron import mpu
from megatron.utils import get_ltor_masks_and_position_ids, is_mp_rank_0
from megatron.data.indexed_dataset import make_builder, make_dataset
from megatron.mpu.mappings import gather_from_model_parallel_region


def get_batch(neox_args, context_tokens: torch.Tensor):
Expand All @@ -52,25 +55,31 @@ def get_batch(neox_args, context_tokens: torch.Tensor):
return tokens, attention_mask, position_ids


def pad_batch(context_tokens: List[List[int]], pad_id: int, pad_len: int):
def pad_batch(
context_tokens: List[List[int]], pad_id: int, pad_len: int, truncate: bool = False
):
"""
pads context lengths in context_tokens with pad_id to equal neox_args.seq_length,
and returns the padded batch and the new lengths.
context_tokens: list of lists of tokens
pad_id: int, integer to use as padding token
pad_len: int, context length to be padded; all batch items will be padded to the same length
truncate: bool, if True, truncate context tokens to pad_len if they are longer than pad_len
returns: tuple of padded context tokens and a list of unpadded token count
"""

context_lengths = []
for tokens in context_tokens:
for i, tokens in enumerate(context_tokens):
context_length = len(tokens)
if context_length < pad_len:
tokens.extend([pad_id] * (pad_len - context_length))
elif context_length > pad_len:
raise ValueError("context_length is bigger than to be padded length")
if not truncate:
raise ValueError("context_length is bigger than to be padded length")
context_tokens[i] = tokens[:pad_len]
context_length = pad_len
context_lengths.append(context_length)
return context_tokens, context_lengths

Expand Down Expand Up @@ -807,3 +816,180 @@ def generate_samples_interactive(
print_rank_0("Generated Text: " + generated_text)
if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
_ = input("\n<press enter to continue>")


def get_logp(logits, labels, force_fp32=False):
if force_fp32:
logits = logits.float()
logp = logits.log_softmax(dim=-1)
return torch.gather(logp, dim=2, index=labels.unsqueeze(2)).squeeze(2)


def precompute_logits(neox_args, model):
"""
Precomputes logprobs from training/testing/validation datasets
Saves it to the same directory as the dataset with the model name appended to it
neox_args: NeoXArgs.
model: a Megatron model
"""
if neox_args.precompute_model_name is None:
mdl_name = str(hash(neox_args.load))
else:
mdl_name = neox_args.precompute_model_name
print_rank_0("Precomputing logprobs...")
model.eval()
data_paths = list()
if neox_args.train_data_paths is not None:
for path in neox_args.train_data_paths:
data_paths.append(path)
for path in neox_args.test_data_paths:
data_paths.append(path)
for path in neox_args.valid_data_paths:
data_paths.append(path)
elif neox_args.pos_train_data_paths is not None:
# Pairwise data...
for path in neox_args.pos_train_data_paths:
data_paths.append(path)
for path in neox_args.neg_train_data_paths:
data_paths.append(path)
for path in neox_args.pos_valid_data_paths:
data_paths.append(path)
for path in neox_args.neg_valid_data_paths:
data_paths.append(path)
for path in neox_args.pos_test_data_paths:
data_paths.append(path)
for path in neox_args.neg_test_data_paths:
data_paths.append(path)
for path in data_paths:
print_rank_0(f"Precomputing logits for {path}")
# Add hash to path...
out_path = path + f"_{mdl_name}"
if os.path.exists(out_path + ".idx"):
continue
dataset = make_dataset(path, neox_args.data_impl, not neox_args.mmap_warmup)
if is_mp_rank_0():
out_dataset = make_builder(out_path + ".bin", neox_args.data_impl)
out_dataset._dtype = np.float32
i = 0
while i < len(dataset):
start = time.time()
model.module.clear_cache() # clear kv cache between batches
if is_mp_rank_0():
offset = (
mpu.get_data_parallel_rank()
* neox_args.train_micro_batch_size_per_gpu
)
context_tokens = [
[int(x) for x in dataset.get(j % len(dataset)).tolist()]
for j in range(
i + offset,
i + (neox_args.train_micro_batch_size_per_gpu + offset),
)
]
# grab microbatch
# pad batch in order to allow conversion to tensor
context_tokens, context_lengths = pad_batch(
copy.deepcopy(context_tokens),
pad_id=0,
pad_len=neox_args.seq_length + 1,
truncate=True,
)
# print(context_tokens)
label_tokens = [tokens[1:] for tokens in context_tokens]
context_tokens = [tokens[:-1] for tokens in context_tokens]
else:
context_tokens = [
[0 for _ in range(neox_args.seq_length)]
for _ in range(neox_args.batch_size)
]
label_tokens = [
[0 for _ in range(neox_args.seq_length)]
for _ in range(neox_args.batch_size)
]
context_lengths = [0 for _ in range(neox_args.batch_size)]
i += (
neox_args.train_micro_batch_size_per_gpu
* mpu.get_data_parallel_world_size()
)
# print(context_tokens)
# convert to tensor and broadcast
context_tokens = torch.cuda.LongTensor(context_tokens)
label_tokens = torch.cuda.LongTensor(label_tokens)
# Make sure context tokens + start tokens are the same across all ranks
token_generation_start_index = torch.cuda.LongTensor(context_lengths)
torch.distributed.broadcast(
context_tokens,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group(),
)
torch.distributed.broadcast(
token_generation_start_index,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group(),
)
torch.distributed.broadcast(
label_tokens,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group(),
)
# context_tokens = context_tokens[:, :chop_len].contiguous()
# label_tokens = label_tokens[:, :chop_len].contiguous()
with torch.no_grad():
# get attention mask / position ids
context_tokens, attention_mask, position_ids = get_batch(
neox_args, context_tokens
)
model_inputs = (
context_tokens,
position_ids,
attention_mask,
)
maybe_tuple = forward_model(
model, model_inputs, neox_args.is_pipe_parallel
)
if isinstance(maybe_tuple, tuple):
logits, _ = maybe_tuple
else:
logits = maybe_tuple
if logits is not None: # if pipe parallel, not all ranks return logits
logits = gather_from_model_parallel_region(logits)
logp = get_logp(logits, label_tokens, True).squeeze()
if neox_args.is_pipe_parallel:
# broadcast generated tokens to pipe parallel group
src_rank = model.grid.stage_to_global(model.num_stages - 1)
logp = (
logp
if logits is not None
else torch.zeros(
neox_args.batch_size, dtype=torch.float32
).cuda()
)
torch.distributed.broadcast(
tensor=logp,
src=src_rank,
group=mpu.get_pipe_parallel_group(),
)
logp = logp.squeeze()
logp_list = [
torch.zeros_like(logp)
for _ in range(mpu.get_data_parallel_world_size())
]
torch.distributed.all_gather(
logp_list, logp, group=mpu.get_data_parallel_group()
)
logp = torch.cat(logp_list, dim=0).cpu().numpy()
if (mpu.get_model_parallel_rank() == 0) and (
mpu.get_data_parallel_rank() == 0
):
for j in range(logp.shape[0]):
out_dataset.add_item(logp[j])
out_dataset.end_document()
print_rank_0(f"Processed {i} / {len(dataset)} in {time.time() - start}")
if is_mp_rank_0():
out_dataset.finalize(
out_path + ".idx",
)
torch.distributed.barrier()
18 changes: 9 additions & 9 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,12 @@ def get_model(neox_args, use_cache=False):

def get_optimizer(model, neox_args, dummy=False):
"""Set up the optimizer."""
if neox_args.no_load_optim:
if neox_args.no_load_optim and neox_args.deepspeed:
# Required to have something so...
dummy = True
neox_args.optimizer = {"params": {"lr": 0.0}}
neox_args.optimizer_type = "adam"
elif neox_args.no_load_optim:
return None, None

if neox_args.optimizer is None:
Expand Down Expand Up @@ -808,7 +813,7 @@ def get_optimizer(model, neox_args, dummy=False):

def get_learning_rate_scheduler(optimizer, neox_args):
"""Build the learning rate scheduler."""
if neox_args.no_load_optim:
if (neox_args.no_load_optim) and not neox_args.deepspeed:
# TODO: this should be configured as a separate arg
return None
if neox_args.deepspeed and neox_args.optimizer_type.lower() == "onebitadam":
Expand Down Expand Up @@ -873,13 +878,8 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None):
ref_optimizer, ref_param_groups, ref_lr_scheduler = None, None, None
if neox_args.deepspeed:
print_rank_0("DeepSpeed is enabled.")
if neox_args.no_load_optim:
assert optimizer is None
_model_params = None
_lr_scheduler = None
else:
_model_params = param_groups if optimizer is None else None
_lr_scheduler = lr_scheduler
_model_params = param_groups if optimizer is None else None
_lr_scheduler = lr_scheduler

model, optimizer, _, lr_scheduler = deepspeed.initialize(
model=model,
Expand Down

0 comments on commit 361f459

Please sign in to comment.