Skip to content

Conversation

@samremes
Copy link
Contributor

@samremes samremes commented Oct 22, 2025

Proposed changes

Introduces 2d block scale support for B matrix (grouping both on N and K axes). The tile distribution for the scale matrix has different options depending on the group size.

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered

@illsilin
Copy link
Collaborator

Hi @samremes, could you please resolve the merge conflicts?

Comment on lines 62 to 73
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,

std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>,

// 2d cases with grouping also on the n axis
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D>,
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is awesome to have these unit tests 👍


std::string quant_mode = arg_parser.get_str("quant_mode");

using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we make the Quant Group Size as an interface? Currently, we need to manually put the quant dim size.

@ThomasNing
Copy link
Contributor

@CongMa13 Please try the solution we discussed of the tile distribution today and see the perf difference.

@samremes
Copy link
Contributor Author

@ThomasNing @CongMa13 Did you have some ideas for the tile distribution? I think the current versions require that it exactly splits with NWarps and/or NIterPerWarp.

@CongMa13
Copy link
Collaborator

I updated the distribution and calculation of the offset of bq.

There are 3 kinds of distribution according to the N group size.

  1. N group size < warp::N
    One warp needs multiple bq
  2. N group size <= warp::N * NWarp
    Warp group needs multiple bq
  3. other
    Multiple NIters share one bq

Tests with N group size {1, 8, 16, 32, 64, 128} passed.

I provided wrong statement in one comment that N group size should be greater than Warp::N. Obviously, 1 and 8 are all legal value of N group size.

@samremes
Copy link
Contributor Author

Thanks a lot @CongMa13!

I've added an example for the 2d block scale, separately as the dispatching was getting a bit complex with non-preshuffle and other quants. We can maybe merge them again once every variant supports 2d blocks too.

@ThomasNing ThomasNing requested review from a team and ddembeckAMD as code owners October 31, 2025 23:22
@ThomasNing
Copy link
Contributor

Reformat the example. It should be good now. If we need we could separate the example out again.

@ThomasNing ThomasNing dismissed CongMa13’s stale review October 31, 2025 23:28

Review addressed

ThomasNing
ThomasNing previously approved these changes Oct 31, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants