This repository has been archived by the owner on Aug 7, 2024. It is now read-only.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
The fused_sat_cast kernel can be found here: https://github.com/drisspg/driss_torch/blob/6d1be6ec21c5a56cf8ddfeb12a57cfce316e40bb/src/saturated_cast.cu#L243
The other two kernels can be found here: https://github.com/pytorch-labs/float8_experimental/pull/227/files#diff-a1a29e99a81b48419f66c77a301ca1f09c51bf754baf82e0962c9bc243d89310
Eager numbers from this script: which is a stripped down version of the full benchmark script to compare fused/vs unfused casting
https://github.com/pytorch-labs/float8_experimental/pull/227/files#diff-729d5216ec3b30dea879056f9eb4a9bac9127501b8c8e6d516b640abb2f106ae
Table
Key/Structure:
shape ref_dtype fuse_cast ref_time_sec pt_fp8_time_sec pt_fp8_speedup (16384, 8192, 1280) torch.bfloat16 True 0.002140 0.002688 0.796081 (16384, 8192, 1280) torch.bfloat16 False 0.002142 0.004025 0.532102 (16384, 1024, 8192) torch.bfloat16 True 0.001883 0.002398 0.785198 (16384, 1024, 8192) torch.bfloat16 False 0.001885 0.003384 0.556938 (16384, 8192, 7168) torch.bfloat16 True 0.010392 0.007928 1.310714 (16384, 8192, 7168) torch.bfloat16 False 0.010418 0.011007 0.946480 (16384, 3584, 8192) torch.bfloat16 True 0.005375 0.004720 1.138892 (16384, 3584, 8192) torch.bfloat16 False 0.005423 0.006589 0.823073
Traces:
Repro script: https://gist.github.com/drisspg/693b53527859433fc9d8987a1b7e464b
Things left to do still for more perf
During the backward pass we also want the option to transpose so that we can we have prepare inputs for the TN format. I still need to add support for this in the kernel but besides that this is a less contained change since we need to sinnal this to the
to_fp8_no_autograd
constructor instead of relying on the compiler to generate this.The same could be done for scale inverse calls as well but again we need to cache these on the fp8 tensor