Skip to content
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
8bb5255
Refactor quant group size to be configurable for M/N/K, not just K
samremes Oct 13, 2025
98365f5
add some asserts for configurations not implemented
samremes Oct 13, 2025
f6b07dc
start setting of group size for N dimension
samremes Oct 13, 2025
22362f2
enable 2d for reference quant gemm
samremes Oct 14, 2025
9988a46
WIP: trying to figure out tile dstr and/or indexing for scale matrix
samremes Oct 16, 2025
36b88c6
WIP
samremes Oct 20, 2025
bb52cd9
Fix handling of n dim blocks in tile windows etc
samremes Oct 21, 2025
f179a8a
remove commented code and enable all tests again
samremes Oct 22, 2025
d100ab6
fix formatting
samremes Oct 22, 2025
37738e4
Add more specialized tile distributions
samremes Oct 27, 2025
98deefa
Enable NWarps replication for bquant tile dstr
samremes Oct 27, 2025
2d86cd0
fix formatting
samremes Oct 27, 2025
470d6e4
Merge remote-tracking branch 'origin/develop' into samremes/bmatrix_2…
samremes Oct 27, 2025
1f13003
fix format
samremes Oct 27, 2025
a449728
Merge remote-tracking branch 'origin/develop' into samremes/bmatrix_2…
samremes Oct 28, 2025
e12ab56
Fix some issues from the merge
samremes Oct 28, 2025
7c93551
fix formatting
samremes Oct 28, 2025
e1475d4
one more fix to tile dstr, and revert debug initialization
samremes Oct 28, 2025
5e0a356
Remove commented code
samremes Oct 29, 2025
1290b1b
simplify conditions that are needed for tile distributions
samremes Oct 29, 2025
306e25a
only enable the working group sizes in tests
samremes Oct 29, 2025
68e41da
fix formatting
samremes Oct 30, 2025
bcccafe
Update tile distribution for 2D bquant
CongMa13 Oct 31, 2025
fe92102
add some documentation and 2d block scale example
samremes Oct 31, 2025
6f90564
fix formatting
samremes Oct 31, 2025
89be44d
Add in Changlog and restructure the quant 2d example
ThomasNing Oct 31, 2025
346ee26
solve the merge conflict
ThomasNing Oct 31, 2025
6b4b6fb
fix CMake
ThomasNing Nov 2, 2025
c494b23
support the change for blockscale 2d
ThomasNing Nov 2, 2025
a25f7cd
fix the test file
ThomasNing Nov 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added WMMA (gfx12) support for FMHA.
* Added pooling kernel in CK_TILE
* Added top-k sigmoid kernel in CK_TILE
* Added the blockscale 2D support for CK_TILE GEMM.

### Changed

Expand Down
3 changes: 3 additions & 0 deletions example/ck_tile/38_block_scale_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
add_executable(tile_example_gemm_quant_basic EXCLUDE_FROM_ALL gemm_quant_basic.cpp)
target_compile_options(tile_example_gemm_quant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})

add_executable(tile_example_gemm_quant_2d_block EXCLUDE_FROM_ALL gemm_quant_2d_block.cpp)
target_compile_options(tile_example_gemm_quant_2d_block PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping ck_tile quant gemm tests for current target")
endif()
265 changes: 114 additions & 151 deletions example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

// This example demonstrates 2D block scale quantization (N×K) for BQuant
// using non-preshuffled configuration.
// NOTE: Once more 2d support is ready, we can migrate all 2d quant types to this example
// This is currently done separately to avoid too verbose dispatching.

#include <cstring>
#include <iostream>
#include <ostream>
Expand All @@ -17,7 +22,7 @@ template <typename GemmConfig,
typename ALayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize,
typename QuantGroupSize,
ck_tile::QuantType QuantMode,
typename CDEElementWise>
float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
Expand Down Expand Up @@ -57,11 +62,12 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
GemmTraits,
ComputeDataType>;

// This example only supports BQuant (no AQuant)
// For non-preshuffled BQuant, use BaseBQuantGemmPipelineAgBgCrCompV3
using BaseGemmPipeline = std::conditional_t<
GemmConfig::PreshuffleB == true,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
ck_tile::BaseAQuantGemmPipelineAgBgCrMem<GemmPipelineProblem>>; // memory pipeline hardcoded
// for aquant
ck_tile::BaseBQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;

const ck_tile::index_t K_split =
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
Expand Down Expand Up @@ -229,7 +235,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str

template <typename GemmConfig,
typename TypeConfig,
uint32_t QuantGroupSize,
typename QuantGroupSize,
ck_tile::QuantType QuantMode>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
Expand Down Expand Up @@ -266,146 +272,99 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
return 0;
}

// Forward declaration for dispatch function
template <template <typename PreType> typename GemmConfig, typename QuantGroupSize>
int dispatch_by_data_type(const std::string& data_type,
const std::string& quant_mode,
const std::string& a_layout,
const std::string& b_layout,
int argc,
char* argv[]);

// Helper function to parse group size string "MxNxK"
std::tuple<int, int, int> parse_group_size(const std::string& group_size_str)
{
int m = 1, n = 1, k = 128;

size_t first_x = group_size_str.find('x');
if(first_x == std::string::npos)
{
// Single number provided, assume it's the K dimension
k = std::stoi(group_size_str);
return {1, 1, k};
}

size_t second_x = group_size_str.find('x', first_x + 1);
if(second_x == std::string::npos)
{
throw std::runtime_error("Invalid group_size format! Expected MxNxK (e.g., 1x32x128)");
}

m = std::stoi(group_size_str.substr(0, first_x));
n = std::stoi(group_size_str.substr(first_x + 1, second_x - first_x - 1));
k = std::stoi(group_size_str.substr(second_x + 1));

return {m, n, k};
}

template <template <typename PreType> typename GemmConfig>
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;

std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
std::string quant_mode = arg_parser.get_str("quant_mode");
std::string group_size_str = arg_parser.get_str("group_size");

auto [m_group, n_group, k_group] = parse_group_size(group_size_str);

// Dispatch based on group size (M, N, K)
return dispatch_group_size_ct<GemmConfig>(m_group, n_group, k_group, [&](auto QGS_) {
using QuantGroupSize = decltype(QGS_);
return dispatch_by_data_type<GemmConfig, QuantGroupSize>(
data_type, quant_mode, a_layout, b_layout, argc, argv);
});
}

std::string quant_mode = arg_parser.get_str("quant_mode");
template <template <typename PreType> typename GemmConfig, typename QuantGroupSize>
int dispatch_by_data_type(const std::string& data_type,
const std::string& quant_mode,
const std::string& a_layout,
const std::string& b_layout,
int argc,
char* argv[])
{
// This example ONLY supports BQuant for 2D block scale quantization
if(quant_mode != "bquant")
{
throw std::runtime_error("This example only supports BQuant! Use --quant_mode=bquant");
}

if(data_type == "fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});

if(quant_mode == "aquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
128,
ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "bquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
128,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "rowcol")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
128,
ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "tensor")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
128,
ck_tile::QuantType::TensorQuant>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error(
"Unsupported quantization mode! Use 'aquant', 'bquant', 'tensor' or 'rowcol'");
}
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});

if(quant_mode == "aquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
128,
ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "bquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
128,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "rowcol")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
128,
ck_tile::QuantType::RowColQuant>(
a_layout, b_layout, argc, argv);
}
else if(quant_mode == "tensor")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
128,
ck_tile::QuantType::TensorQuant>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error(
"Unsupported quantization mode! Use 'aquant', 'bquant', 'tensor' or 'rowcol'");
}
}
else if(data_type == "i4fp8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::fp8_t,
ck_tile::half_t,
ck_tile::fp8_t>{});

if(quant_mode == "aquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
128,
ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error(
"Unsupported quantization mode for this datatype! Use 'aquant'.");
}
}
else if(data_type == "i4bf8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::bf8_t,
ck_tile::half_t,
ck_tile::bf8_t>{});

if(quant_mode == "aquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
128,
ck_tile::QuantType::AQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error(
"Unsupported quantization mode for this datatype! Use 'aquant'.");
}
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "fp8i4")
{
Expand All @@ -414,19 +373,11 @@ int run_gemm_example(int argc, char* argv[])
ck_tile::half_t,
ck_tile::fp8_t>{});

if(quant_mode == "bquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
128,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error(
"Unsupported quantization mode for this datatype! Use 'bquant'.");
}
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8i4")
{
Expand All @@ -435,27 +386,39 @@ int run_gemm_example(int argc, char* argv[])
ck_tile::half_t,
ck_tile::bf8_t>{});

if(quant_mode == "bquant")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
128,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error(
"Unsupported quantization mode for this datatype! Use 'bquant'.");
}
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
TypeConfig,
QuantGroupSize,
ck_tile::QuantType::BQuantGrouped>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}

template <template <typename> typename GemmConfig, typename F>
int dispatch_group_size_ct(int m, int n, int k, F&& f)
{
// This expands into a sequence of `if (m==M && n==N && k==K) { ... }`
#define DISPATCH_ONE(M, N, K) \
if(m == M && n == N && k == K) \
{ \
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<M, N, K>>; \
return f(QuantGroupSize{}); \
}

CK_TILE_SUPPORTED_QUANT_GROUPS(DISPATCH_ONE)

#undef DISPATCH_ONE

throw std::runtime_error(
"Unsupported group size! Please add it to CK_TILE_SUPPORTED_QUANT_GROUPS(X).");
}

int main(int argc, char* argv[])
{
return !run_gemm_example<GemmConfigPreshuffleB_Bquant_prefill>(argc, argv);
// Use non-preshuffled GemmConfig for 2D block scale support
return !run_gemm_example<GemmConfigBQuantPrefill>(argc, argv);
}
Loading