Skip to content

Commit 6cffd1a

Browse files
authored
Add RotaryEmbeddings(23) - CPU (microsoft#24980)
### Description <!-- Describe your changes. --> Add ONNX RotaryEmbedding(23) following https://github.com/onnx/onnx/blob/main/docs/Operators.md#RotaryEmbedding. The PR uses contrib op RotaryEmbedding implementation under the hood. The main difference between this op and the contrib op is that the `position_ids` in ONNX RotaryEmbedding is optional. When it's not provided, `cos_cache` and `sin_cache` should be 3d. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Fix microsoft#24556 Reference microsoft#23507
1 parent 7402b6c commit 6cffd1a

File tree

10 files changed

+1495
-11
lines changed

10 files changed

+1495
-11
lines changed

docs/OperatorKernels.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ Do not modify directly.*
371371
|ReverseSequence|*in* input:**T**<br> *in* sequence_lens:**tensor(int64)**<br> *out* Y:**T**|10+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
372372
|RoiAlign|*in* X:**T1**<br> *in* rois:**T1**<br> *in* batch_indices:**T2**<br> *out* Y:**T1**|16+|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(int64)|
373373
|||[10, 15]|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(int64)|
374+
|RotaryEmbedding|*in* X:**T**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**M**<br> *out* Y:**T**|23+|**M** = tensor(int64)<br/> **T** = tensor(float), tensor(float16)|
374375
|Round|*in* X:**T**<br> *out* Y:**T**|22+|**T** = tensor(double), tensor(float), tensor(float16)|
375376
|||[11, 21]|**T** = tensor(double), tensor(float), tensor(float16)|
376377
|STFT|*in* signal:**T1**<br> *in* frame_step:**T2**<br> *in* window:**T1**<br> *in* frame_length:**T2**<br> *out* output:**T1**|17+|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(int32), tensor(int64)|

onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,20 @@ Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters paramete
5757
const int position_ids_format = parameters.position_ids_format;
5858
const int rotary_emb_dim = parameters.rotary_embedding_dim;
5959
const int half_rotary_emb_dim = rotary_emb_dim / 2;
60-
60+
// Parallel to calculate based on head_size
6161
const int loop_len = batch_size * sequence_length * n_heads;
62-
const double cost = static_cast<double>(rotary_emb_dim);
62+
// The cost is calculated as:
63+
// - head_size * sizeof(T) for reading input
64+
// - head_size * sizeof(T) for writing output
65+
// - rotary_emb_dim * 32 for the rotary embedding operations (32 is an approximation of the number of CPU cycles)
66+
const double cost = static_cast<double>(head_size * sizeof(T) * 2 + rotary_emb_dim * 32);
6367
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
6468
for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
6569
const int b = static_cast<int>((ptr / n_heads) / sequence_length);
6670
const int s = static_cast<int>((ptr / n_heads) % sequence_length);
6771
const int n = static_cast<int>(ptr % n_heads);
68-
72+
// Identify the index of batch, sequence, and head (specific range) in the input/output tensor
73+
// for read/write
6974
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
7075

7176
const T* input_data = input + block_offset;

onnxruntime/core/mlas/lib/rotary_embedding_kernel_avx2.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,11 @@ RopeKernel_Avx2_fp32_Impl<true>(
235235
__m256i in_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0);
236236
float32x8_t real = _mm256_permutevar8x32_ps(real_s, in_mask_vec);
237237
float32x8_t imag = _mm256_permutevar8x32_ps(imag_s, in_mask_vec);
238-
float32x8_t sin_val = _mm256_loadu_ps(sin_data+ i / 2);
239-
float32x8_t cos_val = _mm256_loadu_ps(cos_data + i / 2);
238+
// Use masked loads for sin/cos data to avoid reading beyond buffer bounds
239+
size_t cos_sin_rem = rem / 2;
240+
const __m256i cos_sin_mask = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - cos_sin_rem));
241+
float32x8_t sin_val = _mm256_maskload_ps(sin_data + i / 2, cos_sin_mask);
242+
float32x8_t cos_val = _mm256_maskload_ps(cos_data + i / 2, cos_sin_mask);
240243
//Compute Real and Imaginary output values
241244
float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val));
242245
float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val));

