[DRAFT][s4xbf16] JoinOp vectorization patch #22
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Do not merge until rebased on upstream Triton up to triton-lang@c1ed673
This is 1 of the 2 patches needed to improve int4xbf16 GEMM perf.
This is needed because joinOp by default interleaves every element of the two input matrices. In the case of bf16, this means Triton will extract the 2x bf16 values out of the 32-bit register and re-insert them into a new register. This results in many
mov
instructions before MMA. On certain shapes, this could mean a ~10% perf penalty.This PR addresses the above by situationally "vectorizing" the interleaving; namely, join every two elements instead of one. This avoids the need to extract values out of registers. Of course, this would also require one to modify the inline_asm logic before the join to produce the correct layout.
cc @gflegar