From 79c234668a8fa8b552e8c6f5017f9f30332e3c1a Mon Sep 17 00:00:00 2001 From: Dino Music Date: Wed, 24 Sep 2025 12:09:32 +0000 Subject: [PATCH 1/2] Implement reference op for rotary embedding --- .gitignore | 1 + src/CMakeLists.txt | 1 + src/include/migraphx/op/rotary_embedding.hpp | 220 ++++++++ src/include/migraphx/operators.hpp | 1 + src/op/builder/rotary_embedding.cpp | 67 +++ test/ref/rotary_embedding.cpp | 502 +++++++++++++++++++ 6 files changed, 792 insertions(+) create mode 100644 src/include/migraphx/op/rotary_embedding.hpp create mode 100644 src/op/builder/rotary_embedding.cpp create mode 100644 test/ref/rotary_embedding.cpp diff --git a/.gitignore b/.gitignore index 20dead71ed3..da4e7553619 100644 --- a/.gitignore +++ b/.gitignore @@ -80,6 +80,7 @@ docs/html .idea/ cmake-build*/ build*/ +!/src/op/builder/ # Recommended location to install rbuild dependencies from README.md depend*/ diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 75cac154e61..3959baf7c60 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -277,6 +277,7 @@ register_migraphx_ops( rnn_last_hs_output rnn_var_sl_last_output roialign + rotary_embedding rsqrt run_on_target scalar diff --git a/src/include/migraphx/op/rotary_embedding.hpp b/src/include/migraphx/op/rotary_embedding.hpp new file mode 100644 index 00000000000..c519311c45a --- /dev/null +++ b/src/include/migraphx/op/rotary_embedding.hpp @@ -0,0 +1,220 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#ifndef MIGRAPHX_GUARD_OPERATORS_GQA_ROTARY_EMBEDDING_HPP +#define MIGRAPHX_GUARD_OPERATORS_GQA_ROTARY_EMBEDDING_HPP + +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { + +struct rotary_embedding +{ + size_t num_heads = 1; + size_t kv_num_heads = 1; + bool interleaved = false; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.num_heads, "num_heads"), + f(self.kv_num_heads, "kv_num_heads"), + f(self.interleaved, "interleaved")); + } + + std::string name() const { return "rotary_embedding"; } + + shape compute_shape(std::vector inputs) const { return inputs.front(); } + + struct rotary_parameters + { + size_t batch_size = 0; + size_t sequence_length = 0; + size_t head_size = 0; + size_t num_heads = 0; + size_t rotary_embedding_dim = 0; + size_t max_sequence_length = 0; // Sequence length used by cos/sin cache + size_t head_stride = 0; + size_t seq_stride = 0; + size_t batch_stride = 0; + bool position_ids_use_batch = false; + }; + + template + void run_rotary_embedding(T input, + T cos_cache, + T sin_cache, + T output, + const size_t* pos_ids, + rotary_parameters params) const + { + const size_t half_rotary_emb_dim = params.rotary_embedding_dim / 2; + + const size_t loop_len = params.batch_size * params.sequence_length * params.num_heads; + par_for(loop_len, [&](auto idx) { + const size_t b = (idx / params.num_heads) / params.sequence_length; + const size_t s = (idx / params.num_heads) % params.sequence_length; + const size_t n = idx % params.num_heads; + const size_t block_offset = + b * params.batch_stride + s * params.seq_stride + n * params.head_stride; + auto input_data = input + block_offset; + auto output_data = output + block_offset; + + const size_t position_id = params.position_ids_use_batch + ? pos_ids[b * params.sequence_length + s] + : pos_ids[0] + s; + const size_t cache_offset = position_id * half_rotary_emb_dim; + auto cos_data = cos_cache + cache_offset; + auto sin_data = sin_cache + cache_offset; + + size_t cache_idx = 0; + float sign = 0.0; + size_t j = 0; + for(size_t i = 0; i < params.rotary_embedding_dim; i++) + { + if(interleaved) + { + cache_idx = (i / 2) % half_rotary_emb_dim; + sign = (i % 2 == 0) ? -1.0 : 1.0; + j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign + } + else + { + cache_idx = i % half_rotary_emb_dim; + sign = (i < half_rotary_emb_dim) ? -1.0 : 1.0; + j = (i + half_rotary_emb_dim) % params.rotary_embedding_dim; + } + output_data[i] = input_data[i] * cos_data[cache_idx] + + sign * input_data[j] * sin_data[cache_idx]; + } + std::copy(input_data + params.rotary_embedding_dim, + input_data + params.head_size, + output_data + params.rotary_embedding_dim); + }); + } + + template + void pack_v_into_rotary_qkv(rotary_parameters params, const T input, T output) const + { + const size_t loop_len = params.batch_size * params.sequence_length * kv_num_heads; + par_for(loop_len, [&](const auto idx) { + const size_t b = (idx / kv_num_heads) / params.sequence_length; + const size_t s = (idx / kv_num_heads) % params.sequence_length; + const size_t n = idx % kv_num_heads; + const size_t block_offset = + b * params.batch_stride + s * params.seq_stride + n * params.head_stride; + const T input_data = input + block_offset; + T output_data = output + block_offset; + for(size_t i = 0; i < params.head_size; i++) + { + output_data[i] = input_data[i]; + } + }); + } + + // Args: + // 0 - packed QKV (batch_size, num_heads + 2 * kv_num_heads, sequence_length, head_size) + // 1 - seqlens_k (batch_size) + // 2 - cos cache (max_rotary_sequence_length, head_size / 2) + // 3 - sin cache (max_rotary_sequence_length, head_size / 2) + argument compute(const shape& output_shape, std::vector args) const + { + rotary_parameters params; + + const auto& qkv_lens = args[0].get_shape().lens(); + params.batch_size = qkv_lens[0]; + params.sequence_length = qkv_lens[2]; + params.head_size = qkv_lens[3]; + const auto& cache_lens = args[2].get_shape().lens(); + params.max_sequence_length = cache_lens[0]; + params.rotary_embedding_dim = cache_lens[1] * 2; + params.seq_stride = params.head_size; + params.head_stride = params.sequence_length * params.seq_stride; + params.batch_stride = + (num_heads + 2 * kv_num_heads) * params.sequence_length * params.head_size; + params.position_ids_use_batch = params.sequence_length == 1; + + argument result{output_shape}; + + visit_all(result, args[0], args[2], args[3])( + [&](auto output, auto qkv, auto cos_cache, auto sin_cache) { + visit_all(args[1])([&](auto seqlens_k) { + std::vector pos_ids(params.position_ids_use_batch ? params.batch_size + : 1); + if(params.position_ids_use_batch) + { + std::transform(seqlens_k.begin(), + seqlens_k.end(), + pos_ids.begin(), + [](auto len) { return len - 1; }); + } + else + { + pos_ids[0] = 0; + } + + auto q_input = qkv.begin(); + auto k_input = q_input + num_heads * params.head_stride; + auto q_rotary = output.begin(); + auto k_rotary = q_rotary + num_heads * params.head_stride; + + params.num_heads = num_heads; + run_rotary_embedding(q_input, + cos_cache.begin(), + sin_cache.begin(), + q_rotary, + pos_ids.data(), + params); + + params.num_heads = kv_num_heads; + run_rotary_embedding(k_input, + cos_cache.begin(), + sin_cache.begin(), + k_rotary, + pos_ids.data(), + params); + + auto v_input = k_input + kv_num_heads * params.head_stride; + auto v_rotary = k_rotary + kv_num_heads * params.head_stride; + params.num_heads = num_heads; + + pack_v_into_rotary_qkv(params, v_input, v_rotary); + }); + }); + + return result; + } +}; + +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/operators.hpp b/src/include/migraphx/operators.hpp index fa617517791..b8114366f30 100644 --- a/src/include/migraphx/operators.hpp +++ b/src/include/migraphx/operators.hpp @@ -115,6 +115,7 @@ #include #include #include +#include #include #include #include diff --git a/src/op/builder/rotary_embedding.cpp b/src/op/builder/rotary_embedding.cpp new file mode 100644 index 00000000000..78f7e0a119a --- /dev/null +++ b/src/op/builder/rotary_embedding.cpp @@ -0,0 +1,67 @@ +/* The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace op { +namespace builder { + +struct rotary_embedding : op_builder +{ + size_t num_heads = 1; + size_t kv_num_heads = 1; + bool interleaved = false; + + template + static auto reflect(Self& self, F f) + { + return pack(f(self.num_heads, "num_heads"), + f(self.kv_num_heads, "kv_num_heads"), + f(self.interleaved, "interleaved")); + } + + // For now just a wrapper around the ref op + // The goal is to remove the ref op and implement rotary embedding via other existing operators + std::vector + insert(module& m, instruction_ref ins, const std::vector& args) const + { + return {m.insert_instruction(ins, + make_op("rotary_embedding", + {{"num_heads", num_heads}, + {"kv_num_heads", kv_num_heads}, + {"interleaved", interleaved}}), + args)}; + } +}; + +} // namespace builder +} // namespace op +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/test/ref/rotary_embedding.cpp b/test/ref/rotary_embedding.cpp new file mode 100644 index 00000000000..371484f4738 --- /dev/null +++ b/test/ref/rotary_embedding.cpp @@ -0,0 +1,502 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#include +#include +#include +#include +#include +#include + +#include + +TEST_CASE(rotary_embedding_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + const size_t batch_size = 1; + const size_t sequence_length = 8; + const size_t num_heads = 4; + const size_t kv_num_heads = 2; + const size_t head_size = 16; + const size_t max_cache_sequence_length = 8; + const size_t total_sequence_length = 8; + const size_t max_rotary_seq_length = max_cache_sequence_length; + const size_t rotary_dim = head_size / 2; + const bool interleaved = false; + + migraphx::shape qkv_shape{ + migraphx::shape::float_type, + {batch_size, num_heads + 2 * kv_num_heads, sequence_length, head_size}}; + migraphx::shape key_total_sequence_lens_shape(migraphx::shape::int32_type, {batch_size}); + migraphx::shape cos_cache_shape(migraphx::shape::float_type, + {max_rotary_seq_length, rotary_dim}); + migraphx::shape sin_cache_shape(migraphx::shape::float_type, + {max_rotary_seq_length, rotary_dim}); + + auto qkv = mm->add_parameter("qkv", qkv_shape); + auto ktsl = mm->add_parameter("ktsl", key_total_sequence_lens_shape); + auto cos_cache = mm->add_parameter("cos_cache", cos_cache_shape); + auto sin_cache = mm->add_parameter("sin_cache", sin_cache_shape); + + auto rotary = mm->add_instruction(migraphx::make_op("rotary_embedding", + {{"num_heads", num_heads}, + {"kv_num_heads", kv_num_heads}, + {"interleaved", interleaved}}), + qkv, + ktsl, + cos_cache, + sin_cache); + mm->add_return({rotary}); + + std::vector qkv_val{ + 0.41749f, -0.69577f, -1.70273f, -0.79187f, 0.07310f, -0.27880f, -0.75174f, -0.72621f, + -2.06164f, -1.44138f, -0.01891f, -0.05486f, -0.50047f, 0.35353f, -0.76615f, -1.74248f, + 0.59954f, -1.00662f, 1.02043f, -0.07006f, -1.96865f, -0.42948f, 1.38265f, 1.39979f, + 1.61774f, 0.76564f, -0.06511f, 0.51657f, -0.78820f, -0.53020f, 1.00847f, 0.75522f, + 1.10552f, 0.33774f, -1.25448f, -0.74513f, 0.32448f, 0.18892f, -0.80532f, -0.47895f, + -0.15562f, 0.23953f, 1.14514f, -0.72504f, 0.60569f, -0.01937f, 1.17494f, 0.44646f, + 0.17420f, -1.29149f, 0.09795f, 1.30044f, 0.65743f, -0.43525f, 1.21967f, -0.26364f, + -1.17287f, 1.08942f, -0.20404f, -0.67642f, -1.04814f, -0.30387f, -0.34929f, -0.67358f, + -0.45570f, 1.01702f, -0.20984f, -0.66244f, -1.03716f, 1.67139f, 0.22007f, 0.94488f, + -0.15520f, 0.00094f, 0.93557f, -1.05383f, 0.36045f, 0.05115f, -0.19947f, 1.57586f, + 0.00395f, 0.20170f, -1.59494f, -2.09666f, 1.28563f, 0.10925f, -1.95444f, -0.10990f, + -1.39617f, 0.96968f, 0.43793f, 1.49594f, 0.69040f, -0.45398f, -0.43307f, 0.13872f, + -1.23702f, 0.27749f, -1.54765f, -0.33909f, 0.95985f, -0.80969f, -0.06333f, -0.35975f, + -1.03807f, 0.95710f, -0.72967f, 0.73937f, -0.01144f, 0.52790f, 0.81002f, -0.27235f, + 1.68369f, -1.10408f, -1.12985f, 0.08155f, -1.80029f, -1.07238f, -0.50818f, 0.48642f, + 1.65785f, 0.34942f, -0.29293f, -0.02263f, -0.28906f, 0.21972f, -0.31756f, -0.61890f, + -0.97612f, 0.70052f, 1.17989f, 1.03988f, 0.28753f, 0.44653f, 0.58795f, -0.75481f, + -0.86908f, -0.86280f, 0.46151f, -0.40664f, 0.93974f, -0.59152f, 0.62067f, -1.16862f, + 0.19509f, -2.08621f, -1.25910f, -0.69394f, -0.46405f, 1.65313f, -0.62303f, 0.57686f, + 0.12457f, -0.50076f, 0.52905f, -0.21097f, 0.13185f, -0.06517f, -0.81022f, -0.15991f, + -0.85608f, -0.66973f, 0.82188f, 0.84750f, 1.11381f, -0.27893f, 0.96663f, 0.17090f, + -0.58623f, 0.08810f, -0.64246f, 1.13702f, 0.78309f, 0.92382f, -0.37715f, 0.37516f, + 0.56550f, -0.14689f, -0.78602f, 1.58871f, -1.83865f, -0.05113f, -0.00113f, -0.20537f, + -0.25702f, -0.48554f, -0.42965f, -0.80981f, -0.77238f, -0.23455f, -1.04290f, 1.03438f, + -2.45761f, 0.44027f, 1.38838f, -1.69399f, -0.26094f, 1.95926f, 0.15914f, 0.56232f, + 1.48083f, 0.08590f, -0.19955f, 0.75460f, -0.52121f, -1.55349f, -0.34823f, -0.26756f, + -1.55488f, 0.92425f, -0.71535f, 0.61587f, -0.45286f, -0.92762f, -0.14223f, -1.34620f, + 1.03421f, 0.22707f, -0.90337f, 0.42504f, 1.18391f, 0.13978f, -0.16003f, -0.27895f, + -0.62004f, -0.68513f, 0.15930f, 0.81287f, 0.70400f, 0.62487f, 0.92692f, -1.02183f, + 0.67300f, 0.54589f, 0.81747f, -1.05813f, -0.44010f, -3.08890f, 0.98540f, 0.35170f, + -0.17988f, -0.37412f, 0.25646f, -0.60839f, 0.13135f, -0.06886f, 0.55302f, -0.49637f, + -0.41079f, 0.71376f, -0.92411f, -0.04047f, -0.49496f, -0.16855f, -0.23395f, 1.22807f, + 0.98338f, -0.72929f, -1.31395f, 1.29489f, 0.32306f, -0.69334f, 0.08734f, -0.15889f, + 1.89157f, 0.64903f, 0.08721f, 1.88299f, 0.90821f, 0.57134f, 0.29974f, -0.02099f, + -1.47984f, -1.19396f, -0.86907f, -0.46785f, -1.51203f, -0.48414f, -0.28719f, -0.35233f, + -1.87752f, 1.27135f, 0.65182f, -0.97655f, 0.51972f, -0.56604f, -0.78326f, -0.22718f, + 0.03281f, 0.60940f, -0.27093f, 1.08492f, -2.64958f, 0.47716f, -1.31534f, -1.52301f, + 1.54800f, -0.89364f, -1.19203f, -1.02952f, -0.89654f, 0.19871f, 0.22578f, -0.24013f, + 0.74516f, -0.76311f, -0.70077f, -0.05940f, 2.51484f, 0.45830f, -0.58173f, 0.88659f, + -0.75417f, -0.97556f, 1.05278f, 0.12707f, 0.86963f, -0.41560f, -1.37358f, 0.73604f, + 0.34192f, 0.80609f, 1.53944f, 0.40806f, 0.14858f, 0.28071f, 0.10681f, -0.88577f, + 1.25632f, -0.91592f, -1.26772f, 1.14000f, 2.31817f, 1.38137f, -0.50686f, -0.02952f, + -0.39148f, 0.32838f, -0.36859f, -1.35995f, 1.32267f, -1.54589f, -1.57510f, -0.66400f, + -0.19672f, -1.07867f, 0.70195f, -2.03537f, -0.45462f, -0.64457f, 0.31647f, 1.20802f, + -0.42087f, 0.68382f, 2.00249f, 1.45079f, -0.95763f, -1.11372f, 0.87328f, -0.39358f, + 1.41147f, -1.37817f, -0.47008f, -0.61774f, -0.05532f, 0.95466f, -1.10584f, 0.17766f, + -0.73075f, -0.69812f, -0.68570f, -0.07292f, 0.63213f, 1.15353f, -0.40322f, -0.05030f, + -0.83515f, -0.06614f, 1.02871f, 0.31555f, 0.13493f, -1.07473f, -1.43802f, 0.74787f, + -0.17787f, 2.72647f, 1.12792f, -0.87049f, 2.23661f, -0.72028f, -1.94251f, -0.82372f, + -0.39990f, -0.83586f, -0.15177f, -1.16006f, -1.00450f, 0.46886f, 1.12177f, 2.36713f, + -0.85803f, 1.14115f, -0.54052f, -0.33579f, 0.80441f, -0.03780f, 0.10341f, 0.08724f, + 0.32319f, 0.20989f, -0.11073f, 0.25509f, -0.29188f, -0.68462f, 0.27836f, -0.16048f, + -0.58644f, -0.75119f, -0.71220f, 0.36131f, -2.72436f, -0.05491f, 0.15442f, -0.42460f, + -0.45834f, -0.73080f, 0.64778f, 0.76119f, -2.56361f, 0.04359f, 0.71025f, 2.14647f, + 1.51538f, 0.01322f, 2.17842f, -0.20094f, 2.44536f, 0.01128f, -0.94017f, 2.85214f, + 0.61998f, 1.25180f, -0.71116f, -0.37180f, -1.22676f, 0.45919f, 0.23963f, 0.86249f, + -0.58306f, 0.27836f, -0.20477f, -1.56387f, -0.45388f, 0.40267f, 1.77693f, 0.41270f, + -2.19566f, 0.46689f, -1.00889f, 1.18128f, -1.69895f, 0.94235f, -0.48070f, 0.88864f, + 1.25445f, -0.85912f, 0.63483f, 2.48495f, 1.16718f, 0.94282f, 1.01831f, 0.97040f, + -0.48674f, -0.98615f, -0.53975f, -0.33098f, 0.80657f, 1.15605f, 0.09327f, -1.08562f, + -1.64618f, -0.29966f, 0.21517f, -2.44469f, -1.25757f, 0.51321f, 0.61697f, 0.31420f, + -0.84412f, -0.40086f, -0.28057f, -0.24301f, -1.28083f, -0.43024f, -0.65606f, -0.14499f, + -0.70121f, 0.70340f, 1.35571f, -0.25670f, 0.44737f, -1.85772f, -1.01866f, 0.57725f, + -2.49752f, 0.42235f, -0.23843f, -1.39177f, -0.76879f, 0.41183f, 2.10108f, 0.28865f, + 0.06077f, -1.14342f, 0.34651f, 0.67033f, -0.08636f, -0.84150f, 0.55234f, -0.52409f, + -0.84508f, 1.20556f, -0.98627f, 0.02287f, 2.32880f, -1.82350f, -0.65655f, 1.44936f, + 0.10410f, 2.05022f, 1.08761f, 1.57381f, -0.01312f, -0.66371f, -1.02632f, 0.24346f, + -1.47868f, 1.05582f, 0.38794f, -0.88618f, -0.48330f, -1.93097f, 0.38556f, 0.45613f, + -1.53924f, -1.21630f, 0.97219f, 0.68307f, 0.36223f, -1.26150f, -0.27909f, -1.29269f, + 0.81657f, -0.05342f, -1.31958f, -0.60319f, -1.04079f, 0.39469f, -0.75666f, 1.50323f, + 0.25399f, -0.10152f, -0.76702f, 1.83365f, 0.57462f, -0.77335f, -1.35593f, -0.46823f, + -0.04678f, -0.66350f, -0.24358f, -1.57921f, -0.69181f, -0.36496f, -0.01369f, -0.70503f, + -1.71507f, 1.58063f, -0.58568f, -1.55290f, -0.58359f, 1.35833f, 0.54550f, -0.70589f, + -0.24310f, -1.23408f, 0.42297f, -0.51576f, 1.30628f, -0.81409f, -0.35845f, 0.89670f, + 1.01546f, 0.39702f, -1.26967f, -0.36245f, 1.29258f, -0.49741f, -0.25744f, 0.92564f, + 0.59873f, 0.85128f, -0.86577f, 1.25455f, 0.36780f, 0.70003f, 1.60151f, -0.67805f, + 0.25288f, -1.44379f, 0.59452f, -0.74957f, 0.92938f, 0.93066f, 1.91045f, 0.58974f, + -0.33494f, 0.52111f, 0.73124f, -0.27959f, 0.81104f, 0.25312f, -1.38699f, 0.20307f, + 0.76345f, 0.14203f, 0.90887f, -0.83386f, -0.21649f, -0.13751f, 2.50241f, 0.38624f, + -0.67405f, -2.28033f, 1.00245f, 2.74601f, 0.78501f, 2.08665f, 0.72543f, 0.57949f, + -2.17947f, -0.52484f, -0.23829f, -1.05610f, 0.67673f, -0.81046f, 1.36806f, 0.27357f, + -0.68868f, 0.17381f, -1.02706f, 1.17405f, 1.05738f, -0.74259f, 0.57364f, 0.67129f, + 1.16308f, -0.78270f, -2.66280f, 0.29430f, -1.55534f, 0.02705f, -0.49464f, 1.54097f, + 0.58723f, 0.08179f, 0.39286f, -0.74650f, -0.38050f, -0.98030f, -0.29283f, -0.65102f, + 0.32762f, 0.84588f, -0.42554f, -0.15737f, -0.85467f, 0.05529f, 0.69491f, -1.04209f, + -0.85984f, -1.32812f, -0.29532f, 0.66120f, 0.11061f, 0.84870f, 0.86388f, -1.53822f, + 0.44723f, 0.13878f, -0.17476f, -0.41955f, -1.01415f, -1.36983f, 0.41508f, -1.88003f, + -1.13967f, -1.00110f, -0.31072f, -0.78613f, 0.69637f, -0.52052f, 1.20858f, -0.24648f, + 1.31078f, -1.65722f, 0.52834f, 1.24568f, 0.21253f, -0.40617f, -0.10599f, -0.49272f, + -0.12628f, 0.37853f, -0.03980f, 0.58942f, -1.50853f, 0.32210f, 0.40937f, -0.49893f, + 0.37507f, 0.19290f, -1.85422f, -0.03451f, 1.00867f, 0.74706f, 1.14846f, 0.29571f, + 0.77227f, -0.32998f, 0.14887f, 0.55427f, -0.71247f, 1.00684f, -1.00142f, 0.19765f, + -1.54164f, 0.75906f, -3.29353f, 0.62726f, -0.96910f, 1.45244f, 0.37286f, 0.70223f, + 0.13735f, 1.50773f, 0.91042f, 0.19532f, -0.16148f, -1.98474f, -0.99344f, 0.45995f, + -2.43026f, -0.08176f, -1.10964f, 0.18464f, -1.20368f, -1.53090f, -0.47588f, 0.36687f, + -0.30770f, 2.05622f, 0.80959f, 0.24895f, -0.55335f, 2.20205f, -0.59102f, 0.61142f, + 1.54316f, -0.70976f, 0.40827f, 0.09174f, -0.81059f, -0.94260f, 0.83612f, 1.54785f, + 0.52762f, 0.29894f, 0.51916f, 0.18861f, -0.48849f, 0.37812f, -0.55108f, 1.16715f, + 0.18136f, -2.05249f, -0.52255f, -3.17783f, -1.23747f, 0.96977f, 0.94546f, 0.53913f, + -0.87105f, -0.23456f, 0.94457f, 0.33633f, 1.67793f, -0.64343f, 1.64205f, -0.27973f, + -0.28263f, 0.59008f, 0.46478f, -0.67101f, -0.97408f, 1.09617f, -1.49934f, 1.65995f, + 0.48849f, 2.09345f, 1.12164f, 0.00400f, 0.41253f, -1.82862f, 0.02542f, 0.15923f, + -0.02319f, -1.10920f, 0.33385f, 0.44133f, 0.52444f, 0.17839f, 0.35989f, 1.74398f, + 3.12485f, 0.32066f, 1.83530f, 0.05643f, -1.48623f, -2.86640f, -0.98590f, 2.43114f, + -0.35453f, 1.19489f, 1.21466f, 0.43728f, 1.17409f, 1.10532f, -0.54684f, -0.27415f, + 2.10486f, 0.42502f, 0.31442f, 1.09029f, -1.39990f, 0.34544f, -0.04665f, 1.93631f, + -1.80739f, -0.72702f, 0.26711f, -0.49996f, -0.08280f, 1.45220f, -1.87458f, 2.56975f, + -0.35613f, -0.82097f, -0.73138f, 0.64083f, -0.82572f, 0.29276f, -1.20876f, 0.17214f, + -0.98109f, -0.93785f, 0.23218f, 0.74644f, -0.55791f, 0.76360f, -2.14018f, 0.41112f, + -1.00239f, 0.71319f, 0.23396f, 1.44287f, 0.08250f, 2.22372f, -0.53487f, 2.49612f, + -1.27615f, -0.19306f, -0.08782f, 1.84034f, 0.32768f, 0.56849f, 0.04441f, -0.65827f, + -0.23335f, -0.23385f, -0.99781f, 0.02744f, 0.27255f, -0.16298f, 0.16596f, -0.43781f, + 0.45376f, 1.61600f, 0.43510f, -0.15598f, 0.02663f, 0.20136f, 0.16456f, 2.32006f, + -0.17321f, -0.29193f, 0.91427f, 1.58508f, -0.46040f, -1.24638f, -0.67820f, -1.15898f, + 0.17969f, 1.22766f, 2.75006f, 0.42565f, -0.70687f, 0.32029f, 1.39965f, 0.56489f, + 0.71663f, 0.96956f, 0.11987f, 0.13721f, -1.87023f, -0.23010f, 0.06482f, 0.04463f, + 1.68528f, 0.76126f, 1.49722f, -0.38899f, 0.55481f, 1.01654f, 1.39907f, 0.52457f, + -1.95718f, 0.11925f, -0.93854f, -0.97164f, 0.36083f, 0.15714f, -1.08132f, 0.18311f, + 0.04688f, 0.84368f, -0.93179f, -0.51981f, -0.10838f, -0.21285f, -1.50430f, 0.11392f, + 1.62670f, -1.49010f, 0.52821f, -0.38828f, 0.48409f, -0.39124f, 0.59059f, -0.09530f, + -0.25620f, -1.30062f, 0.56251f, -0.36103f, 0.99661f, 1.14148f, 1.92076f, -0.42200f, + -0.98999f, 2.53498f, -0.86633f, 0.25785f, -1.04939f, 1.24230f, 0.19481f, -0.23041f, + 0.23570f, -1.02525f, 1.16260f, -1.04002f, 0.22364f, 0.28499f, -0.34476f, -0.67831f, + -0.55750f, -0.91398f, 1.37583f, 0.64503f, -2.02422f, -1.52848f, -0.27042f, 0.41021f, + -0.43892f, 0.75682f, 0.18781f, 0.92758f, 0.50460f, 0.73314f, 0.03367f, -0.27875f, + -0.63667f, 0.18394f, 1.42434f, 2.00770f, -0.88286f, -0.55983f, -1.12401f, 0.34193f, + -2.51687f, -1.04707f, -0.63970f, -0.70438f, 0.59782f, 0.74183f, 0.31749f, -0.28442f, + -1.95803f, -1.79381f, 0.46461f, -0.17142f, 0.41181f, 0.27836f, -0.02363f, 0.93865f}; + + std::vector ktsl_val(key_total_sequence_lens_shape.elements(), total_sequence_length); + + std::vector cos_cache_val{ + 0.60305f, 0.94544f, 0.59646f, -0.94253f, -0.92642f, 0.19489f, -0.97555f, -0.99972f, + -0.35706f, -0.77041f, -0.87666f, 0.63573f, -0.13506f, 0.51438f, 0.44535f, -0.99741f, + 0.80078f, 0.27214f, 0.75584f, 0.82948f, 0.45954f, -0.68932f, -0.76928f, 0.81090f, + 0.75784f, 0.88164f, -0.93231f, 0.04034f, 0.57829f, 0.90856f, 0.55415f, -0.96171f, + 0.14195f, 0.61549f, 0.45128f, 0.91793f, 0.17327f, 0.99181f, 0.99467f, -0.99735f, + 0.99103f, 0.67559f, 0.98632f, -0.58485f, 0.21572f, -0.53796f, -0.78902f, 0.17645f, + 0.97012f, -0.39310f, 0.33858f, 0.82521f, -0.35662f, 0.39955f, 0.24059f, 0.32539f, + 0.82187f, 0.89318f, -0.54623f, 0.05817f, 0.80386f, -0.63893f, -0.19885f, -0.96778f}; + + std::vector sin_cache_val{ + -0.79771f, -0.32581f, -0.80264f, 0.33412f, 0.37650f, -0.98082f, -0.21977f, 0.02385f, + -0.93408f, 0.63755f, -0.48110f, -0.77191f, -0.99084f, -0.85756f, 0.89536f, 0.07192f, + 0.59896f, -0.96226f, 0.65475f, -0.55853f, 0.88816f, 0.72445f, -0.63891f, -0.58519f, + 0.65244f, 0.47192f, -0.36167f, -0.99919f, -0.81583f, 0.41776f, 0.83242f, 0.27406f, + -0.98987f, -0.78815f, -0.89238f, -0.39675f, 0.98487f, -0.12775f, 0.10315f, 0.07274f, + -0.13363f, 0.73728f, -0.16483f, 0.81114f, 0.97646f, -0.84297f, 0.61436f, -0.98431f, + -0.24262f, -0.91950f, -0.94094f, 0.56482f, 0.93425f, 0.91671f, -0.97063f, 0.94558f, + -0.56968f, -0.44970f, 0.83764f, 0.99831f, 0.59482f, 0.76926f, 0.98003f, 0.25181f}; + + p.compile(migraphx::make_target("ref")); + migraphx::parameter_map pm; + pm["qkv"] = migraphx::argument(qkv_shape, qkv_val.data()); + pm["ktsl"] = migraphx::argument(key_total_sequence_lens_shape, ktsl_val.data()); + pm["cos_cache"] = migraphx::argument(cos_cache_shape, cos_cache_val.data()); + pm["sin_cache"] = migraphx::argument(sin_cache_shape, sin_cache_val.data()); + + auto qkv_rotary = p.eval(pm).front(); + std::vector qkv_rotary_vals(qkv_shape.elements()); + qkv_rotary.visit([&](auto output) { qkv_rotary_vals.assign(output.begin(), output.end()); }); + + std::vector qkv_rotary_gold{ + -1.39282, -1.12742, -1.03079, 0.764691, 0.120706, 0.292414, 0.564983, + 0.767565, -1.57631, -1.13605, 1.3554, -0.212872, 0.491168, 0.342352, + 0.912628, 1.72467, 1.29703, 0.287376, -0.925895, 0.354206, -0.515094, + -0.675594, -0.287181, -1.45048, -1.13765, -1.23163, -0.43385, 0.382479, + 2.05707, 0.0955806, 1.68709, -0.652591, 0.978488, 0.322403, -1.69797, + -1.02303, -0.388838, -0.116194, 1.3702, -0.127117, 0.537545, -0.259808, + 0.0441719, -0.185229, 0.566529, 0.150215, -0.389331, 0.642311, 0.897243, + -1.65275, -0.165115, -0.623412, -0.474919, -0.268506, 0.966636, 0.438147, + -0.775193, 0.350996, 0.154803, -1.32667, -1.14248, -0.457914, 0.821719, + 0.575535, -0.218314, 0.626707, 0.740187, -1.02618, -0.534705, 1.66424, + 0.239472, -1.057, 0.429053, -0.800986, 0.609461, -0.704519, -0.959013, + -0.162789, -0.175707, -1.50295, -0.182656, -0.578659, -1.50094, 0.0128149, + -0.396812, -0.441464, 1.80815, 0.117152, -1.38417, 0.803815, 0.694833, + -2.57559, 1.4043, 0.152129, -0.859029, 0.132653, -1.45191, 0.770972, + -1.21058, -0.697431, -0.331614, -0.807443, 0.770993, 0.14047, -0.706927, + -0.631388, 1.20919, 0.418611, 0.90082, -0.531328, 0.256353, -0.428792, + 2.32822, -0.829008, 0.862528, 0.0273355, -1.27524, 0.516154, 0.41227, + -0.314902, 0.403373, 0.8086, -0.7864, 0.0800958, -1.30321, -0.965325, + -0.434885, 0.721444, -1.28192, 0.381191, 1.07418, -0.844252, -0.620186, + -0.49315, -0.43717, 0.78247, 0.254562, -1.04396, -0.671755, 0.730715, + -0.762339, -0.553247, -0.734708, 1.15029, 0.0466995, 1.9265, 1.35833, + -0.604008, 0.193317, 0.79445, 0.447972, -0.563865, -0.226709, -0.944273, + 0.141956, 0.401539, 0.441992, -1.45118, -0.918668, 0.200984, -0.334403, + -0.0974852, 1.04186, 1.33804, -0.183669, -0.476989, -0.984574, 0.358123, + -0.982199, 0.66843, 0.052529, 0.469781, 1.3491, -0.838878, -0.327456, + 0.204208, 0.596249, 0.099632, 0.577423, -0.745065, -1.6934, 0.0515309, + 0.867505, -0.0859758, 0.174175, -0.497392, 0.684847, -1.62009, 1.05337, + -0.234463, -0.578864, -1.05106, 1.11697, 0.338684, 0.448474, -1.25558, + 0.468111, 1.74476, 0.194212, -0.541368, 2.64292, -0.294128, -1.32902, + 1.36476, -0.347302, -1.79106, -0.329959, 0.307754, -1.40273, 0.457, + -0.854466, -0.704959, -1.25373, 0.616853, 0.210538, -0.51211, 1.23271, + 0.834837, -0.773101, 0.250972, -0.186807, 0.70676, 0.0388865, 1.27586, + -0.43823, 0.77127, 0.823126, 1.26844, 0.160103, 3.08129, 1.17947, + -0.665054, 0.803325, 0.415388, 0.126887, -0.414054, 0.81466, -0.661345, + -0.662619, -0.851782, -0.381857, -0.0131786, 0.633985, 0.00501156, 0.399999, + 0.173655, 0.11931, 0.171137, -0.235142, 0.805758, 0.719598, -0.609716, + -0.319749, 0.0547204, 0.588497, -1.31349, 2.10195, -0.478039, -0.71372, + -1.84962, -0.64123, 0.425257, -0.0193307, 0.159346, 0.356259, 0.851229, + 1.10665, -1.34213, -0.719752, 0.79139, -0.311606, 0.0171946, -1.22536, + 0.109289, 1.07547, -1.05123, 0.719174, -0.734445, 0.5734, 0.367756, + 2.05268, -1.74067, -0.153315, -0.259684, 1.42799, 0.124019, -0.605963, + 0.201252, -0.900916, -0.694072, 0.575702, 0.324902, -0.421317, -0.472871, + 1.15612, -1.37553, 1.25926, -0.829596, -1.07838, -1.45993, -2.76525, + 0.208704, 0.666696, 0.696529, 1.05676, -0.212402, 1.03409, 0.124571, + 2.16378, 0.590014, 0.82103, -1.05436, -0.085368, -1.22022, -0.72807, + 0.0644779, -1.54878, -0.186138, -1.24541, -0.464878, 1.29213, -0.225742, + -0.436569, 0.826866, -2.25735, 0.454881, 0.158523, 0.88557, -0.160122, + -1.19906, -1.94586, 0.884542, 0.548001, 1.3342, -0.493141, -0.0349891, + -0.414256, 1.01713, -0.247845, 2.44634, 0.729245, 0.288274, 1.04836, + 1.0719, -0.142642, -0.486631, 0.753102, 0.0872765, 1.19346, 1.64989, + -1.21738, 0.866737, -0.0658435, -1.53604, 0.235686, 1.54612, 0.393193, + -1.32013, -0.863259, -0.296059, 1.47141, -0.0870139, -2.04338, 0.30967, + -0.874938, -0.639524, -1.11369, -0.314353, -1.07635, -0.65329, -0.487139, + -0.319258, 0.427885, 0.0897219, 1.48948, -0.139642, -0.270091, 0.25487, + -1.13628, -0.0544412, 0.484468, 1.57404, -0.109217, -0.73644, -0.426269, + 2.30538, 0.550942, 1.20806, -1.69385, 0.319492, 2.14155, 0.767033, + -0.099271, -1.67857, -0.995838, 0.802543, 1.77267, 0.797841, -0.667437, + -2.38611, 0.608253, -1.01297, 0.42058, -0.0165653, -0.39785, -0.606546, + -0.203179, -0.0754723, 0.68607, 0.565839, 0.357117, 0.421368, -0.75762, + -0.319739, 0.216557, 0.166339, -0.195082, -0.907648, -0.962443, 0.724847, + 1.02494, 0.00627179, 0.334994, 0.911785, -0.718284, 0.52396, 0.0233051, + 0.429589, -3.59775, -0.069827, -0.645042, 1.98904, 0.743916, -0.579094, + -2.28817, -0.379605, 0.413299, -0.181583, -0.720468, -2.97931, 1.45854, + 1.10988, -0.124848, 0.185779, -2.70442, 0.421914, -0.649825, -0.0478077, + -2.25618, 0.539307, -0.992722, -0.96685, 1.5946, 0.519757, 1.81704, + -0.476246, 0.26548, 0.0679767, -0.272559, 1.7048, -0.74139, 0.883191, + -0.294848, -0.856265, 1.17815, 0.146656, 0.537179, -1.18485, -0.535799, + 0.467316, -0.860768, -0.897359, -0.650006, -1.29965, -0.637005, 2.20922, + 1.3137, -1.41668, 0.552017, -1.14673, -1.80179, -0.250794, -0.191147, + -1.88013, 1.64509, 0.599458, -0.488355, 0.239337, -0.419502, 0.433115, + -0.297457, -1.58134, -0.718115, 0.298562, -0.756691, 0.249923, -1.99909, + 0.818194, -0.540811, 1.37449, 0.816915, 0.870149, -1.85656, -0.631336, + -1.65317, 0.0609156, 1.26583, -0.337225, -0.351895, -1.6922, -1.41612, + -0.133992, -0.637481, -0.688252, -0.58494, -0.639447, -0.796788, -1.95253, + -0.683125, 0.489376, -0.558102, 1.51232, -0.866393, 0.202415, -2.18996, + 0.469978, 0.51911, -1.46145, -1.41838, -2.25265, -0.766826, 0.316467, + -0.477101, -1.99732, -0.802287, -0.275634, 0.43074, 0.493704, -0.863341, + -1.77821, 0.0782743, -0.424081, -0.747217, -0.437439, -1.72169, -0.382408, + 1.59882, 0.229693, 1.09085, 0.583644, -0.268739, -0.168567, -0.26805, + 1.15586, -0.36085, -0.881849, -0.156566, -1.18596, 0.760397, 1.97544, + 0.223005, 0.223615, 0.627005, -1.50396, -0.232102, -0.550169, -0.739993, + 0.643522, 0.130261, -0.632877, 0.5045, -1.89587, -0.868859, -0.654663, + -1.13629, 0.549711, -0.484092, 0.000221789, 0.113144, -1.63008, -1.38763, + 1.24321, 0.579567, 0.638793, 1.66319, -2.00534, 0.713527, 0.142681, + -0.348421, -0.980949, -0.300271, -0.94567, 1.08636, -0.359409, -1.39501, + -0.805637, -0.0803066, 0.857691, -0.780778, -0.504082, 0.457663, 0.867831, + -0.644647, -1.02772, 1.34149, 0.0427136, -1.42178, -1.03076, 0.164061, + 1.04671, 0.889346, -0.460635, -1.08915, 0.139808, -0.886619, -0.000123426, + -0.386286, 1.12272, -0.311824, -0.654093, 0.57904, 0.954279, -2.18804, + 0.623723, 0.243464, -0.898606, -1.33614, -2.78988, -0.640967, -1.51732, + -1.20855, -0.519717, -0.988904, -2.10062, 0.213738, -0.672715, 0.502266, + -1.439, 2.30819, -0.46356, -1.8637, -0.439576, -0.96649, 0.603132, + -1.02504, -0.886298, -1.20854, -0.289504, 1.32328, 0.335325, -0.421339, + -1.45944, -0.724789, 0.650192, -0.860273, -0.664577, 0.133231, 0.550855, + 2.52338, -0.389135, -0.16695, -0.826752, 0.0419003, -1.49016, -1.29609, + -0.562022, 0.936668, -0.701746, 1.59248, -0.527444, -0.573293, 0.76016, + 0.777361, -1.0478, -0.128279, 0.238765, -0.490994, -0.652953, 0.0173612, + -1.74518, -0.492311, -1.17539, -0.501837, 0.636348, -0.708254, -0.544971, + -1.10855, -0.637522, 1.0825, 0.594793, 0.0505524, -0.802418, -0.0183533, + -1.02712, -0.77603, 1.87559, -0.571897, -0.817117, 0.352893, 0.387498, + 1.23008, -1.04518, 1.01526, -0.278199, 0.0610645, -0.721664, 0.202913, + 1.3773, 1.52253, -0.361695, -0.147652, 0.527706, -1.31543, 1.53912, + -0.489441, 0.0468228, -0.0520686, 0.37135, 0.396255, 0.461767, 0.474904, + 0.373609, -1.80432, -0.429407, 0.913289, 0.446848, -0.290926, 0.246727, + 0.715222, -0.0807099, 0.452465, -0.352157, 0.831232, -1.17139, 1.49571, + -0.256195, -1.46225, 1.08797, -0.258473, 0.407301, 0.496463, 2.39975, + -0.874556, -0.206421, 0.507279, -1.29064, 3.40726, 0.515469, -0.847795, + 0.538463, -0.600921, 0.813678, -2.17265, 0.851656, -0.0720263, -0.237789, + -0.638447, -0.715811, 0.673846, -0.509011, 1.13158, 1.87334, -1.3717, + 0.198809, -1.16079, -2.58462, -0.348852, -0.499339, 1.54316, -0.70976, + 0.40827, 0.09174, -0.81059, -0.9426, 0.83612, 1.54785, 0.52762, + 0.29894, 0.51916, 0.18861, -0.48849, 0.37812, -0.55108, 1.16715, + 0.18136, -2.05249, -0.52255, -3.17783, -1.23747, 0.96977, 0.94546, + 0.53913, -0.87105, -0.23456, 0.94457, 0.33633, 1.67793, -0.64343, + 1.64205, -0.27973, -0.28263, 0.59008, 0.46478, -0.67101, -0.97408, + 1.09617, -1.49934, 1.65995, 0.48849, 2.09345, 1.12164, 0.004, + 0.41253, -1.82862, 0.02542, 0.15923, -0.02319, -1.1092, 0.33385, + 0.44133, 0.52444, 0.17839, 0.35989, 1.74398, 3.12485, 0.32066, + 1.8353, 0.05643, -1.48623, -2.8664, -0.9859, 2.43114, -0.35453, + 1.19489, 1.21466, 0.43728, 1.17409, 1.10532, -0.54684, -0.27415, + 2.10486, 0.42502, 0.31442, 1.09029, -1.3999, 0.34544, -0.04665, + 1.93631, -1.80739, -0.72702, 0.26711, -0.49996, -0.0828, 1.4522, + -1.87458, 2.56975, -0.35613, -0.82097, -0.73138, 0.64083, -0.82572, + 0.29276, -1.20876, 0.17214, -0.98109, -0.93785, 0.23218, 0.74644, + -0.55791, 0.7636, -2.14018, 0.41112, -1.00239, 0.71319, 0.23396, + 1.44287, 0.0825, 2.22372, -0.53487, 2.49612, -1.27615, -0.19306, + -0.08782, 1.84034, 0.32768, 0.56849, 0.04441, -0.65827, -0.23335, + -0.23385, -0.99781, 0.02744, 0.27255, -0.16298, 0.16596, -0.43781, + 0.45376, 1.616, 0.4351, -0.15598, 0.02663, 0.20136, 0.16456, + 2.32006, -0.17321, -0.29193, 0.91427, 1.58508, -0.4604, -1.24638, + -0.6782, -1.15898, 0.17969, 1.22766, 2.75006, 0.42565, -0.70687, + 0.32029, 1.39965, 0.56489, 0.71663, 0.96956, 0.11987, 0.13721, + -1.87023, -0.2301, 0.06482, 0.04463, 1.68528, 0.76126, 1.49722, + -0.38899, 0.55481, 1.01654, 1.39907, 0.52457, -1.95718, 0.11925, + -0.93854, -0.97164, 0.36083, 0.15714, -1.08132, 0.18311, 0.04688, + 0.84368, -0.93179, -0.51981, -0.10838, -0.21285, -1.5043, 0.11392, + 1.6267, -1.4901, 0.52821, -0.38828, 0.48409, -0.39124, 0.59059, + -0.0953, -0.2562, -1.30062, 0.56251, -0.36103, 0.99661, 1.14148, + 1.92076, -0.422, -0.98999, 2.53498, -0.86633, 0.25785, -1.04939, + 1.2423, 0.19481, -0.23041, 0.2357, -1.02525, 1.1626, -1.04002, + 0.22364, 0.28499, -0.34476, -0.67831, -0.5575, -0.91398, 1.37583, + 0.64503, -2.02422, -1.52848, -0.27042, 0.41021, -0.43892, 0.75682, + 0.18781, 0.92758, 0.5046, 0.73314, 0.03367, -0.27875, -0.63667, + 0.18394, 1.42434, 2.0077, -0.88286, -0.55983, -1.12401, 0.34193, + -2.51687, -1.04707, -0.6397, -0.70438, 0.59782, 0.74183, 0.31749, + -0.28442, -1.95803, -1.79381, 0.46461, -0.17142, 0.41181, 0.27836, + -0.02363, 0.93865}; + + EXPECT(migraphx::verify::verify_rms_range(qkv_rotary_vals, qkv_rotary_gold)); +} + +TEST_CASE(rotary_embedding_interleaved_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + const size_t batch_size = 2; + const size_t sequence_length = 1; + const size_t num_heads = 2; + const size_t kv_num_heads = 1; + const size_t head_size = 16; + const size_t max_cache_sequence_length = 8; + const size_t total_sequence_length = 8; + const size_t max_rotary_seq_length = max_cache_sequence_length; + const size_t rotary_dim = head_size / 2; + const bool interleaved = true; + + migraphx::shape qkv_shape{ + migraphx::shape::float_type, + {batch_size, num_heads + 2 * kv_num_heads, sequence_length, head_size}}; + migraphx::shape key_total_sequence_lens_shape(migraphx::shape::int32_type, {batch_size}); + migraphx::shape cos_cache_shape(migraphx::shape::float_type, + {max_rotary_seq_length, rotary_dim}); + migraphx::shape sin_cache_shape(migraphx::shape::float_type, + {max_rotary_seq_length, rotary_dim}); + + auto qkv = mm->add_parameter("qkv", qkv_shape); + auto ktsl = mm->add_parameter("ktsl", key_total_sequence_lens_shape); + auto cos_cache = mm->add_parameter("cos_cache", cos_cache_shape); + auto sin_cache = mm->add_parameter("sin_cache", sin_cache_shape); + + auto rotary = mm->add_instruction(migraphx::make_op("rotary_embedding", + {{"num_heads", num_heads}, + {"kv_num_heads", kv_num_heads}, + {"interleaved", interleaved}}), + qkv, + ktsl, + cos_cache, + sin_cache); + mm->add_return({rotary}); + + std::vector qkv_val{ + -0.65048f, -0.73475f, -1.16252f, 1.23505f, 0.30815f, 1.41725f, -0.99702f, 1.83288f, + 0.17508f, -0.44192f, -0.60220f, 0.57942f, -1.13502f, -0.21030f, -0.21183f, -0.59764f, + 0.03369f, -2.07573f, 0.26817f, 0.79531f, 0.82783f, 1.75045f, -0.13390f, 0.55881f, + -0.68510f, -0.22383f, -1.07129f, -0.37183f, 0.59560f, -0.24106f, 0.72188f, 0.96579f, + 1.10218f, 0.71842f, -0.13842f, -1.18598f, -0.97063f, 0.34577f, -0.09583f, -0.12853f, + -0.86645f, -0.41244f, 0.26598f, -0.01910f, -0.14762f, -0.01239f, -0.42813f, -0.25926f, + 0.51649f, -1.49558f, -0.64219f, -1.09694f, 0.17579f, -0.52930f, 0.99243f, -0.48142f, + 2.87901f, -0.34344f, 1.67369f, -1.01097f, 1.18906f, 0.79148f, 0.03848f, -0.08710f, + 1.06663f, 0.53670f, 0.18055f, 0.06149f, -1.27977f, 1.04653f, -0.09765f, 0.66428f, + 1.37472f, -0.90719f, 2.09439f, -0.54025f, 0.67836f, -0.12357f, -1.05392f, 1.01185f, + 1.20504f, -0.34832f, 1.38105f, -0.43522f, -0.08815f, -0.12122f, 0.66614f, -0.24025f, + 0.78835f, -1.20457f, -1.46959f, 1.04148f, -0.15917f, -0.11583f, 0.32542f, -0.82666f, + -1.77097f, -0.02825f, 0.03732f, 0.36346f, -1.30739f, 0.38991f, -0.11768f, 0.23001f, + -1.09870f, -0.80748f, 1.09745f, 0.33018f, -0.80205f, 0.10119f, -1.22517f, -0.54121f, + 0.50709f, -0.39303f, -0.94137f, 0.54072f, -0.17975f, 0.04328f, 0.37207f, 2.18807f, + -0.53601f, -0.44769f, 2.41322f, -1.96112f, -0.13698f, 0.57829f, -1.85719f, 0.77514f}; + + std::vector ktsl_val(key_total_sequence_lens_shape.elements(), total_sequence_length); + + std::vector cos_cache_val{ + 0.15911f, -0.05395f, -0.59862f, 0.98028f, -0.55443f, -0.81122f, -0.47045f, 0.91331f, + -0.84324f, 0.98846f, -0.71105f, -0.79251f, -0.57709f, -0.82186f, -0.07956f, 0.78094f, + -0.65349f, 0.99776f, 0.91546f, -0.05603f, -0.37510f, 0.05328f, -0.90493f, -0.50526f, + -0.67425f, 0.99716f, -0.51956f, 0.52490f, -0.45744f, 0.85886f, 0.87088f, 0.20695f, + 0.10653f, -0.79970f, 0.00293f, -0.81702f, 0.38169f, 0.86232f, 0.99098f, -0.75832f, + 0.14130f, -0.97297f, 0.44135f, 0.97386f, -0.08836f, -0.99652f, -0.29374f, 0.67556f, + -0.24402f, 0.51681f, 0.01236f, -0.04189f, 0.89800f, -0.76051f, -0.50112f, -0.96679f, + -0.76057f, -0.74721f, -0.56633f, 0.05373f, 0.51381f, 0.76501f, 0.19223f, -0.36373f}; + + std::vector sin_cache_val{ + -0.98726f, 0.99854f, 0.80103f, 0.19763f, 0.83223f, -0.58474f, -0.88243f, 0.40727f, + -0.53754f, -0.15150f, 0.70314f, -0.60985f, 0.81668f, 0.56969f, -0.99683f, 0.62461f, + -0.75693f, -0.06687f, 0.40241f, 0.99843f, -0.92699f, -0.99858f, 0.42557f, -0.86297f, + -0.73850f, -0.07530f, -0.85444f, 0.85116f, 0.88924f, 0.51221f, -0.49150f, -0.97835f, + -0.99431f, -0.60041f, -1.00000f, 0.57661f, 0.92429f, 0.50637f, 0.13403f, 0.65189f, + -0.98997f, -0.23095f, 0.89734f, -0.22716f, 0.99609f, 0.08341f, 0.95588f, -0.73730f, + -0.96977f, 0.85610f, -0.99992f, -0.99912f, 0.44000f, -0.64933f, -0.86538f, 0.25557f, + 0.64926f, 0.66459f, 0.82418f, -0.99856f, 0.85790f, -0.64402f, -0.98135f, 0.93151f}; + + p.compile(migraphx::make_target("ref")); + migraphx::parameter_map pm; + pm["qkv"] = migraphx::argument(qkv_shape, qkv_val.data()); + pm["ktsl"] = migraphx::argument(key_total_sequence_lens_shape, ktsl_val.data()); + pm["cos_cache"] = migraphx::argument(cos_cache_shape, cos_cache_val.data()); + pm["sin_cache"] = migraphx::argument(sin_cache_shape, sin_cache_val.data()); + + auto qkv_rotary = p.eval(pm).front(); + std::vector qkv_rotary_vals(qkv_shape.elements()); + qkv_rotary.visit([&](auto output) { qkv_rotary_vals.assign(output.begin(), output.end()); }); + + std::vector qkv_rotary_gold{ + 0.971779, 0.136498, 0.0478448, -1.69544, -1.34258, -0.54866, 1.77667, + 1.09406, 0.469081, -0.0768618, -0.0875309, 0.831091, -0.424563, 1.07343, + 0.633757, 0.0200578, 1.32206, 1.60061, -0.728934, -0.416041, -1.91151, + -0.309051, 0.550811, 0.163732, -0.159987, -0.702753, -1.05901, 0.405479, + -0.122072, -0.630831, -1.16221, 0.321152, -1.30473, 0.169193, 0.891619, + 0.794184, 0.26472, -0.995794, -0.133494, 0.0887861, -0.0913584, -0.955243, + 0.191177, -0.185908, -0.0405359, 0.142485, 0.397227, -0.304507, 0.51649, + -1.49558, -0.64219, -1.09694, 0.17579, -0.5293, 0.99243, -0.48142, + 2.87901, -0.34344, 1.67369, -1.01097, 1.18906, 0.79148, 0.03848, + -0.0871, -1.1597, 0.284322, -0.175774, 0.0740458, -0.137757, -1.64744, + 0.658077, 0.133201, 1.48462, 0.713249, 1.2543, -1.76213, 0.00913571, + -0.689462, -0.559206, -1.34978, -0.690367, 1.04731, -0.742692, 1.24303, + 0.149829, -0.00400095, -0.204112, -0.678089, 1.43846, 0.0574054, -0.453517, + 1.74319, -0.144267, 0.133935, 0.651677, 0.603813, 1.36529, -1.12833, + -0.269438, -0.246778, 0.419058, -1.29834, 0.223356, 0.129869, 0.128214, + -1.35747, 1.0522, -0.454189, -0.0548753, 0.806544, 0.949774, -0.944404, + 0.50709, -0.39303, -0.94137, 0.54072, -0.17975, 0.04328, 0.37207, + 2.18807, -0.53601, -0.44769, 2.41322, -1.96112, -0.13698, 0.57829, + -1.85719, 0.77514}; + + EXPECT(migraphx::verify::verify_rms_range(qkv_rotary_vals, qkv_rotary_gold)); +} From 861c97339e8744c0e8f65290873fb17f7c98d460 Mon Sep 17 00:00:00 2001 From: Dino Music Date: Fri, 10 Oct 2025 09:56:04 +0000 Subject: [PATCH 2/2] Remove seqlens_k decrementing --- src/include/migraphx/op/rotary_embedding.hpp | 5 +---- test/ref/rotary_embedding.cpp | 4 ++-- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/include/migraphx/op/rotary_embedding.hpp b/src/include/migraphx/op/rotary_embedding.hpp index c519311c45a..6cbdb481c38 100644 --- a/src/include/migraphx/op/rotary_embedding.hpp +++ b/src/include/migraphx/op/rotary_embedding.hpp @@ -170,10 +170,7 @@ struct rotary_embedding : 1); if(params.position_ids_use_batch) { - std::transform(seqlens_k.begin(), - seqlens_k.end(), - pos_ids.begin(), - [](auto len) { return len - 1; }); + std::copy(seqlens_k.begin(), seqlens_k.end(), pos_ids.begin()); } else { diff --git a/test/ref/rotary_embedding.cpp b/test/ref/rotary_embedding.cpp index 371484f4738..55d9c1b9761 100644 --- a/test/ref/rotary_embedding.cpp +++ b/test/ref/rotary_embedding.cpp @@ -201,7 +201,7 @@ TEST_CASE(rotary_embedding_test) -2.51687f, -1.04707f, -0.63970f, -0.70438f, 0.59782f, 0.74183f, 0.31749f, -0.28442f, -1.95803f, -1.79381f, 0.46461f, -0.17142f, 0.41181f, 0.27836f, -0.02363f, 0.93865f}; - std::vector ktsl_val(key_total_sequence_lens_shape.elements(), total_sequence_length); + std::vector ktsl_val(key_total_sequence_lens_shape.elements(), total_sequence_length - 1); std::vector cos_cache_val{ 0.60305f, 0.94544f, 0.59646f, -0.94253f, -0.92642f, 0.19489f, -0.97555f, -0.99972f, @@ -444,7 +444,7 @@ TEST_CASE(rotary_embedding_interleaved_test) 0.50709f, -0.39303f, -0.94137f, 0.54072f, -0.17975f, 0.04328f, 0.37207f, 2.18807f, -0.53601f, -0.44769f, 2.41322f, -1.96112f, -0.13698f, 0.57829f, -1.85719f, 0.77514f}; - std::vector ktsl_val(key_total_sequence_lens_shape.elements(), total_sequence_length); + std::vector ktsl_val(key_total_sequence_lens_shape.elements(), total_sequence_length - 1); std::vector cos_cache_val{ 0.15911f, -0.05395f, -0.59862f, 0.98028f, -0.55443f, -0.81122f, -0.47045f, 0.91331f,