Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add muon code #168

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open

add muon code #168

wants to merge 8 commits into from

Conversation

samsja
Copy link
Collaborator

@samsja samsja commented Dec 4, 2024

No description provided.

Copy link
Member

@Jackmin801 Jackmin801 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! just some small nits

src/zeroband/train.py Outdated Show resolved Hide resolved
src/zeroband/train.py Outdated Show resolved Hide resolved
Comment on lines +175 to +176
if isinstance(g, DTensor):
g, meta = to_local(g, keep_sharded=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will result in every rank doing the same orthogonalization computation?

meta = None
if isinstance(g, DTensor):
g, meta = to_local(g, keep_sharded=False)
# gives NaNs when done with Dtensor, instead of throwing a typical op not supported error, quite sneaky
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's not good 🥲

Copy link
Collaborator Author

@samsja samsja Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is mostly copy paste from this repo btw : https://github.com/ethansmith2000/fsdp_optimizers

How do you think it should be handled?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally, if the optimizer code is meant to run locally per rank without any collectives and you own the optimizer implementation, then converting DTensors to local torch.Tensors for all of the computation seems fine (and will have slightly lower eager-mode overhead due to avoiding DTensor.__torch_dispatch__).

The main value add of DTensor in that case is providing the sharding info/metadata on the tensor, which could be useful in the state dict for example. For that, you would still want the optimizer states to be saved as DTensors.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see makes sense. Thanks !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants