@@ -75,7 +75,7 @@ class GQAAttentionBase {
75
75
int seqlen_present_kv_cache = static_cast <int >(present_key->Shape ().GetDims ()[2 ]);
76
76
77
77
// Compute the attention score.
78
- size_t bytes = SafeInt<size_t >(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * sizeof (T );
78
+ size_t bytes = SafeInt<size_t >(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * sizeof (float );
79
79
auto attention_probs = allocator->Alloc (bytes);
80
80
BufferUniquePtr scratch_buffer (attention_probs, BufferDeleter (allocator));
81
81
@@ -87,16 +87,17 @@ class GQAAttentionBase {
87
87
bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data;
88
88
89
89
const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K;
90
- ComputeAttentionProbs<T>(static_cast <T *>(attention_probs), Q, k, seqlens_k->Data <int32_t >(), batch_size,
90
+ ComputeAttentionProbs<T>(static_cast <float *>(attention_probs), Q, k, seqlens_k->Data <int32_t >(), batch_size,
91
91
sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data,
92
- present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp);
92
+ present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator );
93
93
94
94
// Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
95
95
const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V;
96
- ComputeVxAttentionScore (output->MutableData <T>(), static_cast <T*>(attention_probs), v, seqlens_k->Data <int32_t >(),
96
+ ComputeVxAttentionScore (output->MutableData <T>(), static_cast <float *>(attention_probs), v,
97
+ seqlens_k->Data <int32_t >(),
97
98
batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size,
98
99
hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv,
99
- is_prompt, tp);
100
+ is_prompt, tp, allocator );
100
101
101
102
return Status::OK ();
102
103
}
@@ -106,7 +107,7 @@ class GQAAttentionBase {
106
107
// attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T)
107
108
// attention_probs(B, N, S, T) = Softmax(attention_probs)
108
109
template <typename T>
109
- void ComputeAttentionProbs (T * attention_probs, // output buffer with size BxNxSxT
110
+ void ComputeAttentionProbs (float * attention_probs, // output buffer with size BxNxSxT
110
111
const T* Q, // Q data. Its size is BxNxSxH
111
112
const T* K, // k data. Its size is BxNxLxH
112
113
const int32_t * seqlens_k, // total - 1 sequence lengths tensor
@@ -120,7 +121,8 @@ class GQAAttentionBase {
120
121
const bool past_present_share_buffer, // whether present key and value share the same buffer
121
122
const bool packed_qkv, // whether Q, K, V are packed
122
123
const bool is_prompt, // whether it is prompt
123
- ThreadPool* tp) const { // thread pool
124
+ ThreadPool* tp, // thread pool
125
+ AllocatorPtr allocator) const { // allocator for temporary buffer
124
126
const ptrdiff_t packed_batch_stride =
125
127
packed_qkv ? SafeInt<ptrdiff_t >(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size
126
128
: SafeInt<ptrdiff_t >(0 );
@@ -131,7 +133,9 @@ class GQAAttentionBase {
131
133
const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H
132
134
133
135
if (!past_present_share_buffer) {
134
- memset (present_key, 0 , batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof (T));
136
+ memset ((void *)present_key,
137
+ 0 ,
138
+ batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof (T));
135
139
}
136
140
137
141
const size_t loop_len = batch_size * num_heads_;
@@ -164,7 +168,7 @@ class GQAAttentionBase {
164
168
const size_t past_chunk_length = past_seqlen * head_size;
165
169
166
170
const ptrdiff_t output_offset = SafeInt<ptrdiff_t >(i) * sequence_length * present_buffer_sequence_length;
167
- T * output = attention_probs + output_offset;
171
+ float * output = attention_probs + output_offset;
168
172
169
173
const T* k;
170
174
if (packed_qkv) {
@@ -190,12 +194,28 @@ class GQAAttentionBase {
190
194
q = Q + q_input_chunk_length * i;
191
195
}
192
196
193
- math::GemmEx<T, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q,
194
- static_cast <int >(head_size), k, static_cast <int >(head_size), 0 .0f /* bata*/ , output,
195
- static_cast <int >(present_buffer_sequence_length), nullptr );
197
+ if constexpr (std::is_same<T, float >::value) {
198
+ math::GemmEx<float , ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q,
199
+ static_cast <int >(head_size), k, static_cast <int >(head_size), 0 .0f /* bata*/ ,
200
+ output, static_cast <int >(present_buffer_sequence_length), nullptr );
201
+ } else {
202
+ size_t bytes = head_size * (sequence_length + total_seqlen) * sizeof (float );
203
+ auto q_k_fp32 = allocator->Alloc (bytes);
204
+ BufferUniquePtr scratch_buffer (q_k_fp32, BufferDeleter (allocator));
205
+
206
+ float * q_fp32 = static_cast <float *>(q_k_fp32);
207
+ MlasConvertHalfToFloatBuffer (q, q_fp32, head_size * sequence_length);
208
+
209
+ float * k_fp32 = q_fp32 + head_size * sequence_length;
210
+ MlasConvertHalfToFloatBuffer (k, k_fp32, head_size * total_seqlen);
211
+
212
+ math::GemmEx<float , ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q_fp32,
213
+ static_cast <int >(head_size), k_fp32, static_cast <int >(head_size), 0 .0f /* bata*/ ,
214
+ output, static_cast <int >(present_buffer_sequence_length), nullptr );
215
+ }
196
216
197
217
// compute Softmax
198
- T * output_softmax = output;
218
+ float * output_softmax = output;
199
219
for (size_t seq = 0 ; seq < sequence_length; seq++) {
200
220
size_t seq_causal_length = past_seqlen + seq + 1 ;
201
221
if (local_window_size_ > 0 && seq_causal_length > static_cast <size_t >(local_window_size_) + 1 ) {
@@ -237,7 +257,7 @@ class GQAAttentionBase {
237
257
238
258
template <typename T>
239
259
void ComputeVxAttentionScore (T* output, // buffer for the result with size BxSxNxH
240
- const T * attention_probs, // Attention probs with size BxNxSxT
260
+ const float * attention_probs, // Attention probs with size BxNxSxT
241
261
const T* V, // V value with size BxN_kvxSxH
242
262
const int32_t * seqlens_k, // total - 1 sequence lengths tensor
243
263
const size_t batch_size, // batch size
@@ -251,7 +271,8 @@ class GQAAttentionBase {
251
271
const bool past_present_share_buffer, // whether present key and value share the same buffer
252
272
const bool packed_qkv, // whether Q, K, V are packed
253
273
const bool is_prompt, // whether it is prompt
254
- ThreadPool* tp) const {
274
+ ThreadPool* tp,
275
+ AllocatorPtr allocator) const {
255
276
const ptrdiff_t packed_batch_stride =
256
277
packed_qkv ? SafeInt<ptrdiff_t >(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size
257
278
: SafeInt<ptrdiff_t >(0 );
@@ -261,7 +282,9 @@ class GQAAttentionBase {
261
282
const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H
262
283
263
284
if (!past_present_share_buffer) {
264
- memset (present_value, 0 , batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof (T));
285
+ memset ((void *)present_value,
286
+ 0 ,
287
+ batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof (T));
265
288
}
266
289
267
290
const size_t loop_len = batch_size * num_heads_;
@@ -285,6 +308,13 @@ class GQAAttentionBase {
285
308
unit_cost.bytes_loaded += bytes_to_copy_trans_all;
286
309
unit_cost.bytes_stored += bytes_to_copy_trans_all;
287
310
311
+ size_t output_fp32_bytes = 0 ;
312
+ if constexpr (std::is_same<T, MLFloat16>::value) {
313
+ output_fp32_bytes = SafeInt<size_t >(sequence_length) * batch_size * num_heads_ * head_size * sizeof (float );
314
+ }
315
+ auto output_fp32 = allocator->Alloc (output_fp32_bytes);
316
+ BufferUniquePtr scratch_buffer (output_fp32, BufferDeleter (allocator));
317
+
288
318
ThreadPool::TryParallelFor (tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
289
319
for (std::ptrdiff_t i = begin; i != end; ++i) {
290
320
const size_t batch_index = i / num_heads_;
@@ -305,15 +335,39 @@ class GQAAttentionBase {
305
335
i / kv_num_heads_factor);
306
336
}
307
337
308
- T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size;
309
338
ptrdiff_t attention_probs_offset = SafeInt<ptrdiff_t >(sequence_length) * present_buffer_sequence_length * i;
310
339
311
- math::GemmEx<T, ThreadPool>(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, 1 .f , /* alpha*/
312
- attention_probs + attention_probs_offset,
313
- static_cast <int >(present_buffer_sequence_length), v, static_cast <int >(head_size),
314
- 0 .0f /* beta*/ , output_current, static_cast <int >(hidden_size), nullptr );
340
+ if constexpr (std::is_same<T, float >::value) {
341
+ T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size;
342
+ math::GemmEx<float , ThreadPool>(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen,
343
+ 1 .f , /* alpha*/ attention_probs + attention_probs_offset,
344
+ static_cast <int >(present_buffer_sequence_length), v,
345
+ static_cast <int >(head_size), 0 .0f /* beta*/ , output_current,
346
+ static_cast <int >(hidden_size), nullptr );
347
+ } else {
348
+ size_t bytes = head_size * total_seqlen * sizeof (float );
349
+ auto v_fp32 = allocator->Alloc (bytes);
350
+ BufferUniquePtr scratch_buffer (v_fp32, BufferDeleter (allocator));
351
+
352
+ float * v_fp32_ptr = static_cast <float *>(v_fp32);
353
+ MlasConvertHalfToFloatBuffer (v, v_fp32_ptr, head_size * total_seqlen);
354
+
355
+ float * output_fp32_current = static_cast <float *>(output_fp32) +
356
+ (batch_index * sequence_length * num_heads_ + head_index) * head_size;
357
+ math::GemmEx<float , ThreadPool>(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen,
358
+ 1 .f , /* alpha*/ attention_probs + attention_probs_offset,
359
+ static_cast <int >(present_buffer_sequence_length), v_fp32_ptr,
360
+ static_cast <int >(head_size), 0 .0f /* beta*/ , output_fp32_current,
361
+ static_cast <int >(hidden_size), nullptr );
362
+ }
315
363
}
316
364
});
365
+
366
+ if constexpr (std::is_same<T, MLFloat16>::value) {
367
+ MlasConvertFloatToHalfBuffer (static_cast <float *>(output_fp32),
368
+ output,
369
+ SafeInt<size_t >(sequence_length) * batch_size * num_heads_ * head_size);
370
+ }
317
371
}
318
372
};
319
373
0 commit comments