diff --git a/requirements.txt b/requirements.txt
index 5d55195d544..961b19bdc71 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -27,5 +27,5 @@ ROCm/half@rocm-5.6.0
pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On
-ROCm/composable_kernel@b7775add2d28251674d81e220cd4a857b90b997a -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
-ROCm/rocMLIR@221ad7adb549ea2a39c8e15d07a7ada6e89cef88 -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off
+ROCm/composable_kernel@fbce6c7bb6dad3750e33e999d438197cdc5c7fe8 --cmake subdir -DCMAKE_DIR=codegen -DCMAKE_POSITION_INDEPENDENT_CODE=On
+ROCm/rocMLIR@221ad7adb549ea2a39c8e15d07a7ada6e89cef88 -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off
\ No newline at end of file
diff --git a/src/targets/gpu/CMakeLists.txt b/src/targets/gpu/CMakeLists.txt
index 497ac403cec..83e8bff2318 100644
--- a/src/targets/gpu/CMakeLists.txt
+++ b/src/targets/gpu/CMakeLists.txt
@@ -57,7 +57,7 @@ else()
endif()
if(MIGRAPHX_USE_COMPOSABLEKERNEL)
- find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS jit_library)
+ find_package(composable_kernel_host 1.0.0 REQUIRED)
endif()
if(BUILD_DEV)
@@ -121,7 +121,7 @@ target_compile_definitions(kernel_file_check PRIVATE -DMIGRAPHX_WAVEFRONTSIZE=64
target_include_directories(kernel_file_check PRIVATE $)
target_link_libraries(kernel_file_check compile_for_gpu)
if(MIGRAPHX_USE_COMPOSABLEKERNEL)
- target_link_libraries(kernel_file_check composable_kernel::jit_library)
+ target_link_libraries(kernel_file_check composable_kernel::ck_host)
endif()
rocm_clang_tidy_check(kernel_file_check)
@@ -389,7 +389,7 @@ else()
endif()
target_link_libraries(migraphx_gpu PRIVATE migraphx_kernels)
if(MIGRAPHX_USE_COMPOSABLEKERNEL)
- target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library)
+ target_link_libraries(migraphx_gpu PRIVATE composable_kernel::ck_host)
target_compile_definitions(migraphx_gpu PRIVATE MIGRAPHX_USE_COMPOSABLEKERNEL=1)
endif()
diff --git a/src/targets/gpu/fuse_ck.cpp b/src/targets/gpu/fuse_ck.cpp
index bf9a269f3e1..fa15250f2a6 100644
--- a/src/targets/gpu/fuse_ck.cpp
+++ b/src/targets/gpu/fuse_ck.cpp
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
- * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
+ * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
@@ -90,11 +90,11 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return false;
if(not ck_gemm::is_ck_supported_type(ins->get_shape().type()))
return false;
- auto a = ins->inputs().front()->get_shape();
- auto b = ins->inputs().back()->get_shape();
- auto m = a.lens()[a.lens().size() - 2];
- auto n = b.lens().back();
- auto k = a.lens().back();
+ auto a = ins->inputs().front()->get_shape();
+ auto b = ins->inputs().back()->get_shape();
+ auto m = a.lens()[a.lens().size() - 2];
+ auto n = b.lens().back();
+ auto k = a.lens().back();
auto batch_size = std::accumulate(
a.lens().rbegin() + 2, a.lens().rend(), std::size_t{1}, std::multiplies());
// Integer gemms must be divisible by 4 in ck
@@ -118,7 +118,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
}
return true;
}
- return k <= 2048;
+ return k <= 1024;
}
struct find_ck_gemm_pointwise
@@ -207,7 +207,8 @@ struct find_ck_gemm_softmax_gemm
void fuse_ck::apply(module_pass_manager& mpm) const
{
- match::find_matches(mpm, find_ck_gemm_softmax_gemm{}, find_ck_gemm_pointwise{});
+ match::find_matches(mpm, find_ck_gemm_softmax_gemm{});
+ match::find_matches(mpm, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm{});
}
diff --git a/src/targets/gpu/include/migraphx/gpu/ck.hpp b/src/targets/gpu/include/migraphx/gpu/ck.hpp
index 18d4dce25a2..55a08793417 100644
--- a/src/targets/gpu/include/migraphx/gpu/ck.hpp
+++ b/src/targets/gpu/include/migraphx/gpu/ck.hpp
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
- * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
+ * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
@@ -30,8 +30,9 @@
#include
#include
-#include "ck/host/device_gemm_multiple_d.hpp"
-#include "ck/host/device_batched_gemm_softmax_gemm.hpp"
+#include "ck/host/headers.hpp"
+#include "ck/host/device_gemm_multiple_d/problem.hpp"
+#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
diff --git a/src/targets/gpu/jit/ck_gemm.cpp b/src/targets/gpu/jit/ck_gemm.cpp
index 392eaa0c67b..43b51e03ddd 100644
--- a/src/targets/gpu/jit/ck_gemm.cpp
+++ b/src/targets/gpu/jit/ck_gemm.cpp
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
- * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
+ * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
@@ -37,6 +37,7 @@
#include
#include
#include
+#include
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
@@ -136,17 +137,21 @@ struct ck_gemm_compiler : compiler
{
const auto& c_shape = inputs.back();
auto tuning_value = v.get("tuning_value", 34);
- auto batch_count = get_batch_count(c_shape);
- auto problem = create_problem(inputs, v);
+ auto batch_count = get_batch_count(c_shape);
+ auto problem = create_problem(inputs, v);
const auto include_header = problem.GetIncludeHeader();
const auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name());
const auto& solution = solutions.at(tuning_value);
- const auto template_str = solution.template_str;
- const auto blocks_per_batch = solution.grid_size;
- const auto block_size = solution.block_size;
+ const auto template_str = solution.ToTemplateString();
+ const auto block_size = solution.GetTemplateParameter("BlockSize");
+ const auto m_per_block = solution.GetTemplateParameter("MPerBlock");
+ const auto n_per_block = solution.GetTemplateParameter("NPerBlock");
+ const auto blocks_per_batch = ck::host::integer_divide_ceil(problem.M, m_per_block) *
+ ck::host::integer_divide_ceil(problem.N, n_per_block);
hip_compile_options options;
+ options.emplace_param("-DCK_CODE_GEN_RTC");
options.additional_src_files = ck_headers();
auto grid_size = can_fold_batch(inputs) ? blocks_per_batch : batch_count * blocks_per_batch;
options.set_launch_params(v, grid_size * block_size, block_size);
diff --git a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
index 693153d0982..97a8edc4bbb 100644
--- a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
+++ b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
- * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
+ * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
@@ -37,6 +37,7 @@
#include
#include
#include
+#include
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
@@ -137,11 +138,15 @@ struct ck_gemm_softmax_gemm_compiler : compiler
const auto include_header = problem.GetIncludeHeader();
const auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name());
const auto& solution = solutions.at(tuning_value);
- const auto template_str = solution.template_str;
- const auto blocks_per_batch = solution.grid_size;
- const auto block_size = solution.block_size;
+ const auto template_str = solution.ToTemplateString();
+ const auto block_size = solution.GetTemplateParameter("BlockSize");
+ const auto m_per_block = solution.GetTemplateParameter("Gemm01MPerBlock");
+ const auto n1_per_block = solution.GetTemplateParameter("Gemm1NPerBlock");
+ const auto blocks_per_batch = ck::host::integer_divide_ceil(problem.M, m_per_block) *
+ ck::host::integer_divide_ceil(problem.O, n1_per_block);
hip_compile_options options;
+ options.emplace_param("-DCK_CODE_GEN_RTC");
options.additional_src_files = ck_headers();
auto grid_size = can_fold_batch(inputs) ? blocks_per_batch : batch_count * blocks_per_batch;
options.set_launch_params(v, grid_size * block_size, block_size);