Skip to content

Commit

Permalink
Conversion script bugfixes (#1218)
Browse files Browse the repository at this point in the history
* update is_pipe_parallel logic ; handle tied-embeddings case correctly

* Update NeoXArgs docs automatically

* revert PP to be consistent

* Update NeoXArgs docs automatically

---------

Co-authored-by: github-actions <[email protected]>
Co-authored-by: Quentin Anthony <[email protected]>
  • Loading branch information
3 people committed Jun 7, 2024
1 parent 2382bd4 commit 4c426da
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 19 deletions.
2 changes: 1 addition & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 8451671
Default = 714b299

current git hash of repository

Expand Down
65 changes: 47 additions & 18 deletions tools/ckpts/convert_neox_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,30 +580,59 @@ def convert(

# Load output embedding
if not sequential:
loaded_tp_ranks = load_partitions(
input_checkpoint_path,
mp_partitions,
get_key(loaded_config, "num-layers") + 4,
sequential=sequential,
)
if get_key(loaded_config, "no-weight-tying", False):
# if we have trained input + output embedding layers without tied weights
loaded_tp_ranks = load_partitions(
input_checkpoint_path,
mp_partitions,
get_key(loaded_config, "num-layers") + 4,
sequential=sequential,
)
else:
# in this case, output embedding layer and input embedding layer are tied.
# load + save the input embed weights into the output embedding layer's place.
loaded_tp_ranks = load_partitions(
input_checkpoint_path,
mp_partitions,
layer_idx=0,
sequential=sequential,
)
# output embedding / LM head
if architecture == "neox": # name of lm head / final linear proj varies
lm_head = hf_model.embed_out
else:
lm_head = hf_model.lm_head
lm_head.load_state_dict(
{
"weight": torch.cat(
get_state(
loaded_tp_ranks,
"final_linear.weight",
layer_idx=get_key(loaded_config, "num-layers") + 4,
sequential=sequential,

if get_key(loaded_config, "no-weight-tying", False):
# save the (untied) final linear into LM head for HF
lm_head.load_state_dict(
{
"weight": torch.cat(
get_state(
loaded_tp_ranks,
"final_linear.weight",
layer_idx=get_key(loaded_config, "num-layers") + 4,
sequential=sequential,
),
dim=0,
),
dim=0,
),
}
)
}
)
else:
# embedding layers are tied. transpose input layer and save
lm_head.load_state_dict(
{
"weight": torch.cat(
get_state(
loaded_tp_ranks,
"word_embeddings.weight",
layer_idx=0,
sequential=sequential,
),
dim=0,
),
}
)

del loaded_tp_ranks

Expand Down

0 comments on commit 4c426da

Please sign in to comment.