Skip to content

Conversation

@shimizust
Copy link
Contributor

@shimizust shimizust commented Nov 3, 2025

What does this PR do?

  • If you use accelerate with transformers-based Trainers, using config-file based approach to launching the training job, fp8 using torchao doesn't work properly
  • This PR sets reasonable defaults if use of torchao is specified via accelerate configs, specifically enable_fsdp_float8_all_gather=True and pad_inner_dim=True
  • Added ability to set these params by CLI or accelerate configs
mixed_precision: fp8
fp8_config:
  backend: AO

or

fp8_config:
  backend: AO
  pad_inner_dim: true
  enable_fsdp_float8_all_gather: true

Fixes #3830

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).

  • Did you read the contributor guideline,
    Pull Request section?

  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.

  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.

  • Did you write any new necessary tests?

  • Ran pytest tests/test_fp8.py -v successfully

Who can review?

@shimizust shimizust marked this pull request as ready for review November 3, 2025 18:30
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the changes, really appreciate it ! Left a few minor comments to make it better !

Comment on lines 317 to 324
The configuration for the FP8 training. If `None`, a default config will be created with sensible
defaults for most use cases:
- `pad_inner_dim=True`: Pads matrix dimensions to be divisible by 16, required for `torch._scaled_mm`
operations to prevent runtime errors.
- `enable_fsdp_float8_all_gather=True`: Enables FP8 all-gather for FSDP2. This provides memory bandwidth
savings by casting parameters before the all-gather operation, saving 50% bandwidth compared to BF16.
You can override these defaults by providing your own `Float8LinearConfig` instance.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, maybe we can also allow users to easily change that with env var + update the cluster.py file which is responsible of the behavior of accelerate config ? Here's a PR that should help with the changes: #2983

env_prefix = "ACCELERATE_FP8_"
enable_fsdp_float8_all_gather = os.environ.get(env_prefix + "ENABLE_FSDP_FLOAT_ALL_GATHER", True)
pad_inner_dim = os.environ.get(env_prefix + "PAD_INNER_DIM", True)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks for the reference. Will add that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated @SunMarc

Copy link
Contributor Author

@shimizust shimizust Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @SunMarc let me know if there's anything else needed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope everything looks good !

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's really nice, thanks for fixing this !

@shimizust
Copy link
Contributor Author

@SunMarc Sorry, had to fix style issue. Can you re-approve?

@SunMarc SunMarc merged commit 75983a5 into huggingface:main Dec 3, 2025
24 of 25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Torchao fp8 fails if using accelerate config file with Trainer

3 participants