Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Sep 24, 2024
1 parent a9a07e1 commit 3c55948
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 191 deletions.
77 changes: 0 additions & 77 deletions onnxruntime/contrib_ops/cpu/bert/attention_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,6 @@ using onnxruntime::concurrency::ThreadPool;
namespace onnxruntime {
namespace contrib {

namespace {
template <typename T>
struct EigenType;

template <>
struct EigenType<float> {
using Type = float;
};

template <>
struct EigenType<MLFloat16> {
using Type = Eigen::half;
};
}

// Reshape Q/K/V from BxSxD to BxSxNxH
inline Status Reshape_BSD_to_BSNH(Tensor* qkv,
int batch_size,
Expand Down Expand Up @@ -64,42 +49,12 @@ Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V dat
ProcessBroadcastSpanFuncs add_funcs{
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<float>() = per_iter_bh.ScalarInput0<float>() + per_iter_bh.EigenInput1<float>().array();
// auto num_elements = per_iter_bh.NumOutputElements();

// const auto* input_1 = reinterpret_cast<const typename EigenType<T>::Type*>(per_iter_bh.EigenInput1<T>().data());
// ConstEigenVectorArrayMap<typename EigenType<T>::Type> input_1_vec_map(input_1, num_elements);

// auto* output = reinterpret_cast<typename EigenType<T>::Type*>(per_iter_bh.OutputEigen<T>().data());
// EigenVectorArrayMap<typename EigenType<T>::Type> output_vec_map(output, num_elements);

// output_vec_map = input_1_vec_map + static_cast<typename EigenType<T>::Type>(per_iter_bh.ScalarInput0<T>());
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<float>() = per_iter_bh.EigenInput0<float>().array() + per_iter_bh.ScalarInput1<float>();
// auto num_elements = per_iter_bh.NumOutputElements();

// const auto* input_0 = reinterpret_cast<const typename EigenType<T>::Type*>(per_iter_bh.EigenInput0<T>().data());
// ConstEigenVectorArrayMap<typename EigenType<T>::Type> input_0_vec_map(input_0, num_elements);

// auto* output = reinterpret_cast<typename EigenType<T>::Type*>(per_iter_bh.OutputEigen<T>().data());
// EigenVectorArrayMap<typename EigenType<T>::Type> output_vec_map(output, num_elements);

// output_vec_map = input_0_vec_map + static_cast<typename EigenType<T>::Type>(per_iter_bh.ScalarInput1<T>());
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<float>() = per_iter_bh.EigenInput0<float>() + per_iter_bh.EigenInput1<float>();
// auto num_elements = per_iter_bh.NumOutputElements();

// const auto* input_0 = reinterpret_cast<const typename EigenType<T>::Type*>(per_iter_bh.EigenInput0<T>().data());
// ConstEigenVectorArrayMap<typename EigenType<T>::Type> input_0_vec_map(input_0, num_elements);

// const auto* input_1 = reinterpret_cast<const typename EigenType<T>::Type*>(per_iter_bh.EigenInput1<T>().data());
// ConstEigenVectorArrayMap<typename EigenType<T>::Type> input_1_vec_map(input_1, num_elements);

// auto* output = reinterpret_cast<typename EigenType<T>::Type*>(per_iter_bh.OutputEigen<T>().data());
// EigenVectorArrayMap<typename EigenType<T>::Type> output_vec_map(output, num_elements);

// output_vec_map = input_0_vec_map + input_1_vec_map;
}}; // For element-wise add

// Allocate space for output of Q(BS, D) + bias(D)
Expand Down Expand Up @@ -159,7 +114,6 @@ Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V dat
return Status::OK();
}


// Add bias + reshape for each of Q/K/V
// This is used in decoder_with_past when the sequence length is 1
template <typename T>
Expand All @@ -175,47 +129,16 @@ Status AddBiasReshape(const Tensor* qkv, // Input: Q/K/V data - query is
OpKernelContext* context) {
// Note: the comments below will refer to Q's dimensions for simplicity
auto element_type = DataTypeImpl::GetType<T>();
//using eigen_type = typename EigenType<T>::Type;
constexpr size_t element_size = sizeof(T);
ProcessBroadcastSpanFuncs add_funcs{
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<float>() = per_iter_bh.ScalarInput0<float>() + per_iter_bh.EigenInput1<float>().array();
// auto num_elements = per_iter_bh.NumOutputElements();

// const auto* input_1 = reinterpret_cast<const eigen_type*>(per_iter_bh.EigenInput1<T>().data());
// ConstEigenVectorArrayMap<eigen_type> input_1_vec_map(input_1, num_elements);

// auto* output = reinterpret_cast<eigen_type*>(per_iter_bh.OutputEigen<T>().data());
// EigenVectorArrayMap<eigen_type> output_vec_map(output, num_elements);

// output_vec_map = input_1_vec_map + static_cast<eigen_type>(per_iter_bh.ScalarInput0<T>());
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<float>() = per_iter_bh.EigenInput0<float>().array() + per_iter_bh.ScalarInput1<float>();
// auto num_elements = per_iter_bh.NumOutputElements();

// const auto* input_0 = reinterpret_cast<const eigen_type*>(per_iter_bh.EigenInput0<T>().data());
// ConstEigenVectorArrayMap<eigen_type> input_0_vec_map(input_0, num_elements);

// auto* output = reinterpret_cast<eigen_type*>(per_iter_bh.OutputEigen<T>().data());
// EigenVectorArrayMap<eigen_type> output_vec_map(output, num_elements);

// output_vec_map = input_0_vec_map + static_cast<eigen_type>(per_iter_bh.ScalarInput1<T>());
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<float>() = per_iter_bh.EigenInput0<float>() + per_iter_bh.EigenInput1<float>();
// auto num_elements = per_iter_bh.NumOutputElements();

// const auto* input_0 = reinterpret_cast<const eigen_type*>(per_iter_bh.EigenInput0<T>().data());
// ConstEigenVectorArrayMap<eigen_type> input_0_vec_map(input_0, num_elements);

// const auto* input_1 = reinterpret_cast<const eigen_type*>(per_iter_bh.EigenInput1<T>().data());
// ConstEigenVectorArrayMap<eigen_type> input_1_vec_map(input_1, num_elements);

// auto* output = reinterpret_cast<eigen_type*>(per_iter_bh.OutputEigen<T>().data());
// EigenVectorArrayMap<eigen_type> output_vec_map(output, num_elements);

// output_vec_map = input_0_vec_map + input_1_vec_map;
}}; // For element-wise add

// Get Q's bias from combined bias
Expand Down
63 changes: 41 additions & 22 deletions onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ class GQAAttentionBase {

// Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V;
ComputeVxAttentionScore(output->MutableData<T>(), static_cast<float*>(attention_probs), v, seqlens_k->Data<int32_t>(),
ComputeVxAttentionScore(output->MutableData<T>(), static_cast<float*>(attention_probs), v,
seqlens_k->Data<int32_t>(),
batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size,
hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv,
is_prompt, tp, allocator);
Expand Down Expand Up @@ -132,7 +133,9 @@ class GQAAttentionBase {
const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H

if (!past_present_share_buffer) {
memset((void*)present_key, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
memset((void*)present_key,

Check warning on line 136 in onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use reinterpret_cast<void*>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h:136: Using C-style cast. Use reinterpret_cast<void*>(...) instead [readability/casting] [4]
0,
batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
}

const size_t loop_len = batch_size * num_heads_;
Expand Down Expand Up @@ -193,8 +196,8 @@ class GQAAttentionBase {

if constexpr (std::is_same<T, float>::value) {
math::GemmEx<float, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q,
static_cast<int>(head_size), k, static_cast<int>(head_size), 0.0f /*bata*/, output,
static_cast<int>(present_buffer_sequence_length), nullptr);
static_cast<int>(head_size), k, static_cast<int>(head_size), 0.0f /*bata*/,
output, static_cast<int>(present_buffer_sequence_length), nullptr);
} else {
size_t bytes = head_size * (sequence_length + total_seqlen) * sizeof(float);
auto q_k_fp32 = allocator->Alloc(bytes);
Expand Down Expand Up @@ -254,7 +257,7 @@ class GQAAttentionBase {

template <typename T>
void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH
const float* attention_probs, // Attention probs with size BxNxSxT
const float* attention_probs, // Attention probs with size BxNxSxT
const T* V, // V value with size BxN_kvxSxH
const int32_t* seqlens_k, // total - 1 sequence lengths tensor
const size_t batch_size, // batch size
Expand All @@ -279,7 +282,9 @@ class GQAAttentionBase {
const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H

if (!past_present_share_buffer) {
memset((void*)present_value, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
memset((void*)present_value,

Check warning on line 285 in onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Using C-style cast. Use reinterpret_cast<void*>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h:285: Using C-style cast. Use reinterpret_cast<void*>(...) instead [readability/casting] [4]
0,
batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
}

const size_t loop_len = batch_size * num_heads_;
Expand All @@ -303,6 +308,13 @@ class GQAAttentionBase {
unit_cost.bytes_loaded += bytes_to_copy_trans_all;
unit_cost.bytes_stored += bytes_to_copy_trans_all;

size_t output_fp32_bytes = 0;
if constexpr (std::is_same<T, MLFloat16>::value) {
output_fp32_bytes = SafeInt<size_t>(sequence_length) * batch_size * num_heads_ * head_size * sizeof(float);
}
auto output_fp32 = allocator->Alloc(output_fp32_bytes);
BufferUniquePtr scratch_buffer(output_fp32, BufferDeleter(allocator));

ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const size_t batch_index = i / num_heads_;
Expand All @@ -323,32 +335,39 @@ class GQAAttentionBase {
i / kv_num_heads_factor);
}

T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size;
ptrdiff_t attention_probs_offset = SafeInt<ptrdiff_t>(sequence_length) * present_buffer_sequence_length * i;

if constexpr (std::is_same<T, float>::value) {
math::GemmEx<float, ThreadPool>(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, 1.f, /*alpha*/
attention_probs + attention_probs_offset,
static_cast<int>(present_buffer_sequence_length), v, static_cast<int>(head_size),
0.0f /*beta*/, output_current, static_cast<int>(hidden_size), nullptr);
T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size;
math::GemmEx<float, ThreadPool>(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen,
1.f, /*alpha*/ attention_probs + attention_probs_offset,
static_cast<int>(present_buffer_sequence_length), v,
static_cast<int>(head_size), 0.0f /*beta*/, output_current,
static_cast<int>(hidden_size), nullptr);
} else {
size_t bytes = head_size * (sequence_length + total_seqlen) * sizeof(float);
auto v_o_fp32 = allocator->Alloc(bytes);
BufferUniquePtr scratch_buffer(v_o_fp32, BufferDeleter(allocator));
size_t bytes = head_size * total_seqlen * sizeof(float);
auto v_fp32 = allocator->Alloc(bytes);
BufferUniquePtr scratch_buffer(v_fp32, BufferDeleter(allocator));

float* v_fp32_ptr = static_cast<float*>(v_o_fp32);
float* v_fp32_ptr = static_cast<float*>(v_fp32);
MlasConvertHalfToFloatBuffer(v, v_fp32_ptr, head_size * total_seqlen);

float* output_fp32 = v_fp32_ptr + head_size * total_seqlen;
math::GemmEx<float, ThreadPool>(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, 1.f, /*alpha*/
attention_probs + attention_probs_offset,
static_cast<int>(present_buffer_sequence_length), v_fp32_ptr, static_cast<int>(head_size),
0.0f /*beta*/, output_fp32, static_cast<int>(hidden_size), nullptr);

MlasConvertFloatToHalfBuffer(output_fp32, output_current, head_size * sequence_length);
float* output_fp32_current = static_cast<float*>(output_fp32) +
(batch_index * sequence_length * num_heads_ + head_index) * head_size;
math::GemmEx<float, ThreadPool>(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen,
1.f, /*alpha*/ attention_probs + attention_probs_offset,
static_cast<int>(present_buffer_sequence_length), v_fp32_ptr,
static_cast<int>(head_size), 0.0f /*beta*/, output_fp32_current,
static_cast<int>(hidden_size), nullptr);
}
}
});

if constexpr (std::is_same<T, MLFloat16>::value) {
MlasConvertFloatToHalfBuffer(static_cast<float*>(output_fp32),
output,
SafeInt<size_t>(sequence_length) * batch_size * num_heads_ * head_size);
}
}
};

Expand Down
22 changes: 11 additions & 11 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ namespace onnxruntime {
namespace contrib {

// These ops are internal-only, so register outside of onnx
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
GroupQueryAttention, \
kMSDomain, \
1, \
T, \
kCpuExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("M", DataTypeImpl::GetTensorType<int32_t>()), \
GroupQueryAttention<T>);
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
GroupQueryAttention, \
kMSDomain, \
1, \
T, \
kCpuExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("M", DataTypeImpl::GetTensorType<int32_t>()), \
GroupQueryAttention<T>);

REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
Expand Down
Loading

0 comments on commit 3c55948

Please sign in to comment.