Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
ee0cb4f
Integrate new CK API
Sep 2, 2024
2da1062
Merge remote-tracking branch 'upstream/develop' into ck_integration
Sep 2, 2024
944a8a5
Merge remote-tracking branch 'upstream/develop' into ck_integration
mirza-halilcevic Oct 15, 2024
4488269
Matcher for gemm_gemm.
mirza-halilcevic Oct 21, 2024
69305ae
Merge remote-tracking branch 'upstream/develop' into ck_integration_g…
mirza-halilcevic Oct 28, 2024
edaa048
Finalize gemm_gemm integration.
mirza-halilcevic Oct 31, 2024
5e1b0e0
Merge remote-tracking branch 'upstream/develop' into ck_integration_g…
mirza-halilcevic Nov 6, 2024
d3f1c29
Remove log lines.
mirza-halilcevic Nov 6, 2024
9a375b2
Merge remote-tracking branch 'upstream/develop' into ck_integration_g…
mirza-halilcevic Nov 9, 2024
3684677
Limit CK GEMMs to k <= 1024.
mirza-halilcevic Nov 13, 2024
ab11f51
Fix test.
mirza-halilcevic Nov 20, 2024
19ebed9
Merge remote-tracking branch 'origin/develop' into ck_integration_gem…
mirza-halilcevic Feb 17, 2025
31f68b5
Revert ck_gemm_gemm.
mirza-halilcevic Feb 18, 2025
1c1b778
Merge remote-tracking branch 'upstream/develop' into ck_integration_g…
mirza-halilcevic Feb 18, 2025
25eabd1
Merge remote-tracking branch 'upstream/develop' into ck_integration_g…
mirza-halilcevic Feb 21, 2025
2589f37
Merge remote-tracking branch 'upstream/develop' into ck_integration_g…
mirza-halilcevic Feb 28, 2025
170b227
Merge remote-tracking branch 'upstream/develop' into ck_integration_g…
mirza-halilcevic Mar 4, 2025
301f48e
Update commit id for composable_kernel and add CK_CODE_GEN_RTC flag to
mirza-halilcevic Mar 5, 2025
f5da942
Merge remote-tracking branch 'upstream/develop' into ck_integration_g…
mirza-halilcevic Mar 26, 2025
51fea53
Use default arguments for prologue and epilogue in CK compilers.
mirza-halilcevic Mar 26, 2025
e15464f
Update CK commit hash.
mirza-halilcevic Mar 26, 2025
294ddc8
Merge remote-tracking branch 'upstream/develop' into ck_integration_g…
mirza-halilcevic Apr 16, 2025
ee9ae14
Update licenses.
mirza-halilcevic Apr 16, 2025
5a186e8
Merge remote-tracking branch 'upstream/develop' into ck_integration_g…
mirza-halilcevic May 30, 2025
8f8e77d
Remove manual targets include.
mirza-halilcevic May 30, 2025
202e84b
Update CK commit hash.
mirza-halilcevic May 30, 2025
eb84233
Update CK commit hash.
mirza-halilcevic May 30, 2025
92f55d8
Merge branch 'develop' into ck_integration_gemm_gemm
mirza-halilcevic Jun 4, 2025
7658dba
Merge branch 'develop' into ck_integration_gemm_gemm
mirza-halilcevic Jun 18, 2025
dc4da4d
Merge branch 'develop' into ck_integration_gemm_gemm
mirza-halilcevic Jul 11, 2025
cc546c1
Merge branch 'develop' into ck_integration_gemm_gemm
mirza-halilcevic Jul 16, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ ROCm/[email protected]
pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build
msgpack/[email protected] -DMSGPACK_BUILD_TESTS=Off
[email protected] -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCm/composable_kernel@b7775add2d28251674d81e220cd4a857b90b997a -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCm/rocMLIR@221ad7adb549ea2a39c8e15d07a7ada6e89cef88 -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off
ROCm/composable_kernel@fbce6c7bb6dad3750e33e999d438197cdc5c7fe8 --cmake subdir -DCMAKE_DIR=codegen -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCm/rocMLIR@221ad7adb549ea2a39c8e15d07a7ada6e89cef88 -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off
6 changes: 3 additions & 3 deletions src/targets/gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ else()
endif()

if(MIGRAPHX_USE_COMPOSABLEKERNEL)
find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS jit_library)
find_package(composable_kernel_host 1.0.0 REQUIRED)
endif()

if(BUILD_DEV)
Expand Down Expand Up @@ -121,7 +121,7 @@ target_compile_definitions(kernel_file_check PRIVATE -DMIGRAPHX_WAVEFRONTSIZE=64
target_include_directories(kernel_file_check PRIVATE $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/>)
target_link_libraries(kernel_file_check compile_for_gpu)
if(MIGRAPHX_USE_COMPOSABLEKERNEL)
target_link_libraries(kernel_file_check composable_kernel::jit_library)
target_link_libraries(kernel_file_check composable_kernel::ck_host)
endif()

rocm_clang_tidy_check(kernel_file_check)
Expand Down Expand Up @@ -389,7 +389,7 @@ else()
endif()
target_link_libraries(migraphx_gpu PRIVATE migraphx_kernels)
if(MIGRAPHX_USE_COMPOSABLEKERNEL)
target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library)
target_link_libraries(migraphx_gpu PRIVATE composable_kernel::ck_host)
target_compile_definitions(migraphx_gpu PRIVATE MIGRAPHX_USE_COMPOSABLEKERNEL=1)
endif()

