Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT][s4xbf16] JoinOp vectorization patch #22

Open
wants to merge 1 commit into
base: llvm-head
Choose a base branch
from

Conversation

ggengnv
Copy link

@ggengnv ggengnv commented Feb 26, 2025

Do not merge until rebased on upstream Triton up to triton-lang@c1ed673

This is 1 of the 2 patches needed to improve int4xbf16 GEMM perf.

This is needed because joinOp by default interleaves every element of the two input matrices. In the case of bf16, this means Triton will extract the 2x bf16 values out of the 32-bit register and re-insert them into a new register. This results in many mov instructions before MMA. On certain shapes, this could mean a ~10% perf penalty.

This PR addresses the above by situationally "vectorizing" the interleaving; namely, join every two elements instead of one. This avoids the need to extract values out of registers. Of course, this would also require one to modify the inline_asm logic before the join to produce the correct layout.

cc @gflegar

@ggengnv ggengnv changed the title [DRAFT] JoinOp vectorization patch [DRAFT][s4xbf16] JoinOp vectorization patch Feb 26, 2025
@ggengnv
Copy link
Author

ggengnv commented Feb 26, 2025

For small-M shapes, for best perf, we'll additionally want XLA to swap A/B so that the LHS of dot is quantized, and then set envvar DISABLE_MMA_V3 to force Ampere-MMA.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant