From 22ecb30caf2f3cf2166e60ccaf4d9107be3d47a8 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Wed, 29 Oct 2025 18:25:37 +0000 Subject: [PATCH 01/14] Fix splitK multiply_multiply_wp --- ..._multiply_multiply_xdl_fp8_bpreshuffle.cpp | 107 ++++++++++++++++-- ...m_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp | 45 +++++--- profiler/include/profiler/common.hpp | 103 +++++++++++++++++ ...profile_gemm_multiply_multiply_wp_impl.hpp | 29 ++++- 4 files changed, 260 insertions(+), 24 deletions(-) create mode 100644 profiler/include/profiler/common.hpp diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp index 3ee4955ae4..dc8082e256 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp @@ -39,7 +39,7 @@ using CShuffleDataType = F32; using D0DataType = F32; using D1DataType = F32; using DsDataType = ck::Tuple; -using EDataType = F16; +using EDataType = BF16; using A0Layout = Row; using B0Layout = Col; @@ -48,6 +48,96 @@ using D1Layout = Col; using DsLayout = ck::Tuple; using ELayout = Row; +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v && std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v && std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + struct MultiplyMultiply { template @@ -139,14 +229,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu // clang-format off < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, - 256, 256, 128, + 64, 128, 128, 16, 16, 16, 16, - 16, 4, + 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; + 2, 1, S<1, 32, 1, 8>, S<2, 2, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>; // clang-format on int main(int argc, char* argv[]) @@ -405,8 +495,11 @@ int main(int argc, char* argv[]) e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - return ck::utils::check_err( - e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) + return ck::utils::check_err(e_m_n_device_result, + e_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()) ? 0 : 1; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index 2e95ec0d52..aaec961cb9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -44,17 +44,21 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + const index_t num_k_per_block = + karg.K / (GridwiseGemm::KLane * GridwiseGemm::KPackPerGroup); + const index_t k_id = blockIdx.z * num_k_per_block; GridwiseGemm::template Run( karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_grid, karg.p_ds_grid, karg.p_c_grid, p_shared, karg, karg.a_element_op, karg.b_element_op, - karg.c_element_op); + karg.c_element_op, + k_id); } #else ignore = karg; @@ -80,10 +84,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + const index_t num_k_per_block = + karg.K / (GridwiseGemm::KLane * GridwiseGemm::KPackPerGroup); + const index_t k_id = blockIdx.z * num_k_per_block; GridwiseGemm::template Run_2Lds( karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_grid, karg.p_ds_grid, karg.p_c_grid, p_shared, @@ -91,7 +98,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) karg, karg.a_element_op, karg.b_element_op, - karg.c_element_op); + karg.c_element_op, + k_id); } #else ignore = karg; @@ -1163,7 +1171,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle const Problem& problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) + CElementwiseOperation c_element_op, + index_t k_id) { const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4}; Run( @@ -1176,7 +1185,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle a_element_op, b_element_op, c_element_op, - block_2_ctile_map); + block_2_ctile_map, + k_id); } template (b_grid_desc_bpreshuffled, make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, - 0, + k_id, KPackPerGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment @@ -1597,7 +1610,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle const Problem& problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) + CElementwiseOperation c_element_op, + index_t k_id) { const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4}; Run_2Lds( @@ -1611,7 +1625,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle a_element_op, b_element_op, c_element_op, - block_2_ctile_map); + block_2_ctile_map, + k_id); } template (b_grid_desc_bpreshuffled, make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, - 0, + k_id, KPackPerGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment diff --git a/profiler/include/profiler/common.hpp b/profiler/include/profiler/common.hpp new file mode 100644 index 0000000000..2f72e67c6b --- /dev/null +++ b/profiler/include/profiler/common.hpp @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include "ck/utility/data_type.hpp" + +namespace ck { +namespace profiler { + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v && std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v && std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_multiply_multiply_wp_impl.hpp b/profiler/include/profiler/profile_gemm_multiply_multiply_wp_impl.hpp index c76387e2b0..21613e49c6 100644 --- a/profiler/include/profiler/profile_gemm_multiply_multiply_wp_impl.hpp +++ b/profiler/include/profiler/profile_gemm_multiply_multiply_wp_impl.hpp @@ -20,6 +20,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "profiler/common.hpp" namespace ck { namespace profiler { @@ -112,6 +113,28 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification, Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + // Update strides based on tensor properties if they are <= 0 + auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t { + if(current_stride <= 0) + { + if constexpr(std::is_same_v) + { + return tensor.GetStrides()[0]; + } + else + { + return tensor.GetStrides()[1]; + } + } + return current_stride; + }; + + StrideA = get_stride(a_m_k, ALayout{}, StrideA); + StrideB = get_stride(b_k_n, BLayout{}, StrideB); + StrideD0 = get_stride(d0_m_n, D0Layout{}, StrideD0); + StrideD1 = get_stride(d1_m_n, D1Layout{}, StrideD1); + StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE); + int total_gemm_needed = a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes() + d0_m_n.GetElementSpaceSizeInBytes() + d1_m_n.GetElementSpaceSizeInBytes(); @@ -133,7 +156,7 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification, case 1: a_m_k.GenerateTensorValue(GeneratorTensor_2{-1, 2}); b_k_n.GenerateTensorValue(GeneratorTensor_2{-1, 2}); - d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-1, 1}); d1_m_n.GenerateTensorValue(GeneratorTensor_2{-1, 1}); break; default: @@ -282,8 +305,8 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification, is_same_v)) { std::string msg = "Error: Incorrect results!"; - double rtol = 1e-3; - double atol = 5e-2; + double rtol = get_rtol(); + double atol = get_atol(); pass = pass & ck::utils::check_err( e_m_n_device_result, e_m_n_host_result, msg, rtol, atol); } From 764c05e8ea183e899593d833b4204360bcf02d77 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Thu, 30 Oct 2025 11:26:00 +0000 Subject: [PATCH 02/14] Add tests for gemm_multiply_multiply_wp --- ..._multiply_multiply_xdl_fp8_bpreshuffle.cpp | 6 +- test/CMakeLists.txt | 1 + test/gemm_multiply_multiply_wp/CMakeLists.txt | 4 + .../test_gemm_common.hpp | 93 +++++++++++++++++++ ...test_gemm_multiply_multiply_wp_xdl_fp8.cpp | 82 ++++++++++++++++ 5 files changed, 183 insertions(+), 3 deletions(-) create mode 100644 test/gemm_multiply_multiply_wp/CMakeLists.txt create mode 100644 test/gemm_multiply_multiply_wp/test_gemm_common.hpp create mode 100644 test/gemm_multiply_multiply_wp/test_gemm_multiply_multiply_wp_xdl_fp8.cpp diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp index dc8082e256..b37711b1f9 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp @@ -385,10 +385,10 @@ int main(int argc, char* argv[]) d1_m_n.GenerateTensorValue(GeneratorTensor_1{}); break; default: - a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a0_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - d1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 0.5}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 0.5}); } DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 810ae8d231..d344ac1a70 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -247,6 +247,7 @@ add_subdirectory(gemm) add_subdirectory(gemm_add) add_subdirectory(gemm_layernorm) add_subdirectory(gemm_multi_abd) +add_subdirectory(gemm_multiply_multiply_wp) add_subdirectory(gemm_split_k) add_subdirectory(gemm_universal) add_subdirectory(gemm_b_scale) diff --git a/test/gemm_multiply_multiply_wp/CMakeLists.txt b/test/gemm_multiply_multiply_wp/CMakeLists.txt new file mode 100644 index 0000000000..4a479e321b --- /dev/null +++ b/test/gemm_multiply_multiply_wp/CMakeLists.txt @@ -0,0 +1,4 @@ +add_gtest_executable(test_gemm_multiply_multiply_wp_xdl_fp8 test_gemm_multiply_multiply_wp_xdl_fp8.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_multiply_multiply_wp_xdl_fp8 PRIVATE utility device_gemm_multiply_multiply_wp_instance) +endif() diff --git a/test/gemm_multiply_multiply_wp/test_gemm_common.hpp b/test/gemm_multiply_multiply_wp/test_gemm_common.hpp new file mode 100644 index 0000000000..37e2b353e6 --- /dev/null +++ b/test/gemm_multiply_multiply_wp/test_gemm_common.hpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_multiply_multiply_wp_impl.hpp" + +namespace ck { +namespace test { + +using Row = ck::tensor_layout::gemm::RowMajor; +using F32 = float; + +template +class TestGemmMultiplyMultiplyWPCommon : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using D0Layout = std::tuple_element_t<2, Tuple>; + using D1Layout = std::tuple_element_t<3, Tuple>; + using ELayout = Row; + using ADataType = std::tuple_element_t<4, Tuple>; + using BDataType = std::tuple_element_t<5, Tuple>; + using ComputeDataType = std::tuple_element_t<6, Tuple>; + using D0DataType = std::tuple_element_t<7, Tuple>; + using D1DataType = std::tuple_element_t<8, Tuple>; + using EDataType = std::tuple_element_t<9, Tuple>; + + public: + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; // decimal value initialization + static constexpr bool log_ = false; + static constexpr bool bench_ = false; // measure kernel performance + std::vector k_batches_; + + void SetUp() override { k_batches_ = {1, 2, 4}; } + + void Run(const int M, const int N, const int K) + { + for(size_t i = 0; i < k_batches_.size(); i++) + { + RunSingle(M, N, K, k_batches_[i]); + } + } + + void RunSingle( + const int M, const int N, const int K, int kbatch = 1, int n_warmup = 1, int n_iter = 10) + { + bool all_success = true; + + int StrideA = std::is_same_v, Row> ? K : M; + int StrideB = std::is_same_v, Row> ? N : K; + int StrideD0 = std::is_same_v, Row> ? N : M; + int StrideD1 = std::is_same_v, Row> ? N : M; + int StrideE = std::is_same_v ? N : M; + + all_success = + all_success & + ck::profiler::profile_gemm_multiply_multiply_weight_preshuffle_impl( + verify_, + init_method_, + log_, + bench_, + M, + N, + K, + StrideA, + StrideB, + StrideD0, + StrideD1, + StrideE, + kbatch, + n_warmup, + n_iter); + + EXPECT_TRUE(all_success); + } +}; + +} // namespace test +} // namespace ck diff --git a/test/gemm_multiply_multiply_wp/test_gemm_multiply_multiply_wp_xdl_fp8.cpp b/test/gemm_multiply_multiply_wp/test_gemm_multiply_multiply_wp_xdl_fp8.cpp new file mode 100644 index 0000000000..846747d548 --- /dev/null +++ b/test/gemm_multiply_multiply_wp/test_gemm_multiply_multiply_wp_xdl_fp8.cpp @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_common.hpp" + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmMultiplyMultiplyWP_FP8_MK_NK + : public ck::test::TestGemmMultiplyMultiplyWPCommon< + typename tuple_concat, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_NK = ::testing::Types< +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) + std::tuple< F8, F8, F8, F32, F32, F16>, + std::tuple< F8, F8, F8, F32, F32, BF16> +#endif + >; + +TYPED_TEST_SUITE(TestGemmMultiplyMultiplyWP_FP8_MK_NK, KernelTypes_MK_NK); + +TYPED_TEST(TestGemmMultiplyMultiplyWP_FP8_MK_NK, Regular0) +{ + std::vector Ms{128, 224, 256, 448, 512}; + constexpr int N = 512; + constexpr int K = 2048; + + for(int M : Ms) + this->Run(M, N, K); +} + +TYPED_TEST(TestGemmMultiplyMultiplyWP_FP8_MK_NK, Regular1) +{ + std::vector Ms{128, 224, 256, 448, 512}; + constexpr int N = 1024; + constexpr int K = 4096; + + for(int M : Ms) + this->Run(M, N, K); +} + +TYPED_TEST(TestGemmMultiplyMultiplyWP_FP8_MK_NK, Regular2) +{ + std::vector Ms{128, 256, 512}; + constexpr int N = 448; + constexpr int K = 2048; + + for(int M : Ms) + this->Run(M, N, K); +} + +// int main(int argc, char** argv) +// { +// testing::InitGoogleTest(&argc, argv); +// return RUN_ALL_TESTS(); +// } From 66e12d786d0bd2df6dbfdf752c66c42d365ec1f9 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Thu, 30 Oct 2025 16:10:47 +0000 Subject: [PATCH 03/14] Add tests for gemm_universal_preshuffle (KBatch = 1) --- ...vice_gemm_xdl_cshuffle_v3_b_preshuffle.hpp | 5 ++ ...wise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp | 5 ++ test/CMakeLists.txt | 1 + ...test_gemm_multiply_multiply_wp_xdl_fp8.cpp | 6 -- test/gemm_universal_preshuffle/CMakeLists.txt | 4 + .../test_gemm_common.hpp | 79 +++++++++++++++++++ ...test_gemm_universal_preshuffle_xdl_fp8.cpp | 76 ++++++++++++++++++ 7 files changed, 170 insertions(+), 6 deletions(-) create mode 100644 test/gemm_universal_preshuffle/CMakeLists.txt create mode 100644 test/gemm_universal_preshuffle/test_gemm_common.hpp create mode 100644 test/gemm_universal_preshuffle/test_gemm_universal_preshuffle_xdl_fp8.cpp diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index ebd168a7d0..ea4e6de6fd 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -425,6 +425,11 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle 0) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index 78546c4f99..49ab8d0ad6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -900,6 +900,11 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); + if constexpr(NXdlPerWave % CShuffleNXdlPerWavePerShuffle != 0) + { + return false; + } + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index d344ac1a70..3f98a3447e 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -250,6 +250,7 @@ add_subdirectory(gemm_multi_abd) add_subdirectory(gemm_multiply_multiply_wp) add_subdirectory(gemm_split_k) add_subdirectory(gemm_universal) +add_subdirectory(gemm_universal_preshuffle) add_subdirectory(gemm_b_scale) add_subdirectory(gemm_universal_streamk) add_subdirectory(gemm_reduce) diff --git a/test/gemm_multiply_multiply_wp/test_gemm_multiply_multiply_wp_xdl_fp8.cpp b/test/gemm_multiply_multiply_wp/test_gemm_multiply_multiply_wp_xdl_fp8.cpp index 846747d548..0dfbf685db 100644 --- a/test/gemm_multiply_multiply_wp/test_gemm_multiply_multiply_wp_xdl_fp8.cpp +++ b/test/gemm_multiply_multiply_wp/test_gemm_multiply_multiply_wp_xdl_fp8.cpp @@ -74,9 +74,3 @@ TYPED_TEST(TestGemmMultiplyMultiplyWP_FP8_MK_NK, Regular2) for(int M : Ms) this->Run(M, N, K); } - -// int main(int argc, char** argv) -// { -// testing::InitGoogleTest(&argc, argv); -// return RUN_ALL_TESTS(); -// } diff --git a/test/gemm_universal_preshuffle/CMakeLists.txt b/test/gemm_universal_preshuffle/CMakeLists.txt new file mode 100644 index 0000000000..45effc2f70 --- /dev/null +++ b/test/gemm_universal_preshuffle/CMakeLists.txt @@ -0,0 +1,4 @@ +add_gtest_executable(test_gemm_universal_preshuffle_xdl_fp8 test_gemm_universal_preshuffle_xdl_fp8.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_universal_preshuffle_xdl_fp8 PRIVATE utility device_gemm_universal_preshuffle_instance) +endif() diff --git a/test/gemm_universal_preshuffle/test_gemm_common.hpp b/test/gemm_universal_preshuffle/test_gemm_common.hpp new file mode 100644 index 0000000000..5f18205915 --- /dev/null +++ b/test/gemm_universal_preshuffle/test_gemm_common.hpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_universal_preshuffle_impl.hpp" + +namespace ck { +namespace test { + +using Row = ck::tensor_layout::gemm::RowMajor; +using F32 = float; + +template +class TestGemmUniversalPreshuffleCommon : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = Row; + using ADataType = std::tuple_element_t<2, Tuple>; + using BDataType = std::tuple_element_t<3, Tuple>; + using ComputeDataType = std::tuple_element_t<4, Tuple>; + using CDataType = std::tuple_element_t<5, Tuple>; + + public: + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; + static constexpr bool log_ = false; + static constexpr bool bench_ = false; + std::vector k_batches_; + + void SetUp() override { k_batches_ = {1}; } + + void Run(const int M, const int N, const int K) + { + for(size_t i = 0; i < k_batches_.size(); i++) + { + RunSingle(M, N, K, k_batches_[i]); + } + } + + void RunSingle( + const int M, const int N, const int K, int kbatch = 1, int n_warmup = 1, int n_iter = 10) + { + bool all_success = true; + + int StrideA = std::is_same_v ? K : M; + int StrideB = std::is_same_v ? N : K; + int StrideC = std::is_same_v ? N : M; + + all_success = all_success & + ck::profiler::profile_gemm_universal_preshuffle_impl(verify_, + init_method_, + log_, + bench_, + M, + N, + K, + StrideA, + StrideB, + StrideC, + kbatch, + n_warmup, + n_iter); + + EXPECT_TRUE(all_success); + } +}; + +} // namespace test +} // namespace ck diff --git a/test/gemm_universal_preshuffle/test_gemm_universal_preshuffle_xdl_fp8.cpp b/test/gemm_universal_preshuffle/test_gemm_universal_preshuffle_xdl_fp8.cpp new file mode 100644 index 0000000000..53465bfc66 --- /dev/null +++ b/test/gemm_universal_preshuffle/test_gemm_universal_preshuffle_xdl_fp8.cpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_common.hpp" + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmUniversalPreshuffle_FP8_MK_NK + : public ck::test::TestGemmUniversalPreshuffleCommon< + typename tuple_concat, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_NK = ::testing::Types< +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) + std::tuple< F8, F8, F8, F16>, + std::tuple< F8, F8, F8, BF16> +#endif + >; + +TYPED_TEST_SUITE(TestGemmUniversalPreshuffle_FP8_MK_NK, KernelTypes_MK_NK); + +TYPED_TEST(TestGemmUniversalPreshuffle_FP8_MK_NK, Regular0) +{ + std::vector Ms{128, 224, 256, 448, 512}; + constexpr int N = 512; + constexpr int K = 2048; + + for(int M : Ms) + this->Run(M, N, K); +} + +TYPED_TEST(TestGemmUniversalPreshuffle_FP8_MK_NK, Regular1) +{ + std::vector Ms{128, 224, 256, 448, 512}; + constexpr int N = 1024; + constexpr int K = 4096; + + for(int M : Ms) + this->Run(M, N, K); +} + +TYPED_TEST(TestGemmUniversalPreshuffle_FP8_MK_NK, Regular2) +{ + std::vector Ms{128, 256, 512}; + constexpr int N = 448; + constexpr int K = 2048; + + for(int M : Ms) + this->Run(M, N, K); +} From 81879421ff7c66dcad4af7a9839fecb66ec845dc Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Thu, 30 Oct 2025 20:39:23 +0000 Subject: [PATCH 04/14] Add tests gemm_blockscale_wp --- test/CMakeLists.txt | 1 + test/gemm_blockscale_wp/CMakeLists.txt | 4 + .../test_gemm_blockscale_wp_xdl_fp8.cpp | 63 +++++++++++++++ test/gemm_blockscale_wp/test_gemm_common.hpp | 77 +++++++++++++++++++ 4 files changed, 145 insertions(+) create mode 100644 test/gemm_blockscale_wp/CMakeLists.txt create mode 100644 test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp create mode 100644 test/gemm_blockscale_wp/test_gemm_common.hpp diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 3f98a3447e..d47e55db64 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -245,6 +245,7 @@ add_subdirectory(conv_util) add_subdirectory(reference_conv_fwd) add_subdirectory(gemm) add_subdirectory(gemm_add) +add_subdirectory(gemm_blockscale_wp) add_subdirectory(gemm_layernorm) add_subdirectory(gemm_multi_abd) add_subdirectory(gemm_multiply_multiply_wp) diff --git a/test/gemm_blockscale_wp/CMakeLists.txt b/test/gemm_blockscale_wp/CMakeLists.txt new file mode 100644 index 0000000000..791b5b4d8a --- /dev/null +++ b/test/gemm_blockscale_wp/CMakeLists.txt @@ -0,0 +1,4 @@ +add_gtest_executable(test_gemm_blockscale_wp_xdl_fp8 test_gemm_blockscale_wp_xdl_fp8.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_blockscale_wp_xdl_fp8 PRIVATE utility device_gemm_blockscale_wp_instance) +endif() diff --git a/test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp b/test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp new file mode 100644 index 0000000000..4986649ba5 --- /dev/null +++ b/test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_common.hpp" + +using F8 = ck::f8_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmBlockScaleWP_FP8_MK_NK : public ck::test::TestGemmBlockscaleWPCommon< + typename tuple_concat, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_NK = ::testing::Types< +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) + std::tuple< F8, F32, F8, F32, F8, BF16> +#endif + >; + +TYPED_TEST_SUITE(TestGemmBlockScaleWP_FP8_MK_NK, KernelTypes_MK_NK); + +TYPED_TEST(TestGemmBlockScaleWP_FP8_MK_NK, Regular0) +{ + std::vector Ms{128, 256, 512}; + constexpr int N = 512; + constexpr int K = 2048; + + for(int M : Ms) + this->Run(M, N, K); +} + +TYPED_TEST(TestGemmBlockScaleWP_FP8_MK_NK, Regular1) +{ + std::vector Ms{128, 256, 512}; + constexpr int N = 1024; + constexpr int K = 4096; + + for(int M : Ms) + this->Run(M, N, K); +} diff --git a/test/gemm_blockscale_wp/test_gemm_common.hpp b/test/gemm_blockscale_wp/test_gemm_common.hpp new file mode 100644 index 0000000000..daf5ed7d27 --- /dev/null +++ b/test/gemm_blockscale_wp/test_gemm_common.hpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_blockscale_wp_impl.hpp" + +namespace ck { +namespace test { + +using Row = ck::tensor_layout::gemm::RowMajor; +using F32 = float; + +template +class TestGemmBlockscaleWPCommon : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = Row; + using A0DataType = std::tuple_element_t<2, Tuple>; + using A1DataType = std::tuple_element_t<3, Tuple>; + using B0DataType = std::tuple_element_t<4, Tuple>; + using B1DataType = std::tuple_element_t<5, Tuple>; + using ComputeDataType = std::tuple_element_t<6, Tuple>; + using CDataType = std::tuple_element_t<7, Tuple>; + + public: + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; + static constexpr bool log_ = false; + static constexpr bool bench_ = false; + static constexpr index_t ScaleBlockM = 1; + static constexpr index_t ScaleBlockN = 128; + static constexpr index_t ScaleBlockK = 128; + + void Run(const int M, const int N, const int K, int n_warmup = 1, int n_iter = 10) + { + bool all_success = true; + + int StrideA = std::is_same_v ? K : M; + int StrideB = std::is_same_v ? N : K; + int StrideC = std::is_same_v ? N : M; + + all_success = + all_success & + ck::profiler::profile_gemm_blockscale_weighpreshuffle_impl(verify_, + init_method_, + log_, + bench_, + M, + N, + K, + StrideA, + StrideB, + StrideC, + n_warmup, + n_iter); + + EXPECT_TRUE(all_success); + } +}; + +} // namespace test +} // namespace ck From 00b87db8217113c2e010cbdb00544f538033b82a Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Thu, 30 Oct 2025 21:24:58 +0000 Subject: [PATCH 05/14] Fix splitk gemm universal preshuffle --- ...wise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp | 51 ++++++++++++------- .../test_gemm_common.hpp | 2 +- 2 files changed, 35 insertions(+), 18 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index 49ab8d0ad6..b48d9b620b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -42,12 +42,16 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + const index_t num_k_per_block = karg.K / (GridwiseGemm::KLane * GridwiseGemm::KPack); + const index_t k_id = blockIdx.z * num_k_per_block; + GridwiseGemm::template Run( karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_grid, karg.p_c_grid + splitk_batch_offset.c_reduce_offset, p_shared, - karg); + karg, + k_id); } #else ignore = karg; @@ -74,15 +78,18 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + const index_t num_k_per_block = karg.K / (GridwiseGemm::KLane * GridwiseGemm::KPack); + const index_t k_id = blockIdx.z * num_k_per_block; GridwiseGemm::template Run_2Lds( karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_grid, karg.p_c_grid + splitk_batch_offset.c_reduce_offset, p_shared_0, p_shared_1, - karg); + karg, + k_id); } #else ignore = karg; @@ -1139,7 +1146,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, const BGridDesc_BPreshuffled& b_grid_desc_bpreshuffled, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& - c_grid_desc_mblock_mperblock_nblock_nperblock) + c_grid_desc_mblock_mperblock_nblock_nperblock, + index_t k_id) { const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); @@ -1231,7 +1239,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle true>(b_grid_desc_bpreshuffled, make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, - 0, + k_id, KPack * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment @@ -1470,10 +1478,13 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle const BDataType* p_b_grid, CDataType* p_c_grid, void* p_shared, - const Problem& problem) + const Problem& problem, + index_t k_id) { - index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); - index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); + index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); + // recompute K without splitK for matrix B + const index_t Kt = problem.K + problem.KRead * (problem.KBatch - 1); + index_t BK0Shuffled = CalculateBK0Shuffled(Kt); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); const auto b_grid_desc_bpreshuffled = @@ -1496,7 +1507,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle problem, a_grid_desc_ak0_m_ak1, b_grid_desc_bpreshuffled, - c_grid_desc_mblock_mperblock_nblock_nperblock); + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id); } template ( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); @@ -1611,7 +1624,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle true>(b_grid_desc_bpreshuffled, make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, - 0, + k_id, KPack * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment @@ -1854,10 +1867,13 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle CDataType* p_c_grid, void* p_shared_0, void* p_shared_1, - const Problem& problem) + const Problem& problem, + index_t k_id) { - index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); - index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); + index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); + // recompute K without splitK for matrix B + const index_t Kt = problem.K + problem.KRead * (problem.KBatch - 1); + index_t BK0Shuffled = CalculateBK0Shuffled(Kt); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); const auto b_grid_desc_bpreshuffled = @@ -1882,7 +1898,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle problem, a_grid_desc_ak0_m_ak1, b_grid_desc_bpreshuffled, - c_grid_desc_mblock_mperblock_nblock_nperblock); + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id); } }; diff --git a/test/gemm_universal_preshuffle/test_gemm_common.hpp b/test/gemm_universal_preshuffle/test_gemm_common.hpp index 5f18205915..367c1a9c7e 100644 --- a/test/gemm_universal_preshuffle/test_gemm_common.hpp +++ b/test/gemm_universal_preshuffle/test_gemm_common.hpp @@ -30,7 +30,7 @@ class TestGemmUniversalPreshuffleCommon : public ::testing::Test static constexpr bool bench_ = false; std::vector k_batches_; - void SetUp() override { k_batches_ = {1}; } + void SetUp() override { k_batches_ = {1, 2, 4}; } void Run(const int M, const int N, const int K) { From 34df427fbdaf54fdae82bfdd7ae831ef1a9697f2 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Fri, 31 Oct 2025 08:09:05 +0000 Subject: [PATCH 06/14] Run new tests on arch supporting fp8 --- test/gemm_blockscale_wp/CMakeLists.txt | 8 +++++--- test/gemm_multiply_multiply_wp/CMakeLists.txt | 8 +++++--- test/gemm_universal_preshuffle/CMakeLists.txt | 8 +++++--- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/test/gemm_blockscale_wp/CMakeLists.txt b/test/gemm_blockscale_wp/CMakeLists.txt index 791b5b4d8a..d198db0870 100644 --- a/test/gemm_blockscale_wp/CMakeLists.txt +++ b/test/gemm_blockscale_wp/CMakeLists.txt @@ -1,4 +1,6 @@ -add_gtest_executable(test_gemm_blockscale_wp_xdl_fp8 test_gemm_blockscale_wp_xdl_fp8.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_blockscale_wp_xdl_fp8 PRIVATE utility device_gemm_blockscale_wp_instance) +if(GPU_TARGETS MATCHES "gfx9[45]|gfx12") + add_gtest_executable(test_gemm_blockscale_wp_xdl_fp8 test_gemm_blockscale_wp_xdl_fp8.cpp) + if(result EQUAL 0) + target_link_libraries(test_gemm_blockscale_wp_xdl_fp8 PRIVATE utility device_gemm_blockscale_wp_instance) + endif() endif() diff --git a/test/gemm_multiply_multiply_wp/CMakeLists.txt b/test/gemm_multiply_multiply_wp/CMakeLists.txt index 4a479e321b..4302084a6f 100644 --- a/test/gemm_multiply_multiply_wp/CMakeLists.txt +++ b/test/gemm_multiply_multiply_wp/CMakeLists.txt @@ -1,4 +1,6 @@ -add_gtest_executable(test_gemm_multiply_multiply_wp_xdl_fp8 test_gemm_multiply_multiply_wp_xdl_fp8.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_multiply_multiply_wp_xdl_fp8 PRIVATE utility device_gemm_multiply_multiply_wp_instance) +if(GPU_TARGETS MATCHES "gfx9[45]|gfx12") + add_gtest_executable(test_gemm_multiply_multiply_wp_xdl_fp8 test_gemm_multiply_multiply_wp_xdl_fp8.cpp) + if(result EQUAL 0) + target_link_libraries(test_gemm_multiply_multiply_wp_xdl_fp8 PRIVATE utility device_gemm_multiply_multiply_wp_instance) + endif() endif() diff --git a/test/gemm_universal_preshuffle/CMakeLists.txt b/test/gemm_universal_preshuffle/CMakeLists.txt index 45effc2f70..0d8955f6a4 100644 --- a/test/gemm_universal_preshuffle/CMakeLists.txt +++ b/test/gemm_universal_preshuffle/CMakeLists.txt @@ -1,4 +1,6 @@ -add_gtest_executable(test_gemm_universal_preshuffle_xdl_fp8 test_gemm_universal_preshuffle_xdl_fp8.cpp) -if(result EQUAL 0) - target_link_libraries(test_gemm_universal_preshuffle_xdl_fp8 PRIVATE utility device_gemm_universal_preshuffle_instance) +if(GPU_TARGETS MATCHES "gfx9[45]|gfx12") + add_gtest_executable(test_gemm_universal_preshuffle_xdl_fp8 test_gemm_universal_preshuffle_xdl_fp8.cpp) + if(result EQUAL 0) + target_link_libraries(test_gemm_universal_preshuffle_xdl_fp8 PRIVATE utility device_gemm_universal_preshuffle_instance) + endif() endif() From aca078d533ffb4ae5062143bbe4474e8ec6f6899 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Fri, 31 Oct 2025 08:31:12 +0000 Subject: [PATCH 07/14] Restore example --- ..._multiply_multiply_xdl_fp8_bpreshuffle.cpp | 113 ++---------------- 1 file changed, 10 insertions(+), 103 deletions(-) diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp index b37711b1f9..3ee4955ae4 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp @@ -39,7 +39,7 @@ using CShuffleDataType = F32; using D0DataType = F32; using D1DataType = F32; using DsDataType = ck::Tuple; -using EDataType = BF16; +using EDataType = F16; using A0Layout = Row; using B0Layout = Col; @@ -48,96 +48,6 @@ using D1Layout = Col; using DsLayout = ck::Tuple; using ELayout = Row; -template -inline __host__ __device__ constexpr double get_rtol() -{ - if constexpr(std::is_same_v && std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 1.5e-1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - -template -inline __host__ __device__ constexpr double get_atol() -{ - if constexpr(std::is_same_v && std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 16.1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 8192.1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - struct MultiplyMultiply { template @@ -229,14 +139,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu // clang-format off < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, - 64, 128, 128, + 256, 256, 128, 16, 16, 16, 16, - 4, 2, + 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - 2, 1, S<1, 32, 1, 8>, S<2, 2, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>; + 2, 1, S<1, 32, 1, 8>, S<8, 8, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // clang-format on int main(int argc, char* argv[]) @@ -385,10 +295,10 @@ int main(int argc, char* argv[]) d1_m_n.GenerateTensorValue(GeneratorTensor_1{}); break; default: - a0_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 0.5}); - d1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 0.5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + d1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); @@ -495,11 +405,8 @@ int main(int argc, char* argv[]) e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - return ck::utils::check_err(e_m_n_device_result, - e_m_n_host_result, - "Error: Incorrect results!", - get_rtol(), - get_atol()) + return ck::utils::check_err( + e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) ? 0 : 1; } From e54e5be3a35942acc90d5ea0087ed7a54ec6b68e Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Fri, 31 Oct 2025 09:27:15 +0000 Subject: [PATCH 08/14] Fix strides profiler --- .../profile_gemm_blockscale_wp_impl.hpp | 20 +++++++++++++++++++ ...profile_gemm_universal_preshuffle_impl.hpp | 20 +++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp b/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp index 0921b48842..23c93f7aee 100644 --- a/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp +++ b/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp @@ -126,6 +126,26 @@ bool profile_gemm_blockscale_weighpreshuffle_impl(int do_verification, Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + // Update strides based on tensor properties if they are <= 0 + auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t { + if(current_stride <= 0) + { + if constexpr(std::is_same_v) + { + return tensor.GetStrides()[0]; + } + else + { + return tensor.GetStrides()[1]; + } + } + return current_stride; + }; + + StrideA = get_stride(a0_m_k, ALayout{}, StrideA); + StrideB = get_stride(b0_k_n, BLayout{}, StrideB); + StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE); + int total_gemm_needed = a0_m_k.GetElementSpaceSizeInBytes() + b0_k_n.GetElementSpaceSizeInBytes() + a1_m_k.GetElementSpaceSizeInBytes() + b1_k_n.GetElementSpaceSizeInBytes(); diff --git a/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp b/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp index e537cf2770..e7f13580db 100644 --- a/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp @@ -99,6 +99,26 @@ bool profile_gemm_universal_preshuffle_impl(int do_verification, Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + // Update strides based on tensor properties if they are <= 0 + auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t { + if(current_stride <= 0) + { + if constexpr(std::is_same_v) + { + return tensor.GetStrides()[0]; + } + else + { + return tensor.GetStrides()[1]; + } + } + return current_stride; + }; + + StrideA = get_stride(a_m_k, ALayout{}, StrideA); + StrideB = get_stride(b_k_n, BLayout{}, StrideB); + StrideC = get_stride(c_m_n_host_result, CLayout{}, StrideC); + std::size_t total_gemm_needed = a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes(); int rotating_count = std::max( From e3e414df1ff2febb94ce205ac332fcff9355ebf4 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Fri, 31 Oct 2025 09:40:25 +0000 Subject: [PATCH 09/14] Fix tests --- test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp | 2 +- .../test_gemm_multiply_multiply_wp_xdl_fp8.cpp | 2 +- .../test_gemm_universal_preshuffle_xdl_fp8.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp b/test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp index 4986649ba5..0efe810feb 100644 --- a/test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp +++ b/test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp @@ -35,7 +35,7 @@ class TestGemmBlockScaleWP_FP8_MK_NK : public ck::test::TestGemmBlockscaleWPComm // clang-format off using KernelTypes_MK_NK = ::testing::Types< -#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) +#if defined(CK_ENABLE_FP8) std::tuple< F8, F32, F8, F32, F8, BF16> #endif >; diff --git a/test/gemm_multiply_multiply_wp/test_gemm_multiply_multiply_wp_xdl_fp8.cpp b/test/gemm_multiply_multiply_wp/test_gemm_multiply_multiply_wp_xdl_fp8.cpp index 0dfbf685db..1b6b5a0f15 100644 --- a/test/gemm_multiply_multiply_wp/test_gemm_multiply_multiply_wp_xdl_fp8.cpp +++ b/test/gemm_multiply_multiply_wp/test_gemm_multiply_multiply_wp_xdl_fp8.cpp @@ -37,7 +37,7 @@ class TestGemmMultiplyMultiplyWP_FP8_MK_NK // clang-format off using KernelTypes_MK_NK = ::testing::Types< -#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) +#if defined(CK_ENABLE_FP8) std::tuple< F8, F8, F8, F32, F32, F16>, std::tuple< F8, F8, F8, F32, F32, BF16> #endif diff --git a/test/gemm_universal_preshuffle/test_gemm_universal_preshuffle_xdl_fp8.cpp b/test/gemm_universal_preshuffle/test_gemm_universal_preshuffle_xdl_fp8.cpp index 53465bfc66..4f831ca33f 100644 --- a/test/gemm_universal_preshuffle/test_gemm_universal_preshuffle_xdl_fp8.cpp +++ b/test/gemm_universal_preshuffle/test_gemm_universal_preshuffle_xdl_fp8.cpp @@ -37,7 +37,7 @@ class TestGemmUniversalPreshuffle_FP8_MK_NK // clang-format off using KernelTypes_MK_NK = ::testing::Types< -#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) +#if defined(CK_ENABLE_FP8) std::tuple< F8, F8, F8, F16>, std::tuple< F8, F8, F8, BF16> #endif From 348a407d66e33902afcb0147c0947773c200667d Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Fri, 31 Oct 2025 09:42:08 +0000 Subject: [PATCH 10/14] Fix clang format --- .../include/profiler/profile_gemm_blockscale_wp_impl.hpp | 6 +++--- .../profiler/profile_gemm_universal_preshuffle_impl.hpp | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp b/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp index 23c93f7aee..7863ec9f1a 100644 --- a/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp +++ b/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp @@ -142,9 +142,9 @@ bool profile_gemm_blockscale_weighpreshuffle_impl(int do_verification, return current_stride; }; - StrideA = get_stride(a0_m_k, ALayout{}, StrideA); - StrideB = get_stride(b0_k_n, BLayout{}, StrideB); - StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE); + StrideA = get_stride(a0_m_k, ALayout{}, StrideA); + StrideB = get_stride(b0_k_n, BLayout{}, StrideB); + StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE); int total_gemm_needed = a0_m_k.GetElementSpaceSizeInBytes() + b0_k_n.GetElementSpaceSizeInBytes() + diff --git a/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp b/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp index e7f13580db..da3d6a1ef3 100644 --- a/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp @@ -115,9 +115,9 @@ bool profile_gemm_universal_preshuffle_impl(int do_verification, return current_stride; }; - StrideA = get_stride(a_m_k, ALayout{}, StrideA); - StrideB = get_stride(b_k_n, BLayout{}, StrideB); - StrideC = get_stride(c_m_n_host_result, CLayout{}, StrideC); + StrideA = get_stride(a_m_k, ALayout{}, StrideA); + StrideB = get_stride(b_k_n, BLayout{}, StrideB); + StrideC = get_stride(c_m_n_host_result, CLayout{}, StrideC); std::size_t total_gemm_needed = a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes(); From 8eb93a4acfc5ae011b3b6782ceae9ec422c81fd6 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Fri, 31 Oct 2025 10:47:22 +0000 Subject: [PATCH 11/14] Finalize profiler preshuffle with tolerances --- ...profile_gemm_universal_preshuffle_impl.hpp | 5 +- ...ile_grouped_conv_fwd_outelementop_impl.hpp | 83 +------------------ 2 files changed, 4 insertions(+), 84 deletions(-) diff --git a/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp b/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp index da3d6a1ef3..5ec056efd1 100644 --- a/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp @@ -20,6 +20,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "profiler/common.hpp" namespace ck { namespace profiler { @@ -337,8 +338,8 @@ bool profile_gemm_universal_preshuffle_impl(int do_verification, is_same_v) { std::string msg = "Error: Incorrect results!"; - double rtol = 1e-1; - double atol = 1e-1; + double rtol = get_rtol(); + double atol = get_atol(); pass = pass & ck::utils::check_err( c_m_n_device_result, c_m_n_host_result, msg, rtol, atol); } diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp index b553e07735..ae12070014 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp @@ -5,92 +5,11 @@ #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor_generator.hpp" +#include "profiler/common.hpp" namespace ck { namespace profiler { -template -inline constexpr double get_rtol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 1.5e-1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - -template -inline constexpr double get_atol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 16.1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 8192.1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - template Date: Fri, 31 Oct 2025 10:49:01 +0000 Subject: [PATCH 12/14] Minor improvements to splitk related changes --- ...wise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp | 37 ++++++++------- ...m_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp | 45 ++++++++++++------- 2 files changed, 50 insertions(+), 32 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index b48d9b620b..b2def215c0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -40,9 +40,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + // Full K needed for matrix B + const index_t Kt = karg.K; + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); - const index_t num_k_per_block = karg.K / (GridwiseGemm::KLane * GridwiseGemm::KPack); + const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K); const index_t k_id = blockIdx.z * num_k_per_block; GridwiseGemm::template Run( @@ -51,7 +54,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) karg.p_c_grid + splitk_batch_offset.c_reduce_offset, p_shared, karg, - k_id); + k_id, + Kt); } #else ignore = karg; @@ -78,8 +82,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); - const index_t num_k_per_block = karg.K / (GridwiseGemm::KLane * GridwiseGemm::KPack); + // Full K needed for matrix B + const index_t Kt = karg.K; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + + const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K); const index_t k_id = blockIdx.z * num_k_per_block; GridwiseGemm::template Run_2Lds( @@ -89,7 +97,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) p_shared_0, p_shared_1, karg, - k_id); + k_id, + Kt); } #else ignore = karg; @@ -1147,7 +1156,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle const BGridDesc_BPreshuffled& b_grid_desc_bpreshuffled, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& c_grid_desc_mblock_mperblock_nblock_nperblock, - index_t k_id) + const index_t k_id) { const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); @@ -1479,11 +1488,10 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle CDataType* p_c_grid, void* p_shared, const Problem& problem, - index_t k_id) + const index_t k_id, + const index_t Kt) { - index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); - // recompute K without splitK for matrix B - const index_t Kt = problem.K + problem.KRead * (problem.KBatch - 1); + index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); index_t BK0Shuffled = CalculateBK0Shuffled(Kt); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); @@ -1527,7 +1535,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle const BGridDesc_BPreshuffled& b_grid_desc_bpreshuffled, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& c_grid_desc_mblock_mperblock_nblock_nperblock, - index_t k_id) + const index_t k_id) { const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); @@ -1868,11 +1876,10 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle void* p_shared_0, void* p_shared_1, const Problem& problem, - index_t k_id) + const index_t k_id, + const index_t Kt) { - index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); - // recompute K without splitK for matrix B - const index_t Kt = problem.K + problem.KRead * (problem.KBatch - 1); + index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); index_t BK0Shuffled = CalculateBK0Shuffled(Kt); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index aaec961cb9..c3177f46fd 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -43,10 +43,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + // Full K needed for matrix B + const index_t Kt = karg.K; + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - const index_t num_k_per_block = - karg.K / (GridwiseGemm::KLane * GridwiseGemm::KPackPerGroup); - const index_t k_id = blockIdx.z * num_k_per_block; + + const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K); + const index_t k_id = blockIdx.z * num_k_per_block; GridwiseGemm::template Run( karg.p_a_grid + splitk_batch_offset.a_k_split_offset, @@ -58,7 +61,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) karg.a_element_op, karg.b_element_op, karg.c_element_op, - k_id); + k_id, + Kt); } #else ignore = karg; @@ -83,10 +87,13 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + // Full K needed for matrix B + const index_t Kt = karg.K; + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - const index_t num_k_per_block = - karg.K / (GridwiseGemm::KLane * GridwiseGemm::KPackPerGroup); - const index_t k_id = blockIdx.z * num_k_per_block; + + const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K); + const index_t k_id = blockIdx.z * num_k_per_block; GridwiseGemm::template Run_2Lds( karg.p_a_grid + splitk_batch_offset.a_k_split_offset, @@ -99,7 +106,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) karg.a_element_op, karg.b_element_op, karg.c_element_op, - k_id); + k_id, + Kt); } #else ignore = karg; @@ -1172,7 +1180,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, - index_t k_id) + const index_t k_id, + const index_t Kt) { const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4}; Run( @@ -1186,7 +1195,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle b_element_op, c_element_op, block_2_ctile_map, - k_id); + k_id, + Kt); } template ( @@ -1626,7 +1636,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle b_element_op, c_element_op, block_2_ctile_map, - k_id); + k_id, + Kt); } template Date: Mon, 3 Nov 2025 10:12:43 +0000 Subject: [PATCH 13/14] Address review comments: clang format and ckProfiler typo --- .../profile_gemm_blockscale_wp_impl.hpp | 26 +++++----- profiler/src/profile_gemm_blockscale_wp.cpp | 26 +++++----- .../test_gemm_blockscale_wp_xdl_fp8.cpp | 1 + test/gemm_blockscale_wp/test_gemm_common.hpp | 48 +++++++++---------- ...test_gemm_multiply_multiply_wp_xdl_fp8.cpp | 1 + ...test_gemm_universal_preshuffle_xdl_fp8.cpp | 1 + 6 files changed, 53 insertions(+), 50 deletions(-) diff --git a/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp b/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp index 7863ec9f1a..da0dc60760 100644 --- a/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp +++ b/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp @@ -69,19 +69,19 @@ template -bool profile_gemm_blockscale_weighpreshuffle_impl(int do_verification, - int init_method, - bool do_log, - bool time_kernel, - int M, - int N, - int K, - int StrideA, - int StrideB, - int StrideE, - int n_warmup, - int n_iter, - uint64_t rotating = 0) +bool profile_gemm_blockscale_weightpreshuffle_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideE, + int n_warmup, + int n_iter, + uint64_t rotating = 0) { bool pass = true; diff --git a/profiler/src/profile_gemm_blockscale_wp.cpp b/profiler/src/profile_gemm_blockscale_wp.cpp index e6a2fbb8f6..d5f66c0b65 100644 --- a/profiler/src/profile_gemm_blockscale_wp.cpp +++ b/profiler/src/profile_gemm_blockscale_wp.cpp @@ -126,19 +126,19 @@ int profile_gemm_blockscale_weighpreshuffle(int argc, char* argv[]) const int DefaultStrideB = ck::is_same_v ? N : K; const int DefaultStrideE = ck::is_same_v ? N : M; - bool pass = ck::profiler::profile_gemm_blockscale_weighpreshuffle_impl( + bool pass = ck::profiler::profile_gemm_blockscale_weightpreshuffle_impl( do_verification, init_method, do_log, diff --git a/test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp b/test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp index 0efe810feb..5d88e04690 100644 --- a/test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp +++ b/test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp @@ -39,6 +39,7 @@ using KernelTypes_MK_NK = ::testing::Types< std::tuple< F8, F32, F8, F32, F8, BF16> #endif >; +// clang-format on TYPED_TEST_SUITE(TestGemmBlockScaleWP_FP8_MK_NK, KernelTypes_MK_NK); diff --git a/test/gemm_blockscale_wp/test_gemm_common.hpp b/test/gemm_blockscale_wp/test_gemm_common.hpp index daf5ed7d27..25ed67a737 100644 --- a/test/gemm_blockscale_wp/test_gemm_common.hpp +++ b/test/gemm_blockscale_wp/test_gemm_common.hpp @@ -44,30 +44,30 @@ class TestGemmBlockscaleWPCommon : public ::testing::Test all_success = all_success & - ck::profiler::profile_gemm_blockscale_weighpreshuffle_impl(verify_, - init_method_, - log_, - bench_, - M, - N, - K, - StrideA, - StrideB, - StrideC, - n_warmup, - n_iter); + ck::profiler::profile_gemm_blockscale_weightpreshuffle_impl(verify_, + init_method_, + log_, + bench_, + M, + N, + K, + StrideA, + StrideB, + StrideC, + n_warmup, + n_iter); EXPECT_TRUE(all_success); } diff --git a/test/gemm_multiply_multiply_wp/test_gemm_multiply_multiply_wp_xdl_fp8.cpp b/test/gemm_multiply_multiply_wp/test_gemm_multiply_multiply_wp_xdl_fp8.cpp index 1b6b5a0f15..bf9b909628 100644 --- a/test/gemm_multiply_multiply_wp/test_gemm_multiply_multiply_wp_xdl_fp8.cpp +++ b/test/gemm_multiply_multiply_wp/test_gemm_multiply_multiply_wp_xdl_fp8.cpp @@ -42,6 +42,7 @@ using KernelTypes_MK_NK = ::testing::Types< std::tuple< F8, F8, F8, F32, F32, BF16> #endif >; +// clang-format on TYPED_TEST_SUITE(TestGemmMultiplyMultiplyWP_FP8_MK_NK, KernelTypes_MK_NK); diff --git a/test/gemm_universal_preshuffle/test_gemm_universal_preshuffle_xdl_fp8.cpp b/test/gemm_universal_preshuffle/test_gemm_universal_preshuffle_xdl_fp8.cpp index 4f831ca33f..06dca026ee 100644 --- a/test/gemm_universal_preshuffle/test_gemm_universal_preshuffle_xdl_fp8.cpp +++ b/test/gemm_universal_preshuffle/test_gemm_universal_preshuffle_xdl_fp8.cpp @@ -42,6 +42,7 @@ using KernelTypes_MK_NK = ::testing::Types< std::tuple< F8, F8, F8, BF16> #endif >; +// clang-format on TYPED_TEST_SUITE(TestGemmUniversalPreshuffle_FP8_MK_NK, KernelTypes_MK_NK); From 3c5465a89d717f5a8c7dde6b9aa253f38f97943f Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Mon, 3 Nov 2025 15:48:43 +0000 Subject: [PATCH 14/14] Remove b_k_split_offset from SplitKBatchOffset struct --- ...wise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp | 20 ------------------- ...m_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp | 11 ---------- 2 files changed, 31 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index b2def215c0..6ce2f63e3a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -674,25 +674,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA; } - if constexpr(is_same_v) - { - b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB; - } - else if constexpr(is_same_v) - { - if constexpr(!PermuteB) - { - // b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize; - - b_k_split_offset = blockIdx.z * karg.KRead * NLane / BPackedSize; - } - else - { - const int k0_offset = karg.KRead * karg.N; - b_k_split_offset = blockIdx.z * k0_offset / BPackedSize; - } - } - if(blockIdx.z < static_cast(karg.KBatch - 1)) { karg.K = karg.KRead; @@ -713,7 +694,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle } index_t a_k_split_offset; - index_t b_k_split_offset; index_t c_reduce_offset; }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index c3177f46fd..f2f1530599 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -707,16 +707,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle a_k_split_offset = k_id * karg.KRead * karg.StrideA; } - if constexpr(is_same_v) - { - b_k_split_offset = k_id * karg.KRead * karg.StrideB; - } - else if constexpr(is_same_v) - { - // KPack * NLane * KLane * K0 * N0 - b_k_split_offset = k_id * karg.KRead * NLane; - } - if(k_id < karg.KBatch - 1) { karg.K = karg.KRead; @@ -728,7 +718,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle } index_t a_k_split_offset; - index_t b_k_split_offset; }; __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()