Skip to content

Commit 6cc06ad

Browse files
wangyemsYour Name
and
Your Name
authored
GQA MLFloat16 cpu (#22102)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: Your Name <[email protected]>
1 parent 5fa4505 commit 6cc06ad

File tree

7 files changed

+229
-135
lines changed

7 files changed

+229
-135
lines changed

docs/OperatorKernels.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,7 @@ Do not modify directly.*
482482
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
483483
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float)|
484484
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
485-
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float)|
485+
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
486486
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
487487
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
488488
|MatMulFpQ4|*in* A:**T1**<br> *in* B:**T2**<br> *in* B_shape:**T3**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int64)|
@@ -508,7 +508,7 @@ Do not modify directly.*
508508
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)|
509509
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
510510
|Range|*in* start:**T**<br> *in* limit:**T**<br> *in* delta:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
511-
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float)|
511+
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float), tensor(float16)|
512512
|SampleOp|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
513513
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *in* seed:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float)|
514514
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|

onnxruntime/contrib_ops/cpu/bert/attention_utils.cc

+14-6
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,13 @@ Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V dat
4848
constexpr size_t element_size = sizeof(T);
4949
ProcessBroadcastSpanFuncs add_funcs{
5050
[](BroadcastHelper& per_iter_bh) {
51-
per_iter_bh.OutputEigen<T>() = per_iter_bh.ScalarInput0<T>() + per_iter_bh.EigenInput1<T>().array();
51+
per_iter_bh.OutputEigen<float>() = per_iter_bh.ScalarInput0<float>() + per_iter_bh.EigenInput1<float>().array();
5252
},
5353
[](BroadcastHelper& per_iter_bh) {
54-
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>().array() + per_iter_bh.ScalarInput1<T>();
54+
per_iter_bh.OutputEigen<float>() = per_iter_bh.EigenInput0<float>().array() + per_iter_bh.ScalarInput1<float>();
5555
},
5656
[](BroadcastHelper& per_iter_bh) {
57-
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>() + per_iter_bh.EigenInput1<T>();
57+
per_iter_bh.OutputEigen<float>() = per_iter_bh.EigenInput0<float>() + per_iter_bh.EigenInput1<float>();
5858
}}; // For element-wise add
5959

6060
// Allocate space for output of Q(BS, D) + bias(D)
@@ -132,13 +132,13 @@ Status AddBiasReshape(const Tensor* qkv, // Input: Q/K/V data - query is
132132
constexpr size_t element_size = sizeof(T);
133133
ProcessBroadcastSpanFuncs add_funcs{
134134
[](BroadcastHelper& per_iter_bh) {
135-
per_iter_bh.OutputEigen<T>() = per_iter_bh.ScalarInput0<T>() + per_iter_bh.EigenInput1<T>().array();
135+
per_iter_bh.OutputEigen<float>() = per_iter_bh.ScalarInput0<float>() + per_iter_bh.EigenInput1<float>().array();
136136
},
137137
[](BroadcastHelper& per_iter_bh) {
138-
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>().array() + per_iter_bh.ScalarInput1<T>();
138+
per_iter_bh.OutputEigen<float>() = per_iter_bh.EigenInput0<float>().array() + per_iter_bh.ScalarInput1<float>();
139139
},
140140
[](BroadcastHelper& per_iter_bh) {
141-
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>() + per_iter_bh.EigenInput1<T>();
141+
per_iter_bh.OutputEigen<float>() = per_iter_bh.EigenInput0<float>() + per_iter_bh.EigenInput1<float>();
142142
}}; // For element-wise add
143143

144144
// Get Q's bias from combined bias
@@ -219,6 +219,10 @@ template Status MaybeTransposeToBNSHAndAddBias<float>(OpKernelContext* context,
219219
int batch_size, int num_heads, int sequence_length, int head_size,
220220
const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out);
221221

222+
template Status MaybeTransposeToBNSHAndAddBias<MLFloat16>(OpKernelContext* context, AllocatorPtr allocator,
223+
int batch_size, int num_heads, int sequence_length, int head_size,
224+
const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out);
225+
222226
template <typename T>
223227
Status MaybeTransposeToBNSH(AllocatorPtr allocator,
224228
int batch_size, int num_heads, int sequence_length, int head_size,
@@ -242,5 +246,9 @@ template Status MaybeTransposeToBNSH<float>(AllocatorPtr allocator,
242246
int batch_size, int num_heads, int sequence_length, int head_size,
243247
const Tensor* in, OrtValue& out);
244248

249+
template Status MaybeTransposeToBNSH<MLFloat16>(AllocatorPtr allocator,
250+
int batch_size, int num_heads, int sequence_length, int head_size,
251+
const Tensor* in, OrtValue& out);
252+
245253
} // namespace contrib
246254
} // namespace onnxruntime

onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h