onnxruntime/core/providers/cpu/cpu_execution_provider.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,6 +1332,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, Si
13321332
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, float, RMSNormalization);
13331333
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, double, RMSNormalization);
13341334
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, MLFloat16, RMSNormalization);
1335+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, float, RotaryEmbedding);
1336+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, MLFloat16, RotaryEmbedding);
13351337

13361338
// !!PLEASE READ BELOW!! Following that, add new entries above this comment
13371339

@@ -3318,6 +3320,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
33183320
RMSNormalization)>,
33193321
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, MLFloat16,
33203322
RMSNormalization)>,
3323+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, float,
3324+
RotaryEmbedding)>,
3325+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 23, MLFloat16,
3326+
RotaryEmbedding)>,
33213327
};
33223328
for (auto& function_table_entry : function_table) {
33233329
KernelCreateInfo info = function_table_entry();
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/cpu/llm/rotary_embedding.h"
5+
#include "core/providers/cpu/llm/rotary_embedding_helper.h"
6+
7+
#include "core/mlas/inc/mlas.h"
8+
#include "core/platform/threadpool.h"
9+
10+
using onnxruntime::concurrency::ThreadPool;
11+
using namespace onnxruntime::rotary_embedding_helper;
12+
13+
namespace onnxruntime {
14+
15+
#define REGISTER_ONNX_KERNEL_TYPED(T) \
16+
ONNX_CPU_OPERATOR_TYPED_KERNEL( \
17+
RotaryEmbedding, \
18+
23, \
19+
T, \
20+
KernelDefBuilder() \
21+
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
22+
.TypeConstraint("M", DataTypeImpl::GetTensorType<int64_t>()), \
23+
RotaryEmbedding<T>);
24+
25+
REGISTER_ONNX_KERNEL_TYPED(float)
26+
REGISTER_ONNX_KERNEL_TYPED(MLFloat16)
27+
28+
template <typename T>
29+
RotaryEmbedding<T>::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) {
30+
num_heads = static_cast<int>(info.GetAttrOrDefault<int64_t>("num_heads", 0));
31+
rotary_embedding_dim = static_cast<int>(info.GetAttrOrDefault<int64_t>("rotary_embedding_dim", 0));
32+
interleaved = (info.GetAttrOrDefault<int64_t>("interleaved", 0) == 1); // Turn 0/1 into bool
33+
34+
if (rotary_embedding_dim > 0) {
35+
ORT_ENFORCE(num_heads > 0, "num_heads must be provided if rotary_embedding_dim is specified");
36+
}
37+
}
38+
39+
// TODO: rotary embedding in place
40+
template <typename T>
41+
Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters parameters, const T* input,
42+
const int64_t* position_ids, const T* cos_cache, const T* sin_cache, T* output,
43+
bool interleaved) {
44+
const int batch_size = parameters.batch_size;
45+
const int sequence_length = parameters.sequence_length;
46+
const int n_heads = parameters.num_heads;
47+
const int head_size = parameters.head_size;
48+
const int head_stride = parameters.head_stride;
49+
const int seq_stride = parameters.seq_stride;
50+
const int batch_stride = parameters.batch_stride;
51+
const int position_ids_format = parameters.position_ids_format;
52+
const int rotary_emb_dim = parameters.rotary_embedding_dim;
53+
const int half_rotary_emb_dim = rotary_emb_dim / 2;
54+
// Parallel to calculate based on head_size
55+
const int loop_len = batch_size * sequence_length * n_heads;
56+
// The cost is calculated as:
57+
// - head_size * sizeof(T) for reading input
58+
// - head_size * sizeof(T) for writing output
59+
// - rotary_emb_dim * 32 for the rotary embedding operations (32 is an approximation of the number of CPU cycles)
60+
const double cost = static_cast<double>(head_size * sizeof(T) * 2 + rotary_emb_dim * 32);
61+
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
62+
for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
63+
const int b = static_cast<int>((ptr / n_heads) / sequence_length);
64+
const int s = static_cast<int>((ptr / n_heads) % sequence_length);
65+
const int n = static_cast<int>(ptr % n_heads);
66+
// Identify the index of batch, sequence, and head (specific range) in the input/output tensor
67+
// for read/write
68+
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
69+
const T* input_data = input + block_offset;
70+
T* output_data = output + block_offset;
71+
72+
const T* cos_data;
73+
const T* sin_data;
74+
int cache_offset;
75+
if (position_ids_format == 0) {
76+
cache_offset = (b * sequence_length + s) * half_rotary_emb_dim;
77+
} else {
78+
// Cache is (M, H/2) or (M, rotary_embedding_dim/2)
79+
const int position_id = static_cast<int>(position_ids[b * sequence_length + s]);
80+
cache_offset = position_id * half_rotary_emb_dim;
81+
}
82+
cos_data = cos_cache + cache_offset;
83+
sin_data = sin_cache + cache_offset;
84+
85+
MlasRotaryEmbedOneRow<T>(input_data, sin_data, cos_data, rotary_emb_dim, interleaved, output_data);
86+
87+
if (rotary_emb_dim < head_size) {
88+
std::memcpy(output_data + rotary_emb_dim,
89+
input_data + rotary_emb_dim,
90+
(head_size - rotary_emb_dim) * sizeof(T));
91+
}
92+
}
93+
});
94+
95+
return Status::OK();
96+
}
97+
98+
template Status RunRotaryEmbedding<float>(concurrency::ThreadPool* tp, RotaryParameters parameters, const float* input,
99+
const int64_t* position_ids, const float* cos_cache, const float* sin_cache, float* output,
100+
bool interleaved);
101+
102+
template Status RunRotaryEmbedding<MLFloat16>(concurrency::ThreadPool* tp, RotaryParameters parameters, const MLFloat16* input,
103+
const int64_t* position_ids, const MLFloat16* cos_cache, const MLFloat16* sin_cache,
104+
MLFloat16* output, bool interleaved);
105+
106+
template <typename T>
107+
Status RotaryEmbedding<T>::Compute(OpKernelContext* context) const {
108+
const Tensor* X = context->Input<Tensor>(0);
109+
const Tensor* cos_cache = context->Input<Tensor>(1);
110+
const Tensor* sin_cache = context->Input<Tensor>(2);
111+
// Optional position_ids input, can be nullptr
112+
const Tensor* position_ids = context->Input<Tensor>(3);
113+
114+
RotaryParameters parameters = {};
115+
ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs<Tensor>(X,
116+
position_ids,
117+
cos_cache,
118+
sin_cache,
119+
num_heads,
120+
rotary_embedding_dim,
121+
&parameters));
122+
123+
Tensor* output = context->Output(0, X->Shape());
124+
125+
const T* x_src = X->Data<T>();
126+
const int64_t* pos_ids_data = (nullptr == position_ids) ? nullptr : position_ids->Data<int64_t>();
127+
const T* cos_cache_data = cos_cache->Data<T>();
128+
const T* sin_cache_data = sin_cache->Data<T>();
129+
T* output_dest = output->MutableData<T>();
130+
131+
AllocatorPtr allocator;
132+
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
133+
auto* tp = context->GetOperatorThreadPool();
134+
135+
return RunRotaryEmbedding<T>(tp, parameters, x_src, pos_ids_data, cos_cache_data, sin_cache_data, output_dest,
136+
interleaved);
137+
}
138+
139+
} // namespace onnxruntime
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
#include "core/common/common.h"
6+
#include "core/framework/op_kernel.h"
7+
#include "core/providers/cpu/llm/rotary_embedding_helper.h"
8+
9+
namespace onnxruntime {
10+
11+
template <typename T>
12+
Status RunRotaryEmbedding(onnxruntime::concurrency::ThreadPool* tp, rotary_embedding_helper::RotaryParameters parameters, const T* input,
13+
const int64_t* position_ids, const T* cos_cache, const T* sin_cache, T* output,
14+
bool interleaved);
15+
16+
template <typename T>
17+
class RotaryEmbedding final : public OpKernel {
18+
public:
19+
RotaryEmbedding(const OpKernelInfo& info);
20+
Status Compute(OpKernelContext* context) const override;
21+
22+
protected:
23+
int num_heads;
24+
int rotary_embedding_dim;
25+
int interleaved;
26+
};
27+
28+
} // namespace onnxruntime

0 commit comments

Comments
 (0)