Expand Down
17 changes: 9 additions & 8 deletions src/targets/gpu/fuse_ck.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -90,11 +90,11 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return false;
if(not ck_gemm::is_ck_supported_type(ins->get_shape().type()))
return false;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
auto m = a.lens()[a.lens().size() - 2];
auto n = b.lens().back();
auto k = a.lens().back();
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
auto m = a.lens()[a.lens().size() - 2];
auto n = b.lens().back();
auto k = a.lens().back();
auto batch_size = std::accumulate(
a.lens().rbegin() + 2, a.lens().rend(), std::size_t{1}, std::multiplies<std::size_t>());
// Integer gemms must be divisible by 4 in ck
Expand All @@ -118,7 +118,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
}
return true;
}
return k <= 2048;
return k <= 1024;
}

struct find_ck_gemm_pointwise
Expand Down Expand Up @@ -207,7 +207,8 @@ struct find_ck_gemm_softmax_gemm

void fuse_ck::apply(module_pass_manager& mpm) const
{
match::find_matches(mpm, find_ck_gemm_softmax_gemm{}, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm_softmax_gemm{});
match::find_matches(mpm, find_ck_gemm_pointwise{});
match::find_matches(mpm, find_ck_gemm{});
}

Expand Down
7 changes: 4 additions & 3 deletions src/targets/gpu/include/migraphx/gpu/ck.hpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -30,8 +30,9 @@
#include <migraphx/stringutils.hpp>
#include <string_view>

#include "ck/host/device_gemm_multiple_d.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm.hpp"
#include "ck/host/headers.hpp"
#include "ck/host/device_gemm_multiple_d/problem.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down
17 changes: 11 additions & 6 deletions src/targets/gpu/jit/ck_gemm.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -37,6 +37,7 @@
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <ck/host/utils.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down Expand Up @@ -136,17 +137,21 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{
const auto& c_shape = inputs.back();
auto tuning_value = v.get("tuning_value", 34);
auto batch_count = get_batch_count(c_shape);
auto problem = create_problem(inputs, v);
auto batch_count = get_batch_count(c_shape);
auto problem = create_problem(inputs, v);

const auto include_header = problem.GetIncludeHeader();
const auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name());
const auto& solution = solutions.at(tuning_value);
const auto template_str = solution.template_str;
const auto blocks_per_batch = solution.grid_size;
const auto block_size = solution.block_size;
const auto template_str = solution.ToTemplateString();
const auto block_size = solution.GetTemplateParameter<std::size_t>("BlockSize");
const auto m_per_block = solution.GetTemplateParameter<std::size_t>("MPerBlock");
const auto n_per_block = solution.GetTemplateParameter<std::size_t>("NPerBlock");
const auto blocks_per_batch = ck::host::integer_divide_ceil(problem.M, m_per_block) *
ck::host::integer_divide_ceil(problem.N, n_per_block);

hip_compile_options options;
options.emplace_param("-DCK_CODE_GEN_RTC");
options.additional_src_files = ck_headers();
auto grid_size = can_fold_batch(inputs) ? blocks_per_batch : batch_count * blocks_per_batch;
options.set_launch_params(v, grid_size * block_size, block_size);
Expand Down
13 changes: 9 additions & 4 deletions src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -37,6 +37,7 @@
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <ck/host/utils.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down Expand Up @@ -137,11 +138,15 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
const auto include_header = problem.GetIncludeHeader();
const auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name());
const auto& solution = solutions.at(tuning_value);
const auto template_str = solution.template_str;
const auto blocks_per_batch = solution.grid_size;
const auto block_size = solution.block_size;
const auto template_str = solution.ToTemplateString();
const auto block_size = solution.GetTemplateParameter<std::size_t>("BlockSize");
const auto m_per_block = solution.GetTemplateParameter<std::size_t>("Gemm01MPerBlock");
const auto n1_per_block = solution.GetTemplateParameter<std::size_t>("Gemm1NPerBlock");
const auto blocks_per_batch = ck::host::integer_divide_ceil(problem.M, m_per_block) *
ck::host::integer_divide_ceil(problem.O, n1_per_block);

hip_compile_options options;
options.emplace_param("-DCK_CODE_GEN_RTC");
options.additional_src_files = ck_headers();
auto grid_size = can_fold_batch(inputs) ? blocks_per_batch : batch_count * blocks_per_batch;
options.set_launch_params(v, grid_size * block_size, block_size);
Expand Down
Loading