Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

state_dict_factory: llama checkpoint - support SWIGLU #5601

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

nelyahu
Copy link
Contributor

@nelyahu nelyahu commented Jun 2, 2024

DeepSpeed supports loading a checkpoint for inference with different DP/TP/PP. This requires to split/merge parameters based on their TP attributes. Currently, this is done by using model specific parameter names. This is not a good practice and should be modified.

This commit handles the required changes to support MDS LLaMA model. There are 2 changes:

  • Support for lm_head.weight
  • Support for mlp.h_to_4h.weight for SWIGLU

SWIGLU requires different handling, however there is no meta data available that identifies mlp.h_to_4h.weight as SWIGLU. Therefore, for now we use a hack to detect it.

DeepSpeed supports loading a checkpoint for inference with different DP/TP/PP.
This requires to split/merge parameters based on their TP attributes.
Currently, this is done by using model specific parameter names.
This is not a good practice and should be modified.

This commit handles the required changes to support MDS LLaMA model.
There are 2 changes:
- Support for lm_head.weight
- Support for mlp.h_to_4h.weight for SWIGLU

SWIGLU requires different handling, however there is no meta data available
that identifies mlp.h_to_4h.weight as SWIGLU. Therefore, for now we use a hack
to detect it.
@tjruwase tjruwase requested review from samadejacobs and tohtana and removed request for mrwyattii June 21, 2024 22:27
new_client_sd[key] = torch.cat(value_list, axis=0)
elif "mlp.dense_h_to_4h.weight" in key:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nelyahu, this logic is very old and hacky. Could it not be replaced with Universal Checkpointing?

Copy link
Contributor

@tohtana tohtana left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As Tunji commented, we now have a more flexible solution introduced in #5390. Can you it? The examples in #5390 might help.

@tohtana tohtana self-assigned this Sep 4, 2024
@tohtana
Copy link
Contributor

tohtana commented Sep 6, 2024

@nelyahu Do we have any update?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants