-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Fix FP8 torchao default config with padding and FSDP2 all-gather support #3831
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the changes, really appreciate it ! Left a few minor comments to make it better !
| The configuration for the FP8 training. If `None`, a default config will be created with sensible | ||
| defaults for most use cases: | ||
| - `pad_inner_dim=True`: Pads matrix dimensions to be divisible by 16, required for `torch._scaled_mm` | ||
| operations to prevent runtime errors. | ||
| - `enable_fsdp_float8_all_gather=True`: Enables FP8 all-gather for FSDP2. This provides memory bandwidth | ||
| savings by casting parameters before the all-gather operation, saving 50% bandwidth compared to BF16. | ||
| You can override these defaults by providing your own `Float8LinearConfig` instance. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, maybe we can also allow users to easily change that with env var + update the cluster.py file which is responsible of the behavior of accelerate config ? Here's a PR that should help with the changes: #2983
env_prefix = "ACCELERATE_FP8_"
enable_fsdp_float8_all_gather = os.environ.get(env_prefix + "ENABLE_FSDP_FLOAT_ALL_GATHER", True)
pad_inner_dim = os.environ.get(env_prefix + "PAD_INNER_DIM", True)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, thanks for the reference. Will add that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated @SunMarc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @SunMarc let me know if there's anything else needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nope everything looks good !
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
d8f018f to
937b6ea
Compare
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's really nice, thanks for fixing this !
|
@SunMarc Sorry, had to fix style issue. Can you re-approve? |
What does this PR do?
enable_fsdp_float8_all_gather=Trueandpad_inner_dim=Trueor
Fixes #3830
Before submitting
This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
Did you read the contributor guideline,
Pull Request section?
Was this discussed/approved via a Github issue or the forum? Please add a link
to it if that's the case.
Did you make sure to update the documentation with your changes? Here are the
documentation guidelines, and
here are tips on formatting docstrings.
Did you write any new necessary tests?
Ran
pytest tests/test_fp8.py -vsuccessfullyWho can review?