|
| 1 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
| 4 | +#include "core/providers/cpu/llm/rotary_embedding.h" |
| 5 | +#include "core/providers/cpu/llm/rotary_embedding_helper.h" |
| 6 | + |
| 7 | +#include "core/mlas/inc/mlas.h" |
| 8 | +#include "core/platform/threadpool.h" |
| 9 | + |
| 10 | +using onnxruntime::concurrency::ThreadPool; |
| 11 | +using namespace onnxruntime::rotary_embedding_helper; |
| 12 | + |
| 13 | +namespace onnxruntime { |
| 14 | + |
| 15 | +#define REGISTER_ONNX_KERNEL_TYPED(T) \ |
| 16 | + ONNX_CPU_OPERATOR_TYPED_KERNEL( \ |
| 17 | + RotaryEmbedding, \ |
| 18 | + 23, \ |
| 19 | + T, \ |
| 20 | + KernelDefBuilder() \ |
| 21 | + .TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \ |
| 22 | + .TypeConstraint("M", DataTypeImpl::GetTensorType<int64_t>()), \ |
| 23 | + RotaryEmbedding<T>); |
| 24 | + |
| 25 | +REGISTER_ONNX_KERNEL_TYPED(float) |
| 26 | +REGISTER_ONNX_KERNEL_TYPED(MLFloat16) |
| 27 | + |
| 28 | +template <typename T> |
| 29 | +RotaryEmbedding<T>::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) { |
| 30 | + num_heads = static_cast<int>(info.GetAttrOrDefault<int64_t>("num_heads", 0)); |
| 31 | + rotary_embedding_dim = static_cast<int>(info.GetAttrOrDefault<int64_t>("rotary_embedding_dim", 0)); |
| 32 | + interleaved = (info.GetAttrOrDefault<int64_t>("interleaved", 0) == 1); // Turn 0/1 into bool |
| 33 | + |
| 34 | + if (rotary_embedding_dim > 0) { |
| 35 | + ORT_ENFORCE(num_heads > 0, "num_heads must be provided if rotary_embedding_dim is specified"); |
| 36 | + } |
| 37 | +} |
| 38 | + |
| 39 | +// TODO: rotary embedding in place |
| 40 | +template <typename T> |
| 41 | +Status RunRotaryEmbedding(concurrency::ThreadPool* tp, RotaryParameters parameters, const T* input, |
| 42 | + const int64_t* position_ids, const T* cos_cache, const T* sin_cache, T* output, |
| 43 | + bool interleaved) { |
| 44 | + const int batch_size = parameters.batch_size; |
| 45 | + const int sequence_length = parameters.sequence_length; |
| 46 | + const int n_heads = parameters.num_heads; |
| 47 | + const int head_size = parameters.head_size; |
| 48 | + const int head_stride = parameters.head_stride; |
| 49 | + const int seq_stride = parameters.seq_stride; |
| 50 | + const int batch_stride = parameters.batch_stride; |
| 51 | + const int position_ids_format = parameters.position_ids_format; |
| 52 | + const int rotary_emb_dim = parameters.rotary_embedding_dim; |
| 53 | + const int half_rotary_emb_dim = rotary_emb_dim / 2; |
| 54 | + // Parallel to calculate based on head_size |
| 55 | + const int loop_len = batch_size * sequence_length * n_heads; |
| 56 | + // The cost is calculated as: |
| 57 | + // - head_size * sizeof(T) for reading input |
| 58 | + // - head_size * sizeof(T) for writing output |
| 59 | + // - rotary_emb_dim * 32 for the rotary embedding operations (32 is an approximation of the number of CPU cycles) |
| 60 | + const double cost = static_cast<double>(head_size * sizeof(T) * 2 + rotary_emb_dim * 32); |
| 61 | + ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { |
| 62 | + for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) { |
| 63 | + const int b = static_cast<int>((ptr / n_heads) / sequence_length); |
| 64 | + const int s = static_cast<int>((ptr / n_heads) % sequence_length); |
| 65 | + const int n = static_cast<int>(ptr % n_heads); |
| 66 | + // Identify the index of batch, sequence, and head (specific range) in the input/output tensor |
| 67 | + // for read/write |
| 68 | + const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; |
| 69 | + const T* input_data = input + block_offset; |
| 70 | + T* output_data = output + block_offset; |
| 71 | + |
| 72 | + const T* cos_data; |
| 73 | + const T* sin_data; |
| 74 | + int cache_offset; |
| 75 | + if (position_ids_format == 0) { |
| 76 | + cache_offset = (b * sequence_length + s) * half_rotary_emb_dim; |
| 77 | + } else { |
| 78 | + // Cache is (M, H/2) or (M, rotary_embedding_dim/2) |
| 79 | + const int position_id = static_cast<int>(position_ids[b * sequence_length + s]); |
| 80 | + cache_offset = position_id * half_rotary_emb_dim; |
| 81 | + } |
| 82 | + cos_data = cos_cache + cache_offset; |
| 83 | + sin_data = sin_cache + cache_offset; |
| 84 | + |
| 85 | + MlasRotaryEmbedOneRow<T>(input_data, sin_data, cos_data, rotary_emb_dim, interleaved, output_data); |
| 86 | + |
| 87 | + if (rotary_emb_dim < head_size) { |
| 88 | + std::memcpy(output_data + rotary_emb_dim, |
| 89 | + input_data + rotary_emb_dim, |
| 90 | + (head_size - rotary_emb_dim) * sizeof(T)); |
| 91 | + } |
| 92 | + } |
| 93 | + }); |
| 94 | + |
| 95 | + return Status::OK(); |
| 96 | +} |
| 97 | + |
| 98 | +template Status RunRotaryEmbedding<float>(concurrency::ThreadPool* tp, RotaryParameters parameters, const float* input, |
| 99 | + const int64_t* position_ids, const float* cos_cache, const float* sin_cache, float* output, |
| 100 | + bool interleaved); |
| 101 | + |
| 102 | +template Status RunRotaryEmbedding<MLFloat16>(concurrency::ThreadPool* tp, RotaryParameters parameters, const MLFloat16* input, |
| 103 | + const int64_t* position_ids, const MLFloat16* cos_cache, const MLFloat16* sin_cache, |
| 104 | + MLFloat16* output, bool interleaved); |
| 105 | + |
| 106 | +template <typename T> |
| 107 | +Status RotaryEmbedding<T>::Compute(OpKernelContext* context) const { |
| 108 | + const Tensor* X = context->Input<Tensor>(0); |
| 109 | + const Tensor* cos_cache = context->Input<Tensor>(1); |
| 110 | + const Tensor* sin_cache = context->Input<Tensor>(2); |
| 111 | + // Optional position_ids input, can be nullptr |
| 112 | + const Tensor* position_ids = context->Input<Tensor>(3); |
| 113 | + |
| 114 | + RotaryParameters parameters = {}; |
| 115 | + ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs<Tensor>(X, |
| 116 | + position_ids, |
| 117 | + cos_cache, |
| 118 | + sin_cache, |
| 119 | + num_heads, |
| 120 | + rotary_embedding_dim, |
| 121 | + ¶meters)); |
| 122 | + |
| 123 | + Tensor* output = context->Output(0, X->Shape()); |
| 124 | + |
| 125 | + const T* x_src = X->Data<T>(); |
| 126 | + const int64_t* pos_ids_data = (nullptr == position_ids) ? nullptr : position_ids->Data<int64_t>(); |
| 127 | + const T* cos_cache_data = cos_cache->Data<T>(); |
| 128 | + const T* sin_cache_data = sin_cache->Data<T>(); |
| 129 | + T* output_dest = output->MutableData<T>(); |
| 130 | + |
| 131 | + AllocatorPtr allocator; |
| 132 | + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); |
| 133 | + auto* tp = context->GetOperatorThreadPool(); |
| 134 | + |
| 135 | + return RunRotaryEmbedding<T>(tp, parameters, x_src, pos_ids_data, cos_cache_data, sin_cache_data, output_dest, |
| 136 | + interleaved); |
| 137 | +} |
| 138 | + |
| 139 | +} // namespace onnxruntime |
0 commit comments