-
Notifications
You must be signed in to change notification settings - Fork 183
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
float8 with delayed scaling: fix autocast handling (#1306)
Summary: Fixes a bug with delayed scaling + autocast. Before, the last input dtype when in autocast was queried from the input to `torch._scaled_mm`: ``` x_hp -> {query_dtype_here} -> to_autocast -> torch._scaled_mm ``` This is incorrect because the dtype was saved from before the place where autocast could change it. This happened to work if `x_hp` was already of the correct dtype, but did not work in cases such as the new test case added in this PR, or real models such as the repro from #1297. The reason we haven't caught this for so long is we've been using FSDP's mixed precision and not single-GPU autocast. The fix I'm taking here is to query the original post-autocast dtype based on the output of `torch._scaled_mm`. Since this dtype is based on the dtype of the input to `torch._scaled_mm`, this will properly capture autocasting: ``` x_hp -> to_autocast -> x_autocast_dtype -> to_fp8 -> x_fp8 -> torch._scaled_mm -> {query_dtype_here} ``` Test Plan: ``` // first, test the updated test case - it passes // second - test a modified version of the repro in // #1297: // code: https://gist.github.com/vkuzo/6c53a1deca19856238d38746b1e52ee7 // logs: https://gist.github.com/vkuzo/60846b1f6b2822f91d2dfa67cab10a10 // we now see a speedup with float8 ``` Reviewers: Subscribers: Tasks: Tags:
- Loading branch information
Showing
3 changed files
with
15 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters