From ee0cb4f366a785c33cfaa384fc6eb925696f951a Mon Sep 17 00:00:00 2001 From: HTEC <> Date: Mon, 2 Sep 2024 12:53:57 +0000 Subject: [PATCH 01/14] Integrate new CK API --- requirements.txt | 2 +- src/targets/gpu/CMakeLists.txt | 6 +++--- src/targets/gpu/compile_hip_code_object.cpp | 2 +- src/targets/gpu/include/migraphx/gpu/ck.hpp | 5 +++-- src/targets/gpu/jit/ck_gemm.cpp | 21 ++++++++++++-------- src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp | 21 +++++++++++++------- 6 files changed, 35 insertions(+), 22 deletions(-) diff --git a/requirements.txt b/requirements.txt index 21cd16b50aa..9623eee8d99 100755 --- a/requirements.txt +++ b/requirements.txt @@ -27,5 +27,5 @@ ROCm/half@rocm-5.6.0 pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/composable_kernel@57cdd70b7cb14e5e3b60cd9a5f96ba8dc343763e -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On +mirza-halilcevic/composable_kernel@2cd9b2b1ec1ade098a0f43c44fa5a0989e0b0a9d -DCK_BUILD_HOST_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/rocMLIR@3f153a79f86ef4dba62d0afe0d0e95e29518dc1e -DBUILD_FAT_LIBROCKCOMPILER=On diff --git a/src/targets/gpu/CMakeLists.txt b/src/targets/gpu/CMakeLists.txt index 3e2201edcab..d099a366aec 100644 --- a/src/targets/gpu/CMakeLists.txt +++ b/src/targets/gpu/CMakeLists.txt @@ -47,7 +47,7 @@ else() endif() if(MIGRAPHX_USE_COMPOSABLEKERNEL) - find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS jit_library) + find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS ck_host) endif() if(BUILD_DEV) @@ -111,7 +111,7 @@ target_compile_definitions(kernel_file_check PRIVATE -DMIGRAPHX_NLOCAL=256) target_include_directories(kernel_file_check PRIVATE $) target_link_libraries(kernel_file_check compile_for_gpu) if(MIGRAPHX_USE_COMPOSABLEKERNEL) - target_link_libraries(kernel_file_check composable_kernel::jit_library) + target_link_libraries(kernel_file_check composable_kernel::ck_host) endif() rocm_clang_tidy_check(kernel_file_check) @@ -362,7 +362,7 @@ else() endif() target_link_libraries(migraphx_gpu PRIVATE migraphx_kernels) if(MIGRAPHX_USE_COMPOSABLEKERNEL) - target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library) + target_link_libraries(migraphx_gpu PRIVATE composable_kernel::ck_host) target_compile_definitions(migraphx_gpu PRIVATE MIGRAPHX_USE_COMPOSABLEKERNEL=1) endif() diff --git a/src/targets/gpu/compile_hip_code_object.cpp b/src/targets/gpu/compile_hip_code_object.cpp index 2e3142bfa1f..d7ae4c0f3b0 100644 --- a/src/targets/gpu/compile_hip_code_object.cpp +++ b/src/targets/gpu/compile_hip_code_object.cpp @@ -200,7 +200,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option const auto& warnings = compiler_warnings(); options.params.insert(options.params.end(), warnings.begin(), warnings.end()); options.emplace_param("-ftemplate-backtrace-limit=0"); - options.emplace_param("-Werror"); + // options.emplace_param("-Werror"); auto cos = compile_hip_src(srcs, options.params, get_device_name()); if(cos.size() != 1) MIGRAPHX_THROW("No code object"); diff --git a/src/targets/gpu/include/migraphx/gpu/ck.hpp b/src/targets/gpu/include/migraphx/gpu/ck.hpp index 18d4dce25a2..ea41b252547 100644 --- a/src/targets/gpu/include/migraphx/gpu/ck.hpp +++ b/src/targets/gpu/include/migraphx/gpu/ck.hpp @@ -30,8 +30,9 @@ #include #include -#include "ck/host/device_gemm_multiple_d.hpp" -#include "ck/host/device_batched_gemm_softmax_gemm.hpp" +#include "ck/host/headers.hpp" +#include "ck/host/device_gemm_multiple_d/problem.hpp" +#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp" namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { diff --git a/src/targets/gpu/jit/ck_gemm.cpp b/src/targets/gpu/jit/ck_gemm.cpp index 7cd20b3931d..5d0b27953a6 100644 --- a/src/targets/gpu/jit/ck_gemm.cpp +++ b/src/targets/gpu/jit/ck_gemm.cpp @@ -37,6 +37,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -136,15 +137,19 @@ struct ck_gemm_compiler : compiler { const auto& c_shape = inputs.back(); auto tuning_value = v.get("tuning_value", 34); - auto batch_count = get_batch_count(c_shape); - auto problem = create_problem(inputs, v); + auto batch_count = get_batch_count(c_shape); + auto problem = create_problem(inputs, v); - const auto include_header = problem.GetIncludeHeader(); - const auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name()); + const auto include_header = problem.GetIncludeHeader(); + const auto solutions = + problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", ""); const auto& solution = solutions.at(tuning_value); - const auto template_str = solution.template_str; - const auto blocks_per_batch = solution.grid_size; - const auto block_size = solution.block_size; + const auto template_str = solution.ToTemplateString(); + const auto block_size = solution.GetTemplateParameter("BlockSize"); + const auto m_per_block = solution.GetTemplateParameter("MPerBlock"); + const auto n_per_block = solution.GetTemplateParameter("NPerBlock"); + const auto blocks_per_batch = ck::host::integer_divide_ceil(problem.M, m_per_block) * + ck::host::integer_divide_ceil(problem.N, n_per_block); hip_compile_options options; options.additional_src_files = ck_headers(); @@ -221,7 +226,7 @@ struct ck_gemm_compiler : compiler tuning_config tc; auto shapes = to_shapes(ins->inputs()); auto problem = create_problem(shapes, create_settings(ins, op)); - auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name()); + auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", ""); tc.solutions.resize(solutions.size()); std::iota(tc.solutions.begin(), tc.solutions.end(), 0); std::vector gemm_shapes{shapes[0], shapes[1], shapes.back()}; diff --git a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp index 5fe60372b94..c5da7c4c4b3 100644 --- a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp +++ b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp @@ -37,6 +37,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -129,20 +130,26 @@ struct ck_gemm_softmax_gemm_compiler : compiler operation compile_op(context& ctx, const std::vector& inputs, const value& v) const { + const auto& c_shape = inputs.back(); auto tuning_value = v.get("tuning_value", 5); auto batch_count = get_batch_count(c_shape); auto problem = create_problem(inputs, v); - const auto include_header = problem.GetIncludeHeader(); - const auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name()); - const auto& solution = solutions.at(tuning_value); - const auto template_str = solution.template_str; - const auto blocks_per_batch = solution.grid_size; - const auto block_size = solution.block_size; + const auto include_header = problem.GetIncludeHeader(); + const auto solutions = + problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", ""); + const auto& solution = solutions.at(tuning_value); + const auto template_str = solution.ToTemplateString(); + const auto block_size = solution.GetTemplateParameter("BlockSize"); + const auto m_per_block = solution.GetTemplateParameter("Gemm01MPerBlock"); + const auto n1_per_block = solution.GetTemplateParameter("Gemm1NPerBlock"); + const auto blocks_per_batch = ck::host::integer_divide_ceil(problem.M, m_per_block) * + ck::host::integer_divide_ceil(problem.O, n1_per_block); hip_compile_options options; options.additional_src_files = ck_headers(); + auto grid_size = can_fold_batch(inputs) ? blocks_per_batch : batch_count * blocks_per_batch; options.set_launch_params(v, grid_size * block_size, block_size); options.inputs = inputs; @@ -222,7 +229,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler tuning_config tc; auto shapes = to_shapes(ins->inputs()); auto problem = create_problem(shapes, create_settings(ins, op)); - auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name()); + auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", ""); tc.solutions.resize(solutions.size()); std::iota(tc.solutions.begin(), tc.solutions.end(), 0); std::vector gemm_shapes{shapes[0], shapes[1], shapes.back()}; From 448826917d79bf55d0492fde2d98235fd090438b Mon Sep 17 00:00:00 2001 From: Mirza Halilcevic Date: Mon, 21 Oct 2024 19:47:00 +0000 Subject: [PATCH 02/14] Matcher for gemm_gemm. --- CMakeLists.txt | 2 +- requirements.txt | 2 +- src/pass_manager.cpp | 3 + src/targets/gpu/CMakeLists.txt | 3 +- src/targets/gpu/fuse_ck.cpp | 66 +++++ src/targets/gpu/include/migraphx/gpu/ck.hpp | 1 + src/targets/gpu/jit/ck_gemm_gemm.cpp | 259 ++++++++++++++++++ src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp | 3 + .../include/migraphx/kernels/ck_gemm_gemm.hpp | 66 +++++ test/verify/run_verify.cpp | 7 + test/verify/test_ck_gemm_gemm.cpp | 51 ++++ 11 files changed, 460 insertions(+), 3 deletions(-) create mode 100644 src/targets/gpu/jit/ck_gemm_gemm.cpp create mode 100644 src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_gemm.hpp create mode 100644 test/verify/test_ck_gemm_gemm.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index e4a884ffbd0..bc4af5ad06c 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -76,7 +76,7 @@ endif() if(WIN32) option(MIGRAPHX_USE_HIPBLASLT "Enable MIGraphX to use hipBLASLt" OFF) else() -option(MIGRAPHX_USE_HIPBLASLT "Enable MIGraphX to use hipBLASLt" ON) +option(MIGRAPHX_USE_HIPBLASLT "Enable MIGraphX to use hipBLASLt" OFF) endif() # By default build shared libraries diff --git a/requirements.txt b/requirements.txt index c686d609449..1316cdf0450 100755 --- a/requirements.txt +++ b/requirements.txt @@ -27,5 +27,5 @@ ROCm/half@rocm-5.6.0 pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On -mirza-halilcevic/composable_kernel@2cd9b2b1ec1ade098a0f43c44fa5a0989e0b0a9d -DCK_BUILD_HOST_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On +mirza-halilcevic/composable_kernel@f22232bf4812f02b5afd403c560ea37b5d7168a3 -DCK_BUILD_HOST_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/rocMLIR@82378ac44b8627ecbaf9078353fb53588090fe28 -DBUILD_FAT_LIBROCKCOMPILER=On diff --git a/src/pass_manager.cpp b/src/pass_manager.cpp index af8f0e4d6a7..8824426ff6c 100644 --- a/src/pass_manager.cpp +++ b/src/pass_manager.cpp @@ -178,6 +178,9 @@ void run_passes(program& prog, module_ref root_mod, const std::vector& pas mpm.run_pass(p); } run_pass(prog, p, trace); + + std::cout << "PASS " << p.name() << std::endl; + std::cout << prog << std::endl; } } diff --git a/src/targets/gpu/CMakeLists.txt b/src/targets/gpu/CMakeLists.txt index 48d36dadeda..101f4c87269 100644 --- a/src/targets/gpu/CMakeLists.txt +++ b/src/targets/gpu/CMakeLists.txt @@ -131,7 +131,8 @@ file(GLOB JIT_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/jit/*.cpp) if(NOT MIGRAPHX_USE_COMPOSABLEKERNEL) list(REMOVE_ITEM JIT_GPU_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm_softmax_gemm.cpp) + ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm_softmax_gemm.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm_gemm.cpp) endif() if(MIGRAPHX_USE_MIOPEN) diff --git a/src/targets/gpu/fuse_ck.cpp b/src/targets/gpu/fuse_ck.cpp index bf9a269f3e1..3b13c8bfb7a 100644 --- a/src/targets/gpu/fuse_ck.cpp +++ b/src/targets/gpu/fuse_ck.cpp @@ -76,6 +76,48 @@ struct ck_gemm }; MIGRAPHX_REGISTER_OP(ck_gemm); +struct ck_gemm_gemm +{ + operation op = make_op("dot"); + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.op, "op")); + } + + std::string name() const { return "gpu::ck_gemm_gemm"; } + + void check_gemm_shape(const shape& s) const + { + if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1) and + not s.scalar()) + MIGRAPHX_THROW("Invalid shape for " + name()); + } + + shape compute_shape(std::vector inputs, const std::vector&) const + { + check_shapes{inputs, *this}.same_ndims(); + if(inputs.size() < 3) + MIGRAPHX_THROW(name() + ": Expected 3 inputs but got " + to_string(inputs.size())); + + auto a = inputs[0]; + auto b = inputs[1]; + auto b1 = inputs.back(); + + for(const auto& input : inputs) + { + check_gemm_shape(input); + } + + auto gemm0_shape = op.compute_shape({a, b}); + return op.compute_shape({gemm0_shape, b1}); + } + + static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); } +}; +MIGRAPHX_REGISTER_OP(ck_gemm_gemm); + struct ck_gemm_softmax_gemm : gemm_softmax_gemm { std::string name() const { return "gpu::ck_gemm_softmax_gemm"; } @@ -203,11 +245,35 @@ struct find_ck_gemm_softmax_gemm } }; +struct find_ck_gemm_gemm +{ + auto matcher() const + { + // TODO don't mix dot and quant_dot + auto gemm1 = match::skip(match::name("contiguous"))( + match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm1"))); + return match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm2"))(match::arg(0)(gemm1)); + } + + void apply(module_pass_manager& mpm, const match::matcher_result& r) const + { + auto ins = r.result; + auto gemm1_ins = r.instructions["gemm1"]; + auto gemm2_ins = r.instructions["gemm2"]; + + auto inputs = gemm1_ins->inputs(); // A, B + inputs.push_back(gemm2_ins->inputs().back()); // B1 + + mpm.get_module().replace_instruction(ins, ck_gemm_gemm{gemm2_ins->get_operator()}, inputs); + } +}; + } // namespace void fuse_ck::apply(module_pass_manager& mpm) const { match::find_matches(mpm, find_ck_gemm_softmax_gemm{}, find_ck_gemm_pointwise{}); + match::find_matches(mpm, find_ck_gemm_gemm{}); match::find_matches(mpm, find_ck_gemm{}); } diff --git a/src/targets/gpu/include/migraphx/gpu/ck.hpp b/src/targets/gpu/include/migraphx/gpu/ck.hpp index ea41b252547..d252db2c8b9 100644 --- a/src/targets/gpu/include/migraphx/gpu/ck.hpp +++ b/src/targets/gpu/include/migraphx/gpu/ck.hpp @@ -33,6 +33,7 @@ #include "ck/host/headers.hpp" #include "ck/host/device_gemm_multiple_d/problem.hpp" #include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp" +#include "ck/host/device_batched_gemm_multiple_d_gemm_multiple_d/problem.hpp" namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { diff --git a/src/targets/gpu/jit/ck_gemm_gemm.cpp b/src/targets/gpu/jit/ck_gemm_gemm.cpp new file mode 100644 index 00000000000..dcce8611c19 --- /dev/null +++ b/src/targets/gpu/jit/ck_gemm_gemm.cpp @@ -0,0 +1,259 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +namespace gpu { + +using namespace migraphx::gpu::gen; // NOLINT + +// NOLINTNEXTLINE +static const char* const ck_gemm_gemm_kernel = R"__migraphx__( +#include +#include +#include +#include +#include <${include}> + +namespace migraphx { + +${preamble} + +extern "C" { + +MIGRAPHX_GLOBAL void ${kernel}(${params}) +{ + transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) { + ck_gemm_gemm<${solution}, ${blocks_per_batch}>(xs...); + }); +} + +} + +} // namespace migraphx + +)__migraphx__"; + +struct ck_gemm_gemm_compiler : compiler +{ + std::vector names() const { return {"ck_gemm_gemm", "gpu::ck_gemm_gemm"}; } + + ck::host::device_batched_gemm_multiple_d_gemm_multiple_d::Problem + create_problem(const std::vector& inputs, const value& v) const + { + const auto& a0_shape = inputs[0]; + const auto& b0_shape = inputs[1]; + const auto& b1_shape = inputs[2]; + const auto& e1_shape = inputs.back(); + + // cppcheck-suppress unreadVariable + auto rank = a0_shape.ndim(); + auto batch_count = get_batch_count(e1_shape); + auto m = e1_shape.lens()[rank - 2]; + m = can_fold_batch(inputs) ? m * batch_count : m; + auto n = b0_shape.lens().back(); + auto k = a0_shape.lens().back(); + auto o = e1_shape.lens().back(); + + const bool trans_a0 = transposed_matrix(a0_shape); + const bool trans_b0 = transposed_matrix(b0_shape); + const bool trans_b1 = transposed_matrix(b1_shape); + const bool trans_e1 = transposed_matrix(e1_shape); + + std::vector trans_d0s; + std::transform(inputs.begin() + 3, + inputs.end() - 1, + std::back_inserter(trans_d0s), + [](const auto& i) { return transposed_matrix(i); }); + // TODO trans_d1s + + const auto a0_type = get_type(a0_shape); + const auto b0_type = get_type(b0_shape); + const auto b1_type = get_type(b1_shape); + const auto e1_type = get_type(e1_shape); + + std::vector d0s_type; + std::transform(inputs.begin() + 3, + inputs.end() - 1, + std::back_inserter(d0s_type), + [](const auto& i) { return get_type(i); }); + // TODO d1s_type + + std::string ck_passthrough = "ck_passthrough"; + std::string cde0_op = ck_passthrough; + std::string cde1_op = ck_passthrough; + assert(inputs.size() < 5 or v.contains("post")); + if(v.contains("post")) + { + cde0_op = v.at("post").to(); + // TODO CDE1ElementOp + } + + return ck::host::device_batched_gemm_multiple_d_gemm_multiple_d::Problem{m, + n, + k, + o, + trans_a0, + trans_b0, + trans_d0s, + trans_b1, + {}, // trans_d1s + trans_e1, + a0_type, + b0_type, + d0s_type, + b1_type, + {}, // d1s_type + e1_type, + ck_passthrough, + ck_passthrough, + cde0_op, + ck_passthrough, + cde1_op}; + } + + operation compile_op(context& ctx, const std::vector& inputs, const value& v) const + { + const auto& e1_shape = inputs.back(); + auto tuning_value = v.get("tuning_value", 0); + auto batch_count = get_batch_count(e1_shape); + auto problem = create_problem(inputs, v); + + const auto include_header = problem.GetIncludeHeader(); + const auto solutions = + problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", ""); + const auto& solution = solutions.at(tuning_value); + const auto template_str = solution.ToTemplateString(); + const auto block_size = solution.GetTemplateParameter("BlockSize"); + const auto m_per_block = solution.GetTemplateParameter("Gemm0MPerBlock"); + const auto n1_per_block = solution.GetTemplateParameter("Gemm1NPerBlock"); + const auto blocks_per_batch = ck::host::integer_divide_ceil(problem.M, m_per_block) * + ck::host::integer_divide_ceil(problem.O, n1_per_block); + + hip_compile_options options; + options.additional_src_files = ck_headers(); + + auto grid_size = can_fold_batch(inputs) ? blocks_per_batch : batch_count * blocks_per_batch; + options.set_launch_params(v, grid_size * block_size, block_size); + options.inputs = inputs; + options.output = e1_shape; + options.kernel_name = v.get("kernel", "ck_gemm_gemm_kernel"); + options.virtual_inputs = inputs; + if(can_fold_batch(inputs)) + { + auto vinputs = inputs; + fold_batch_dims(vinputs[0]); + remove_batch_dims(vinputs[1]); + std::for_each(vinputs.begin() + 2, vinputs.end(), fold_batch_dims); + options.virtual_inputs = vinputs; + } + + if(v.get("check", false) or enabled(MIGRAPHX_CK_DEBUG{})) + options.emplace_param("-DMIGRAPHX_CK_CHECK=1"); + + auto src = interpolate_string(ck_gemm_gemm_kernel, + {{"solution", template_str}, + {"include", include_header}, + {"params", enum_params(inputs.size(), "void * private_p")}, + {"args", enum_params(inputs.size(), "private_p")}, + {"blocks_per_batch", to_string(blocks_per_batch)}, + {"preamble", v.get("preamble", std::string{})}, + {"kernel", options.kernel_name}}); + + return compile_hip_code_object(src, options); + } + + value create_settings(instruction_ref ins, const operation& op) const + { + auto v = op.to_value(); + v["kernel"] = "ck_gemm_gemm_kernel"; + if(not ins->module_inputs().empty()) + { + auto* pm = ins->module_inputs().front(); + v["preamble"] = generate_pointwise(*pm, "post_ck_gemm0_function") + + "\nMIGRAPHX_LIFT_CLASS(post_ck_gemm0, post_ck_gemm0_function);"; + v["post"] = "ck_function_adaptor"; + v["kernel"] = to_c_id("ck_gemm_" + generate_name_from_ops(*pm) + "_gemm_kernel"); + } + return v; + } + + compiler_replace + compile(context& ctx, instruction_ref ins, const operation& op, const value& solution) const + { + std::cout << "USING GEMM GEMM" << std::endl; + + auto shapes = to_shapes(ins->inputs()); + auto v = create_settings(ins, op); + if(not solution.is_null()) + v["tuning_value"] = solution; + return {compile_op(ctx, shapes, v), + [=](module& m, instruction_ref ins2, const operation& code_object) { + if(enabled(MIGRAPHX_LOG_CK_GEMM{})) + { + std::vector gemm_shapes{ + shapes[0], shapes[1], shapes.back().with_type(shapes[0].type())}; + std::cout << "gpu::ck_gemm_gemm: " << to_json_string(to_value(gemm_shapes)) + << std::endl; + } + m.replace_instruction(ins2, code_object, ins2->inputs()); + }}; + } + + optional + get_tuning_config(context& ctx, instruction_ref ins, const operation& op, bool exhaustive) const + { + if(not exhaustive and not enabled(MIGRAPHX_TUNE_CK{})) + return nullopt; + tuning_config tc; + auto shapes = to_shapes(ins->inputs()); + auto problem = create_problem(shapes, create_settings(ins, op)); + auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", ""); + tc.solutions.resize(solutions.size()); + std::iota(tc.solutions.begin(), tc.solutions.end(), 0); + std::vector gemm_shapes{shapes[0], shapes[1], shapes.back()}; + tc.problem = to_value(gemm_shapes); + return tc; + } +}; + +} // namespace gpu +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp index c5da7c4c4b3..fdd68bd077b 100644 --- a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp +++ b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp @@ -182,6 +182,9 @@ struct ck_gemm_softmax_gemm_compiler : compiler {"preamble", v.get("preamble", std::string{})}, {"kernel", options.kernel_name}}); + // std::cout << "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" << std::endl; + // std::cout << src << std::endl; + return compile_hip_code_object(src, options); } diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_gemm.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_gemm.hpp new file mode 100644 index 00000000000..2b9e6f55b31 --- /dev/null +++ b/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_gemm.hpp @@ -0,0 +1,66 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#ifndef MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP +#define MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP + +#include +#include +#include +#include +#include +#include + +namespace migraphx { + +template +__device__ void ck_gemm_gemm_matrix(E1 e1, A0 a0, B0 b0, B1 b1, D0s... d0s) +{ + constexpr auto desc = G::make_descriptor(to_ck_tensor(), + to_ck_tensor>(), + ck::make_tuple(to_ck_tensor()...), + to_ck_tensor>(), + ck::make_tuple(), + to_ck_tensor()); + + MIGRAPHX_STATIC_ASSERT_FOR(desc.IsValid()) + { + G::Run(desc, + to_ck_const_pointer(a0.data()), + to_ck_const_pointer(b0.data()), + ck::make_tuple(to_ck_const_pointer(d0s.data())...), + to_ck_const_pointer(b1.data()), + ck::make_tuple(), + to_ck_pointer(e1.data())); + } +} + +template +__device__ void ck_gemm_gemm(Ts... xs) +{ + gemm_batch_args(make_index(), _c, xs...)( + [](auto... ys) { ck_gemm_gemm_matrix(ys...); }); +} + +} // namespace migraphx +#endif diff --git a/test/verify/run_verify.cpp b/test/verify/run_verify.cpp index e7fdd535fcc..56e244f9d3b 100644 --- a/test/verify/run_verify.cpp +++ b/test/verify/run_verify.cpp @@ -132,6 +132,10 @@ run_verify::run_ref(migraphx::program p, auto_print pp{p, t.name()}; auto trace_target = migraphx::string_value_of(MIGRAPHX_TRACE_TEST_COMPILE{}); compile_check(p, t, c_opts, (trace_target == "ref")); + + std::cout << "REF PROGRAM" << std::endl; + p.debug_print(); + return std::make_pair(std::move(p), p.eval(std::move(inputs))); } @@ -159,6 +163,9 @@ run_verify::run_target(const migraphx::target& t, validate(t, p, m); p.eval(m); + std::cout << "GPU PROGRAM" << std::endl; + p.debug_print(); + auto tres = p.eval(m); std::vector res(tres.size()); std::transform( diff --git a/test/verify/test_ck_gemm_gemm.cpp b/test/verify/test_ck_gemm_gemm.cpp new file mode 100644 index 00000000000..a3ae26d243e --- /dev/null +++ b/test/verify/test_ck_gemm_gemm.cpp @@ -0,0 +1,51 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_ck_gemm_gemm : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 128}}; + migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 512, 128}}; + migraphx::shape m3_shape{migraphx::shape::half_type, {1, 12, 512, 64}}; + auto a = mm->add_parameter("1", m1_shape); + auto b = mm->add_parameter("2", m2_shape); + auto b1 = mm->add_parameter("3", m3_shape); + + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); + auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); + mm->add_instruction(migraphx::make_op("dot"), gemm1, b1); + + return p; + } + std::string section() const { return "gemm"; } +}; From edaa04856a3af6e8f74bb96ec457fd5efae0c624 Mon Sep 17 00:00:00 2001 From: Mirza Halilcevic Date: Thu, 31 Oct 2024 08:57:26 +0000 Subject: [PATCH 03/14] Finalize gemm_gemm integration. --- requirements.txt | 9 +- src/pass_manager.cpp | 3 - src/targets/gpu/fuse_ck.cpp | 158 +++++++++++++++--- src/targets/gpu/jit/ck_gemm_gemm.cpp | 81 ++++++--- .../include/migraphx/kernels/ck_gemm_gemm.hpp | 46 +++-- test/verify/test_ck_gemm_gemm.cpp | 12 +- test/verify/test_ck_gemm_gemm_pointwise.cpp | 58 +++++++ test/verify/test_ck_gemm_pointwise_gemm.cpp | 58 +++++++ .../test_ck_gemm_pointwise_gemm_pointwise.cpp | 65 +++++++ ..._gemm_pointwise_gemm_pointwise_rotated.cpp | 62 +++++++ 10 files changed, 481 insertions(+), 71 deletions(-) create mode 100644 test/verify/test_ck_gemm_gemm_pointwise.cpp create mode 100644 test/verify/test_ck_gemm_pointwise_gemm.cpp create mode 100644 test/verify/test_ck_gemm_pointwise_gemm_pointwise.cpp create mode 100644 test/verify/test_ck_gemm_pointwise_gemm_pointwise_rotated.cpp diff --git a/requirements.txt b/requirements.txt index d6d437a9ff1..37ef26395b7 100755 --- a/requirements.txt +++ b/requirements.txt @@ -27,10 +27,5 @@ ROCm/half@rocm-5.6.0 pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On -<<<<<<< HEAD -mirza-halilcevic/composable_kernel@f22232bf4812f02b5afd403c560ea37b5d7168a3 -DCK_BUILD_HOST_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/rocMLIR@82378ac44b8627ecbaf9078353fb53588090fe28 -DBUILD_FAT_LIBROCKCOMPILER=On -======= -ROCm/composable_kernel@57cdd70b7cb14e5e3b60cd9a5f96ba8dc343763e -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/rocMLIR@e454b5d06fc2f099f7de3ee43450e7a6b1efe015 -DBUILD_FAT_LIBROCKCOMPILER=On ->>>>>>> upstream/develop +mirza-halilcevic/composable_kernel@089c978451b3a5c15e146c5372e8ef8c5d4bc7d9 -DCK_BUILD_HOST_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On +ROCm/rocMLIR@87a55290b7f89f9f8b97c611e0cf929399a9e0c5 -DBUILD_FAT_LIBROCKCOMPILER=On diff --git a/src/pass_manager.cpp b/src/pass_manager.cpp index 8824426ff6c..af8f0e4d6a7 100644 --- a/src/pass_manager.cpp +++ b/src/pass_manager.cpp @@ -178,9 +178,6 @@ void run_passes(program& prog, module_ref root_mod, const std::vector& pas mpm.run_pass(p); } run_pass(prog, p, trace); - - std::cout << "PASS " << p.name() << std::endl; - std::cout << prog << std::endl; } } diff --git a/src/targets/gpu/fuse_ck.cpp b/src/targets/gpu/fuse_ck.cpp index 3b13c8bfb7a..515a44bfc55 100644 --- a/src/targets/gpu/fuse_ck.cpp +++ b/src/targets/gpu/fuse_ck.cpp @@ -78,12 +78,15 @@ MIGRAPHX_REGISTER_OP(ck_gemm); struct ck_gemm_gemm { - operation op = make_op("dot"); + operation op = make_op("dot"); + size_t d0s_count = 0u; + size_t d1s_count = 0u; template static auto reflect(Self& self, F f) { - return pack(f(self.op, "op")); + return pack( + f(self.op, "op"), f(self.d0s_count, "d0s_count"), f(self.d1s_count, "d1s_count")); } std::string name() const { return "gpu::ck_gemm_gemm"; } @@ -101,9 +104,9 @@ struct ck_gemm_gemm if(inputs.size() < 3) MIGRAPHX_THROW(name() + ": Expected 3 inputs but got " + to_string(inputs.size())); - auto a = inputs[0]; - auto b = inputs[1]; - auto b1 = inputs.back(); + auto a = inputs[0]; + auto b = inputs[1]; + auto b1 = inputs[2]; for(const auto& input : inputs) { @@ -132,11 +135,11 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) return false; if(not ck_gemm::is_ck_supported_type(ins->get_shape().type())) return false; - auto a = ins->inputs().front()->get_shape(); - auto b = ins->inputs().back()->get_shape(); - auto m = a.lens()[a.lens().size() - 2]; - auto n = b.lens().back(); - auto k = a.lens().back(); + auto a = ins->inputs().front()->get_shape(); + auto b = ins->inputs().back()->get_shape(); + auto m = a.lens()[a.lens().size() - 2]; + auto n = b.lens().back(); + auto k = a.lens().back(); auto batch_size = std::accumulate( a.lens().rbegin() + 2, a.lens().rend(), std::size_t{1}, std::multiplies()); // Integer gemms must be divisible by 4 in ck @@ -245,26 +248,137 @@ struct find_ck_gemm_softmax_gemm } }; -struct find_ck_gemm_gemm +struct find_ck_gemm_pointwise_gemm { auto matcher() const { // TODO don't mix dot and quant_dot - auto gemm1 = match::skip(match::name("contiguous"))( - match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm1"))); - return match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm2"))(match::arg(0)(gemm1)); + // TODO match used_once? + auto gemm0 = match::skip(match::name("contiguous"))( + match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm0"))); + auto pw0 = + match::name("pointwise")(match::any_of[match::inputs()](gemm0.bind("x0")).bind("pw0")); + return match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm1"))( + match::arg(0)(match::any_of(pw0, gemm0))); } - void apply(module_pass_manager& mpm, const match::matcher_result& r) const + bool transposed_matrix(const shape& s) { return s.strides().back() != 1; } + + void apply(module_pass_manager& mpm, const match::matcher_result& r) { + std::cout << "matching ck_gemm_gemm" << std::endl; + auto ins = r.result; + auto gemm0_ins = r.instructions["gemm0"]; auto gemm1_ins = r.instructions["gemm1"]; - auto gemm2_ins = r.instructions["gemm2"]; - auto inputs = gemm1_ins->inputs(); // A, B - inputs.push_back(gemm2_ins->inputs().back()); // B1 + auto inputs = gemm0_ins->inputs(); // A, B + inputs.push_back(gemm1_ins->inputs().back()); // B1 + + if (!transposed_matrix(inputs[1]->get_shape())) + return; - mpm.get_module().replace_instruction(ins, ck_gemm_gemm{gemm2_ins->get_operator()}, inputs); + size_t d0s_count = 0, d1s_count = 0; + std::vector module_inputs; + if(r.instructions.find("pw0") != r.instructions.end()) + { + auto pw0 = r.instructions["pw0"]; + auto x0_ins = r.instructions["x0"]; + + if(gemm0_ins->get_shape().type() != shape::int32_type and + pw0->get_shape().type() != gemm0_ins->get_shape().type()) + return; + if(std::any_of(pw0->inputs().begin(), pw0->inputs().end(), [](auto input) { + return not ck_gemm::is_ck_supported_type(input->get_shape().type()); + })) + return; + if(std::any_of(pw0->inputs().begin(), pw0->inputs().end(), [](auto input) { + return not input->inputs().empty() and + input->inputs().front()->name() == "capture"; + })) + return; + + auto pw0_inputs = pw0->inputs(); + auto* pw0m = pw0->module_inputs().front(); + + auto gemm_it = std::find(pw0_inputs.begin(), pw0_inputs.end(), x0_ins); + auto gemm_idx = gemm_it - pw0_inputs.begin(); + if(gemm_idx != 0) + { + rotate_gemm_input(pw0m, gemm_idx); + } + + pw0_inputs.erase(gemm_it); + inputs.insert(inputs.end(), pw0_inputs.begin(), pw0_inputs.end()); // D0s + + d0s_count = pw0_inputs.size(); + module_inputs.push_back(pw0m); + } + if(r.instructions.find("pw1") != r.instructions.end()) + { + auto pw1 = r.instructions["pw1"]; + + if(gemm1_ins->get_shape().type() != shape::int32_type and + pw1->get_shape().type() != gemm1_ins->get_shape().type()) + return; + if(std::any_of(pw1->inputs().begin(), pw1->inputs().end(), [](auto input) { + return not ck_gemm::is_ck_supported_type(input->get_shape().type()); + })) + return; + if(std::any_of(pw1->inputs().begin(), pw1->inputs().end(), [](auto input) { + return not input->inputs().empty() and + input->inputs().front()->name() == "capture"; + })) + return; + + auto pw1_inputs = pw1->inputs(); + auto* pw1m = pw1->module_inputs().front(); + + auto gemm_it = std::find(pw1_inputs.begin(), pw1_inputs.end(), gemm1_ins); + auto gemm_idx = gemm_it - pw1_inputs.begin(); + if(gemm_idx != 0) + { + rotate_gemm_input(pw1m, gemm_idx); + } + + pw1_inputs.erase(gemm_it); + inputs.insert(inputs.end(), pw1_inputs.begin(), pw1_inputs.end()); // D1s + + d1s_count = pw1_inputs.size(); + module_inputs.push_back(pw1m); + } + + mpm.get_module().replace_instruction( + ins, + ck_gemm_gemm{gemm1_ins->get_operator(), d0s_count, d1s_count}, + inputs, + module_inputs); + } + + void rotate_gemm_input(module* pwm, size_t gemm_idx) + { + auto names = pwm->get_parameter_names(); + + auto first_param = pwm->get_parameter(names[0]); + auto gemm_param = pwm->get_parameter(names[gemm_idx]); + + auto new_gemm_param = pwm->add_parameter(names[0] + "_0", gemm_param->get_shape()); + auto new_first_param = pwm->add_parameter(names[gemm_idx] + "_0", first_param->get_shape()); + + pwm->replace_instruction(gemm_param, new_gemm_param); + pwm->replace_instruction(first_param, new_first_param); + pwm->remove_instruction(first_param); + pwm->remove_instruction(gemm_param); + } +}; + +struct find_ck_gemm_pointwise_gemm_pointwise : find_ck_gemm_pointwise_gemm +{ + auto matcher() const + { + // TODO match used_once? + auto gemm1 = find_ck_gemm_pointwise_gemm::matcher(); + return match::name("pointwise")(match::any_of[match::inputs()](gemm1).bind("pw1")); } }; @@ -272,8 +386,10 @@ struct find_ck_gemm_gemm void fuse_ck::apply(module_pass_manager& mpm) const { - match::find_matches(mpm, find_ck_gemm_softmax_gemm{}, find_ck_gemm_pointwise{}); - match::find_matches(mpm, find_ck_gemm_gemm{}); + match::find_matches(mpm, find_ck_gemm_softmax_gemm{}); + match::find_matches(mpm, find_ck_gemm_pointwise_gemm_pointwise{}); + match::find_matches(mpm, find_ck_gemm_pointwise_gemm{}); + match::find_matches(mpm, find_ck_gemm_pointwise{}); match::find_matches(mpm, find_ck_gemm{}); } diff --git a/src/targets/gpu/jit/ck_gemm_gemm.cpp b/src/targets/gpu/jit/ck_gemm_gemm.cpp index dcce8611c19..005cbda11e8 100644 --- a/src/targets/gpu/jit/ck_gemm_gemm.cpp +++ b/src/targets/gpu/jit/ck_gemm_gemm.cpp @@ -54,14 +54,16 @@ static const char* const ck_gemm_gemm_kernel = R"__migraphx__( namespace migraphx { -${preamble} +${preamble0} + +${preamble1} extern "C" { MIGRAPHX_GLOBAL void ${kernel}(${params}) { transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) { - ck_gemm_gemm<${solution}, ${blocks_per_batch}>(xs...); + ck_gemm_gemm<${solution}, ${blocks_per_batch}, ${d0s_count}>(xs...); }); } @@ -97,12 +99,20 @@ struct ck_gemm_gemm_compiler : compiler const bool trans_b1 = transposed_matrix(b1_shape); const bool trans_e1 = transposed_matrix(e1_shape); + auto d0s_count = v.get("d0s_count", 0); + auto d1s_count = v.get("d1s_count", 0); + std::vector trans_d0s; std::transform(inputs.begin() + 3, - inputs.end() - 1, + inputs.end() - 1 - d1s_count, std::back_inserter(trans_d0s), [](const auto& i) { return transposed_matrix(i); }); - // TODO trans_d1s + + std::vector trans_d1s; + std::transform(inputs.begin() + 3 + d0s_count, + inputs.end() - 1, + std::back_inserter(trans_d1s), + [](const auto& i) { return transposed_matrix(i); }); const auto a0_type = get_type(a0_shape); const auto b0_type = get_type(b0_shape); @@ -111,19 +121,27 @@ struct ck_gemm_gemm_compiler : compiler std::vector d0s_type; std::transform(inputs.begin() + 3, - inputs.end() - 1, + inputs.end() - 1 - d1s_count, std::back_inserter(d0s_type), [](const auto& i) { return get_type(i); }); - // TODO d1s_type + + std::vector d1s_type; + std::transform(inputs.begin() + 3 + d0s_count, + inputs.end() - 1, + std::back_inserter(d1s_type), + [](const auto& i) { return get_type(i); }); std::string ck_passthrough = "ck_passthrough"; std::string cde0_op = ck_passthrough; std::string cde1_op = ck_passthrough; - assert(inputs.size() < 5 or v.contains("post")); - if(v.contains("post")) + assert(inputs.size() < 5 or v.contains("cde0_op") or v.containse("cde1_op")); + if(v.contains("cde0_op")) + { + cde0_op = v.at("cde0_op").to(); + } + if(v.contains("cde1_op")) { - cde0_op = v.at("post").to(); - // TODO CDE1ElementOp + cde1_op = v.at("cde1_op").to(); } return ck::host::device_batched_gemm_multiple_d_gemm_multiple_d::Problem{m, @@ -134,13 +152,13 @@ struct ck_gemm_gemm_compiler : compiler trans_b0, trans_d0s, trans_b1, - {}, // trans_d1s + trans_d1s, trans_e1, a0_type, b0_type, d0s_type, b1_type, - {}, // d1s_type + d1s_type, e1_type, ck_passthrough, ck_passthrough, @@ -151,6 +169,8 @@ struct ck_gemm_gemm_compiler : compiler operation compile_op(context& ctx, const std::vector& inputs, const value& v) const { + std::cout << "compiling ck_gemm_gemm: " << v.get("kernel", std::string{}) << std::endl; + const auto& e1_shape = inputs.back(); auto tuning_value = v.get("tuning_value", 0); auto batch_count = get_batch_count(e1_shape); @@ -194,7 +214,9 @@ struct ck_gemm_gemm_compiler : compiler {"params", enum_params(inputs.size(), "void * private_p")}, {"args", enum_params(inputs.size(), "private_p")}, {"blocks_per_batch", to_string(blocks_per_batch)}, - {"preamble", v.get("preamble", std::string{})}, + {"d0s_count", to_string(v.get("d0s_count", 0))}, + {"preamble0", v.get("preamble0", std::string{})}, + {"preamble1", v.get("preamble1", std::string{})}, {"kernel", options.kernel_name}}); return compile_hip_code_object(src, options); @@ -202,24 +224,37 @@ struct ck_gemm_gemm_compiler : compiler value create_settings(instruction_ref ins, const operation& op) const { - auto v = op.to_value(); - v["kernel"] = "ck_gemm_gemm_kernel"; - if(not ins->module_inputs().empty()) + auto v = op.to_value(); + + auto d0s_count = v.get("d0s_count", 0); + auto d1s_count = v.get("d1s_count", 0); + + std::string pw0_name, pw1_name, kernel_name = "ck_gemm_"; + if(d0s_count > 0) { - auto* pm = ins->module_inputs().front(); - v["preamble"] = generate_pointwise(*pm, "post_ck_gemm0_function") + - "\nMIGRAPHX_LIFT_CLASS(post_ck_gemm0, post_ck_gemm0_function);"; - v["post"] = "ck_function_adaptor"; - v["kernel"] = to_c_id("ck_gemm_" + generate_name_from_ops(*pm) + "_gemm_kernel"); + auto* pw0m = ins->module_inputs().front(); + v["preamble0"] = generate_pointwise(*pw0m, "post_ck_gemm0_function") + + "\nMIGRAPHX_LIFT_CLASS(post_ck_gemm0, post_ck_gemm0_function);"; + v["cde0_op"] = "ck_function_adaptor"; + kernel_name += generate_name_from_ops(*pw0m) + "_"; } + kernel_name += "gemm_"; + if(d1s_count > 0) + { + auto* pw1m = ins->module_inputs().back(); + v["preamble1"] = generate_pointwise(*pw1m, "post_ck_gemm1_function") + + "\nMIGRAPHX_LIFT_CLASS(post_ck_gemm1, post_ck_gemm1_function);"; + v["cde1_op"] = "ck_function_adaptor"; + kernel_name += generate_name_from_ops(*pw1m) + "_"; + } + + v["kernel"] = to_c_id(kernel_name + "kernel"); return v; } compiler_replace compile(context& ctx, instruction_ref ins, const operation& op, const value& solution) const { - std::cout << "USING GEMM GEMM" << std::endl; - auto shapes = to_shapes(ins->inputs()); auto v = create_settings(ins, op); if(not solution.is_null()) diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_gemm.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_gemm.hpp index 2b9e6f55b31..f0dbeefd333 100644 --- a/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_gemm.hpp +++ b/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_gemm.hpp @@ -24,6 +24,7 @@ #ifndef MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP #define MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP +#include "migraphx/kernels/type_traits.hpp" #include #include #include @@ -33,33 +34,54 @@ namespace migraphx { -template -__device__ void ck_gemm_gemm_matrix(E1 e1, A0 a0, B0 b0, B1 b1, D0s... d0s) +template +__device__ void ck_gemm_gemm_matrix_impl( + E1 e1, A0 a0, B0 b0, B1 b1, detail::seq, detail::seq, tuple ds) { - constexpr auto desc = G::make_descriptor(to_ck_tensor(), - to_ck_tensor>(), - ck::make_tuple(to_ck_tensor()...), - to_ck_tensor>(), - ck::make_tuple(), - to_ck_tensor()); + constexpr auto desc = G::make_descriptor( + to_ck_tensor(), + to_ck_tensor>(), + ck::make_tuple( + to_ck_tensor(ds))>>()...), + to_ck_tensor>(), + ck::make_tuple( + to_ck_tensor(ds))>>()...), + to_ck_tensor()); MIGRAPHX_STATIC_ASSERT_FOR(desc.IsValid()) { G::Run(desc, to_ck_const_pointer(a0.data()), to_ck_const_pointer(b0.data()), - ck::make_tuple(to_ck_const_pointer(d0s.data())...), + ck::make_tuple(to_ck_const_pointer(tuple_detail::get_element(ds).data())...), to_ck_const_pointer(b1.data()), - ck::make_tuple(), + ck::make_tuple(to_ck_const_pointer( + tuple_detail::get_element(ds).data())...), to_ck_pointer(e1.data())); } } -template +template +__device__ void ck_gemm_gemm_matrix(E1 e1, A0 a0, B0 b0, B1 b1, Ds... ds) +{ + auto all_ds = make_tuple(ds...); + ck_gemm_gemm_matrix_impl( + e1, a0, b0, b1, detail::gens{}, detail::gens{}, all_ds); +} + +template __device__ void ck_gemm_gemm(Ts... xs) { gemm_batch_args(make_index(), _c, xs...)( - [](auto... ys) { ck_gemm_gemm_matrix(ys...); }); + [](auto... ys) { ck_gemm_gemm_matrix(ys...); }); } } // namespace migraphx diff --git a/test/verify/test_ck_gemm_gemm.cpp b/test/verify/test_ck_gemm_gemm.cpp index a3ae26d243e..9595f0cd581 100644 --- a/test/verify/test_ck_gemm_gemm.cpp +++ b/test/verify/test_ck_gemm_gemm.cpp @@ -37,13 +37,15 @@ struct test_ck_gemm_gemm : verify_program migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 128}}; migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 512, 128}}; migraphx::shape m3_shape{migraphx::shape::half_type, {1, 12, 512, 64}}; - auto a = mm->add_parameter("1", m1_shape); - auto b = mm->add_parameter("2", m2_shape); - auto b1 = mm->add_parameter("3", m3_shape); + migraphx::shape x_shape{migraphx::shape::half_type, {1, 12, 256, 512}}; + + auto a = mm->add_parameter("1", m1_shape); + auto b = mm->add_parameter("2", m2_shape); + auto b1 = mm->add_parameter("3", m3_shape); b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); - auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); - mm->add_instruction(migraphx::make_op("dot"), gemm1, b1); + auto gemm0 = mm->add_instruction(migraphx::make_op("dot"), a, b); + mm->add_instruction(migraphx::make_op("dot"), gemm0, b1); return p; } diff --git a/test/verify/test_ck_gemm_gemm_pointwise.cpp b/test/verify/test_ck_gemm_gemm_pointwise.cpp new file mode 100644 index 00000000000..1dac3c5c101 --- /dev/null +++ b/test/verify/test_ck_gemm_gemm_pointwise.cpp @@ -0,0 +1,58 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_ck_gemm_gemm_pointwise : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 128}}; + migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 512, 128}}; + migraphx::shape m3_shape{migraphx::shape::half_type, {1, 12, 512, 64}}; + migraphx::shape x_shape{migraphx::shape::half_type, {1, 12, 256, 64}}; + + auto a = mm->add_parameter("1", m1_shape); + auto b = mm->add_parameter("2", m2_shape); + auto b1 = mm->add_parameter("3", m3_shape); + auto x = mm->add_parameter("x", x_shape); + + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); + auto gemm0 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), gemm0, b1); + auto sin = mm->add_instruction(migraphx::make_op("sin"), gemm1); + auto cos = mm->add_instruction(migraphx::make_op("cos"), sin); + auto add = mm->add_instruction(migraphx::make_op("add"), cos, x); + mm->add_instruction(migraphx::make_op("add"), add, sin); + + return p; + } + std::string section() const { return "gemm"; } +}; diff --git a/test/verify/test_ck_gemm_pointwise_gemm.cpp b/test/verify/test_ck_gemm_pointwise_gemm.cpp new file mode 100644 index 00000000000..655d7bd0d5a --- /dev/null +++ b/test/verify/test_ck_gemm_pointwise_gemm.cpp @@ -0,0 +1,58 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_ck_gemm_pointwise_gemm : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 128}}; + migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 512, 128}}; + migraphx::shape m3_shape{migraphx::shape::half_type, {1, 12, 512, 64}}; + migraphx::shape x_shape{migraphx::shape::half_type, {1, 12, 256, 512}}; + + auto a = mm->add_parameter("1", m1_shape); + auto b = mm->add_parameter("2", m2_shape); + auto b1 = mm->add_parameter("3", m3_shape); + auto x = mm->add_parameter("x", x_shape); + + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); + auto gemm0 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto sin = mm->add_instruction(migraphx::make_op("sin"), gemm0); + auto cos = mm->add_instruction(migraphx::make_op("cos"), sin); + auto add = mm->add_instruction(migraphx::make_op("add"), cos, x); + auto add2 = mm->add_instruction(migraphx::make_op("add"), add, sin); + mm->add_instruction(migraphx::make_op("dot"), add2, b1); + + return p; + } + std::string section() const { return "gemm"; } +}; diff --git a/test/verify/test_ck_gemm_pointwise_gemm_pointwise.cpp b/test/verify/test_ck_gemm_pointwise_gemm_pointwise.cpp new file mode 100644 index 00000000000..7dccf55be8c --- /dev/null +++ b/test/verify/test_ck_gemm_pointwise_gemm_pointwise.cpp @@ -0,0 +1,65 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_ck_gemm_pointwise_gemm_pointwise : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 128}}; + migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 128, 512}}; + migraphx::shape x_shape{migraphx::shape::half_type, {1, 12, 256, 512}}; + migraphx::shape m3_shape{migraphx::shape::half_type, {1, 12, 64, 512}}; + migraphx::shape y_shape{migraphx::shape::half_type, {1, 12, 256, 64}}; + + auto a = mm->add_parameter("1", m1_shape); + auto b = mm->add_parameter("2", m2_shape); + auto b1 = mm->add_parameter("3", m3_shape); + auto x = mm->add_parameter("x", x_shape); + auto y = mm->add_parameter("y", y_shape); + auto z = mm->add_parameter("z", y_shape); + + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); + auto gemm0 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto sin = mm->add_instruction(migraphx::make_op("sin"), gemm0); + auto cos = mm->add_instruction(migraphx::make_op("cos"), sin); + auto add0 = mm->add_instruction(migraphx::make_op("add"), x, cos); + auto add1 = mm->add_instruction(migraphx::make_op("add"), add0, sin); + b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), + b1); + auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), add1, b1); + auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, y); + mm->add_instruction(migraphx::make_op("mul"), sub, z); + + return p; + } + std::string section() const { return "gemm"; } +}; diff --git a/test/verify/test_ck_gemm_pointwise_gemm_pointwise_rotated.cpp b/test/verify/test_ck_gemm_pointwise_gemm_pointwise_rotated.cpp new file mode 100644 index 00000000000..86d7d740337 --- /dev/null +++ b/test/verify/test_ck_gemm_pointwise_gemm_pointwise_rotated.cpp @@ -0,0 +1,62 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include "verify_program.hpp" +#include +#include +#include + +struct test_ck_gemm_pointwise_gemm_pointwise_rotated + : verify_program +{ + migraphx::program create_program() const + { + migraphx::program p; + auto* mm = p.get_main_module(); + + migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 128}}; + migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 512, 128}}; + migraphx::shape x_shape{migraphx::shape::half_type, {1, 12, 256, 512}}; + migraphx::shape m3_shape{migraphx::shape::half_type, {1, 12, 512, 64}}; + migraphx::shape y_shape{migraphx::shape::half_type, {1, 12, 256, 64}}; + + auto a = mm->add_parameter("1", m1_shape); + auto b = mm->add_parameter("2", m2_shape); + auto b1 = mm->add_parameter("3", m3_shape); + auto x = mm->add_parameter("x", x_shape); + auto y = mm->add_parameter("y", y_shape); + auto z = mm->add_parameter("z", y_shape); + + b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); + auto gemm0 = mm->add_instruction(migraphx::make_op("dot"), a, b); + auto add = mm->add_instruction(migraphx::make_op("sub"), x, gemm0); + + auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), add, b1); + auto sub = mm->add_instruction(migraphx::make_op("sub"), y, gemm1); + mm->add_instruction(migraphx::make_op("mul"), sub, z); + + return p; + } + std::string section() const { return "gemm"; } +}; From d3f1c294a9a4533b4d7a4af45eafca1746216ecd Mon Sep 17 00:00:00 2001 From: Mirza Halilcevic Date: Wed, 6 Nov 2024 12:32:49 +0000 Subject: [PATCH 04/14] Remove log lines. --- CMakeLists.txt | 2 +- src/targets/gpu/fuse_ck.cpp | 2 -- src/targets/gpu/jit/ck_gemm_gemm.cpp | 4 +--- src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp | 3 --- test/verify/run_verify.cpp | 6 ------ 5 files changed, 2 insertions(+), 15 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index bc4af5ad06c..e4a884ffbd0 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -76,7 +76,7 @@ endif() if(WIN32) option(MIGRAPHX_USE_HIPBLASLT "Enable MIGraphX to use hipBLASLt" OFF) else() -option(MIGRAPHX_USE_HIPBLASLT "Enable MIGraphX to use hipBLASLt" OFF) +option(MIGRAPHX_USE_HIPBLASLT "Enable MIGraphX to use hipBLASLt" ON) endif() # By default build shared libraries diff --git a/src/targets/gpu/fuse_ck.cpp b/src/targets/gpu/fuse_ck.cpp index 515a44bfc55..ed92759a857 100644 --- a/src/targets/gpu/fuse_ck.cpp +++ b/src/targets/gpu/fuse_ck.cpp @@ -266,8 +266,6 @@ struct find_ck_gemm_pointwise_gemm void apply(module_pass_manager& mpm, const match::matcher_result& r) { - std::cout << "matching ck_gemm_gemm" << std::endl; - auto ins = r.result; auto gemm0_ins = r.instructions["gemm0"]; auto gemm1_ins = r.instructions["gemm1"]; diff --git a/src/targets/gpu/jit/ck_gemm_gemm.cpp b/src/targets/gpu/jit/ck_gemm_gemm.cpp index 005cbda11e8..a07ae01ad80 100644 --- a/src/targets/gpu/jit/ck_gemm_gemm.cpp +++ b/src/targets/gpu/jit/ck_gemm_gemm.cpp @@ -168,9 +168,7 @@ struct ck_gemm_gemm_compiler : compiler } operation compile_op(context& ctx, const std::vector& inputs, const value& v) const - { - std::cout << "compiling ck_gemm_gemm: " << v.get("kernel", std::string{}) << std::endl; - + { const auto& e1_shape = inputs.back(); auto tuning_value = v.get("tuning_value", 0); auto batch_count = get_batch_count(e1_shape); diff --git a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp index fdd68bd077b..c5da7c4c4b3 100644 --- a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp +++ b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp @@ -182,9 +182,6 @@ struct ck_gemm_softmax_gemm_compiler : compiler {"preamble", v.get("preamble", std::string{})}, {"kernel", options.kernel_name}}); - // std::cout << "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" << std::endl; - // std::cout << src << std::endl; - return compile_hip_code_object(src, options); } diff --git a/test/verify/run_verify.cpp b/test/verify/run_verify.cpp index 56e244f9d3b..b40dd3a2cc3 100644 --- a/test/verify/run_verify.cpp +++ b/test/verify/run_verify.cpp @@ -133,9 +133,6 @@ run_verify::run_ref(migraphx::program p, auto trace_target = migraphx::string_value_of(MIGRAPHX_TRACE_TEST_COMPILE{}); compile_check(p, t, c_opts, (trace_target == "ref")); - std::cout << "REF PROGRAM" << std::endl; - p.debug_print(); - return std::make_pair(std::move(p), p.eval(std::move(inputs))); } @@ -163,9 +160,6 @@ run_verify::run_target(const migraphx::target& t, validate(t, p, m); p.eval(m); - std::cout << "GPU PROGRAM" << std::endl; - p.debug_print(); - auto tres = p.eval(m); std::vector res(tres.size()); std::transform( From 3684677a3795207909e3414b25485c75eb3afe9f Mon Sep 17 00:00:00 2001 From: Mirza Halilcevic Date: Wed, 13 Nov 2024 16:11:41 +0000 Subject: [PATCH 05/14] Limit CK GEMMs to k <= 1024. --- src/targets/gpu/fuse_ck.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/fuse_ck.cpp b/src/targets/gpu/fuse_ck.cpp index ed92759a857..2a768c200fe 100644 --- a/src/targets/gpu/fuse_ck.cpp +++ b/src/targets/gpu/fuse_ck.cpp @@ -163,7 +163,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) } return true; } - return k <= 2048; + return k <= 1024; } struct find_ck_gemm_pointwise From ab11f51cc6af8a1e63ac918606651b265441229b Mon Sep 17 00:00:00 2001 From: Mirza Halilcevic Date: Wed, 20 Nov 2024 00:33:51 +0000 Subject: [PATCH 06/14] Fix test. --- test/verify/run_verify.cpp | 1 - test/verify/test_ck_gemm_pointwise_gemm_pointwise.cpp | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/test/verify/run_verify.cpp b/test/verify/run_verify.cpp index b40dd3a2cc3..e7fdd535fcc 100644 --- a/test/verify/run_verify.cpp +++ b/test/verify/run_verify.cpp @@ -132,7 +132,6 @@ run_verify::run_ref(migraphx::program p, auto_print pp{p, t.name()}; auto trace_target = migraphx::string_value_of(MIGRAPHX_TRACE_TEST_COMPILE{}); compile_check(p, t, c_opts, (trace_target == "ref")); - return std::make_pair(std::move(p), p.eval(std::move(inputs))); } diff --git a/test/verify/test_ck_gemm_pointwise_gemm_pointwise.cpp b/test/verify/test_ck_gemm_pointwise_gemm_pointwise.cpp index 7dccf55be8c..06efcff220f 100644 --- a/test/verify/test_ck_gemm_pointwise_gemm_pointwise.cpp +++ b/test/verify/test_ck_gemm_pointwise_gemm_pointwise.cpp @@ -35,7 +35,7 @@ struct test_ck_gemm_pointwise_gemm_pointwise : verify_program Date: Tue, 18 Feb 2025 22:04:21 +0000 Subject: [PATCH 07/14] Revert ck_gemm_gemm. Signed-off-by: Mirza Halilcevic --- requirements.txt | 2 +- src/targets/gpu/CMakeLists.txt | 9 +- src/targets/gpu/compile_hip_code_object.cpp | 2 +- src/targets/gpu/fuse_ck.cpp | 179 ----------- src/targets/gpu/include/migraphx/gpu/ck.hpp | 1 - src/targets/gpu/jit/ck_gemm_gemm.cpp | 292 ------------------ src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp | 4 +- .../include/migraphx/kernels/ck_gemm_gemm.hpp | 88 ------ test/verify/test_ck_gemm_gemm.cpp | 53 ---- test/verify/test_ck_gemm_gemm_pointwise.cpp | 58 ---- test/verify/test_ck_gemm_pointwise_gemm.cpp | 58 ---- .../test_ck_gemm_pointwise_gemm_pointwise.cpp | 65 ---- ..._gemm_pointwise_gemm_pointwise_rotated.cpp | 62 ---- 13 files changed, 9 insertions(+), 864 deletions(-) delete mode 100644 src/targets/gpu/jit/ck_gemm_gemm.cpp delete mode 100644 src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_gemm.hpp delete mode 100644 test/verify/test_ck_gemm_gemm.cpp delete mode 100644 test/verify/test_ck_gemm_gemm_pointwise.cpp delete mode 100644 test/verify/test_ck_gemm_pointwise_gemm.cpp delete mode 100644 test/verify/test_ck_gemm_pointwise_gemm_pointwise.cpp delete mode 100644 test/verify/test_ck_gemm_pointwise_gemm_pointwise_rotated.cpp diff --git a/requirements.txt b/requirements.txt index 3fd7523ccfc..cc94659fc97 100755 --- a/requirements.txt +++ b/requirements.txt @@ -27,5 +27,5 @@ ROCm/half@rocm-5.6.0 pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/composable_kernel@b7775add2d28251674d81e220cd4a857b90b997a -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On +/home/mhalilce/composable_kernel --cmake subdir -DCMAKE_DIR=codegen -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/rocMLIR@da423a21b8aaf541af470999f34a1ff252bcdece -DBUILD_FAT_LIBROCKCOMPILER=On diff --git a/src/targets/gpu/CMakeLists.txt b/src/targets/gpu/CMakeLists.txt index 65766f99637..b5a8e37b793 100644 --- a/src/targets/gpu/CMakeLists.txt +++ b/src/targets/gpu/CMakeLists.txt @@ -57,7 +57,11 @@ else() endif() if(MIGRAPHX_USE_COMPOSABLEKERNEL) - find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS ck_host) + find_package(composable_kernel_host 1.0.0 REQUIRED) + if(NOT TARGET composable_kernel::ck_host) + # Manually including targets + include(${composable_kernel_host_TARGET_FILE}) + endif() endif() if(BUILD_DEV) @@ -132,8 +136,7 @@ file(GLOB JIT_GPU_SRCS CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/jit/*.cpp) if(NOT MIGRAPHX_USE_COMPOSABLEKERNEL) list(REMOVE_ITEM JIT_GPU_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm_softmax_gemm.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm_gemm.cpp) + ${CMAKE_CURRENT_SOURCE_DIR}/jit/ck_gemm_softmax_gemm.cpp) endif() if(MIGRAPHX_USE_MIOPEN) diff --git a/src/targets/gpu/compile_hip_code_object.cpp b/src/targets/gpu/compile_hip_code_object.cpp index aa294fc54c1..dfd18ad7d42 100644 --- a/src/targets/gpu/compile_hip_code_object.cpp +++ b/src/targets/gpu/compile_hip_code_object.cpp @@ -197,7 +197,7 @@ compile_hip_code_object(context& ctx, const std::string& content, hip_compile_op const auto& warnings = compiler_warnings(); options.params.insert(options.params.end(), warnings.begin(), warnings.end()); options.emplace_param("-ftemplate-backtrace-limit=0"); - // options.emplace_param("-Werror"); + options.emplace_param("-Werror"); auto cos = compile_hip_src(srcs, options.params, get_device_name()); if(cos.size() != 1) MIGRAPHX_THROW("No code object"); diff --git a/src/targets/gpu/fuse_ck.cpp b/src/targets/gpu/fuse_ck.cpp index 2a768c200fe..c60ca163a45 100644 --- a/src/targets/gpu/fuse_ck.cpp +++ b/src/targets/gpu/fuse_ck.cpp @@ -76,51 +76,6 @@ struct ck_gemm }; MIGRAPHX_REGISTER_OP(ck_gemm); -struct ck_gemm_gemm -{ - operation op = make_op("dot"); - size_t d0s_count = 0u; - size_t d1s_count = 0u; - - template - static auto reflect(Self& self, F f) - { - return pack( - f(self.op, "op"), f(self.d0s_count, "d0s_count"), f(self.d1s_count, "d1s_count")); - } - - std::string name() const { return "gpu::ck_gemm_gemm"; } - - void check_gemm_shape(const shape& s) const - { - if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1) and - not s.scalar()) - MIGRAPHX_THROW("Invalid shape for " + name()); - } - - shape compute_shape(std::vector inputs, const std::vector&) const - { - check_shapes{inputs, *this}.same_ndims(); - if(inputs.size() < 3) - MIGRAPHX_THROW(name() + ": Expected 3 inputs but got " + to_string(inputs.size())); - - auto a = inputs[0]; - auto b = inputs[1]; - auto b1 = inputs[2]; - - for(const auto& input : inputs) - { - check_gemm_shape(input); - } - - auto gemm0_shape = op.compute_shape({a, b}); - return op.compute_shape({gemm0_shape, b1}); - } - - static bool is_ck_supported_type(shape::type_t t) { return contains({shape::half_type}, t); } -}; -MIGRAPHX_REGISTER_OP(ck_gemm_gemm); - struct ck_gemm_softmax_gemm : gemm_softmax_gemm { std::string name() const { return "gpu::ck_gemm_softmax_gemm"; } @@ -248,145 +203,11 @@ struct find_ck_gemm_softmax_gemm } }; -struct find_ck_gemm_pointwise_gemm -{ - auto matcher() const - { - // TODO don't mix dot and quant_dot - // TODO match used_once? - auto gemm0 = match::skip(match::name("contiguous"))( - match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm0"))); - auto pw0 = - match::name("pointwise")(match::any_of[match::inputs()](gemm0.bind("x0")).bind("pw0")); - return match::name("dot", "quant_dot")(is_ck_gemm().bind("gemm1"))( - match::arg(0)(match::any_of(pw0, gemm0))); - } - - bool transposed_matrix(const shape& s) { return s.strides().back() != 1; } - - void apply(module_pass_manager& mpm, const match::matcher_result& r) - { - auto ins = r.result; - auto gemm0_ins = r.instructions["gemm0"]; - auto gemm1_ins = r.instructions["gemm1"]; - - auto inputs = gemm0_ins->inputs(); // A, B - inputs.push_back(gemm1_ins->inputs().back()); // B1 - - if (!transposed_matrix(inputs[1]->get_shape())) - return; - - size_t d0s_count = 0, d1s_count = 0; - std::vector module_inputs; - if(r.instructions.find("pw0") != r.instructions.end()) - { - auto pw0 = r.instructions["pw0"]; - auto x0_ins = r.instructions["x0"]; - - if(gemm0_ins->get_shape().type() != shape::int32_type and - pw0->get_shape().type() != gemm0_ins->get_shape().type()) - return; - if(std::any_of(pw0->inputs().begin(), pw0->inputs().end(), [](auto input) { - return not ck_gemm::is_ck_supported_type(input->get_shape().type()); - })) - return; - if(std::any_of(pw0->inputs().begin(), pw0->inputs().end(), [](auto input) { - return not input->inputs().empty() and - input->inputs().front()->name() == "capture"; - })) - return; - - auto pw0_inputs = pw0->inputs(); - auto* pw0m = pw0->module_inputs().front(); - - auto gemm_it = std::find(pw0_inputs.begin(), pw0_inputs.end(), x0_ins); - auto gemm_idx = gemm_it - pw0_inputs.begin(); - if(gemm_idx != 0) - { - rotate_gemm_input(pw0m, gemm_idx); - } - - pw0_inputs.erase(gemm_it); - inputs.insert(inputs.end(), pw0_inputs.begin(), pw0_inputs.end()); // D0s - - d0s_count = pw0_inputs.size(); - module_inputs.push_back(pw0m); - } - if(r.instructions.find("pw1") != r.instructions.end()) - { - auto pw1 = r.instructions["pw1"]; - - if(gemm1_ins->get_shape().type() != shape::int32_type and - pw1->get_shape().type() != gemm1_ins->get_shape().type()) - return; - if(std::any_of(pw1->inputs().begin(), pw1->inputs().end(), [](auto input) { - return not ck_gemm::is_ck_supported_type(input->get_shape().type()); - })) - return; - if(std::any_of(pw1->inputs().begin(), pw1->inputs().end(), [](auto input) { - return not input->inputs().empty() and - input->inputs().front()->name() == "capture"; - })) - return; - - auto pw1_inputs = pw1->inputs(); - auto* pw1m = pw1->module_inputs().front(); - - auto gemm_it = std::find(pw1_inputs.begin(), pw1_inputs.end(), gemm1_ins); - auto gemm_idx = gemm_it - pw1_inputs.begin(); - if(gemm_idx != 0) - { - rotate_gemm_input(pw1m, gemm_idx); - } - - pw1_inputs.erase(gemm_it); - inputs.insert(inputs.end(), pw1_inputs.begin(), pw1_inputs.end()); // D1s - - d1s_count = pw1_inputs.size(); - module_inputs.push_back(pw1m); - } - - mpm.get_module().replace_instruction( - ins, - ck_gemm_gemm{gemm1_ins->get_operator(), d0s_count, d1s_count}, - inputs, - module_inputs); - } - - void rotate_gemm_input(module* pwm, size_t gemm_idx) - { - auto names = pwm->get_parameter_names(); - - auto first_param = pwm->get_parameter(names[0]); - auto gemm_param = pwm->get_parameter(names[gemm_idx]); - - auto new_gemm_param = pwm->add_parameter(names[0] + "_0", gemm_param->get_shape()); - auto new_first_param = pwm->add_parameter(names[gemm_idx] + "_0", first_param->get_shape()); - - pwm->replace_instruction(gemm_param, new_gemm_param); - pwm->replace_instruction(first_param, new_first_param); - pwm->remove_instruction(first_param); - pwm->remove_instruction(gemm_param); - } -}; - -struct find_ck_gemm_pointwise_gemm_pointwise : find_ck_gemm_pointwise_gemm -{ - auto matcher() const - { - // TODO match used_once? - auto gemm1 = find_ck_gemm_pointwise_gemm::matcher(); - return match::name("pointwise")(match::any_of[match::inputs()](gemm1).bind("pw1")); - } -}; - } // namespace void fuse_ck::apply(module_pass_manager& mpm) const { match::find_matches(mpm, find_ck_gemm_softmax_gemm{}); - match::find_matches(mpm, find_ck_gemm_pointwise_gemm_pointwise{}); - match::find_matches(mpm, find_ck_gemm_pointwise_gemm{}); match::find_matches(mpm, find_ck_gemm_pointwise{}); match::find_matches(mpm, find_ck_gemm{}); } diff --git a/src/targets/gpu/include/migraphx/gpu/ck.hpp b/src/targets/gpu/include/migraphx/gpu/ck.hpp index d252db2c8b9..ea41b252547 100644 --- a/src/targets/gpu/include/migraphx/gpu/ck.hpp +++ b/src/targets/gpu/include/migraphx/gpu/ck.hpp @@ -33,7 +33,6 @@ #include "ck/host/headers.hpp" #include "ck/host/device_gemm_multiple_d/problem.hpp" #include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp" -#include "ck/host/device_batched_gemm_multiple_d_gemm_multiple_d/problem.hpp" namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { diff --git a/src/targets/gpu/jit/ck_gemm_gemm.cpp b/src/targets/gpu/jit/ck_gemm_gemm.cpp deleted file mode 100644 index a07ae01ad80..00000000000 --- a/src/targets/gpu/jit/ck_gemm_gemm.cpp +++ /dev/null @@ -1,292 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace migraphx { -inline namespace MIGRAPHX_INLINE_NS { - -namespace gpu { - -using namespace migraphx::gpu::gen; // NOLINT - -// NOLINTNEXTLINE -static const char* const ck_gemm_gemm_kernel = R"__migraphx__( -#include -#include -#include -#include -#include <${include}> - -namespace migraphx { - -${preamble0} - -${preamble1} - -extern "C" { - -MIGRAPHX_GLOBAL void ${kernel}(${params}) -{ - transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) { - ck_gemm_gemm<${solution}, ${blocks_per_batch}, ${d0s_count}>(xs...); - }); -} - -} - -} // namespace migraphx - -)__migraphx__"; - -struct ck_gemm_gemm_compiler : compiler -{ - std::vector names() const { return {"ck_gemm_gemm", "gpu::ck_gemm_gemm"}; } - - ck::host::device_batched_gemm_multiple_d_gemm_multiple_d::Problem - create_problem(const std::vector& inputs, const value& v) const - { - const auto& a0_shape = inputs[0]; - const auto& b0_shape = inputs[1]; - const auto& b1_shape = inputs[2]; - const auto& e1_shape = inputs.back(); - - // cppcheck-suppress unreadVariable - auto rank = a0_shape.ndim(); - auto batch_count = get_batch_count(e1_shape); - auto m = e1_shape.lens()[rank - 2]; - m = can_fold_batch(inputs) ? m * batch_count : m; - auto n = b0_shape.lens().back(); - auto k = a0_shape.lens().back(); - auto o = e1_shape.lens().back(); - - const bool trans_a0 = transposed_matrix(a0_shape); - const bool trans_b0 = transposed_matrix(b0_shape); - const bool trans_b1 = transposed_matrix(b1_shape); - const bool trans_e1 = transposed_matrix(e1_shape); - - auto d0s_count = v.get("d0s_count", 0); - auto d1s_count = v.get("d1s_count", 0); - - std::vector trans_d0s; - std::transform(inputs.begin() + 3, - inputs.end() - 1 - d1s_count, - std::back_inserter(trans_d0s), - [](const auto& i) { return transposed_matrix(i); }); - - std::vector trans_d1s; - std::transform(inputs.begin() + 3 + d0s_count, - inputs.end() - 1, - std::back_inserter(trans_d1s), - [](const auto& i) { return transposed_matrix(i); }); - - const auto a0_type = get_type(a0_shape); - const auto b0_type = get_type(b0_shape); - const auto b1_type = get_type(b1_shape); - const auto e1_type = get_type(e1_shape); - - std::vector d0s_type; - std::transform(inputs.begin() + 3, - inputs.end() - 1 - d1s_count, - std::back_inserter(d0s_type), - [](const auto& i) { return get_type(i); }); - - std::vector d1s_type; - std::transform(inputs.begin() + 3 + d0s_count, - inputs.end() - 1, - std::back_inserter(d1s_type), - [](const auto& i) { return get_type(i); }); - - std::string ck_passthrough = "ck_passthrough"; - std::string cde0_op = ck_passthrough; - std::string cde1_op = ck_passthrough; - assert(inputs.size() < 5 or v.contains("cde0_op") or v.containse("cde1_op")); - if(v.contains("cde0_op")) - { - cde0_op = v.at("cde0_op").to(); - } - if(v.contains("cde1_op")) - { - cde1_op = v.at("cde1_op").to(); - } - - return ck::host::device_batched_gemm_multiple_d_gemm_multiple_d::Problem{m, - n, - k, - o, - trans_a0, - trans_b0, - trans_d0s, - trans_b1, - trans_d1s, - trans_e1, - a0_type, - b0_type, - d0s_type, - b1_type, - d1s_type, - e1_type, - ck_passthrough, - ck_passthrough, - cde0_op, - ck_passthrough, - cde1_op}; - } - - operation compile_op(context& ctx, const std::vector& inputs, const value& v) const - { - const auto& e1_shape = inputs.back(); - auto tuning_value = v.get("tuning_value", 0); - auto batch_count = get_batch_count(e1_shape); - auto problem = create_problem(inputs, v); - - const auto include_header = problem.GetIncludeHeader(); - const auto solutions = - problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", ""); - const auto& solution = solutions.at(tuning_value); - const auto template_str = solution.ToTemplateString(); - const auto block_size = solution.GetTemplateParameter("BlockSize"); - const auto m_per_block = solution.GetTemplateParameter("Gemm0MPerBlock"); - const auto n1_per_block = solution.GetTemplateParameter("Gemm1NPerBlock"); - const auto blocks_per_batch = ck::host::integer_divide_ceil(problem.M, m_per_block) * - ck::host::integer_divide_ceil(problem.O, n1_per_block); - - hip_compile_options options; - options.additional_src_files = ck_headers(); - - auto grid_size = can_fold_batch(inputs) ? blocks_per_batch : batch_count * blocks_per_batch; - options.set_launch_params(v, grid_size * block_size, block_size); - options.inputs = inputs; - options.output = e1_shape; - options.kernel_name = v.get("kernel", "ck_gemm_gemm_kernel"); - options.virtual_inputs = inputs; - if(can_fold_batch(inputs)) - { - auto vinputs = inputs; - fold_batch_dims(vinputs[0]); - remove_batch_dims(vinputs[1]); - std::for_each(vinputs.begin() + 2, vinputs.end(), fold_batch_dims); - options.virtual_inputs = vinputs; - } - - if(v.get("check", false) or enabled(MIGRAPHX_CK_DEBUG{})) - options.emplace_param("-DMIGRAPHX_CK_CHECK=1"); - - auto src = interpolate_string(ck_gemm_gemm_kernel, - {{"solution", template_str}, - {"include", include_header}, - {"params", enum_params(inputs.size(), "void * private_p")}, - {"args", enum_params(inputs.size(), "private_p")}, - {"blocks_per_batch", to_string(blocks_per_batch)}, - {"d0s_count", to_string(v.get("d0s_count", 0))}, - {"preamble0", v.get("preamble0", std::string{})}, - {"preamble1", v.get("preamble1", std::string{})}, - {"kernel", options.kernel_name}}); - - return compile_hip_code_object(src, options); - } - - value create_settings(instruction_ref ins, const operation& op) const - { - auto v = op.to_value(); - - auto d0s_count = v.get("d0s_count", 0); - auto d1s_count = v.get("d1s_count", 0); - - std::string pw0_name, pw1_name, kernel_name = "ck_gemm_"; - if(d0s_count > 0) - { - auto* pw0m = ins->module_inputs().front(); - v["preamble0"] = generate_pointwise(*pw0m, "post_ck_gemm0_function") + - "\nMIGRAPHX_LIFT_CLASS(post_ck_gemm0, post_ck_gemm0_function);"; - v["cde0_op"] = "ck_function_adaptor"; - kernel_name += generate_name_from_ops(*pw0m) + "_"; - } - kernel_name += "gemm_"; - if(d1s_count > 0) - { - auto* pw1m = ins->module_inputs().back(); - v["preamble1"] = generate_pointwise(*pw1m, "post_ck_gemm1_function") + - "\nMIGRAPHX_LIFT_CLASS(post_ck_gemm1, post_ck_gemm1_function);"; - v["cde1_op"] = "ck_function_adaptor"; - kernel_name += generate_name_from_ops(*pw1m) + "_"; - } - - v["kernel"] = to_c_id(kernel_name + "kernel"); - return v; - } - - compiler_replace - compile(context& ctx, instruction_ref ins, const operation& op, const value& solution) const - { - auto shapes = to_shapes(ins->inputs()); - auto v = create_settings(ins, op); - if(not solution.is_null()) - v["tuning_value"] = solution; - return {compile_op(ctx, shapes, v), - [=](module& m, instruction_ref ins2, const operation& code_object) { - if(enabled(MIGRAPHX_LOG_CK_GEMM{})) - { - std::vector gemm_shapes{ - shapes[0], shapes[1], shapes.back().with_type(shapes[0].type())}; - std::cout << "gpu::ck_gemm_gemm: " << to_json_string(to_value(gemm_shapes)) - << std::endl; - } - m.replace_instruction(ins2, code_object, ins2->inputs()); - }}; - } - - optional - get_tuning_config(context& ctx, instruction_ref ins, const operation& op, bool exhaustive) const - { - if(not exhaustive and not enabled(MIGRAPHX_TUNE_CK{})) - return nullopt; - tuning_config tc; - auto shapes = to_shapes(ins->inputs()); - auto problem = create_problem(shapes, create_settings(ins, op)); - auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", ""); - tc.solutions.resize(solutions.size()); - std::iota(tc.solutions.begin(), tc.solutions.end(), 0); - std::vector gemm_shapes{shapes[0], shapes[1], shapes.back()}; - tc.problem = to_value(gemm_shapes); - return tc; - } -}; - -} // namespace gpu -} // namespace MIGRAPHX_INLINE_NS -} // namespace migraphx diff --git a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp index 839e888383e..ad40d84161d 100644 --- a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp +++ b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp @@ -130,7 +130,6 @@ struct ck_gemm_softmax_gemm_compiler : compiler operation compile_op(context& ctx, const std::vector& inputs, const value& v) const { - const auto& c_shape = inputs.back(); auto tuning_value = v.get("tuning_value", 5); auto batch_count = get_batch_count(c_shape); @@ -139,7 +138,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler const auto include_header = problem.GetIncludeHeader(); const auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", ""); - const auto& solution = solutions.at(tuning_value); + const auto& solution = solutions.at(tuning_value); const auto template_str = solution.ToTemplateString(); const auto block_size = solution.GetTemplateParameter("BlockSize"); const auto m_per_block = solution.GetTemplateParameter("Gemm01MPerBlock"); @@ -149,7 +148,6 @@ struct ck_gemm_softmax_gemm_compiler : compiler hip_compile_options options; options.additional_src_files = ck_headers(); - auto grid_size = can_fold_batch(inputs) ? blocks_per_batch : batch_count * blocks_per_batch; options.set_launch_params(v, grid_size * block_size, block_size); options.inputs = inputs; diff --git a/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_gemm.hpp b/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_gemm.hpp deleted file mode 100644 index f0dbeefd333..00000000000 --- a/src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_gemm.hpp +++ /dev/null @@ -1,88 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ -#ifndef MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP -#define MIGRAPHX_GUARD_KERNELS_CK_GEMM_HPP - -#include "migraphx/kernels/type_traits.hpp" -#include -#include -#include -#include -#include -#include - -namespace migraphx { - -template -__device__ void ck_gemm_gemm_matrix_impl( - E1 e1, A0 a0, B0 b0, B1 b1, detail::seq, detail::seq, tuple ds) -{ - constexpr auto desc = G::make_descriptor( - to_ck_tensor(), - to_ck_tensor>(), - ck::make_tuple( - to_ck_tensor(ds))>>()...), - to_ck_tensor>(), - ck::make_tuple( - to_ck_tensor(ds))>>()...), - to_ck_tensor()); - - MIGRAPHX_STATIC_ASSERT_FOR(desc.IsValid()) - { - G::Run(desc, - to_ck_const_pointer(a0.data()), - to_ck_const_pointer(b0.data()), - ck::make_tuple(to_ck_const_pointer(tuple_detail::get_element(ds).data())...), - to_ck_const_pointer(b1.data()), - ck::make_tuple(to_ck_const_pointer( - tuple_detail::get_element(ds).data())...), - to_ck_pointer(e1.data())); - } -} - -template -__device__ void ck_gemm_gemm_matrix(E1 e1, A0 a0, B0 b0, B1 b1, Ds... ds) -{ - auto all_ds = make_tuple(ds...); - ck_gemm_gemm_matrix_impl( - e1, a0, b0, b1, detail::gens{}, detail::gens{}, all_ds); -} - -template -__device__ void ck_gemm_gemm(Ts... xs) -{ - gemm_batch_args(make_index(), _c, xs...)( - [](auto... ys) { ck_gemm_gemm_matrix(ys...); }); -} - -} // namespace migraphx -#endif diff --git a/test/verify/test_ck_gemm_gemm.cpp b/test/verify/test_ck_gemm_gemm.cpp deleted file mode 100644 index 9595f0cd581..00000000000 --- a/test/verify/test_ck_gemm_gemm.cpp +++ /dev/null @@ -1,53 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ - -#include "verify_program.hpp" -#include -#include -#include - -struct test_ck_gemm_gemm : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto* mm = p.get_main_module(); - - migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 128}}; - migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 512, 128}}; - migraphx::shape m3_shape{migraphx::shape::half_type, {1, 12, 512, 64}}; - migraphx::shape x_shape{migraphx::shape::half_type, {1, 12, 256, 512}}; - - auto a = mm->add_parameter("1", m1_shape); - auto b = mm->add_parameter("2", m2_shape); - auto b1 = mm->add_parameter("3", m3_shape); - - b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); - auto gemm0 = mm->add_instruction(migraphx::make_op("dot"), a, b); - mm->add_instruction(migraphx::make_op("dot"), gemm0, b1); - - return p; - } - std::string section() const { return "gemm"; } -}; diff --git a/test/verify/test_ck_gemm_gemm_pointwise.cpp b/test/verify/test_ck_gemm_gemm_pointwise.cpp deleted file mode 100644 index 1dac3c5c101..00000000000 --- a/test/verify/test_ck_gemm_gemm_pointwise.cpp +++ /dev/null @@ -1,58 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ - -#include "verify_program.hpp" -#include -#include -#include - -struct test_ck_gemm_gemm_pointwise : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto* mm = p.get_main_module(); - - migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 128}}; - migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 512, 128}}; - migraphx::shape m3_shape{migraphx::shape::half_type, {1, 12, 512, 64}}; - migraphx::shape x_shape{migraphx::shape::half_type, {1, 12, 256, 64}}; - - auto a = mm->add_parameter("1", m1_shape); - auto b = mm->add_parameter("2", m2_shape); - auto b1 = mm->add_parameter("3", m3_shape); - auto x = mm->add_parameter("x", x_shape); - - b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); - auto gemm0 = mm->add_instruction(migraphx::make_op("dot"), a, b); - auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), gemm0, b1); - auto sin = mm->add_instruction(migraphx::make_op("sin"), gemm1); - auto cos = mm->add_instruction(migraphx::make_op("cos"), sin); - auto add = mm->add_instruction(migraphx::make_op("add"), cos, x); - mm->add_instruction(migraphx::make_op("add"), add, sin); - - return p; - } - std::string section() const { return "gemm"; } -}; diff --git a/test/verify/test_ck_gemm_pointwise_gemm.cpp b/test/verify/test_ck_gemm_pointwise_gemm.cpp deleted file mode 100644 index 655d7bd0d5a..00000000000 --- a/test/verify/test_ck_gemm_pointwise_gemm.cpp +++ /dev/null @@ -1,58 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ - -#include "verify_program.hpp" -#include -#include -#include - -struct test_ck_gemm_pointwise_gemm : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto* mm = p.get_main_module(); - - migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 128}}; - migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 512, 128}}; - migraphx::shape m3_shape{migraphx::shape::half_type, {1, 12, 512, 64}}; - migraphx::shape x_shape{migraphx::shape::half_type, {1, 12, 256, 512}}; - - auto a = mm->add_parameter("1", m1_shape); - auto b = mm->add_parameter("2", m2_shape); - auto b1 = mm->add_parameter("3", m3_shape); - auto x = mm->add_parameter("x", x_shape); - - b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); - auto gemm0 = mm->add_instruction(migraphx::make_op("dot"), a, b); - auto sin = mm->add_instruction(migraphx::make_op("sin"), gemm0); - auto cos = mm->add_instruction(migraphx::make_op("cos"), sin); - auto add = mm->add_instruction(migraphx::make_op("add"), cos, x); - auto add2 = mm->add_instruction(migraphx::make_op("add"), add, sin); - mm->add_instruction(migraphx::make_op("dot"), add2, b1); - - return p; - } - std::string section() const { return "gemm"; } -}; diff --git a/test/verify/test_ck_gemm_pointwise_gemm_pointwise.cpp b/test/verify/test_ck_gemm_pointwise_gemm_pointwise.cpp deleted file mode 100644 index 06efcff220f..00000000000 --- a/test/verify/test_ck_gemm_pointwise_gemm_pointwise.cpp +++ /dev/null @@ -1,65 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ - -#include "verify_program.hpp" -#include -#include -#include - -struct test_ck_gemm_pointwise_gemm_pointwise : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto* mm = p.get_main_module(); - - migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 128}}; - migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 512, 128}}; - migraphx::shape x_shape{migraphx::shape::half_type, {1, 12, 256, 512}}; - migraphx::shape m3_shape{migraphx::shape::half_type, {1, 12, 64, 512}}; - migraphx::shape y_shape{migraphx::shape::half_type, {1, 12, 256, 64}}; - - auto a = mm->add_parameter("1", m1_shape); - auto b = mm->add_parameter("2", m2_shape); - auto b1 = mm->add_parameter("3", m3_shape); - auto x = mm->add_parameter("x", x_shape); - auto y = mm->add_parameter("y", y_shape); - auto z = mm->add_parameter("z", y_shape); - - b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); - auto gemm0 = mm->add_instruction(migraphx::make_op("dot"), a, b); - auto sin = mm->add_instruction(migraphx::make_op("sin"), gemm0); - auto cos = mm->add_instruction(migraphx::make_op("cos"), sin); - auto add0 = mm->add_instruction(migraphx::make_op("add"), x, cos); - auto add1 = mm->add_instruction(migraphx::make_op("add"), add0, sin); - b1 = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), - b1); - auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), add1, b1); - auto sub = mm->add_instruction(migraphx::make_op("sub"), gemm1, y); - mm->add_instruction(migraphx::make_op("mul"), sub, z); - - return p; - } - std::string section() const { return "gemm"; } -}; diff --git a/test/verify/test_ck_gemm_pointwise_gemm_pointwise_rotated.cpp b/test/verify/test_ck_gemm_pointwise_gemm_pointwise_rotated.cpp deleted file mode 100644 index 86d7d740337..00000000000 --- a/test/verify/test_ck_gemm_pointwise_gemm_pointwise_rotated.cpp +++ /dev/null @@ -1,62 +0,0 @@ -/* - * The MIT License (MIT) - * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ - -#include "verify_program.hpp" -#include -#include -#include - -struct test_ck_gemm_pointwise_gemm_pointwise_rotated - : verify_program -{ - migraphx::program create_program() const - { - migraphx::program p; - auto* mm = p.get_main_module(); - - migraphx::shape m1_shape{migraphx::shape::half_type, {1, 12, 256, 128}}; - migraphx::shape m2_shape{migraphx::shape::half_type, {1, 12, 512, 128}}; - migraphx::shape x_shape{migraphx::shape::half_type, {1, 12, 256, 512}}; - migraphx::shape m3_shape{migraphx::shape::half_type, {1, 12, 512, 64}}; - migraphx::shape y_shape{migraphx::shape::half_type, {1, 12, 256, 64}}; - - auto a = mm->add_parameter("1", m1_shape); - auto b = mm->add_parameter("2", m2_shape); - auto b1 = mm->add_parameter("3", m3_shape); - auto x = mm->add_parameter("x", x_shape); - auto y = mm->add_parameter("y", y_shape); - auto z = mm->add_parameter("z", y_shape); - - b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b); - auto gemm0 = mm->add_instruction(migraphx::make_op("dot"), a, b); - auto add = mm->add_instruction(migraphx::make_op("sub"), x, gemm0); - - auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), add, b1); - auto sub = mm->add_instruction(migraphx::make_op("sub"), y, gemm1); - mm->add_instruction(migraphx::make_op("mul"), sub, z); - - return p; - } - std::string section() const { return "gemm"; } -}; From 301f48efafa30e9b80c8492b33a7370809e323c3 Mon Sep 17 00:00:00 2001 From: Mirza Halilcevic Date: Wed, 5 Mar 2025 00:10:03 +0000 Subject: [PATCH 08/14] Update commit id for composable_kernel and add CK_CODE_GEN_RTC flag to CK compilers. --- requirements.txt | 2 +- src/targets/gpu/jit/ck_gemm.cpp | 1 + src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 642fcb4a1c7..dd911134c82 100755 --- a/requirements.txt +++ b/requirements.txt @@ -27,5 +27,5 @@ ROCm/half@rocm-5.6.0 pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/composable_kernel@b7775add2d28251674d81e220cd4a857b90b997a -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On +ROCm/composable_kernel@fd06ed926c0d8b4a8f758cfb9aaa4d0418ca80b6 --cmake subdir -DCMAKE_DIR=codegen -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/rocMLIR@b34604275d8ebfc9a830be931f3909f9cdd00ac2 -DBUILD_FAT_LIBROCKCOMPILER=On diff --git a/src/targets/gpu/jit/ck_gemm.cpp b/src/targets/gpu/jit/ck_gemm.cpp index 23376d6c666..0704a412bb9 100644 --- a/src/targets/gpu/jit/ck_gemm.cpp +++ b/src/targets/gpu/jit/ck_gemm.cpp @@ -152,6 +152,7 @@ struct ck_gemm_compiler : compiler ck::host::integer_divide_ceil(problem.N, n_per_block); hip_compile_options options; + options.emplace_param("-DCK_CODE_GEN_RTC"); options.additional_src_files = ck_headers(); auto grid_size = can_fold_batch(inputs) ? blocks_per_batch : batch_count * blocks_per_batch; options.set_launch_params(v, grid_size * block_size, block_size); diff --git a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp index ad40d84161d..32d84a911fa 100644 --- a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp +++ b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp @@ -147,6 +147,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler ck::host::integer_divide_ceil(problem.O, n1_per_block); hip_compile_options options; + options.emplace_param("-DCK_CODE_GEN_RTC"); options.additional_src_files = ck_headers(); auto grid_size = can_fold_batch(inputs) ? blocks_per_batch : batch_count * blocks_per_batch; options.set_launch_params(v, grid_size * block_size, block_size); From 51fea53e43232016e9c2c3ca9a425ea492deef61 Mon Sep 17 00:00:00 2001 From: Mirza Halilcevic Date: Wed, 26 Mar 2025 12:29:25 +0000 Subject: [PATCH 09/14] Use default arguments for prologue and epilogue in CK compilers. --- requirements.txt | 2 +- src/targets/gpu/jit/ck_gemm.cpp | 7 +++---- src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp | 7 +++---- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/requirements.txt b/requirements.txt index 485535b2dda..dc0452d2763 100755 --- a/requirements.txt +++ b/requirements.txt @@ -27,5 +27,5 @@ ROCm/half@rocm-5.6.0 pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/composable_kernel@fd06ed926c0d8b4a8f758cfb9aaa4d0418ca80b6 --cmake subdir -DCMAKE_DIR=codegen -DCMAKE_POSITION_INDEPENDENT_CODE=On +ROCm/composable_kernel@61e8dd730c48c2651a553903df82c58ff13ed87c --cmake subdir -DCMAKE_DIR=codegen -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/rocMLIR@b6c726a324d2807494f770ed915af0e54cc6cb3f -DBUILD_FAT_LIBROCKCOMPILER=On diff --git a/src/targets/gpu/jit/ck_gemm.cpp b/src/targets/gpu/jit/ck_gemm.cpp index 0704a412bb9..ca357f138f5 100644 --- a/src/targets/gpu/jit/ck_gemm.cpp +++ b/src/targets/gpu/jit/ck_gemm.cpp @@ -140,9 +140,8 @@ struct ck_gemm_compiler : compiler auto batch_count = get_batch_count(c_shape); auto problem = create_problem(inputs, v); - const auto include_header = problem.GetIncludeHeader(); - const auto solutions = - problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", ""); + const auto include_header = problem.GetIncludeHeader(); + const auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name()); const auto& solution = solutions.at(tuning_value); const auto template_str = solution.ToTemplateString(); const auto block_size = solution.GetTemplateParameter("BlockSize"); @@ -227,7 +226,7 @@ struct ck_gemm_compiler : compiler tuning_config tc; auto shapes = to_shapes(ins->inputs()); auto problem = create_problem(shapes, create_settings(ins, op)); - auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", ""); + auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name()); tc.solutions.resize(solutions.size()); std::iota(tc.solutions.begin(), tc.solutions.end(), 0); std::vector gemm_shapes{shapes[0], shapes[1], shapes.back()}; diff --git a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp index 32d84a911fa..2ae4c0dd123 100644 --- a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp +++ b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp @@ -135,9 +135,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler auto batch_count = get_batch_count(c_shape); auto problem = create_problem(inputs, v); - const auto include_header = problem.GetIncludeHeader(); - const auto solutions = - problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", ""); + const auto include_header = problem.GetIncludeHeader(); + const auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name()); const auto& solution = solutions.at(tuning_value); const auto template_str = solution.ToTemplateString(); const auto block_size = solution.GetTemplateParameter("BlockSize"); @@ -228,7 +227,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler tuning_config tc; auto shapes = to_shapes(ins->inputs()); auto problem = create_problem(shapes, create_settings(ins, op)); - auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name(), "", ""); + auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name()); tc.solutions.resize(solutions.size()); std::iota(tc.solutions.begin(), tc.solutions.end(), 0); std::vector gemm_shapes{shapes[0], shapes[1], shapes.back()}; From e15464f80de13bc72e42a12137c5be6cd410f1f7 Mon Sep 17 00:00:00 2001 From: Mirza Halilcevic Date: Wed, 26 Mar 2025 17:03:28 +0000 Subject: [PATCH 10/14] Update CK commit hash. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index dc0452d2763..b294f0fba4e 100755 --- a/requirements.txt +++ b/requirements.txt @@ -27,5 +27,5 @@ ROCm/half@rocm-5.6.0 pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/composable_kernel@61e8dd730c48c2651a553903df82c58ff13ed87c --cmake subdir -DCMAKE_DIR=codegen -DCMAKE_POSITION_INDEPENDENT_CODE=On +ROCm/composable_kernel@21e0ca197de46062ee72f4ed773696a4f266aa9f --cmake subdir -DCMAKE_DIR=codegen -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/rocMLIR@b6c726a324d2807494f770ed915af0e54cc6cb3f -DBUILD_FAT_LIBROCKCOMPILER=On From ee9ae1475c5e5aee5aee73a9b6a97b80d8106261 Mon Sep 17 00:00:00 2001 From: Mirza Halilcevic Date: Wed, 16 Apr 2025 12:14:25 +0000 Subject: [PATCH 11/14] Update licenses. --- src/targets/gpu/CMakeLists.txt | 2 +- src/targets/gpu/fuse_ck.cpp | 2 +- src/targets/gpu/include/migraphx/gpu/ck.hpp | 2 +- src/targets/gpu/jit/ck_gemm.cpp | 2 +- src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/targets/gpu/CMakeLists.txt b/src/targets/gpu/CMakeLists.txt index b5a8e37b793..d4d24e1f889 100644 --- a/src/targets/gpu/CMakeLists.txt +++ b/src/targets/gpu/CMakeLists.txt @@ -1,7 +1,7 @@ # #################################################################################### # The MIT License (MIT) # -# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal diff --git a/src/targets/gpu/fuse_ck.cpp b/src/targets/gpu/fuse_ck.cpp index c60ca163a45..fa15250f2a6 100644 --- a/src/targets/gpu/fuse_ck.cpp +++ b/src/targets/gpu/fuse_ck.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/targets/gpu/include/migraphx/gpu/ck.hpp b/src/targets/gpu/include/migraphx/gpu/ck.hpp index ea41b252547..55a08793417 100644 --- a/src/targets/gpu/include/migraphx/gpu/ck.hpp +++ b/src/targets/gpu/include/migraphx/gpu/ck.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/targets/gpu/jit/ck_gemm.cpp b/src/targets/gpu/jit/ck_gemm.cpp index ca357f138f5..43b51e03ddd 100644 --- a/src/targets/gpu/jit/ck_gemm.cpp +++ b/src/targets/gpu/jit/ck_gemm.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp index 2ae4c0dd123..97a8edc4bbb 100644 --- a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp +++ b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal From 8f8e77d6693cb1c508d48b6190d7c8fafe323263 Mon Sep 17 00:00:00 2001 From: Mirza Halilcevic Date: Fri, 30 May 2025 12:40:26 +0000 Subject: [PATCH 12/14] Remove manual targets include. --- src/targets/gpu/CMakeLists.txt | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/targets/gpu/CMakeLists.txt b/src/targets/gpu/CMakeLists.txt index 38b25535e3a..83e8bff2318 100644 --- a/src/targets/gpu/CMakeLists.txt +++ b/src/targets/gpu/CMakeLists.txt @@ -58,10 +58,6 @@ endif() if(MIGRAPHX_USE_COMPOSABLEKERNEL) find_package(composable_kernel_host 1.0.0 REQUIRED) - if(NOT TARGET composable_kernel::ck_host) - # Manually including targets - include(${composable_kernel_host_TARGET_FILE}) - endif() endif() if(BUILD_DEV) From 202e84b0fdd51c491f701c6de8955a5e2bc6b253 Mon Sep 17 00:00:00 2001 From: Mirza Halilcevic Date: Fri, 30 May 2025 13:18:50 +0000 Subject: [PATCH 13/14] Update CK commit hash. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index d3456f79308..0d1e057bff7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,5 +27,5 @@ ROCm/half@rocm-5.6.0 pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/composable_kernel@306f4c537e08e6ba5c16ee8a406ff7821db490cb --cmake subdir -DCMAKE_DIR=codegen -DCMAKE_POSITION_INDEPENDENT_CODE=On +ROCm/composable_kernel@0c91e209b8b4ad87c00b332131ef355b061bb7e4 --cmake subdir -DCMAKE_DIR=codegen -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/rocMLIR@dfb5b0881786ec78dff2a4294565bd8c4dc2d7b1 -DBUILD_FAT_LIBROCKCOMPILER=On From eb84233bd72b04356802b1851e1cbca0221aaefe Mon Sep 17 00:00:00 2001 From: Mirza Halilcevic Date: Fri, 30 May 2025 15:25:11 +0000 Subject: [PATCH 14/14] Update CK commit hash. --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 0d1e057bff7..94cf067d1f7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,5 +27,5 @@ ROCm/half@rocm-5.6.0 pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/composable_kernel@0c91e209b8b4ad87c00b332131ef355b061bb7e4 --cmake subdir -DCMAKE_DIR=codegen -DCMAKE_POSITION_INDEPENDENT_CODE=On +ROCm/composable_kernel@fbce6c7bb6dad3750e33e999d438197cdc5c7fe8 --cmake subdir -DCMAKE_DIR=codegen -DCMAKE_POSITION_INDEPENDENT_CODE=On ROCm/rocMLIR@dfb5b0881786ec78dff2a4294565bd8c4dc2d7b1 -DBUILD_FAT_LIBROCKCOMPILER=On