-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add muon code #168
base: main
Are you sure you want to change the base?
add muon code #168
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm! just some small nits
if isinstance(g, DTensor): | ||
g, meta = to_local(g, keep_sharded=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will result in every rank doing the same orthogonalization computation?
meta = None | ||
if isinstance(g, DTensor): | ||
g, meta = to_local(g, keep_sharded=False) | ||
# gives NaNs when done with Dtensor, instead of throwing a typical op not supported error, quite sneaky |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's not good 🥲
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code is mostly copy paste from this repo btw : https://github.com/ethansmith2000/fsdp_optimizers
How do you think it should be handled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personally, if the optimizer code is meant to run locally per rank without any collectives and you own the optimizer implementation, then converting DTensor
s to local torch.Tensor
s for all of the computation seems fine (and will have slightly lower eager-mode overhead due to avoiding DTensor.__torch_dispatch__
).
The main value add of DTensor
in that case is providing the sharding info/metadata on the tensor, which could be useful in the state dict for example. For that, you would still want the optimizer states to be saved as DTensor
s.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see makes sense. Thanks !
No description provided.