diff --git a/requirements.txt b/requirements.txt index 5d55195d544..961b19bdc71 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,5 +27,5 @@ ROCm/half@rocm-5.6.0 pybind/pybind11@3e9dfa2866941655c56877882565e7577de6fc7b --build msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off sqlite3@3.43.2 -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/composable_kernel@b7775add2d28251674d81e220cd4a857b90b997a -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On -ROCm/rocMLIR@221ad7adb549ea2a39c8e15d07a7ada6e89cef88 -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off +ROCm/composable_kernel@fbce6c7bb6dad3750e33e999d438197cdc5c7fe8 --cmake subdir -DCMAKE_DIR=codegen -DCMAKE_POSITION_INDEPENDENT_CODE=On +ROCm/rocMLIR@221ad7adb549ea2a39c8e15d07a7ada6e89cef88 -DBUILD_FAT_LIBROCKCOMPILER=On -DLLVM_INCLUDE_TESTS=Off \ No newline at end of file diff --git a/src/targets/gpu/CMakeLists.txt b/src/targets/gpu/CMakeLists.txt index 497ac403cec..83e8bff2318 100644 --- a/src/targets/gpu/CMakeLists.txt +++ b/src/targets/gpu/CMakeLists.txt @@ -57,7 +57,7 @@ else() endif() if(MIGRAPHX_USE_COMPOSABLEKERNEL) - find_package(composable_kernel 1.0.0 REQUIRED COMPONENTS jit_library) + find_package(composable_kernel_host 1.0.0 REQUIRED) endif() if(BUILD_DEV) @@ -121,7 +121,7 @@ target_compile_definitions(kernel_file_check PRIVATE -DMIGRAPHX_WAVEFRONTSIZE=64 target_include_directories(kernel_file_check PRIVATE $) target_link_libraries(kernel_file_check compile_for_gpu) if(MIGRAPHX_USE_COMPOSABLEKERNEL) - target_link_libraries(kernel_file_check composable_kernel::jit_library) + target_link_libraries(kernel_file_check composable_kernel::ck_host) endif() rocm_clang_tidy_check(kernel_file_check) @@ -389,7 +389,7 @@ else() endif() target_link_libraries(migraphx_gpu PRIVATE migraphx_kernels) if(MIGRAPHX_USE_COMPOSABLEKERNEL) - target_link_libraries(migraphx_gpu PRIVATE composable_kernel::jit_library) + target_link_libraries(migraphx_gpu PRIVATE composable_kernel::ck_host) target_compile_definitions(migraphx_gpu PRIVATE MIGRAPHX_USE_COMPOSABLEKERNEL=1) endif() diff --git a/src/targets/gpu/fuse_ck.cpp b/src/targets/gpu/fuse_ck.cpp index bf9a269f3e1..fa15250f2a6 100644 --- a/src/targets/gpu/fuse_ck.cpp +++ b/src/targets/gpu/fuse_ck.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -90,11 +90,11 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) return false; if(not ck_gemm::is_ck_supported_type(ins->get_shape().type())) return false; - auto a = ins->inputs().front()->get_shape(); - auto b = ins->inputs().back()->get_shape(); - auto m = a.lens()[a.lens().size() - 2]; - auto n = b.lens().back(); - auto k = a.lens().back(); + auto a = ins->inputs().front()->get_shape(); + auto b = ins->inputs().back()->get_shape(); + auto m = a.lens()[a.lens().size() - 2]; + auto n = b.lens().back(); + auto k = a.lens().back(); auto batch_size = std::accumulate( a.lens().rbegin() + 2, a.lens().rend(), std::size_t{1}, std::multiplies()); // Integer gemms must be divisible by 4 in ck @@ -118,7 +118,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) } return true; } - return k <= 2048; + return k <= 1024; } struct find_ck_gemm_pointwise @@ -207,7 +207,8 @@ struct find_ck_gemm_softmax_gemm void fuse_ck::apply(module_pass_manager& mpm) const { - match::find_matches(mpm, find_ck_gemm_softmax_gemm{}, find_ck_gemm_pointwise{}); + match::find_matches(mpm, find_ck_gemm_softmax_gemm{}); + match::find_matches(mpm, find_ck_gemm_pointwise{}); match::find_matches(mpm, find_ck_gemm{}); } diff --git a/src/targets/gpu/include/migraphx/gpu/ck.hpp b/src/targets/gpu/include/migraphx/gpu/ck.hpp index 18d4dce25a2..55a08793417 100644 --- a/src/targets/gpu/include/migraphx/gpu/ck.hpp +++ b/src/targets/gpu/include/migraphx/gpu/ck.hpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -30,8 +30,9 @@ #include #include -#include "ck/host/device_gemm_multiple_d.hpp" -#include "ck/host/device_batched_gemm_softmax_gemm.hpp" +#include "ck/host/headers.hpp" +#include "ck/host/device_gemm_multiple_d/problem.hpp" +#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp" namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { diff --git a/src/targets/gpu/jit/ck_gemm.cpp b/src/targets/gpu/jit/ck_gemm.cpp index 392eaa0c67b..43b51e03ddd 100644 --- a/src/targets/gpu/jit/ck_gemm.cpp +++ b/src/targets/gpu/jit/ck_gemm.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -37,6 +37,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -136,17 +137,21 @@ struct ck_gemm_compiler : compiler { const auto& c_shape = inputs.back(); auto tuning_value = v.get("tuning_value", 34); - auto batch_count = get_batch_count(c_shape); - auto problem = create_problem(inputs, v); + auto batch_count = get_batch_count(c_shape); + auto problem = create_problem(inputs, v); const auto include_header = problem.GetIncludeHeader(); const auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name()); const auto& solution = solutions.at(tuning_value); - const auto template_str = solution.template_str; - const auto blocks_per_batch = solution.grid_size; - const auto block_size = solution.block_size; + const auto template_str = solution.ToTemplateString(); + const auto block_size = solution.GetTemplateParameter("BlockSize"); + const auto m_per_block = solution.GetTemplateParameter("MPerBlock"); + const auto n_per_block = solution.GetTemplateParameter("NPerBlock"); + const auto blocks_per_batch = ck::host::integer_divide_ceil(problem.M, m_per_block) * + ck::host::integer_divide_ceil(problem.N, n_per_block); hip_compile_options options; + options.emplace_param("-DCK_CODE_GEN_RTC"); options.additional_src_files = ck_headers(); auto grid_size = can_fold_batch(inputs) ? blocks_per_batch : batch_count * blocks_per_batch; options.set_launch_params(v, grid_size * block_size, block_size); diff --git a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp index 693153d0982..97a8edc4bbb 100644 --- a/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp +++ b/src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -37,6 +37,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -137,11 +138,15 @@ struct ck_gemm_softmax_gemm_compiler : compiler const auto include_header = problem.GetIncludeHeader(); const auto solutions = problem.GetSolutions(ctx.get_current_device().get_gfx_name()); const auto& solution = solutions.at(tuning_value); - const auto template_str = solution.template_str; - const auto blocks_per_batch = solution.grid_size; - const auto block_size = solution.block_size; + const auto template_str = solution.ToTemplateString(); + const auto block_size = solution.GetTemplateParameter("BlockSize"); + const auto m_per_block = solution.GetTemplateParameter("Gemm01MPerBlock"); + const auto n1_per_block = solution.GetTemplateParameter("Gemm1NPerBlock"); + const auto blocks_per_batch = ck::host::integer_divide_ceil(problem.M, m_per_block) * + ck::host::integer_divide_ceil(problem.O, n1_per_block); hip_compile_options options; + options.emplace_param("-DCK_CODE_GEN_RTC"); options.additional_src_files = ck_headers(); auto grid_size = can_fold_batch(inputs) ? blocks_per_batch : batch_count * blocks_per_batch; options.set_launch_params(v, grid_size * block_size, block_size);