+75-21
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class GQAAttentionBase {
7575
int seqlen_present_kv_cache = static_cast<int>(present_key->Shape().GetDims()[2]);
7676

7777
// Compute the attention score.
78-
size_t bytes = SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * sizeof(T);
78+
size_t bytes = SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * sizeof(float);
7979
auto attention_probs = allocator->Alloc(bytes);
8080
BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator));
8181

@@ -87,16 +87,17 @@ class GQAAttentionBase {
8787
bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data;
8888

8989
const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K;
90-
ComputeAttentionProbs<T>(static_cast<T*>(attention_probs), Q, k, seqlens_k->Data<int32_t>(), batch_size,
90+
ComputeAttentionProbs<T>(static_cast<float*>(attention_probs), Q, k, seqlens_k->Data<int32_t>(), batch_size,
9191
sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data,
92-
present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp);
92+
present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator);
9393

9494
// Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
9595
const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V;
96-
ComputeVxAttentionScore(output->MutableData<T>(), static_cast<T*>(attention_probs), v, seqlens_k->Data<int32_t>(),
96+
ComputeVxAttentionScore(output->MutableData<T>(), static_cast<float*>(attention_probs), v,
97+
seqlens_k->Data<int32_t>(),
9798
batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size,
9899
hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv,
99-
is_prompt, tp);
100+
is_prompt, tp, allocator);
100101

101102
return Status::OK();
102103
}
@@ -106,7 +107,7 @@ class GQAAttentionBase {
106107
// attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T)
107108
// attention_probs(B, N, S, T) = Softmax(attention_probs)
108109
template <typename T>
109-
void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT
110+
void ComputeAttentionProbs(float* attention_probs, // output buffer with size BxNxSxT
110111
const T* Q, // Q data. Its size is BxNxSxH
111112
const T* K, // k data. Its size is BxNxLxH
112113
const int32_t* seqlens_k, // total - 1 sequence lengths tensor
@@ -120,7 +121,8 @@ class GQAAttentionBase {
120121
const bool past_present_share_buffer, // whether present key and value share the same buffer
121122
const bool packed_qkv, // whether Q, K, V are packed
122123
const bool is_prompt, // whether it is prompt
123-
ThreadPool* tp) const { // thread pool
124+
ThreadPool* tp, // thread pool
125+
AllocatorPtr allocator) const { // allocator for temporary buffer
124126
const ptrdiff_t packed_batch_stride =
125127
packed_qkv ? SafeInt<ptrdiff_t>(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size
126128
: SafeInt<ptrdiff_t>(0);
@@ -131,7 +133,9 @@ class GQAAttentionBase {
131133
const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H
132134

133135
if (!past_present_share_buffer) {
134-
memset(present_key, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
136+
memset((void*)present_key,
137+
0,
138+
batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
135139
}
136140

137141
const size_t loop_len = batch_size * num_heads_;
@@ -164,7 +168,7 @@ class GQAAttentionBase {
164168
const size_t past_chunk_length = past_seqlen * head_size;
165169

166170
const ptrdiff_t output_offset = SafeInt<ptrdiff_t>(i) * sequence_length * present_buffer_sequence_length;
167-
T* output = attention_probs + output_offset;
171+
float* output = attention_probs + output_offset;
168172

169173
const T* k;
170174
if (packed_qkv) {
@@ -190,12 +194,28 @@ class GQAAttentionBase {
190194
q = Q + q_input_chunk_length * i;
191195
}
192196

193-
math::GemmEx<T, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q,
194-
static_cast<int>(head_size), k, static_cast<int>(head_size), 0.0f /*bata*/, output,
195-
static_cast<int>(present_buffer_sequence_length), nullptr);
197+
if constexpr (std::is_same<T, float>::value) {
198+
math::GemmEx<float, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q,
199+
static_cast<int>(head_size), k, static_cast<int>(head_size), 0.0f /*bata*/,
200+
output, static_cast<int>(present_buffer_sequence_length), nullptr);
201+
} else {
202+
size_t bytes = head_size * (sequence_length + total_seqlen) * sizeof(float);
203+
auto q_k_fp32 = allocator->Alloc(bytes);
204+
BufferUniquePtr scratch_buffer(q_k_fp32, BufferDeleter(allocator));
205+
206+
float* q_fp32 = static_cast<float*>(q_k_fp32);
207+
MlasConvertHalfToFloatBuffer(q, q_fp32, head_size * sequence_length);
208+
209+
float* k_fp32 = q_fp32 + head_size * sequence_length;
210+
MlasConvertHalfToFloatBuffer(k, k_fp32, head_size * total_seqlen);
211+
212+
math::GemmEx<float, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q_fp32,
213+
static_cast<int>(head_size), k_fp32, static_cast<int>(head_size), 0.0f /*bata*/,
214+
output, static_cast<int>(present_buffer_sequence_length), nullptr);
215+
}
196216

197217
// compute Softmax
198-
T* output_softmax = output;
218+
float* output_softmax = output;
199219
for (size_t seq = 0; seq < sequence_length; seq++) {
200220
size_t seq_causal_length = past_seqlen + seq + 1;
201221
if (local_window_size_ > 0 && seq_causal_length > static_cast<size_t>(local_window_size_) + 1) {
@@ -237,7 +257,7 @@ class GQAAttentionBase {
237257

238258
template <typename T>
239259
void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH
240-
const T* attention_probs, // Attention probs with size BxNxSxT
260+
const float* attention_probs, // Attention probs with size BxNxSxT
241261
const T* V, // V value with size BxN_kvxSxH
242262
const int32_t* seqlens_k, // total - 1 sequence lengths tensor
243263
const size_t batch_size, // batch size
@@ -251,7 +271,8 @@ class GQAAttentionBase {
251271
const bool past_present_share_buffer, // whether present key and value share the same buffer
252272
const bool packed_qkv, // whether Q, K, V are packed
253273
const bool is_prompt, // whether it is prompt
254-
ThreadPool* tp) const {
274+
ThreadPool* tp,
275+
AllocatorPtr allocator) const {
255276
const ptrdiff_t packed_batch_stride =
256277
packed_qkv ? SafeInt<ptrdiff_t>(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size
257278
: SafeInt<ptrdiff_t>(0);
@@ -261,7 +282,9 @@ class GQAAttentionBase {
261282
const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H
262283

263284
if (!past_present_share_buffer) {
264-
memset(present_value, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
285+
memset((void*)present_value,
286+
0,
287+
batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
265288
}
266289

267290
const size_t loop_len = batch_size * num_heads_;
@@ -285,6 +308,13 @@ class GQAAttentionBase {
285308
unit_cost.bytes_loaded += bytes_to_copy_trans_all;
286309
unit_cost.bytes_stored += bytes_to_copy_trans_all;
287310

311+
size_t output_fp32_bytes = 0;
312+
if constexpr (std::is_same<T, MLFloat16>::value) {
313+
output_fp32_bytes = SafeInt<size_t>(sequence_length) * batch_size * num_heads_ * head_size * sizeof(float);
314+
}
315+
auto output_fp32 = allocator->Alloc(output_fp32_bytes);
316+
BufferUniquePtr scratch_buffer(output_fp32, BufferDeleter(allocator));
317+
288318
ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
289319
for (std::ptrdiff_t i = begin; i != end; ++i) {
290320
const size_t batch_index = i / num_heads_;
@@ -305,15 +335,39 @@ class GQAAttentionBase {
305335
i / kv_num_heads_factor);
306336
}
307337

308-
T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size;
309338
ptrdiff_t attention_probs_offset = SafeInt<ptrdiff_t>(sequence_length) * present_buffer_sequence_length * i;
310339

311-
math::GemmEx<T, ThreadPool>(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, 1.f, /*alpha*/
312-
attention_probs + attention_probs_offset,
313-
static_cast<int>(present_buffer_sequence_length), v, static_cast<int>(head_size),
314-
0.0f /*beta*/, output_current, static_cast<int>(hidden_size), nullptr);
340+
if constexpr (std::is_same<T, float>::value) {
341+
T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size;
342+
math::GemmEx<float, ThreadPool>(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen,
343+
1.f, /*alpha*/ attention_probs + attention_probs_offset,
344+
static_cast<int>(present_buffer_sequence_length), v,
345+
static_cast<int>(head_size), 0.0f /*beta*/, output_current,
346+
static_cast<int>(hidden_size), nullptr);
347+
} else {
348+
size_t bytes = head_size * total_seqlen * sizeof(float);
349+
auto v_fp32 = allocator->Alloc(bytes);
350+
BufferUniquePtr scratch_buffer(v_fp32, BufferDeleter(allocator));
351+
352+
float* v_fp32_ptr = static_cast<float*>(v_fp32);
353+
MlasConvertHalfToFloatBuffer(v, v_fp32_ptr, head_size * total_seqlen);
354+
355+
float* output_fp32_current = static_cast<float*>(output_fp32) +
356+
(batch_index * sequence_length * num_heads_ + head_index) * head_size;
357+
math::GemmEx<float, ThreadPool>(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen,
358+
1.f, /*alpha*/ attention_probs + attention_probs_offset,
359+
static_cast<int>(present_buffer_sequence_length), v_fp32_ptr,
360+
static_cast<int>(head_size), 0.0f /*beta*/, output_fp32_current,
361+
static_cast<int>(hidden_size), nullptr);
362+
}
315363
}
316364
});
365+
366+
if constexpr (std::is_same<T, MLFloat16>::value) {
367+
MlasConvertFloatToHalfBuffer(static_cast<float*>(output_fp32),
368+
output,
369+
SafeInt<size_t>(sequence_length) * batch_size * num_heads_ * head_size);
370+
}
317371
}
318372
};
319373

0 commit comments

Comments
 (0)