From ee0cb4f366a785c33cfaa384fc6eb925696f951a Mon Sep 17 00:00:00 2001
From: HTEC <>
Date: Mon, 2 Sep 2024 12:53:57 +0000
Subject: [PATCH 01/14] Integrate new CK API
---
requirements.txt | 2 +-
src/targets/gpu/CMakeLists.txt | 6 +++---
src/targets/gpu/compile_hip_code_object.cpp | 2 +-
src/targets/gpu/include/migraphx/gpu/ck.hpp | 5 +++--
src/targets/gpu/jit/ck_gemm.cpp | 21 ++++++++++++--------
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp | 21 +++++++++++++-------
6 files changed, 35 insertions(+), 22 deletions(-)
diff --git a/requirements.txt b/requirements.txt
index 21cd16b50aa..9623eee8d99 100755
--- a/requirements.txt
+++ b/requirements.txt
@@ -27,5 +27,5 @@ ROCm/half@rocm-5.6.0
pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On
-ROCm/composable_kernel@57cdd70b7cb14e5e3b60cd9a5f96ba8dc343763e -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
+mirza-halilcevic/composable_kernel@2cd9b2b1ec1ade098a0f43c44fa5a0989e0b0a9d -DCK_BUILD_HOST_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCm/rocMLIR@3f153a79f86ef4dba62d0afe0d0e95e29518dc1e -DBUILD_FAT_LIBROCKCOMPILER=On
diff --git a/src/targets/gpu/CMakeLists.txt b/src/targets/gpu/CMakeLists.txt
index 3e2201edcab..d099a366aec 100644
--- a/src/targets/gpu/CMakeLists.txt
+++ b/src/targets/gpu/CMakeLists.txt
@@ -47,7 +47,7 @@ else()
endif()
if(MIGRAPHX_USE_COMPOSABLEKERNEL)
- find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS jit_library)
+ find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS ck_host)
endif()
if(BUILD_DEV)
@@ -111,7 +111,7 @@ target_compile_definitions(kernel_file_check PRIVATE -DMIGRAPHX_NLOCAL=256)
target_include_directories(kernel_file_check PRIVATE $)
target_link_libraries(kernel_file_check compile_for_gpu)
if(MIGRAPHX_USE_COMPOSABLEKERNEL)
- target_link_libraries(kernel_file_check composable_kernel::jit_library)
+ target_link_libraries(kernel_file_check composable_kernel::ck_host)
endif()
rocm_clang_tidy_check(kernel_file_check)
@@ -362,7 +362,7 @@ else()
endif()
target_link_libraries(migraphx_gpu PRIVATE migraphx_kernels)
if(MIGRAPHX_USE_COMPOSABLEKERNEL)
- target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library)
+ target_link_libraries(migraphx_gpu PRIVATE composable_kernel::ck_host)
target_compile_definitions(migraphx_gpu PRIVATE MIGRAPHX_USE_COMPOSABLEKERNEL=1)
endif()
diff --git a/src/targets/gpu/compile_hip_code_object.cpp b/src/targets/gpu/compile_hip_code_object.cpp
index 2e3142bfa1f..d7ae4c0f3b0 100644
--- a/src/targets/gpu/compile_hip_code_object.cpp
+++ b/src/targets/gpu/compile_hip_code_object.cpp
@@ -200,7 +200,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
const auto& warnings = compiler_warnings();
options.params.insert(options.params.end(), warnings.begin(), warnings.end());
options.emplace_param("-ftemplate-backtrace-limit=0");
- options.emplace_param("-Werror");
+ // options.emplace_param("-Werror");
auto cos = compile_hip_src(srcs, options.params, get_device_name());
if(cos.size() != 1)
MIGRAPHX_THROW("No code object");
diff --git a/src/targets/gpu/include/migraphx/gpu/ck.hpp b/src/targets/gpu/include/migraphx/gpu/ck.hpp
index 18d4dce25a2..ea41b252547 100644
--- a/src/targets/gpu/include/migraphx/gpu/ck.hpp
+++ b/src/targets/gpu/include/migraphx/gpu/ck.hpp
@@ -30,8 +30,9 @@
#include
#include
-#include "ck/host/device_gemm_multiple_d.hpp"
-#include "ck/host/device_batched_gemm_softmax_gemm.hpp"
+#include "ck/host/headers.hpp"
+#include "ck/host/device_gemm_multiple_d/problem.hpp"
+#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
diff --git a/src/targets/gpu/jit/ck_gemm.cpp b/src/targets/gpu/jit/ck_gemm.cpp
index 7cd20b3931d..5d0b27953a6 100644
--- a/src/targets/gpu/jit/ck_gemm.cpp
+++ b/src/targets/gpu/jit/ck_gemm.cpp
@@ -37,6 +37,7 @@
#include
#include
#include
+#include
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
@@ -136,15 +137,19 @@ struct ck_gemm_compiler : compiler
{
const auto& c_shape = inputs.back();
auto tuning_value = v.get("tuning_value", 34);
- auto batch_count = get_batch_count(c_shape);
- auto problem = create_problem(inputs, v);
+ auto batch_count = get_batch_count(c_shape);
+ auto problem = create_problem(inputs, v);
- const auto include_header = problem.GetIncludeHeader();
- const auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name());
+ const auto include_header = problem.GetIncludeHeader();
+ const auto solutions =
+ problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", "");
const auto& solution = solutions.at(tuning_value);
- const auto template_str = solution.template_str;
- const auto blocks_per_batch = solution.grid_size;
- const auto block_size = solution.block_size;
+ const auto template_str = solution.ToTemplateString();
+ const auto block_size = solution.GetTemplateParameter("BlockSize");
+ const auto m_per_block = solution.GetTemplateParameter("MPerBlock");
+ const auto n_per_block = solution.GetTemplateParameter("NPerBlock");
+ const auto blocks_per_batch = ck::host::integer_divide_ceil(problem.M, m_per_block) *
+ ck::host::integer_divide_ceil(problem.N, n_per_block);
hip_compile_options options;
options.additional_src_files = ck_headers();
@@ -221,7 +226,7 @@ struct ck_gemm_compiler : compiler
tuning_config tc;
auto shapes = to_shapes(ins->inputs());
auto problem = create_problem(shapes, create_settings(ins, op));
- auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name());
+ auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", "");
tc.solutions.resize(solutions.size());
std::iota(tc.solutions.begin(), tc.solutions.end(), 0);
std::vector gemm_shapes{shapes[0], shapes[1], shapes.back()};
diff --git a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
index 5fe60372b94..c5da7c4c4b3 100644
--- a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
+++ b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
@@ -37,6 +37,7 @@
#include
#include
#include
+#include
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
@@ -129,20 +130,26 @@ struct ck_gemm_softmax_gemm_compiler : compiler
operation compile_op(context& ctx, const std::vector& inputs, const value& v) const
{
+
const auto& c_shape = inputs.back();
auto tuning_value = v.get("tuning_value", 5);
auto batch_count = get_batch_count(c_shape);
auto problem = create_problem(inputs, v);
- const auto include_header = problem.GetIncludeHeader();
- const auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name());
- const auto& solution = solutions.at(tuning_value);
- const auto template_str = solution.template_str;
- const auto blocks_per_batch = solution.grid_size;
- const auto block_size = solution.block_size;
+ const auto include_header = problem.GetIncludeHeader();
+ const auto solutions =
+ problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", "");
+ const auto& solution = solutions.at(tuning_value);
+ const auto template_str = solution.ToTemplateString();
+ const auto block_size = solution.GetTemplateParameter("BlockSize");
+ const auto m_per_block = solution.GetTemplateParameter("Gemm01MPerBlock");
+ const auto n1_per_block = solution.GetTemplateParameter("Gemm1NPerBlock");
+ const auto blocks_per_batch = ck::host::integer_divide_ceil(problem.M, m_per_block) *
+ ck::host::integer_divide_ceil(problem.O, n1_per_block);
hip_compile_options options;
options.additional_src_files = ck_headers();
+
auto grid_size = can_fold_batch(inputs) ? blocks_per_batch : batch_count * blocks_per_batch;
options.set_launch_params(v, grid_size * block_size, block_size);
options.inputs = inputs;
@@ -222,7 +229,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler
tuning_config tc;
auto shapes = to_shapes(ins->inputs());
auto problem = create_problem(shapes, create_settings(ins, op));
- auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name());
+ auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", "");
tc.solutions.resize(solutions.size());
std::iota(tc.solutions.begin(), tc.solutions.end(), 0);
std::vector gemm_shapes{shapes[0], shapes[1], shapes.back()};
From 448826917d79bf55d0492fde2d98235fd090438b Mon Sep 17 00:00:00 2001
From: Mirza Halilcevic
Date: Mon, 21 Oct 2024 19:47:00 +0000
Subject: [PATCH 02/14] Matcher for gemm_gemm.
---
CMakeLists.txt | 2 +-
requirements.txt | 2 +-
src/pass_manager.cpp | 3 +
src/targets/gpu/CMakeLists.txt | 3 +-
src/targets/gpu/fuse_ck.cpp | 66 +++++
src/targets/gpu/include/migraphx/gpu/ck.hpp | 1 +
src/targets/gpu/jit/ck_gemm_gemm.cpp | 259 ++++++++++++++++++
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp | 3 +
.../include/migraphx/kernels/ck_gemm_gemm.hpp | 66 +++++
test/verify/run_verify.cpp | 7 +
test/verify/test_ck_gemm_gemm.cpp | 51 ++++
11 files changed, 460 insertions(+), 3 deletions(-)
create mode 100644 src/targets/gpu/jit/ck_gemm_gemm.cpp
create mode 100644 src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_gemm.hpp
create mode 100644 test/verify/test_ck_gemm_gemm.cpp
diff --git a/CMakeLists.txt b/CMakeLists.txt
index e4a884ffbd0..bc4af5ad06c 100755
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -76,7 +76,7 @@ endif()
if(WIN32)
option(MIGRAPHX_USE_HIPBLASLT "Enable MIGraphX to use hipBLASLt" OFF)
else()
-option(MIGRAPHX_USE_HIPBLASLT "Enable MIGraphX to use hipBLASLt" ON)
+option(MIGRAPHX_USE_HIPBLASLT "Enable MIGraphX to use hipBLASLt" OFF)
endif()
# By default build shared libraries
diff --git a/requirements.txt b/requirements.txt
index c686d609449..1316cdf0450 100755
--- a/requirements.txt
+++ b/requirements.txt
@@ -27,5 +27,5 @@ ROCm/half@rocm-5.6.0
pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On
-mirza-halilcevic/composable_kernel@2cd9b2b1ec1ade098a0f43c44fa5a0989e0b0a9d -DCK_BUILD_HOST_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
+mirza-halilcevic/composable_kernel@f22232bf4812f02b5afd403c560ea37b5d7168a3 -DCK_BUILD_HOST_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCm/rocMLIR@82378ac44b8627ecbaf9078353fb53588090fe28 -DBUILD_FAT_LIBROCKCOMPILER=On
diff --git a/src/pass_manager.cpp b/src/pass_manager.cpp
index af8f0e4d6a7..8824426ff6c 100644
--- a/src/pass_manager.cpp
+++ b/src/pass_manager.cpp
@@ -178,6 +178,9 @@ void run_passes(program& prog, module_ref root_mod, const std::vector& pas
mpm.run_pass(p);
}
run_pass(prog, p, trace);
+
+ std::cout << "PASS " << p.name() << std::endl;
+ std::cout << prog << std::endl;
}
}
diff --git a/src/targets/gpu/CMakeLists.txt b/src/targets/gpu/CMakeLists.txt
index 48d36dadeda..101f4c87269 100644
--- a/src/targets/gpu/CMakeLists.txt
+++ b/src/targets/gpu/CMakeLists.txt
@@ -131,7 +131,8 @@ file(GLOB JIT_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/jit/*.cpp)
if(NOT MIGRAPHX_USE_COMPOSABLEKERNEL)
list(REMOVE_ITEM JIT_GPU_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm.cpp
- ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm_softmax_gemm.cpp)
+ ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm_softmax_gemm.cpp
+ ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm_gemm.cpp)
endif()
if(MIGRAPHX_USE_MIOPEN)
diff --git a/src/targets/gpu/fuse_ck.cpp b/src/targets/gpu/fuse_ck.cpp
index bf9a269f3e1..3b13c8bfb7a 100644
--- a/src/targets/gpu/fuse_ck.cpp
+++ b/src/targets/gpu/fuse_ck.cpp
@@ -76,6 +76,48 @@ struct ck_gemm
};
MIGRAPHX_REGISTER_OP(ck_gemm);
+struct ck_gemm_gemm
+{
+ operation op = make_op("dot");
+
+ template
+ static auto reflect(Self& self, F f)
+ {
+ return pack(f(self.op, "op"));
+ }
+
+ std::string name() const { return "gpu::ck_gemm_gemm"; }
+
+ void check_gemm_shape(const shape& s) const
+ {
+ if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1) and
+ not s.scalar())
+ MIGRAPHX_THROW("Invalid shape for " + name());
+ }
+
+ shape compute_shape(std::vector inputs, const std::vector&) const
+ {
+ check_shapes{inputs, *this}.same_ndims();
+ if(inputs.size() < 3)
+ MIGRAPHX_THROW(name() + ": Expected 3 inputs but got " + to_string(inputs.size()));
+
+ auto a = inputs[0];
+ auto b = inputs[1];
+ auto b1 = inputs.back();
+
+ for(const auto& input : inputs)
+ {
+ check_gemm_shape(input);
+ }
+
+ auto gemm0_shape = op.compute_shape({a, b});
+ return op.compute_shape({gemm0_shape, b1});
+ }
+
+ static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); }
+};
+MIGRAPHX_REGISTER_OP(ck_gemm_gemm);
+
struct ck_gemm_softmax_gemm : gemm_softmax_gemm
{
std::string name() const { return "gpu::ck_gemm_softmax_gemm"; }
@@ -203,11 +245,35 @@ struct find_ck_gemm_softmax_gemm
}
};
+struct find_ck_gemm_gemm
+{
+ auto matcher() const
+ {
+ // TODO don't mix dot and quant_dot
+ auto gemm1 = match::skip(match::name("contiguous"))(
+ match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm1")));
+ return match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm2"))(match::arg(0)(gemm1));
+ }
+
+ void apply(module_pass_manager& mpm, const match::matcher_result& r) const
+ {
+ auto ins = r.result;
+ auto gemm1_ins = r.instructions["gemm1"];
+ auto gemm2_ins = r.instructions["gemm2"];
+
+ auto inputs = gemm1_ins->inputs(); // A, B
+ inputs.push_back(gemm2_ins->inputs().back()); // B1
+
+ mpm.get_module().replace_instruction(ins, ck_gemm_gemm{gemm2_ins->get_operator()}, inputs);
+ }
+};
+
} // namespace
void fuse_ck::apply(module_pass_manager& mpm) const
{
match::find_matches(mpm, find_ck_gemm_softmax_gemm{}, find_ck_gemm_pointwise{});
+ match::find_matches(mpm, find_ck_gemm_gemm{});
match::find_matches(mpm, find_ck_gemm{});
}
diff --git a/src/targets/gpu/include/migraphx/gpu/ck.hpp b/src/targets/gpu/include/migraphx/gpu/ck.hpp
index ea41b252547..d252db2c8b9 100644
--- a/src/targets/gpu/include/migraphx/gpu/ck.hpp
+++ b/src/targets/gpu/include/migraphx/gpu/ck.hpp
@@ -33,6 +33,7 @@
#include "ck/host/headers.hpp"
#include "ck/host/device_gemm_multiple_d/problem.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
+#include "ck/host/device_batched_gemm_multiple_d_gemm_multiple_d/problem.hpp"
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
diff --git a/src/targets/gpu/jit/ck_gemm_gemm.cpp b/src/targets/gpu/jit/ck_gemm_gemm.cpp
new file mode 100644
index 00000000000..dcce8611c19
--- /dev/null
+++ b/src/targets/gpu/jit/ck_gemm_gemm.cpp
@@ -0,0 +1,259 @@
+/*
+ * The MIT License (MIT)
+ *
+ * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ * THE SOFTWARE.
+ */
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace migraphx {
+inline namespace MIGRAPHX_INLINE_NS {
+
+namespace gpu {
+
+using namespace migraphx::gpu::gen; // NOLINT
+
+// NOLINTNEXTLINE
+static const char* const ck_gemm_gemm_kernel = R"__migraphx__(
+#include
+#include
+#include
+#include
+#include <${include}>
+
+namespace migraphx {
+
+${preamble}
+
+extern "C" {
+
+MIGRAPHX_GLOBAL void ${kernel}(${params})
+{
+ transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
+ ck_gemm_gemm<${solution}, ${blocks_per_batch}>(xs...);
+ });
+}
+
+}
+
+} // namespace migraphx
+
+)__migraphx__";
+
+struct ck_gemm_gemm_compiler : compiler
+{
+ std::vector names() const { return {"ck_gemm_gemm", "gpu::ck_gemm_gemm"}; }
+
+ ck::host::device_batched_gemm_multiple_d_gemm_multiple_d::Problem
+ create_problem(const std::vector& inputs, const value& v) const
+ {
+ const auto& a0_shape = inputs[0];
+ const auto& b0_shape = inputs[1];
+ const auto& b1_shape = inputs[2];
+ const auto& e1_shape = inputs.back();
+
+ // cppcheck-suppress unreadVariable
+ auto rank = a0_shape.ndim();
+ auto batch_count = get_batch_count(e1_shape);
+ auto m = e1_shape.lens()[rank - 2];
+ m = can_fold_batch(inputs) ? m * batch_count : m;
+ auto n = b0_shape.lens().back();
+ auto k = a0_shape.lens().back();
+ auto o = e1_shape.lens().back();
+
+ const bool trans_a0 = transposed_matrix(a0_shape);
+ const bool trans_b0 = transposed_matrix(b0_shape);
+ const bool trans_b1 = transposed_matrix(b1_shape);
+ const bool trans_e1 = transposed_matrix(e1_shape);
+
+ std::vector trans_d0s;
+ std::transform(inputs.begin() + 3,
+ inputs.end() - 1,
+ std::back_inserter(trans_d0s),
+ [](const auto& i) { return transposed_matrix(i); });
+ // TODO trans_d1s
+
+ const auto a0_type = get_type(a0_shape);
+ const auto b0_type = get_type(b0_shape);
+ const auto b1_type = get_type(b1_shape);
+ const auto e1_type = get_type(e1_shape);
+
+ std::vector d0s_type;
+ std::transform(inputs.begin() + 3,
+ inputs.end() - 1,
+ std::back_inserter(d0s_type),
+ [](const auto& i) { return get_type(i); });
+ // TODO d1s_type
+
+ std::string ck_passthrough = "ck_passthrough";
+ std::string cde0_op = ck_passthrough;
+ std::string cde1_op = ck_passthrough;
+ assert(inputs.size() < 5 or v.contains("post"));
+ if(v.contains("post"))
+ {
+ cde0_op = v.at("post").to();
+ // TODO CDE1ElementOp
+ }
+
+ return ck::host::device_batched_gemm_multiple_d_gemm_multiple_d::Problem{m,
+ n,
+ k,
+ o,
+ trans_a0,
+ trans_b0,
+ trans_d0s,
+ trans_b1,
+ {}, // trans_d1s
+ trans_e1,
+ a0_type,
+ b0_type,
+ d0s_type,
+ b1_type,
+ {}, // d1s_type
+ e1_type,
+ ck_passthrough,
+ ck_passthrough,
+ cde0_op,
+ ck_passthrough,
+ cde1_op};
+ }
+
+ operation compile_op(context& ctx, const std::vector& inputs, const value& v) const
+ {
+ const auto& e1_shape = inputs.back();
+ auto tuning_value = v.get("tuning_value", 0);
+ auto batch_count = get_batch_count(e1_shape);
+ auto problem = create_problem(inputs, v);
+
+ const auto include_header = problem.GetIncludeHeader();
+ const auto solutions =
+ problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", "");
+ const auto& solution = solutions.at(tuning_value);
+ const auto template_str = solution.ToTemplateString();
+ const auto block_size = solution.GetTemplateParameter("BlockSize");
+ const auto m_per_block = solution.GetTemplateParameter("Gemm0MPerBlock");
+ const auto n1_per_block = solution.GetTemplateParameter("Gemm1NPerBlock");
+ const auto blocks_per_batch = ck::host::integer_divide_ceil(problem.M, m_per_block) *
+ ck::host::integer_divide_ceil(problem.O, n1_per_block);
+
+ hip_compile_options options;
+ options.additional_src_files = ck_headers();
+
+ auto grid_size = can_fold_batch(inputs) ? blocks_per_batch : batch_count * blocks_per_batch;
+ options.set_launch_params(v, grid_size * block_size, block_size);
+ options.inputs = inputs;
+ options.output = e1_shape;
+ options.kernel_name = v.get("kernel", "ck_gemm_gemm_kernel");
+ options.virtual_inputs = inputs;
+ if(can_fold_batch(inputs))
+ {
+ auto vinputs = inputs;
+ fold_batch_dims(vinputs[0]);
+ remove_batch_dims(vinputs[1]);
+ std::for_each(vinputs.begin() + 2, vinputs.end(), fold_batch_dims);
+ options.virtual_inputs = vinputs;
+ }
+
+ if(v.get("check", false) or enabled(MIGRAPHX_CK_DEBUG{}))
+ options.emplace_param("-DMIGRAPHX_CK_CHECK=1");
+
+ auto src = interpolate_string(ck_gemm_gemm_kernel,
+ {{"solution", template_str},
+ {"include", include_header},
+ {"params", enum_params(inputs.size(), "void * private_p")},
+ {"args", enum_params(inputs.size(), "private_p")},
+ {"blocks_per_batch", to_string(blocks_per_batch)},
+ {"preamble", v.get("preamble", std::string{})},
+ {"kernel", options.kernel_name}});
+
+ return compile_hip_code_object(src, options);
+ }
+
+ value create_settings(instruction_ref ins, const operation& op) const
+ {
+ auto v = op.to_value();
+ v["kernel"] = "ck_gemm_gemm_kernel";
+ if(not ins->module_inputs().empty())
+ {
+ auto* pm = ins->module_inputs().front();
+ v["preamble"] = generate_pointwise(*pm, "post_ck_gemm0_function") +
+ "\nMIGRAPHX_LIFT_CLASS(post_ck_gemm0, post_ck_gemm0_function);";
+ v["post"] = "ck_function_adaptor";
+ v["kernel"] = to_c_id("ck_gemm_" + generate_name_from_ops(*pm) + "_gemm_kernel");
+ }
+ return v;
+ }
+
+ compiler_replace
+ compile(context& ctx, instruction_ref ins, const operation& op, const value& solution) const
+ {
+ std::cout << "USING GEMM GEMM" << std::endl;
+
+ auto shapes = to_shapes(ins->inputs());
+ auto v = create_settings(ins, op);
+ if(not solution.is_null())
+ v["tuning_value"] = solution;
+ return {compile_op(ctx, shapes, v),
+ [=](module& m, instruction_ref ins2, const operation& code_object) {
+ if(enabled(MIGRAPHX_LOG_CK_GEMM{}))
+ {
+ std::vector gemm_shapes{
+ shapes[0], shapes[1], shapes.back().with_type(shapes[0].type())};
+ std::cout << "gpu::ck_gemm_gemm: " << to_json_string(to_value(gemm_shapes))
+ << std::endl;
+ }
+ m.replace_instruction(ins2, code_object, ins2->inputs());
+ }};
+ }
+
+ optional
+ get_tuning_config(context& ctx, instruction_ref ins, const operation& op, bool exhaustive) const
+ {
+ if(not exhaustive and not enabled(MIGRAPHX_TUNE_CK{}))
+ return nullopt;
+ tuning_config tc;
+ auto shapes = to_shapes(ins->inputs());
+ auto problem = create_problem(shapes, create_settings(ins, op));
+ auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", "");
+ tc.solutions.resize(solutions.size());
+ std::iota(tc.solutions.begin(), tc.solutions.end(), 0);
+ std::vector gemm_shapes{shapes[0], shapes[1], shapes.back()};
+ tc.problem = to_value(gemm_shapes);
+ return tc;
+ }
+};
+
+} // namespace gpu
+} // namespace MIGRAPHX_INLINE_NS
+} // namespace migraphx
diff --git a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
index c5da7c4c4b3..fdd68bd077b 100644
--- a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
+++ b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
@@ -182,6 +182,9 @@ struct ck_gemm_softmax_gemm_compiler : compiler
{"preamble", v.get("preamble", std::string{})},
{"kernel", options.kernel_name}});
+ // std::cout << "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" << std::endl;
+ // std::cout << src << std::endl;
+
return compile_hip_code_object(src, options);
}
diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_gemm.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_gemm.hpp
new file mode 100644
index 00000000000..2b9e6f55b31
--- /dev/null
+++ b/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_gemm.hpp
@@ -0,0 +1,66 @@
+/*
+ * The MIT License (MIT)
+ *
+ * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ * THE SOFTWARE.
+ */
+#ifndef MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP
+#define MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace migraphx {
+
+template
+__device__ void ck_gemm_gemm_matrix(E1 e1, A0 a0, B0 b0, B1 b1, D0s... d0s)
+{
+ constexpr auto desc = G::make_descriptor(to_ck_tensor(),
+ to_ck_tensor>(),
+ ck::make_tuple(to_ck_tensor()...),
+ to_ck_tensor>(),
+ ck::make_tuple(),
+ to_ck_tensor());
+
+ MIGRAPHX_STATIC_ASSERT_FOR(desc.IsValid())
+ {
+ G::Run(desc,
+ to_ck_const_pointer(a0.data()),
+ to_ck_const_pointer(b0.data()),
+ ck::make_tuple(to_ck_const_pointer(d0s.data())...),
+ to_ck_const_pointer(b1.data()),
+ ck::make_tuple(),
+ to_ck_pointer(e1.data()));
+ }
+}
+
+template
+__device__ void ck_gemm_gemm(Ts... xs)
+{
+ gemm_batch_args(make_index(), _c, xs...)(
+ [](auto... ys) { ck_gemm_gemm_matrix(ys...); });
+}
+
+} // namespace migraphx
+#endif
diff --git a/test/verify/run_verify.cpp b/test/verify/run_verify.cpp
index e7fdd535fcc..56e244f9d3b 100644
--- a/test/verify/run_verify.cpp
+++ b/test/verify/run_verify.cpp
@@ -132,6 +132,10 @@ run_verify::run_ref(migraphx::program p,
auto_print pp{p, t.name()};
auto trace_target = migraphx::string_value_of(MIGRAPHX_TRACE_TEST_COMPILE{});
compile_check(p, t, c_opts, (trace_target == "ref"));
+
+ std::cout << "REF PROGRAM" << std::endl;
+ p.debug_print();
+
return std::make_pair(std::move(p), p.eval(std::move(inputs)));
}
@@ -159,6 +163,9 @@ run_verify::run_target(const migraphx::target& t,
validate(t, p, m);
p.eval(m);
+ std::cout << "GPU PROGRAM" << std::endl;
+ p.debug_print();
+
auto tres = p.eval(m);
std::vector res(tres.size());
std::transform(
diff --git a/test/verify/test_ck_gemm_gemm.cpp b/test/verify/test_ck_gemm_gemm.cpp
new file mode 100644
index 00000000000..a3ae26d243e
--- /dev/null
+++ b/test/verify/test_ck_gemm_gemm.cpp
@@ -0,0 +1,51 @@
+/*
+ * The MIT License (MIT)
+ *
+ * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ * THE SOFTWARE.
+ */
+
+#include "verify_program.hpp"
+#include
+#include
+#include
+
+struct test_ck_gemm_gemm : verify_program
+{
+ migraphx::program create_program() const
+ {
+ migraphx::program p;
+ auto* mm = p.get_main_module();
+
+ migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 128}};
+ migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 512, 128}};
+ migraphx::shape m3_shape{migraphx::shape::half_type, {1, 12, 512, 64}};
+ auto a = mm->add_parameter("1", m1_shape);
+ auto b = mm->add_parameter("2", m2_shape);
+ auto b1 = mm->add_parameter("3", m3_shape);
+
+ b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
+ auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
+ mm->add_instruction(migraphx::make_op("dot"), gemm1, b1);
+
+ return p;
+ }
+ std::string section() const { return "gemm"; }
+};
From edaa04856a3af6e8f74bb96ec457fd5efae0c624 Mon Sep 17 00:00:00 2001
From: Mirza Halilcevic
Date: Thu, 31 Oct 2024 08:57:26 +0000
Subject: [PATCH 03/14] Finalize gemm_gemm integration.
---
requirements.txt | 9 +-
src/pass_manager.cpp | 3 -
src/targets/gpu/fuse_ck.cpp | 158 +++++++++++++++---
src/targets/gpu/jit/ck_gemm_gemm.cpp | 81 ++++++---
.../include/migraphx/kernels/ck_gemm_gemm.hpp | 46 +++--
test/verify/test_ck_gemm_gemm.cpp | 12 +-
test/verify/test_ck_gemm_gemm_pointwise.cpp | 58 +++++++
test/verify/test_ck_gemm_pointwise_gemm.cpp | 58 +++++++
.../test_ck_gemm_pointwise_gemm_pointwise.cpp | 65 +++++++
..._gemm_pointwise_gemm_pointwise_rotated.cpp | 62 +++++++
10 files changed, 481 insertions(+), 71 deletions(-)
create mode 100644 test/verify/test_ck_gemm_gemm_pointwise.cpp
create mode 100644 test/verify/test_ck_gemm_pointwise_gemm.cpp
create mode 100644 test/verify/test_ck_gemm_pointwise_gemm_pointwise.cpp
create mode 100644 test/verify/test_ck_gemm_pointwise_gemm_pointwise_rotated.cpp
diff --git a/requirements.txt b/requirements.txt
index d6d437a9ff1..37ef26395b7 100755
--- a/requirements.txt
+++ b/requirements.txt
@@ -27,10 +27,5 @@ ROCm/half@rocm-5.6.0
pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On
-<<<<<<< HEAD
-mirza-halilcevic/composable_kernel@f22232bf4812f02b5afd403c560ea37b5d7168a3 -DCK_BUILD_HOST_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
-ROCm/rocMLIR@82378ac44b8627ecbaf9078353fb53588090fe28 -DBUILD_FAT_LIBROCKCOMPILER=On
-=======
-ROCm/composable_kernel@57cdd70b7cb14e5e3b60cd9a5f96ba8dc343763e -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
-ROCm/rocMLIR@e454b5d06fc2f099f7de3ee43450e7a6b1efe015 -DBUILD_FAT_LIBROCKCOMPILER=On
->>>>>>> upstream/develop
+mirza-halilcevic/composable_kernel@089c978451b3a5c15e146c5372e8ef8c5d4bc7d9 -DCK_BUILD_HOST_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
+ROCm/rocMLIR@87a55290b7f89f9f8b97c611e0cf929399a9e0c5 -DBUILD_FAT_LIBROCKCOMPILER=On
diff --git a/src/pass_manager.cpp b/src/pass_manager.cpp
index 8824426ff6c..af8f0e4d6a7 100644
--- a/src/pass_manager.cpp
+++ b/src/pass_manager.cpp
@@ -178,9 +178,6 @@ void run_passes(program& prog, module_ref root_mod, const std::vector& pas
mpm.run_pass(p);
}
run_pass(prog, p, trace);
-
- std::cout << "PASS " << p.name() << std::endl;
- std::cout << prog << std::endl;
}
}
diff --git a/src/targets/gpu/fuse_ck.cpp b/src/targets/gpu/fuse_ck.cpp
index 3b13c8bfb7a..515a44bfc55 100644
--- a/src/targets/gpu/fuse_ck.cpp
+++ b/src/targets/gpu/fuse_ck.cpp
@@ -78,12 +78,15 @@ MIGRAPHX_REGISTER_OP(ck_gemm);
struct ck_gemm_gemm
{
- operation op = make_op("dot");
+ operation op = make_op("dot");
+ size_t d0s_count = 0u;
+ size_t d1s_count = 0u;
template
static auto reflect(Self& self, F f)
{
- return pack(f(self.op, "op"));
+ return pack(
+ f(self.op, "op"), f(self.d0s_count, "d0s_count"), f(self.d1s_count, "d1s_count"));
}
std::string name() const { return "gpu::ck_gemm_gemm"; }
@@ -101,9 +104,9 @@ struct ck_gemm_gemm
if(inputs.size() < 3)
MIGRAPHX_THROW(name() + ": Expected 3 inputs but got " + to_string(inputs.size()));
- auto a = inputs[0];
- auto b = inputs[1];
- auto b1 = inputs.back();
+ auto a = inputs[0];
+ auto b = inputs[1];
+ auto b1 = inputs[2];
for(const auto& input : inputs)
{
@@ -132,11 +135,11 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return false;
if(not ck_gemm::is_ck_supported_type(ins->get_shape().type()))
return false;
- auto a = ins->inputs().front()->get_shape();
- auto b = ins->inputs().back()->get_shape();
- auto m = a.lens()[a.lens().size() - 2];
- auto n = b.lens().back();
- auto k = a.lens().back();
+ auto a = ins->inputs().front()->get_shape();
+ auto b = ins->inputs().back()->get_shape();
+ auto m = a.lens()[a.lens().size() - 2];
+ auto n = b.lens().back();
+ auto k = a.lens().back();
auto batch_size = std::accumulate(
a.lens().rbegin() + 2, a.lens().rend(), std::size_t{1}, std::multiplies());
// Integer gemms must be divisible by 4 in ck
@@ -245,26 +248,137 @@ struct find_ck_gemm_softmax_gemm
}
};
-struct find_ck_gemm_gemm
+struct find_ck_gemm_pointwise_gemm
{
auto matcher() const
{
// TODO don't mix dot and quant_dot
- auto gemm1 = match::skip(match::name("contiguous"))(
- match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm1")));
- return match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm2"))(match::arg(0)(gemm1));
+ // TODO match used_once?
+ auto gemm0 = match::skip(match::name("contiguous"))(
+ match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm0")));
+ auto pw0 =
+ match::name("pointwise")(match::any_of[match::inputs()](gemm0.bind("x0")).bind("pw0"));
+ return match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm1"))(
+ match::arg(0)(match::any_of(pw0, gemm0)));
}
- void apply(module_pass_manager& mpm, const match::matcher_result& r) const
+ bool transposed_matrix(const shape& s) { return s.strides().back() != 1; }
+
+ void apply(module_pass_manager& mpm, const match::matcher_result& r)
{
+ std::cout << "matching ck_gemm_gemm" << std::endl;
+
auto ins = r.result;
+ auto gemm0_ins = r.instructions["gemm0"];
auto gemm1_ins = r.instructions["gemm1"];
- auto gemm2_ins = r.instructions["gemm2"];
- auto inputs = gemm1_ins->inputs(); // A, B
- inputs.push_back(gemm2_ins->inputs().back()); // B1
+ auto inputs = gemm0_ins->inputs(); // A, B
+ inputs.push_back(gemm1_ins->inputs().back()); // B1
+
+ if (!transposed_matrix(inputs[1]->get_shape()))
+ return;
- mpm.get_module().replace_instruction(ins, ck_gemm_gemm{gemm2_ins->get_operator()}, inputs);
+ size_t d0s_count = 0, d1s_count = 0;
+ std::vector module_inputs;
+ if(r.instructions.find("pw0") != r.instructions.end())
+ {
+ auto pw0 = r.instructions["pw0"];
+ auto x0_ins = r.instructions["x0"];
+
+ if(gemm0_ins->get_shape().type() != shape::int32_type and
+ pw0->get_shape().type() != gemm0_ins->get_shape().type())
+ return;
+ if(std::any_of(pw0->inputs().begin(), pw0->inputs().end(), [](auto input) {
+ return not ck_gemm::is_ck_supported_type(input->get_shape().type());
+ }))
+ return;
+ if(std::any_of(pw0->inputs().begin(), pw0->inputs().end(), [](auto input) {
+ return not input->inputs().empty() and
+ input->inputs().front()->name() == "capture";
+ }))
+ return;
+
+ auto pw0_inputs = pw0->inputs();
+ auto* pw0m = pw0->module_inputs().front();
+
+ auto gemm_it = std::find(pw0_inputs.begin(), pw0_inputs.end(), x0_ins);
+ auto gemm_idx = gemm_it - pw0_inputs.begin();
+ if(gemm_idx != 0)
+ {
+ rotate_gemm_input(pw0m, gemm_idx);
+ }
+
+ pw0_inputs.erase(gemm_it);
+ inputs.insert(inputs.end(), pw0_inputs.begin(), pw0_inputs.end()); // D0s
+
+ d0s_count = pw0_inputs.size();
+ module_inputs.push_back(pw0m);
+ }
+ if(r.instructions.find("pw1") != r.instructions.end())
+ {
+ auto pw1 = r.instructions["pw1"];
+
+ if(gemm1_ins->get_shape().type() != shape::int32_type and
+ pw1->get_shape().type() != gemm1_ins->get_shape().type())
+ return;
+ if(std::any_of(pw1->inputs().begin(), pw1->inputs().end(), [](auto input) {
+ return not ck_gemm::is_ck_supported_type(input->get_shape().type());
+ }))
+ return;
+ if(std::any_of(pw1->inputs().begin(), pw1->inputs().end(), [](auto input) {
+ return not input->inputs().empty() and
+ input->inputs().front()->name() == "capture";
+ }))
+ return;
+
+ auto pw1_inputs = pw1->inputs();
+ auto* pw1m = pw1->module_inputs().front();
+
+ auto gemm_it = std::find(pw1_inputs.begin(), pw1_inputs.end(), gemm1_ins);
+ auto gemm_idx = gemm_it - pw1_inputs.begin();
+ if(gemm_idx != 0)
+ {
+ rotate_gemm_input(pw1m, gemm_idx);
+ }
+
+ pw1_inputs.erase(gemm_it);
+ inputs.insert(inputs.end(), pw1_inputs.begin(), pw1_inputs.end()); // D1s
+
+ d1s_count = pw1_inputs.size();
+ module_inputs.push_back(pw1m);
+ }
+
+ mpm.get_module().replace_instruction(
+ ins,
+ ck_gemm_gemm{gemm1_ins->get_operator(), d0s_count, d1s_count},
+ inputs,
+ module_inputs);
+ }
+
+ void rotate_gemm_input(module* pwm, size_t gemm_idx)
+ {
+ auto names = pwm->get_parameter_names();
+
+ auto first_param = pwm->get_parameter(names[0]);
+ auto gemm_param = pwm->get_parameter(names[gemm_idx]);
+
+ auto new_gemm_param = pwm->add_parameter(names[0] + "_0", gemm_param->get_shape());
+ auto new_first_param = pwm->add_parameter(names[gemm_idx] + "_0", first_param->get_shape());
+
+ pwm->replace_instruction(gemm_param, new_gemm_param);
+ pwm->replace_instruction(first_param, new_first_param);
+ pwm->remove_instruction(first_param);
+ pwm->remove_instruction(gemm_param);
+ }
+};
+
+struct find_ck_gemm_pointwise_gemm_pointwise : find_ck_gemm_pointwise_gemm
+{
+ auto matcher() const
+ {
+ // TODO match used_once?
+ auto gemm1 = find_ck_gemm_pointwise_gemm::matcher();
+ return match::name("pointwise")(match::any_of[match::inputs()](gemm1).bind("pw1"));
}
};
@@ -272,8 +386,10 @@ struct find_ck_gemm_gemm
void fuse_ck::apply(module_pass_manager& mpm) const
{
- match::find_matches(mpm, find_ck_gemm_softmax_gemm{}, find_ck_gemm_pointwise{});
- match::find_matches(mpm, find_ck_gemm_gemm{});
+ match::find_matches(mpm, find_ck_gemm_softmax_gemm{});
+ match::find_matches(mpm, find_ck_gemm_pointwise_gemm_pointwise{});
+ match::find_matches(mpm, find_ck_gemm_pointwise_gemm{});
+ match::find_matches(mpm, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm{});
}
diff --git a/src/targets/gpu/jit/ck_gemm_gemm.cpp b/src/targets/gpu/jit/ck_gemm_gemm.cpp
index dcce8611c19..005cbda11e8 100644
--- a/src/targets/gpu/jit/ck_gemm_gemm.cpp
+++ b/src/targets/gpu/jit/ck_gemm_gemm.cpp
@@ -54,14 +54,16 @@ static const char* const ck_gemm_gemm_kernel = R"__migraphx__(
namespace migraphx {
-${preamble}
+${preamble0}
+
+${preamble1}
extern "C" {
MIGRAPHX_GLOBAL void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
- ck_gemm_gemm<${solution}, ${blocks_per_batch}>(xs...);
+ ck_gemm_gemm<${solution}, ${blocks_per_batch}, ${d0s_count}>(xs...);
});
}
@@ -97,12 +99,20 @@ struct ck_gemm_gemm_compiler : compiler
const bool trans_b1 = transposed_matrix(b1_shape);
const bool trans_e1 = transposed_matrix(e1_shape);
+ auto d0s_count = v.get("d0s_count", 0);
+ auto d1s_count = v.get("d1s_count", 0);
+
std::vector trans_d0s;
std::transform(inputs.begin() + 3,
- inputs.end() - 1,
+ inputs.end() - 1 - d1s_count,
std::back_inserter(trans_d0s),
[](const auto& i) { return transposed_matrix(i); });
- // TODO trans_d1s
+
+ std::vector trans_d1s;
+ std::transform(inputs.begin() + 3 + d0s_count,
+ inputs.end() - 1,
+ std::back_inserter(trans_d1s),
+ [](const auto& i) { return transposed_matrix(i); });
const auto a0_type = get_type(a0_shape);
const auto b0_type = get_type(b0_shape);
@@ -111,19 +121,27 @@ struct ck_gemm_gemm_compiler : compiler
std::vector d0s_type;
std::transform(inputs.begin() + 3,
- inputs.end() - 1,
+ inputs.end() - 1 - d1s_count,
std::back_inserter(d0s_type),
[](const auto& i) { return get_type(i); });
- // TODO d1s_type
+
+ std::vector d1s_type;
+ std::transform(inputs.begin() + 3 + d0s_count,
+ inputs.end() - 1,
+ std::back_inserter(d1s_type),
+ [](const auto& i) { return get_type(i); });
std::string ck_passthrough = "ck_passthrough";
std::string cde0_op = ck_passthrough;
std::string cde1_op = ck_passthrough;
- assert(inputs.size() < 5 or v.contains("post"));
- if(v.contains("post"))
+ assert(inputs.size() < 5 or v.contains("cde0_op") or v.containse("cde1_op"));
+ if(v.contains("cde0_op"))
+ {
+ cde0_op = v.at("cde0_op").to();
+ }
+ if(v.contains("cde1_op"))
{
- cde0_op = v.at("post").to();
- // TODO CDE1ElementOp
+ cde1_op = v.at("cde1_op").to();
}
return ck::host::device_batched_gemm_multiple_d_gemm_multiple_d::Problem{m,
@@ -134,13 +152,13 @@ struct ck_gemm_gemm_compiler : compiler
trans_b0,
trans_d0s,
trans_b1,
- {}, // trans_d1s
+ trans_d1s,
trans_e1,
a0_type,
b0_type,
d0s_type,
b1_type,
- {}, // d1s_type
+ d1s_type,
e1_type,
ck_passthrough,
ck_passthrough,
@@ -151,6 +169,8 @@ struct ck_gemm_gemm_compiler : compiler
operation compile_op(context& ctx, const std::vector& inputs, const value& v) const
{
+ std::cout << "compiling ck_gemm_gemm: " << v.get("kernel", std::string{}) << std::endl;
+
const auto& e1_shape = inputs.back();
auto tuning_value = v.get("tuning_value", 0);
auto batch_count = get_batch_count(e1_shape);
@@ -194,7 +214,9 @@ struct ck_gemm_gemm_compiler : compiler
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"blocks_per_batch", to_string(blocks_per_batch)},
- {"preamble", v.get("preamble", std::string{})},
+ {"d0s_count", to_string(v.get("d0s_count", 0))},
+ {"preamble0", v.get("preamble0", std::string{})},
+ {"preamble1", v.get("preamble1", std::string{})},
{"kernel", options.kernel_name}});
return compile_hip_code_object(src, options);
@@ -202,24 +224,37 @@ struct ck_gemm_gemm_compiler : compiler
value create_settings(instruction_ref ins, const operation& op) const
{
- auto v = op.to_value();
- v["kernel"] = "ck_gemm_gemm_kernel";
- if(not ins->module_inputs().empty())
+ auto v = op.to_value();
+
+ auto d0s_count = v.get("d0s_count", 0);
+ auto d1s_count = v.get("d1s_count", 0);
+
+ std::string pw0_name, pw1_name, kernel_name = "ck_gemm_";
+ if(d0s_count > 0)
{
- auto* pm = ins->module_inputs().front();
- v["preamble"] = generate_pointwise(*pm, "post_ck_gemm0_function") +
- "\nMIGRAPHX_LIFT_CLASS(post_ck_gemm0, post_ck_gemm0_function);";
- v["post"] = "ck_function_adaptor";
- v["kernel"] = to_c_id("ck_gemm_" + generate_name_from_ops(*pm) + "_gemm_kernel");
+ auto* pw0m = ins->module_inputs().front();
+ v["preamble0"] = generate_pointwise(*pw0m, "post_ck_gemm0_function") +
+ "\nMIGRAPHX_LIFT_CLASS(post_ck_gemm0, post_ck_gemm0_function);";
+ v["cde0_op"] = "ck_function_adaptor";
+ kernel_name += generate_name_from_ops(*pw0m) + "_";
}
+ kernel_name += "gemm_";
+ if(d1s_count > 0)
+ {
+ auto* pw1m = ins->module_inputs().back();
+ v["preamble1"] = generate_pointwise(*pw1m, "post_ck_gemm1_function") +
+ "\nMIGRAPHX_LIFT_CLASS(post_ck_gemm1, post_ck_gemm1_function);";
+ v["cde1_op"] = "ck_function_adaptor";
+ kernel_name += generate_name_from_ops(*pw1m) + "_";
+ }
+
+ v["kernel"] = to_c_id(kernel_name + "kernel");
return v;
}
compiler_replace
compile(context& ctx, instruction_ref ins, const operation& op, const value& solution) const
{
- std::cout << "USING GEMM GEMM" << std::endl;
-
auto shapes = to_shapes(ins->inputs());
auto v = create_settings(ins, op);
if(not solution.is_null())
diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_gemm.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_gemm.hpp
index 2b9e6f55b31..f0dbeefd333 100644
--- a/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_gemm.hpp
+++ b/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_gemm.hpp
@@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP
+#include "migraphx/kernels/type_traits.hpp"
#include
#include
#include
@@ -33,33 +34,54 @@
namespace migraphx {
-template
-__device__ void ck_gemm_gemm_matrix(E1 e1, A0 a0, B0 b0, B1 b1, D0s... d0s)
+template
+__device__ void ck_gemm_gemm_matrix_impl(
+ E1 e1, A0 a0, B0 b0, B1 b1, detail::seq, detail::seq, tuple ds)
{
- constexpr auto desc = G::make_descriptor(to_ck_tensor(),
- to_ck_tensor>(),
- ck::make_tuple(to_ck_tensor()...),
- to_ck_tensor>(),
- ck::make_tuple(),
- to_ck_tensor());
+ constexpr auto desc = G::make_descriptor(
+ to_ck_tensor(),
+ to_ck_tensor>(),
+ ck::make_tuple(
+ to_ck_tensor(ds))>>()...),
+ to_ck_tensor>(),
+ ck::make_tuple(
+ to_ck_tensor(ds))>>()...),
+ to_ck_tensor());
MIGRAPHX_STATIC_ASSERT_FOR(desc.IsValid())
{
G::Run(desc,
to_ck_const_pointer(a0.data()),
to_ck_const_pointer(b0.data()),
- ck::make_tuple(to_ck_const_pointer(d0s.data())...),
+ ck::make_tuple(to_ck_const_pointer(tuple_detail::get_element(ds).data())...),
to_ck_const_pointer(b1.data()),
- ck::make_tuple(),
+ ck::make_tuple(to_ck_const_pointer(
+ tuple_detail::get_element(ds).data())...),
to_ck_pointer(e1.data()));
}
}
-template
+template
+__device__ void ck_gemm_gemm_matrix(E1 e1, A0 a0, B0 b0, B1 b1, Ds... ds)
+{
+ auto all_ds = make_tuple(ds...);
+ ck_gemm_gemm_matrix_impl(
+ e1, a0, b0, b1, detail::gens{}, detail::gens{}, all_ds);
+}
+
+template
__device__ void ck_gemm_gemm(Ts... xs)
{
gemm_batch_args(make_index(), _c, xs...)(
- [](auto... ys) { ck_gemm_gemm_matrix(ys...); });
+ [](auto... ys) { ck_gemm_gemm_matrix(ys...); });
}
} // namespace migraphx
diff --git a/test/verify/test_ck_gemm_gemm.cpp b/test/verify/test_ck_gemm_gemm.cpp
index a3ae26d243e..9595f0cd581 100644
--- a/test/verify/test_ck_gemm_gemm.cpp
+++ b/test/verify/test_ck_gemm_gemm.cpp
@@ -37,13 +37,15 @@ struct test_ck_gemm_gemm : verify_program
migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 128}};
migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 512, 128}};
migraphx::shape m3_shape{migraphx::shape::half_type, {1, 12, 512, 64}};
- auto a = mm->add_parameter("1", m1_shape);
- auto b = mm->add_parameter("2", m2_shape);
- auto b1 = mm->add_parameter("3", m3_shape);
+ migraphx::shape x_shape{migraphx::shape::half_type, {1, 12, 256, 512}};
+
+ auto a = mm->add_parameter("1", m1_shape);
+ auto b = mm->add_parameter("2", m2_shape);
+ auto b1 = mm->add_parameter("3", m3_shape);
b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
- auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
- mm->add_instruction(migraphx::make_op("dot"), gemm1, b1);
+ auto gemm0 = mm->add_instruction(migraphx::make_op("dot"), a, b);
+ mm->add_instruction(migraphx::make_op("dot"), gemm0, b1);
return p;
}
diff --git a/test/verify/test_ck_gemm_gemm_pointwise.cpp b/test/verify/test_ck_gemm_gemm_pointwise.cpp
new file mode 100644
index 00000000000..1dac3c5c101
--- /dev/null
+++ b/test/verify/test_ck_gemm_gemm_pointwise.cpp
@@ -0,0 +1,58 @@
+/*
+ * The MIT License (MIT)
+ *
+ * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ * THE SOFTWARE.
+ */
+
+#include "verify_program.hpp"
+#include
+#include
+#include
+
+struct test_ck_gemm_gemm_pointwise : verify_program
+{
+ migraphx::program create_program() const
+ {
+ migraphx::program p;
+ auto* mm = p.get_main_module();
+
+ migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 128}};
+ migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 512, 128}};
+ migraphx::shape m3_shape{migraphx::shape::half_type, {1, 12, 512, 64}};
+ migraphx::shape x_shape{migraphx::shape::half_type, {1, 12, 256, 64}};
+
+ auto a = mm->add_parameter("1", m1_shape);
+ auto b = mm->add_parameter("2", m2_shape);
+ auto b1 = mm->add_parameter("3", m3_shape);
+ auto x = mm->add_parameter("x", x_shape);
+
+ b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
+ auto gemm0 = mm->add_instruction(migraphx::make_op("dot"), a, b);
+ auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), gemm0, b1);
+ auto sin = mm->add_instruction(migraphx::make_op("sin"), gemm1);
+ auto cos = mm->add_instruction(migraphx::make_op("cos"), sin);
+ auto add = mm->add_instruction(migraphx::make_op("add"), cos, x);
+ mm->add_instruction(migraphx::make_op("add"), add, sin);
+
+ return p;
+ }
+ std::string section() const { return "gemm"; }
+};
diff --git a/test/verify/test_ck_gemm_pointwise_gemm.cpp b/test/verify/test_ck_gemm_pointwise_gemm.cpp
new file mode 100644
index 00000000000..655d7bd0d5a
--- /dev/null
+++ b/test/verify/test_ck_gemm_pointwise_gemm.cpp
@@ -0,0 +1,58 @@
+/*
+ * The MIT License (MIT)
+ *
+ * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ * THE SOFTWARE.
+ */
+
+#include "verify_program.hpp"
+#include
+#include
+#include
+
+struct test_ck_gemm_pointwise_gemm : verify_program
+{
+ migraphx::program create_program() const
+ {
+ migraphx::program p;
+ auto* mm = p.get_main_module();
+
+ migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 128}};
+ migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 512, 128}};
+ migraphx::shape m3_shape{migraphx::shape::half_type, {1, 12, 512, 64}};
+ migraphx::shape x_shape{migraphx::shape::half_type, {1, 12, 256, 512}};
+
+ auto a = mm->add_parameter("1", m1_shape);
+ auto b = mm->add_parameter("2", m2_shape);
+ auto b1 = mm->add_parameter("3", m3_shape);
+ auto x = mm->add_parameter("x", x_shape);
+
+ b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
+ auto gemm0 = mm->add_instruction(migraphx::make_op("dot"), a, b);
+ auto sin = mm->add_instruction(migraphx::make_op("sin"), gemm0);
+ auto cos = mm->add_instruction(migraphx::make_op("cos"), sin);
+ auto add = mm->add_instruction(migraphx::make_op("add"), cos, x);
+ auto add2 = mm->add_instruction(migraphx::make_op("add"), add, sin);
+ mm->add_instruction(migraphx::make_op("dot"), add2, b1);
+
+ return p;
+ }
+ std::string section() const { return "gemm"; }
+};
diff --git a/test/verify/test_ck_gemm_pointwise_gemm_pointwise.cpp b/test/verify/test_ck_gemm_pointwise_gemm_pointwise.cpp
new file mode 100644
index 00000000000..7dccf55be8c
--- /dev/null
+++ b/test/verify/test_ck_gemm_pointwise_gemm_pointwise.cpp
@@ -0,0 +1,65 @@
+/*
+ * The MIT License (MIT)
+ *
+ * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ * THE SOFTWARE.
+ */
+
+#include "verify_program.hpp"
+#include
+#include
+#include
+
+struct test_ck_gemm_pointwise_gemm_pointwise : verify_program
+{
+ migraphx::program create_program() const
+ {
+ migraphx::program p;
+ auto* mm = p.get_main_module();
+
+ migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 128}};
+ migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 128, 512}};
+ migraphx::shape x_shape{migraphx::shape::half_type, {1, 12, 256, 512}};
+ migraphx::shape m3_shape{migraphx::shape::half_type, {1, 12, 64, 512}};
+ migraphx::shape y_shape{migraphx::shape::half_type, {1, 12, 256, 64}};
+
+ auto a = mm->add_parameter("1", m1_shape);
+ auto b = mm->add_parameter("2", m2_shape);
+ auto b1 = mm->add_parameter("3", m3_shape);
+ auto x = mm->add_parameter("x", x_shape);
+ auto y = mm->add_parameter("y", y_shape);
+ auto z = mm->add_parameter("z", y_shape);
+
+ b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
+ auto gemm0 = mm->add_instruction(migraphx::make_op("dot"), a, b);
+ auto sin = mm->add_instruction(migraphx::make_op("sin"), gemm0);
+ auto cos = mm->add_instruction(migraphx::make_op("cos"), sin);
+ auto add0 = mm->add_instruction(migraphx::make_op("add"), x, cos);
+ auto add1 = mm->add_instruction(migraphx::make_op("add"), add0, sin);
+ b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}),
+ b1);
+ auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), add1, b1);
+ auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, y);
+ mm->add_instruction(migraphx::make_op("mul"), sub, z);
+
+ return p;
+ }
+ std::string section() const { return "gemm"; }
+};
diff --git a/test/verify/test_ck_gemm_pointwise_gemm_pointwise_rotated.cpp b/test/verify/test_ck_gemm_pointwise_gemm_pointwise_rotated.cpp
new file mode 100644
index 00000000000..86d7d740337
--- /dev/null
+++ b/test/verify/test_ck_gemm_pointwise_gemm_pointwise_rotated.cpp
@@ -0,0 +1,62 @@
+/*
+ * The MIT License (MIT)
+ *
+ * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to deal
+ * in the Software without restriction, including without limitation the rights
+ * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+ * copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+ * THE SOFTWARE.
+ */
+
+#include "verify_program.hpp"
+#include
+#include
+#include
+
+struct test_ck_gemm_pointwise_gemm_pointwise_rotated
+ : verify_program
+{
+ migraphx::program create_program() const
+ {
+ migraphx::program p;
+ auto* mm = p.get_main_module();
+
+ migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 128}};
+ migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 512, 128}};
+ migraphx::shape x_shape{migraphx::shape::half_type, {1, 12, 256, 512}};
+ migraphx::shape m3_shape{migraphx::shape::half_type, {1, 12, 512, 64}};
+ migraphx::shape y_shape{migraphx::shape::half_type, {1, 12, 256, 64}};
+
+ auto a = mm->add_parameter("1", m1_shape);
+ auto b = mm->add_parameter("2", m2_shape);
+ auto b1 = mm->add_parameter("3", m3_shape);
+ auto x = mm->add_parameter("x", x_shape);
+ auto y = mm->add_parameter("y", y_shape);
+ auto z = mm->add_parameter("z", y_shape);
+
+ b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
+ auto gemm0 = mm->add_instruction(migraphx::make_op("dot"), a, b);
+ auto add = mm->add_instruction(migraphx::make_op("sub"), x, gemm0);
+
+ auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), add, b1);
+ auto sub = mm->add_instruction(migraphx::make_op("sub"), y, gemm1);
+ mm->add_instruction(migraphx::make_op("mul"), sub, z);
+
+ return p;
+ }
+ std::string section() const { return "gemm"; }
+};
From d3f1c294a9a4533b4d7a4af45eafca1746216ecd Mon Sep 17 00:00:00 2001
From: Mirza Halilcevic
Date: Wed, 6 Nov 2024 12:32:49 +0000
Subject: [PATCH 04/14] Remove log lines.
---
CMakeLists.txt | 2 +-
src/targets/gpu/fuse_ck.cpp | 2 --
src/targets/gpu/jit/ck_gemm_gemm.cpp | 4 +---
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp | 3 ---
test/verify/run_verify.cpp | 6 ------
5 files changed, 2 insertions(+), 15 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index bc4af5ad06c..e4a884ffbd0 100755
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -76,7 +76,7 @@ endif()
if(WIN32)
option(MIGRAPHX_USE_HIPBLASLT "Enable MIGraphX to use hipBLASLt" OFF)
else()
-option(MIGRAPHX_USE_HIPBLASLT "Enable MIGraphX to use hipBLASLt" OFF)
+option(MIGRAPHX_USE_HIPBLASLT "Enable MIGraphX to use hipBLASLt" ON)
endif()
# By default build shared libraries
diff --git a/src/targets/gpu/fuse_ck.cpp b/src/targets/gpu/fuse_ck.cpp
index 515a44bfc55..ed92759a857 100644
--- a/src/targets/gpu/fuse_ck.cpp
+++ b/src/targets/gpu/fuse_ck.cpp
@@ -266,8 +266,6 @@ struct find_ck_gemm_pointwise_gemm
void apply(module_pass_manager& mpm, const match::matcher_result& r)
{
- std::cout << "matching ck_gemm_gemm" << std::endl;
-
auto ins = r.result;
auto gemm0_ins = r.instructions["gemm0"];
auto gemm1_ins = r.instructions["gemm1"];
diff --git a/src/targets/gpu/jit/ck_gemm_gemm.cpp b/src/targets/gpu/jit/ck_gemm_gemm.cpp
index 005cbda11e8..a07ae01ad80 100644
--- a/src/targets/gpu/jit/ck_gemm_gemm.cpp
+++ b/src/targets/gpu/jit/ck_gemm_gemm.cpp
@@ -168,9 +168,7 @@ struct ck_gemm_gemm_compiler : compiler
}
operation compile_op(context& ctx, const std::vector& inputs, const value& v) const
- {
- std::cout << "compiling ck_gemm_gemm: " << v.get("kernel", std::string{}) << std::endl;
-
+ {
const auto& e1_shape = inputs.back();
auto tuning_value = v.get("tuning_value", 0);
auto batch_count = get_batch_count(e1_shape);
diff --git a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
index fdd68bd077b..c5da7c4c4b3 100644
--- a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
+++ b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
@@ -182,9 +182,6 @@ struct ck_gemm_softmax_gemm_compiler : compiler
{"preamble", v.get("preamble", std::string{})},
{"kernel", options.kernel_name}});
- // std::cout << "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" << std::endl;
- // std::cout << src << std::endl;
-
return compile_hip_code_object(src, options);
}
diff --git a/test/verify/run_verify.cpp b/test/verify/run_verify.cpp
index 56e244f9d3b..b40dd3a2cc3 100644
--- a/test/verify/run_verify.cpp
+++ b/test/verify/run_verify.cpp
@@ -133,9 +133,6 @@ run_verify::run_ref(migraphx::program p,
auto trace_target = migraphx::string_value_of(MIGRAPHX_TRACE_TEST_COMPILE{});
compile_check(p, t, c_opts, (trace_target == "ref"));
- std::cout << "REF PROGRAM" << std::endl;
- p.debug_print();
-
return std::make_pair(std::move(p), p.eval(std::move(inputs)));
}
@@ -163,9 +160,6 @@ run_verify::run_target(const migraphx::target& t,
validate(t, p, m);
p.eval(m);
- std::cout << "GPU PROGRAM" << std::endl;
- p.debug_print();
-
auto tres = p.eval(m);
std::vector res(tres.size());
std::transform(
From 3684677a3795207909e3414b25485c75eb3afe9f Mon Sep 17 00:00:00 2001
From: Mirza Halilcevic
Date: Wed, 13 Nov 2024 16:11:41 +0000
Subject: [PATCH 05/14] Limit CK GEMMs to k <= 1024.
---
src/targets/gpu/fuse_ck.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/targets/gpu/fuse_ck.cpp b/src/targets/gpu/fuse_ck.cpp
index ed92759a857..2a768c200fe 100644
--- a/src/targets/gpu/fuse_ck.cpp
+++ b/src/targets/gpu/fuse_ck.cpp
@@ -163,7 +163,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
}
return true;
}
- return k <= 2048;
+ return k <= 1024;
}
struct find_ck_gemm_pointwise
From ab11f51cc6af8a1e63ac918606651b265441229b Mon Sep 17 00:00:00 2001
From: Mirza Halilcevic
Date: Wed, 20 Nov 2024 00:33:51 +0000
Subject: [PATCH 06/14] Fix test.
---
test/verify/run_verify.cpp | 1 -
test/verify/test_ck_gemm_pointwise_gemm_pointwise.cpp | 2 +-
2 files changed, 1 insertion(+), 2 deletions(-)
diff --git a/test/verify/run_verify.cpp b/test/verify/run_verify.cpp
index b40dd3a2cc3..e7fdd535fcc 100644
--- a/test/verify/run_verify.cpp
+++ b/test/verify/run_verify.cpp
@@ -132,7 +132,6 @@ run_verify::run_ref(migraphx::program p,
auto_print pp{p, t.name()};
auto trace_target = migraphx::string_value_of(MIGRAPHX_TRACE_TEST_COMPILE{});
compile_check(p, t, c_opts, (trace_target == "ref"));
-
return std::make_pair(std::move(p), p.eval(std::move(inputs)));
}
diff --git a/test/verify/test_ck_gemm_pointwise_gemm_pointwise.cpp b/test/verify/test_ck_gemm_pointwise_gemm_pointwise.cpp
index 7dccf55be8c..06efcff220f 100644
--- a/test/verify/test_ck_gemm_pointwise_gemm_pointwise.cpp
+++ b/test/verify/test_ck_gemm_pointwise_gemm_pointwise.cpp
@@ -35,7 +35,7 @@ struct test_ck_gemm_pointwise_gemm_pointwise : verify_program
Date: Tue, 18 Feb 2025 22:04:21 +0000
Subject: [PATCH 07/14] Revert ck_gemm_gemm.
Signed-off-by: Mirza Halilcevic
---
requirements.txt | 2 +-
src/targets/gpu/CMakeLists.txt | 9 +-
src/targets/gpu/compile_hip_code_object.cpp | 2 +-
src/targets/gpu/fuse_ck.cpp | 179 -----------
src/targets/gpu/include/migraphx/gpu/ck.hpp | 1 -
src/targets/gpu/jit/ck_gemm_gemm.cpp | 292 ------------------
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp | 4 +-
.../include/migraphx/kernels/ck_gemm_gemm.hpp | 88 ------
test/verify/test_ck_gemm_gemm.cpp | 53 ----
test/verify/test_ck_gemm_gemm_pointwise.cpp | 58 ----
test/verify/test_ck_gemm_pointwise_gemm.cpp | 58 ----
.../test_ck_gemm_pointwise_gemm_pointwise.cpp | 65 ----
..._gemm_pointwise_gemm_pointwise_rotated.cpp | 62 ----
13 files changed, 9 insertions(+), 864 deletions(-)
delete mode 100644 src/targets/gpu/jit/ck_gemm_gemm.cpp
delete mode 100644 src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_gemm.hpp
delete mode 100644 test/verify/test_ck_gemm_gemm.cpp
delete mode 100644 test/verify/test_ck_gemm_gemm_pointwise.cpp
delete mode 100644 test/verify/test_ck_gemm_pointwise_gemm.cpp
delete mode 100644 test/verify/test_ck_gemm_pointwise_gemm_pointwise.cpp
delete mode 100644 test/verify/test_ck_gemm_pointwise_gemm_pointwise_rotated.cpp
diff --git a/requirements.txt b/requirements.txt
index 3fd7523ccfc..cc94659fc97 100755
--- a/requirements.txt
+++ b/requirements.txt
@@ -27,5 +27,5 @@ ROCm/half@rocm-5.6.0
pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On
-ROCm/composable_kernel@b7775add2d28251674d81e220cd4a857b90b997a -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
+/home/mhalilce/composable_kernel --cmake subdir -DCMAKE_DIR=codegen -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCm/rocMLIR@da423a21b8aaf541af470999f34a1ff252bcdece -DBUILD_FAT_LIBROCKCOMPILER=On
diff --git a/src/targets/gpu/CMakeLists.txt b/src/targets/gpu/CMakeLists.txt
index 65766f99637..b5a8e37b793 100644
--- a/src/targets/gpu/CMakeLists.txt
+++ b/src/targets/gpu/CMakeLists.txt
@@ -57,7 +57,11 @@ else()
endif()
if(MIGRAPHX_USE_COMPOSABLEKERNEL)
- find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS ck_host)
+ find_package(composable_kernel_host 1.0.0 REQUIRED)
+ if(NOT TARGET composable_kernel::ck_host)
+ # Manually including targets
+ include(${composable_kernel_host_TARGET_FILE})
+ endif()
endif()
if(BUILD_DEV)
@@ -132,8 +136,7 @@ file(GLOB JIT_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/jit/*.cpp)
if(NOT MIGRAPHX_USE_COMPOSABLEKERNEL)
list(REMOVE_ITEM JIT_GPU_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm.cpp
- ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm_softmax_gemm.cpp
- ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm_gemm.cpp)
+ ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm_softmax_gemm.cpp)
endif()
if(MIGRAPHX_USE_MIOPEN)
diff --git a/src/targets/gpu/compile_hip_code_object.cpp b/src/targets/gpu/compile_hip_code_object.cpp
index aa294fc54c1..dfd18ad7d42 100644
--- a/src/targets/gpu/compile_hip_code_object.cpp
+++ b/src/targets/gpu/compile_hip_code_object.cpp
@@ -197,7 +197,7 @@ compile_hip_code_object(context& ctx, const std::string& content, hip_compile_op
const auto& warnings = compiler_warnings();
options.params.insert(options.params.end(), warnings.begin(), warnings.end());
options.emplace_param("-ftemplate-backtrace-limit=0");
- // options.emplace_param("-Werror");
+ options.emplace_param("-Werror");
auto cos = compile_hip_src(srcs, options.params, get_device_name());
if(cos.size() != 1)
MIGRAPHX_THROW("No code object");
diff --git a/src/targets/gpu/fuse_ck.cpp b/src/targets/gpu/fuse_ck.cpp
index 2a768c200fe..c60ca163a45 100644
--- a/src/targets/gpu/fuse_ck.cpp
+++ b/src/targets/gpu/fuse_ck.cpp
@@ -76,51 +76,6 @@ struct ck_gemm
};
MIGRAPHX_REGISTER_OP(ck_gemm);
-struct ck_gemm_gemm
-{
- operation op = make_op("dot");
- size_t d0s_count = 0u;
- size_t d1s_count = 0u;
-
- template
- static auto reflect(Self& self, F f)
- {
- return pack(
- f(self.op, "op"), f(self.d0s_count, "d0s_count"), f(self.d1s_count, "d1s_count"));
- }
-
- std::string name() const { return "gpu::ck_gemm_gemm"; }
-
- void check_gemm_shape(const shape& s) const
- {
- if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1) and
- not s.scalar())
- MIGRAPHX_THROW("Invalid shape for " + name());
- }
-
- shape compute_shape(std::vector inputs, const std::vector&) const
- {
- check_shapes{inputs, *this}.same_ndims();
- if(inputs.size() < 3)
- MIGRAPHX_THROW(name() + ": Expected 3 inputs but got " + to_string(inputs.size()));
-
- auto a = inputs[0];
- auto b = inputs[1];
- auto b1 = inputs[2];
-
- for(const auto& input : inputs)
- {
- check_gemm_shape(input);
- }
-
- auto gemm0_shape = op.compute_shape({a, b});
- return op.compute_shape({gemm0_shape, b1});
- }
-
- static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); }
-};
-MIGRAPHX_REGISTER_OP(ck_gemm_gemm);
-
struct ck_gemm_softmax_gemm : gemm_softmax_gemm
{
std::string name() const { return "gpu::ck_gemm_softmax_gemm"; }
@@ -248,145 +203,11 @@ struct find_ck_gemm_softmax_gemm
}
};
-struct find_ck_gemm_pointwise_gemm
-{
- auto matcher() const
- {
- // TODO don't mix dot and quant_dot
- // TODO match used_once?
- auto gemm0 = match::skip(match::name("contiguous"))(
- match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm0")));
- auto pw0 =
- match::name("pointwise")(match::any_of[match::inputs()](gemm0.bind("x0")).bind("pw0"));
- return match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm1"))(
- match::arg(0)(match::any_of(pw0, gemm0)));
- }
-
- bool transposed_matrix(const shape& s) { return s.strides().back() != 1; }
-
- void apply(module_pass_manager& mpm, const match::matcher_result& r)
- {
- auto ins = r.result;
- auto gemm0_ins = r.instructions["gemm0"];
- auto gemm1_ins = r.instructions["gemm1"];
-
- auto inputs = gemm0_ins->inputs(); // A, B
- inputs.push_back(gemm1_ins->inputs().back()); // B1
-
- if (!transposed_matrix(inputs[1]->get_shape()))
- return;
-
- size_t d0s_count = 0, d1s_count = 0;
- std::vector module_inputs;
- if(r.instructions.find("pw0") != r.instructions.end())
- {
- auto pw0 = r.instructions["pw0"];
- auto x0_ins = r.instructions["x0"];
-
- if(gemm0_ins->get_shape().type() != shape::int32_type and
- pw0->get_shape().type() != gemm0_ins->get_shape().type())
- return;
- if(std::any_of(pw0->inputs().begin(), pw0->inputs().end(), [](auto input) {
- return not ck_gemm::is_ck_supported_type(input->get_shape().type());
- }))
- return;
- if(std::any_of(pw0->inputs().begin(), pw0->inputs().end(), [](auto input) {
- return not input->inputs().empty() and
- input->inputs().front()->name() == "capture";
- }))
- return;
-
- auto pw0_inputs = pw0->inputs();
- auto* pw0m = pw0->module_inputs().front();
-
- auto gemm_it = std::find(pw0_inputs.begin(), pw0_inputs.end(), x0_ins);
- auto gemm_idx = gemm_it - pw0_inputs.begin();
- if(gemm_idx != 0)
- {
- rotate_gemm_input(pw0m, gemm_idx);
- }
-
- pw0_inputs.erase(gemm_it);
- inputs.insert(inputs.end(), pw0_inputs.begin(), pw0_inputs.end()); // D0s
-
- d0s_count = pw0_inputs.size();
- module_inputs.push_back(pw0m);
- }
- if(r.instructions.find("pw1") != r.instructions.end())
- {
- auto pw1 = r.instructions["pw1"];
-
- if(gemm1_ins->get_shape().type() != shape::int32_type and
- pw1->get_shape().type() != gemm1_ins->get_shape().type())
- return;
- if(std::any_of(pw1->inputs().begin(), pw1->inputs().end(), [](auto input) {
- return not ck_gemm::is_ck_supported_type(input->get_shape().type());
- }))
- return;
- if(std::any_of(pw1->inputs().begin(), pw1->inputs().end(), [](auto input) {
- return not input->inputs().empty() and
- input->inputs().front()->name() == "capture";
- }))
- return;
-
- auto pw1_inputs = pw1->inputs();
- auto* pw1m = pw1->module_inputs().front();
-
- auto gemm_it = std::find(pw1_inputs.begin(), pw1_inputs.end(), gemm1_ins);
- auto gemm_idx = gemm_it - pw1_inputs.begin();
- if(gemm_idx != 0)
- {
- rotate_gemm_input(pw1m, gemm_idx);
- }
-
- pw1_inputs.erase(gemm_it);
- inputs.insert(inputs.end(), pw1_inputs.begin(), pw1_inputs.end()); // D1s
-
- d1s_count = pw1_inputs.size();
- module_inputs.push_back(pw1m);
- }
-
- mpm.get_module().replace_instruction(
- ins,
- ck_gemm_gemm{gemm1_ins->get_operator(), d0s_count, d1s_count},
- inputs,
- module_inputs);
- }
-
- void rotate_gemm_input(module* pwm, size_t gemm_idx)
- {
- auto names = pwm->get_parameter_names();
-
- auto first_param = pwm->get_parameter(names[0]);
- auto gemm_param = pwm->get_parameter(names[gemm_idx]);
-
- auto new_gemm_param = pwm->add_parameter(names[0] + "_0", gemm_param->get_shape());
- auto new_first_param = pwm->add_parameter(names[gemm_idx] + "_0", first_param->get_shape());
-
- pwm->replace_instruction(gemm_param, new_gemm_param);
- pwm->replace_instruction(first_param, new_first_param);
- pwm->remove_instruction(first_param);
- pwm->remove_instruction(gemm_param);
- }
-};
-
-struct find_ck_gemm_pointwise_gemm_pointwise : find_ck_gemm_pointwise_gemm
-{
- auto matcher() const
- {
- // TODO match used_once?
- auto gemm1 = find_ck_gemm_pointwise_gemm::matcher();
- return match::name("pointwise")(match::any_of[match::inputs()](gemm1).bind("pw1"));
- }
-};
-
} // namespace
void fuse_ck::apply(module_pass_manager& mpm) const
{
match::find_matches(mpm, find_ck_gemm_softmax_gemm{});
- match::find_matches(mpm, find_ck_gemm_pointwise_gemm_pointwise{});
- match::find_matches(mpm, find_ck_gemm_pointwise_gemm{});
match::find_matches(mpm, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm{});
}
diff --git a/src/targets/gpu/include/migraphx/gpu/ck.hpp b/src/targets/gpu/include/migraphx/gpu/ck.hpp
index d252db2c8b9..ea41b252547 100644
--- a/src/targets/gpu/include/migraphx/gpu/ck.hpp
+++ b/src/targets/gpu/include/migraphx/gpu/ck.hpp
@@ -33,7 +33,6 @@
#include "ck/host/headers.hpp"
#include "ck/host/device_gemm_multiple_d/problem.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
-#include "ck/host/device_batched_gemm_multiple_d_gemm_multiple_d/problem.hpp"
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
diff --git a/src/targets/gpu/jit/ck_gemm_gemm.cpp b/src/targets/gpu/jit/ck_gemm_gemm.cpp
deleted file mode 100644
index a07ae01ad80..00000000000
--- a/src/targets/gpu/jit/ck_gemm_gemm.cpp
+++ /dev/null
@@ -1,292 +0,0 @@
-/*
- * The MIT License (MIT)
- *
- * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in
- * all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
- * THE SOFTWARE.
- */
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-
-namespace migraphx {
-inline namespace MIGRAPHX_INLINE_NS {
-
-namespace gpu {
-
-using namespace migraphx::gpu::gen; // NOLINT
-
-// NOLINTNEXTLINE
-static const char* const ck_gemm_gemm_kernel = R"__migraphx__(
-#include
-#include
-#include
-#include
-#include <${include}>
-
-namespace migraphx {
-
-${preamble0}
-
-${preamble1}
-
-extern "C" {
-
-MIGRAPHX_GLOBAL void ${kernel}(${params})
-{
- transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
- ck_gemm_gemm<${solution}, ${blocks_per_batch}, ${d0s_count}>(xs...);
- });
-}
-
-}
-
-} // namespace migraphx
-
-)__migraphx__";
-
-struct ck_gemm_gemm_compiler : compiler
-{
- std::vector names() const { return {"ck_gemm_gemm", "gpu::ck_gemm_gemm"}; }
-
- ck::host::device_batched_gemm_multiple_d_gemm_multiple_d::Problem
- create_problem(const std::vector& inputs, const value& v) const
- {
- const auto& a0_shape = inputs[0];
- const auto& b0_shape = inputs[1];
- const auto& b1_shape = inputs[2];
- const auto& e1_shape = inputs.back();
-
- // cppcheck-suppress unreadVariable
- auto rank = a0_shape.ndim();
- auto batch_count = get_batch_count(e1_shape);
- auto m = e1_shape.lens()[rank - 2];
- m = can_fold_batch(inputs) ? m * batch_count : m;
- auto n = b0_shape.lens().back();
- auto k = a0_shape.lens().back();
- auto o = e1_shape.lens().back();
-
- const bool trans_a0 = transposed_matrix(a0_shape);
- const bool trans_b0 = transposed_matrix(b0_shape);
- const bool trans_b1 = transposed_matrix(b1_shape);
- const bool trans_e1 = transposed_matrix(e1_shape);
-
- auto d0s_count = v.get("d0s_count", 0);
- auto d1s_count = v.get("d1s_count", 0);
-
- std::vector trans_d0s;
- std::transform(inputs.begin() + 3,
- inputs.end() - 1 - d1s_count,
- std::back_inserter(trans_d0s),
- [](const auto& i) { return transposed_matrix(i); });
-
- std::vector trans_d1s;
- std::transform(inputs.begin() + 3 + d0s_count,
- inputs.end() - 1,
- std::back_inserter(trans_d1s),
- [](const auto& i) { return transposed_matrix(i); });
-
- const auto a0_type = get_type(a0_shape);
- const auto b0_type = get_type(b0_shape);
- const auto b1_type = get_type(b1_shape);
- const auto e1_type = get_type(e1_shape);
-
- std::vector d0s_type;
- std::transform(inputs.begin() + 3,
- inputs.end() - 1 - d1s_count,
- std::back_inserter(d0s_type),
- [](const auto& i) { return get_type(i); });
-
- std::vector d1s_type;
- std::transform(inputs.begin() + 3 + d0s_count,
- inputs.end() - 1,
- std::back_inserter(d1s_type),
- [](const auto& i) { return get_type(i); });
-
- std::string ck_passthrough = "ck_passthrough";
- std::string cde0_op = ck_passthrough;
- std::string cde1_op = ck_passthrough;
- assert(inputs.size() < 5 or v.contains("cde0_op") or v.containse("cde1_op"));
- if(v.contains("cde0_op"))
- {
- cde0_op = v.at("cde0_op").to();
- }
- if(v.contains("cde1_op"))
- {
- cde1_op = v.at("cde1_op").to();
- }
-
- return ck::host::device_batched_gemm_multiple_d_gemm_multiple_d::Problem{m,
- n,
- k,
- o,
- trans_a0,
- trans_b0,
- trans_d0s,
- trans_b1,
- trans_d1s,
- trans_e1,
- a0_type,
- b0_type,
- d0s_type,
- b1_type,
- d1s_type,
- e1_type,
- ck_passthrough,
- ck_passthrough,
- cde0_op,
- ck_passthrough,
- cde1_op};
- }
-
- operation compile_op(context& ctx, const std::vector& inputs, const value& v) const
- {
- const auto& e1_shape = inputs.back();
- auto tuning_value = v.get("tuning_value", 0);
- auto batch_count = get_batch_count(e1_shape);
- auto problem = create_problem(inputs, v);
-
- const auto include_header = problem.GetIncludeHeader();
- const auto solutions =
- problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", "");
- const auto& solution = solutions.at(tuning_value);
- const auto template_str = solution.ToTemplateString();
- const auto block_size = solution.GetTemplateParameter("BlockSize");
- const auto m_per_block = solution.GetTemplateParameter("Gemm0MPerBlock");
- const auto n1_per_block = solution.GetTemplateParameter("Gemm1NPerBlock");
- const auto blocks_per_batch = ck::host::integer_divide_ceil(problem.M, m_per_block) *
- ck::host::integer_divide_ceil(problem.O, n1_per_block);
-
- hip_compile_options options;
- options.additional_src_files = ck_headers();
-
- auto grid_size = can_fold_batch(inputs) ? blocks_per_batch : batch_count * blocks_per_batch;
- options.set_launch_params(v, grid_size * block_size, block_size);
- options.inputs = inputs;
- options.output = e1_shape;
- options.kernel_name = v.get("kernel", "ck_gemm_gemm_kernel");
- options.virtual_inputs = inputs;
- if(can_fold_batch(inputs))
- {
- auto vinputs = inputs;
- fold_batch_dims(vinputs[0]);
- remove_batch_dims(vinputs[1]);
- std::for_each(vinputs.begin() + 2, vinputs.end(), fold_batch_dims);
- options.virtual_inputs = vinputs;
- }
-
- if(v.get("check", false) or enabled(MIGRAPHX_CK_DEBUG{}))
- options.emplace_param("-DMIGRAPHX_CK_CHECK=1");
-
- auto src = interpolate_string(ck_gemm_gemm_kernel,
- {{"solution", template_str},
- {"include", include_header},
- {"params", enum_params(inputs.size(), "void * private_p")},
- {"args", enum_params(inputs.size(), "private_p")},
- {"blocks_per_batch", to_string(blocks_per_batch)},
- {"d0s_count", to_string(v.get("d0s_count", 0))},
- {"preamble0", v.get("preamble0", std::string{})},
- {"preamble1", v.get("preamble1", std::string{})},
- {"kernel", options.kernel_name}});
-
- return compile_hip_code_object(src, options);
- }
-
- value create_settings(instruction_ref ins, const operation& op) const
- {
- auto v = op.to_value();
-
- auto d0s_count = v.get("d0s_count", 0);
- auto d1s_count = v.get("d1s_count", 0);
-
- std::string pw0_name, pw1_name, kernel_name = "ck_gemm_";
- if(d0s_count > 0)
- {
- auto* pw0m = ins->module_inputs().front();
- v["preamble0"] = generate_pointwise(*pw0m, "post_ck_gemm0_function") +
- "\nMIGRAPHX_LIFT_CLASS(post_ck_gemm0, post_ck_gemm0_function);";
- v["cde0_op"] = "ck_function_adaptor";
- kernel_name += generate_name_from_ops(*pw0m) + "_";
- }
- kernel_name += "gemm_";
- if(d1s_count > 0)
- {
- auto* pw1m = ins->module_inputs().back();
- v["preamble1"] = generate_pointwise(*pw1m, "post_ck_gemm1_function") +
- "\nMIGRAPHX_LIFT_CLASS(post_ck_gemm1, post_ck_gemm1_function);";
- v["cde1_op"] = "ck_function_adaptor";
- kernel_name += generate_name_from_ops(*pw1m) + "_";
- }
-
- v["kernel"] = to_c_id(kernel_name + "kernel");
- return v;
- }
-
- compiler_replace
- compile(context& ctx, instruction_ref ins, const operation& op, const value& solution) const
- {
- auto shapes = to_shapes(ins->inputs());
- auto v = create_settings(ins, op);
- if(not solution.is_null())
- v["tuning_value"] = solution;
- return {compile_op(ctx, shapes, v),
- [=](module& m, instruction_ref ins2, const operation& code_object) {
- if(enabled(MIGRAPHX_LOG_CK_GEMM{}))
- {
- std::vector gemm_shapes{
- shapes[0], shapes[1], shapes.back().with_type(shapes[0].type())};
- std::cout << "gpu::ck_gemm_gemm: " << to_json_string(to_value(gemm_shapes))
- << std::endl;
- }
- m.replace_instruction(ins2, code_object, ins2->inputs());
- }};
- }
-
- optional
- get_tuning_config(context& ctx, instruction_ref ins, const operation& op, bool exhaustive) const
- {
- if(not exhaustive and not enabled(MIGRAPHX_TUNE_CK{}))
- return nullopt;
- tuning_config tc;
- auto shapes = to_shapes(ins->inputs());
- auto problem = create_problem(shapes, create_settings(ins, op));
- auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", "");
- tc.solutions.resize(solutions.size());
- std::iota(tc.solutions.begin(), tc.solutions.end(), 0);
- std::vector gemm_shapes{shapes[0], shapes[1], shapes.back()};
- tc.problem = to_value(gemm_shapes);
- return tc;
- }
-};
-
-} // namespace gpu
-} // namespace MIGRAPHX_INLINE_NS
-} // namespace migraphx
diff --git a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
index 839e888383e..ad40d84161d 100644
--- a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
+++ b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
@@ -130,7 +130,6 @@ struct ck_gemm_softmax_gemm_compiler : compiler
operation compile_op(context& ctx, const std::vector& inputs, const value& v) const
{
-
const auto& c_shape = inputs.back();
auto tuning_value = v.get("tuning_value", 5);
auto batch_count = get_batch_count(c_shape);
@@ -139,7 +138,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler
const auto include_header = problem.GetIncludeHeader();
const auto solutions =
problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", "");
- const auto& solution = solutions.at(tuning_value);
+ const auto& solution = solutions.at(tuning_value);
const auto template_str = solution.ToTemplateString();
const auto block_size = solution.GetTemplateParameter("BlockSize");
const auto m_per_block = solution.GetTemplateParameter("Gemm01MPerBlock");
@@ -149,7 +148,6 @@ struct ck_gemm_softmax_gemm_compiler : compiler
hip_compile_options options;
options.additional_src_files = ck_headers();
-
auto grid_size = can_fold_batch(inputs) ? blocks_per_batch : batch_count * blocks_per_batch;
options.set_launch_params(v, grid_size * block_size, block_size);
options.inputs = inputs;
diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_gemm.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_gemm.hpp
deleted file mode 100644
index f0dbeefd333..00000000000
--- a/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_gemm.hpp
+++ /dev/null
@@ -1,88 +0,0 @@
-/*
- * The MIT License (MIT)
- *
- * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in
- * all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
- * THE SOFTWARE.
- */
-#ifndef MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP
-#define MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP
-
-#include "migraphx/kernels/type_traits.hpp"
-#include
-#include
-#include
-#include
-#include
-#include
-
-namespace migraphx {
-
-template
-__device__ void ck_gemm_gemm_matrix_impl(
- E1 e1, A0 a0, B0 b0, B1 b1, detail::seq, detail::seq, tuple ds)
-{
- constexpr auto desc = G::make_descriptor(
- to_ck_tensor(),
- to_ck_tensor>(),
- ck::make_tuple(
- to_ck_tensor(ds))>>()...),
- to_ck_tensor>(),
- ck::make_tuple(
- to_ck_tensor(ds))>>()...),
- to_ck_tensor());
-
- MIGRAPHX_STATIC_ASSERT_FOR(desc.IsValid())
- {
- G::Run(desc,
- to_ck_const_pointer(a0.data()),
- to_ck_const_pointer(b0.data()),
- ck::make_tuple(to_ck_const_pointer(tuple_detail::get_element(ds).data())...),
- to_ck_const_pointer(b1.data()),
- ck::make_tuple(to_ck_const_pointer(
- tuple_detail::get_element(ds).data())...),
- to_ck_pointer(e1.data()));
- }
-}
-
-template
-__device__ void ck_gemm_gemm_matrix(E1 e1, A0 a0, B0 b0, B1 b1, Ds... ds)
-{
- auto all_ds = make_tuple(ds...);
- ck_gemm_gemm_matrix_impl(
- e1, a0, b0, b1, detail::gens{}, detail::gens{}, all_ds);
-}
-
-template
-__device__ void ck_gemm_gemm(Ts... xs)
-{
- gemm_batch_args(make_index(), _c, xs...)(
- [](auto... ys) { ck_gemm_gemm_matrix(ys...); });
-}
-
-} // namespace migraphx
-#endif
diff --git a/test/verify/test_ck_gemm_gemm.cpp b/test/verify/test_ck_gemm_gemm.cpp
deleted file mode 100644
index 9595f0cd581..00000000000
--- a/test/verify/test_ck_gemm_gemm.cpp
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * The MIT License (MIT)
- *
- * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in
- * all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
- * THE SOFTWARE.
- */
-
-#include "verify_program.hpp"
-#include
-#include
-#include
-
-struct test_ck_gemm_gemm : verify_program
-{
- migraphx::program create_program() const
- {
- migraphx::program p;
- auto* mm = p.get_main_module();
-
- migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 128}};
- migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 512, 128}};
- migraphx::shape m3_shape{migraphx::shape::half_type, {1, 12, 512, 64}};
- migraphx::shape x_shape{migraphx::shape::half_type, {1, 12, 256, 512}};
-
- auto a = mm->add_parameter("1", m1_shape);
- auto b = mm->add_parameter("2", m2_shape);
- auto b1 = mm->add_parameter("3", m3_shape);
-
- b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
- auto gemm0 = mm->add_instruction(migraphx::make_op("dot"), a, b);
- mm->add_instruction(migraphx::make_op("dot"), gemm0, b1);
-
- return p;
- }
- std::string section() const { return "gemm"; }
-};
diff --git a/test/verify/test_ck_gemm_gemm_pointwise.cpp b/test/verify/test_ck_gemm_gemm_pointwise.cpp
deleted file mode 100644
index 1dac3c5c101..00000000000
--- a/test/verify/test_ck_gemm_gemm_pointwise.cpp
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
- * The MIT License (MIT)
- *
- * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in
- * all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
- * THE SOFTWARE.
- */
-
-#include "verify_program.hpp"
-#include
-#include
-#include
-
-struct test_ck_gemm_gemm_pointwise : verify_program
-{
- migraphx::program create_program() const
- {
- migraphx::program p;
- auto* mm = p.get_main_module();
-
- migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 128}};
- migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 512, 128}};
- migraphx::shape m3_shape{migraphx::shape::half_type, {1, 12, 512, 64}};
- migraphx::shape x_shape{migraphx::shape::half_type, {1, 12, 256, 64}};
-
- auto a = mm->add_parameter("1", m1_shape);
- auto b = mm->add_parameter("2", m2_shape);
- auto b1 = mm->add_parameter("3", m3_shape);
- auto x = mm->add_parameter("x", x_shape);
-
- b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
- auto gemm0 = mm->add_instruction(migraphx::make_op("dot"), a, b);
- auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), gemm0, b1);
- auto sin = mm->add_instruction(migraphx::make_op("sin"), gemm1);
- auto cos = mm->add_instruction(migraphx::make_op("cos"), sin);
- auto add = mm->add_instruction(migraphx::make_op("add"), cos, x);
- mm->add_instruction(migraphx::make_op("add"), add, sin);
-
- return p;
- }
- std::string section() const { return "gemm"; }
-};
diff --git a/test/verify/test_ck_gemm_pointwise_gemm.cpp b/test/verify/test_ck_gemm_pointwise_gemm.cpp
deleted file mode 100644
index 655d7bd0d5a..00000000000
--- a/test/verify/test_ck_gemm_pointwise_gemm.cpp
+++ /dev/null
@@ -1,58 +0,0 @@
-/*
- * The MIT License (MIT)
- *
- * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in
- * all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
- * THE SOFTWARE.
- */
-
-#include "verify_program.hpp"
-#include
-#include
-#include
-
-struct test_ck_gemm_pointwise_gemm : verify_program
-{
- migraphx::program create_program() const
- {
- migraphx::program p;
- auto* mm = p.get_main_module();
-
- migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 128}};
- migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 512, 128}};
- migraphx::shape m3_shape{migraphx::shape::half_type, {1, 12, 512, 64}};
- migraphx::shape x_shape{migraphx::shape::half_type, {1, 12, 256, 512}};
-
- auto a = mm->add_parameter("1", m1_shape);
- auto b = mm->add_parameter("2", m2_shape);
- auto b1 = mm->add_parameter("3", m3_shape);
- auto x = mm->add_parameter("x", x_shape);
-
- b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
- auto gemm0 = mm->add_instruction(migraphx::make_op("dot"), a, b);
- auto sin = mm->add_instruction(migraphx::make_op("sin"), gemm0);
- auto cos = mm->add_instruction(migraphx::make_op("cos"), sin);
- auto add = mm->add_instruction(migraphx::make_op("add"), cos, x);
- auto add2 = mm->add_instruction(migraphx::make_op("add"), add, sin);
- mm->add_instruction(migraphx::make_op("dot"), add2, b1);
-
- return p;
- }
- std::string section() const { return "gemm"; }
-};
diff --git a/test/verify/test_ck_gemm_pointwise_gemm_pointwise.cpp b/test/verify/test_ck_gemm_pointwise_gemm_pointwise.cpp
deleted file mode 100644
index 06efcff220f..00000000000
--- a/test/verify/test_ck_gemm_pointwise_gemm_pointwise.cpp
+++ /dev/null
@@ -1,65 +0,0 @@
-/*
- * The MIT License (MIT)
- *
- * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in
- * all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
- * THE SOFTWARE.
- */
-
-#include "verify_program.hpp"
-#include
-#include
-#include
-
-struct test_ck_gemm_pointwise_gemm_pointwise : verify_program
-{
- migraphx::program create_program() const
- {
- migraphx::program p;
- auto* mm = p.get_main_module();
-
- migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 128}};
- migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 512, 128}};
- migraphx::shape x_shape{migraphx::shape::half_type, {1, 12, 256, 512}};
- migraphx::shape m3_shape{migraphx::shape::half_type, {1, 12, 64, 512}};
- migraphx::shape y_shape{migraphx::shape::half_type, {1, 12, 256, 64}};
-
- auto a = mm->add_parameter("1", m1_shape);
- auto b = mm->add_parameter("2", m2_shape);
- auto b1 = mm->add_parameter("3", m3_shape);
- auto x = mm->add_parameter("x", x_shape);
- auto y = mm->add_parameter("y", y_shape);
- auto z = mm->add_parameter("z", y_shape);
-
- b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
- auto gemm0 = mm->add_instruction(migraphx::make_op("dot"), a, b);
- auto sin = mm->add_instruction(migraphx::make_op("sin"), gemm0);
- auto cos = mm->add_instruction(migraphx::make_op("cos"), sin);
- auto add0 = mm->add_instruction(migraphx::make_op("add"), x, cos);
- auto add1 = mm->add_instruction(migraphx::make_op("add"), add0, sin);
- b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}),
- b1);
- auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), add1, b1);
- auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, y);
- mm->add_instruction(migraphx::make_op("mul"), sub, z);
-
- return p;
- }
- std::string section() const { return "gemm"; }
-};
diff --git a/test/verify/test_ck_gemm_pointwise_gemm_pointwise_rotated.cpp b/test/verify/test_ck_gemm_pointwise_gemm_pointwise_rotated.cpp
deleted file mode 100644
index 86d7d740337..00000000000
--- a/test/verify/test_ck_gemm_pointwise_gemm_pointwise_rotated.cpp
+++ /dev/null
@@ -1,62 +0,0 @@
-/*
- * The MIT License (MIT)
- *
- * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in
- * all copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
- * THE SOFTWARE.
- */
-
-#include "verify_program.hpp"
-#include
-#include
-#include
-
-struct test_ck_gemm_pointwise_gemm_pointwise_rotated
- : verify_program
-{
- migraphx::program create_program() const
- {
- migraphx::program p;
- auto* mm = p.get_main_module();
-
- migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 128}};
- migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 512, 128}};
- migraphx::shape x_shape{migraphx::shape::half_type, {1, 12, 256, 512}};
- migraphx::shape m3_shape{migraphx::shape::half_type, {1, 12, 512, 64}};
- migraphx::shape y_shape{migraphx::shape::half_type, {1, 12, 256, 64}};
-
- auto a = mm->add_parameter("1", m1_shape);
- auto b = mm->add_parameter("2", m2_shape);
- auto b1 = mm->add_parameter("3", m3_shape);
- auto x = mm->add_parameter("x", x_shape);
- auto y = mm->add_parameter("y", y_shape);
- auto z = mm->add_parameter("z", y_shape);
-
- b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
- auto gemm0 = mm->add_instruction(migraphx::make_op("dot"), a, b);
- auto add = mm->add_instruction(migraphx::make_op("sub"), x, gemm0);
-
- auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), add, b1);
- auto sub = mm->add_instruction(migraphx::make_op("sub"), y, gemm1);
- mm->add_instruction(migraphx::make_op("mul"), sub, z);
-
- return p;
- }
- std::string section() const { return "gemm"; }
-};
From 301f48efafa30e9b80c8492b33a7370809e323c3 Mon Sep 17 00:00:00 2001
From: Mirza Halilcevic
Date: Wed, 5 Mar 2025 00:10:03 +0000
Subject: [PATCH 08/14] Update commit id for composable_kernel and add
CK_CODE_GEN_RTC flag to CK compilers.
---
requirements.txt | 2 +-
src/targets/gpu/jit/ck_gemm.cpp | 1 +
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp | 1 +
3 files changed, 3 insertions(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index 642fcb4a1c7..dd911134c82 100755
--- a/requirements.txt
+++ b/requirements.txt
@@ -27,5 +27,5 @@ ROCm/half@rocm-5.6.0
pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On
-ROCm/composable_kernel@b7775add2d28251674d81e220cd4a857b90b997a -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
+ROCm/composable_kernel@fd06ed926c0d8b4a8f758cfb9aaa4d0418ca80b6 --cmake subdir -DCMAKE_DIR=codegen -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCm/rocMLIR@b34604275d8ebfc9a830be931f3909f9cdd00ac2 -DBUILD_FAT_LIBROCKCOMPILER=On
diff --git a/src/targets/gpu/jit/ck_gemm.cpp b/src/targets/gpu/jit/ck_gemm.cpp
index 23376d6c666..0704a412bb9 100644
--- a/src/targets/gpu/jit/ck_gemm.cpp
+++ b/src/targets/gpu/jit/ck_gemm.cpp
@@ -152,6 +152,7 @@ struct ck_gemm_compiler : compiler
ck::host::integer_divide_ceil(problem.N, n_per_block);
hip_compile_options options;
+ options.emplace_param("-DCK_CODE_GEN_RTC");
options.additional_src_files = ck_headers();
auto grid_size = can_fold_batch(inputs) ? blocks_per_batch : batch_count * blocks_per_batch;
options.set_launch_params(v, grid_size * block_size, block_size);
diff --git a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
index ad40d84161d..32d84a911fa 100644
--- a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
+++ b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
@@ -147,6 +147,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler