-
Notifications
You must be signed in to change notification settings - Fork 248
[CK_TILE] B matrix 2D block scale gemm #3074
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
|
Hi @samremes, could you please resolve the merge conflicts? |
| std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>, | ||
|
|
||
| std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, | ||
| std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, | ||
| std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, | ||
| std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, | ||
|
|
||
| // 2d cases with grouping also on the n axis | ||
| std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D>, | ||
| std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D>, | ||
| std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D>, | ||
| std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is awesome to have these unit tests 👍
|
|
||
| std::string quant_mode = arg_parser.get_str("quant_mode"); | ||
|
|
||
| using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we make the Quant Group Size as an interface? Currently, we need to manually put the quant dim size.
|
@CongMa13 Please try the solution we discussed of the tile distribution today and see the perf difference. |
|
@ThomasNing @CongMa13 Did you have some ideas for the tile distribution? I think the current versions require that it exactly splits with NWarps and/or NIterPerWarp. |
|
I updated the distribution and calculation of the offset of bq. There are 3 kinds of distribution according to the N group size.
Tests with N group size {1, 8, 16, 32, 64, 128} passed. I provided wrong statement in one comment that N group size should be greater than Warp::N. Obviously, 1 and 8 are all legal value of N group size. |
|
Thanks a lot @CongMa13! I've added an example for the 2d block scale, separately as the dispatching was getting a bit complex with non-preshuffle and other quants. We can maybe merge them again once every variant supports 2d blocks too. |
|
Reformat the example. It should be good now. If we need we could separate the example out again. |
Proposed changes
Introduces 2d block scale support for B matrix (grouping both on N and K axes). The tile distribution for the scale matrix has different options depending on the group size.
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered