From 4c426da8b6149e2313bc6e00584531f004cfe457 Mon Sep 17 00:00:00 2001 From: Hailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com> Date: Thu, 6 Jun 2024 21:37:48 -0400 Subject: [PATCH] Conversion script bugfixes (#1218) * update is_pipe_parallel logic ; handle tied-embeddings case correctly * Update NeoXArgs docs automatically * revert PP to be consistent * Update NeoXArgs docs automatically --------- Co-authored-by: github-actions Co-authored-by: Quentin Anthony --- configs/neox_arguments.md | 2 +- tools/ckpts/convert_neox_to_hf.py | 65 ++++++++++++++++++++++--------- 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 7a56e361e..c884afd97 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 8451671 + Default = 714b299 current git hash of repository diff --git a/tools/ckpts/convert_neox_to_hf.py b/tools/ckpts/convert_neox_to_hf.py index 35812383e..f4e0ccf9f 100644 --- a/tools/ckpts/convert_neox_to_hf.py +++ b/tools/ckpts/convert_neox_to_hf.py @@ -580,30 +580,59 @@ def convert( # Load output embedding if not sequential: - loaded_tp_ranks = load_partitions( - input_checkpoint_path, - mp_partitions, - get_key(loaded_config, "num-layers") + 4, - sequential=sequential, - ) + if get_key(loaded_config, "no-weight-tying", False): + # if we have trained input + output embedding layers without tied weights + loaded_tp_ranks = load_partitions( + input_checkpoint_path, + mp_partitions, + get_key(loaded_config, "num-layers") + 4, + sequential=sequential, + ) + else: + # in this case, output embedding layer and input embedding layer are tied. + # load + save the input embed weights into the output embedding layer's place. + loaded_tp_ranks = load_partitions( + input_checkpoint_path, + mp_partitions, + layer_idx=0, + sequential=sequential, + ) # output embedding / LM head if architecture == "neox": # name of lm head / final linear proj varies lm_head = hf_model.embed_out else: lm_head = hf_model.lm_head - lm_head.load_state_dict( - { - "weight": torch.cat( - get_state( - loaded_tp_ranks, - "final_linear.weight", - layer_idx=get_key(loaded_config, "num-layers") + 4, - sequential=sequential, + + if get_key(loaded_config, "no-weight-tying", False): + # save the (untied) final linear into LM head for HF + lm_head.load_state_dict( + { + "weight": torch.cat( + get_state( + loaded_tp_ranks, + "final_linear.weight", + layer_idx=get_key(loaded_config, "num-layers") + 4, + sequential=sequential, + ), + dim=0, ), - dim=0, - ), - } - ) + } + ) + else: + # embedding layers are tied. transpose input layer and save + lm_head.load_state_dict( + { + "weight": torch.cat( + get_state( + loaded_tp_ranks, + "word_embeddings.weight", + layer_idx=0, + sequential=sequential, + ), + dim=0, + ), + } + ) del loaded_tp_ranks