diff --git a/generate.py b/generate.py index 743e350d0..e19ef2e0e 100755 --- a/generate.py +++ b/generate.py @@ -23,6 +23,7 @@ generate_samples_from_prompt, generate_samples_unconditional, generate_samples_interactive, + precompute_logits, ) @@ -83,6 +84,8 @@ def main(input_args=None, overwrite_values=None): top_p=neox_args.top_p, ) + elif neox_args.text_gen_type == "precompute": + precompute_logits(neox_args=neox_args, model=model) else: raise ValueError( f"`text_gen_type` either not specified or not recognised: {neox_args.text_gen_type}" diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 7b1a60d46..3ce8b881a 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -1281,7 +1281,12 @@ class NeoXArgsTextgen(NeoXArgsTemplate): text_gen_type: str = None """ How to generate text/sample the model. - Options: `unconditional`, `input-file`, `interactive` + Options: `unconditional`, `input-file`, `interactive`, `precompute` + """ + + precompute_model_name: str = None + """ + Model name to use for saving precomputed logprobs """ temperature: float = 0.0 diff --git a/megatron/text_generation_utils.py b/megatron/text_generation_utils.py index 7b7a390ab..02926c2c3 100644 --- a/megatron/text_generation_utils.py +++ b/megatron/text_generation_utils.py @@ -23,12 +23,15 @@ import time from typing import List, Union +import numpy as np import torch import torch.nn.functional as F from megatron import print_rank_0 from megatron import mpu from megatron.utils import get_ltor_masks_and_position_ids, is_mp_rank_0 +from megatron.data.indexed_dataset import make_builder, make_dataset +from megatron.mpu.mappings import gather_from_model_parallel_region def get_batch(neox_args, context_tokens: torch.Tensor): @@ -52,7 +55,9 @@ def get_batch(neox_args, context_tokens: torch.Tensor): return tokens, attention_mask, position_ids -def pad_batch(context_tokens: List[List[int]], pad_id: int, pad_len: int): +def pad_batch( + context_tokens: List[List[int]], pad_id: int, pad_len: int, truncate: bool = False +): """ pads context lengths in context_tokens with pad_id to equal neox_args.seq_length, and returns the padded batch and the new lengths. @@ -60,17 +65,21 @@ def pad_batch(context_tokens: List[List[int]], pad_id: int, pad_len: int): context_tokens: list of lists of tokens pad_id: int, integer to use as padding token pad_len: int, context length to be padded; all batch items will be padded to the same length + truncate: bool, if True, truncate context tokens to pad_len if they are longer than pad_len returns: tuple of padded context tokens and a list of unpadded token count """ context_lengths = [] - for tokens in context_tokens: + for i, tokens in enumerate(context_tokens): context_length = len(tokens) if context_length < pad_len: tokens.extend([pad_id] * (pad_len - context_length)) elif context_length > pad_len: - raise ValueError("context_length is bigger than to be padded length") + if not truncate: + raise ValueError("context_length is bigger than to be padded length") + context_tokens[i] = tokens[:pad_len] + context_length = pad_len context_lengths.append(context_length) return context_tokens, context_lengths @@ -807,3 +816,180 @@ def generate_samples_interactive( print_rank_0("Generated Text: " + generated_text) if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: _ = input("\n") + + +def get_logp(logits, labels, force_fp32=False): + if force_fp32: + logits = logits.float() + logp = logits.log_softmax(dim=-1) + return torch.gather(logp, dim=2, index=labels.unsqueeze(2)).squeeze(2) + + +def precompute_logits(neox_args, model): + """ + Precomputes logprobs from training/testing/validation datasets + + Saves it to the same directory as the dataset with the model name appended to it + + neox_args: NeoXArgs. + model: a Megatron model + + """ + if neox_args.precompute_model_name is None: + mdl_name = str(hash(neox_args.load)) + else: + mdl_name = neox_args.precompute_model_name + print_rank_0("Precomputing logprobs...") + model.eval() + data_paths = list() + if neox_args.train_data_paths is not None: + for path in neox_args.train_data_paths: + data_paths.append(path) + for path in neox_args.test_data_paths: + data_paths.append(path) + for path in neox_args.valid_data_paths: + data_paths.append(path) + elif neox_args.pos_train_data_paths is not None: + # Pairwise data... + for path in neox_args.pos_train_data_paths: + data_paths.append(path) + for path in neox_args.neg_train_data_paths: + data_paths.append(path) + for path in neox_args.pos_valid_data_paths: + data_paths.append(path) + for path in neox_args.neg_valid_data_paths: + data_paths.append(path) + for path in neox_args.pos_test_data_paths: + data_paths.append(path) + for path in neox_args.neg_test_data_paths: + data_paths.append(path) + for path in data_paths: + print_rank_0(f"Precomputing logits for {path}") + # Add hash to path... + out_path = path + f"_{mdl_name}" + if os.path.exists(out_path + ".idx"): + continue + dataset = make_dataset(path, neox_args.data_impl, not neox_args.mmap_warmup) + if is_mp_rank_0(): + out_dataset = make_builder(out_path + ".bin", neox_args.data_impl) + out_dataset._dtype = np.float32 + i = 0 + while i < len(dataset): + start = time.time() + model.module.clear_cache() # clear kv cache between batches + if is_mp_rank_0(): + offset = ( + mpu.get_data_parallel_rank() + * neox_args.train_micro_batch_size_per_gpu + ) + context_tokens = [ + [int(x) for x in dataset.get(j % len(dataset)).tolist()] + for j in range( + i + offset, + i + (neox_args.train_micro_batch_size_per_gpu + offset), + ) + ] + # grab microbatch + # pad batch in order to allow conversion to tensor + context_tokens, context_lengths = pad_batch( + copy.deepcopy(context_tokens), + pad_id=0, + pad_len=neox_args.seq_length + 1, + truncate=True, + ) + # print(context_tokens) + label_tokens = [tokens[1:] for tokens in context_tokens] + context_tokens = [tokens[:-1] for tokens in context_tokens] + else: + context_tokens = [ + [0 for _ in range(neox_args.seq_length)] + for _ in range(neox_args.batch_size) + ] + label_tokens = [ + [0 for _ in range(neox_args.seq_length)] + for _ in range(neox_args.batch_size) + ] + context_lengths = [0 for _ in range(neox_args.batch_size)] + i += ( + neox_args.train_micro_batch_size_per_gpu + * mpu.get_data_parallel_world_size() + ) + # print(context_tokens) + # convert to tensor and broadcast + context_tokens = torch.cuda.LongTensor(context_tokens) + label_tokens = torch.cuda.LongTensor(label_tokens) + # Make sure context tokens + start tokens are the same across all ranks + token_generation_start_index = torch.cuda.LongTensor(context_lengths) + torch.distributed.broadcast( + context_tokens, + mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group(), + ) + torch.distributed.broadcast( + token_generation_start_index, + mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group(), + ) + torch.distributed.broadcast( + label_tokens, + mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group(), + ) + # context_tokens = context_tokens[:, :chop_len].contiguous() + # label_tokens = label_tokens[:, :chop_len].contiguous() + with torch.no_grad(): + # get attention mask / position ids + context_tokens, attention_mask, position_ids = get_batch( + neox_args, context_tokens + ) + model_inputs = ( + context_tokens, + position_ids, + attention_mask, + ) + maybe_tuple = forward_model( + model, model_inputs, neox_args.is_pipe_parallel + ) + if isinstance(maybe_tuple, tuple): + logits, _ = maybe_tuple + else: + logits = maybe_tuple + if logits is not None: # if pipe parallel, not all ranks return logits + logits = gather_from_model_parallel_region(logits) + logp = get_logp(logits, label_tokens, True).squeeze() + if neox_args.is_pipe_parallel: + # broadcast generated tokens to pipe parallel group + src_rank = model.grid.stage_to_global(model.num_stages - 1) + logp = ( + logp + if logits is not None + else torch.zeros( + neox_args.batch_size, dtype=torch.float32 + ).cuda() + ) + torch.distributed.broadcast( + tensor=logp, + src=src_rank, + group=mpu.get_pipe_parallel_group(), + ) + logp = logp.squeeze() + logp_list = [ + torch.zeros_like(logp) + for _ in range(mpu.get_data_parallel_world_size()) + ] + torch.distributed.all_gather( + logp_list, logp, group=mpu.get_data_parallel_group() + ) + logp = torch.cat(logp_list, dim=0).cpu().numpy() + if (mpu.get_model_parallel_rank() == 0) and ( + mpu.get_data_parallel_rank() == 0 + ): + for j in range(logp.shape[0]): + out_dataset.add_item(logp[j]) + out_dataset.end_document() + print_rank_0(f"Processed {i} / {len(dataset)} in {time.time() - start}") + if is_mp_rank_0(): + out_dataset.finalize( + out_path + ".idx", + ) + torch.distributed.barrier() diff --git a/megatron/training.py b/megatron/training.py index b578c4ad9..ed01996e5 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -654,7 +654,12 @@ def get_model(neox_args, use_cache=False): def get_optimizer(model, neox_args, dummy=False): """Set up the optimizer.""" - if neox_args.no_load_optim: + if neox_args.no_load_optim and neox_args.deepspeed: + # Required to have something so... + dummy = True + neox_args.optimizer = {"params": {"lr": 0.0}} + neox_args.optimizer_type = "adam" + elif neox_args.no_load_optim: return None, None if neox_args.optimizer is None: @@ -808,7 +813,7 @@ def get_optimizer(model, neox_args, dummy=False): def get_learning_rate_scheduler(optimizer, neox_args): """Build the learning rate scheduler.""" - if neox_args.no_load_optim: + if (neox_args.no_load_optim) and not neox_args.deepspeed: # TODO: this should be configured as a separate arg return None if neox_args.deepspeed and neox_args.optimizer_type.lower() == "onebitadam": @@ -873,13 +878,8 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): ref_optimizer, ref_param_groups, ref_lr_scheduler = None, None, None if neox_args.deepspeed: print_rank_0("DeepSpeed is enabled.") - if neox_args.no_load_optim: - assert optimizer is None - _model_params = None - _lr_scheduler = None - else: - _model_params = param_groups if optimizer is None else None - _lr_scheduler = lr_scheduler + _model_params = param_groups if optimizer is None else None + _lr_scheduler = lr_scheduler model, optimizer, _, lr_scheduler = deepspeed.initialize( model=model,