Skip to content

Commit

Permalink
eigen MLFloat16
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Sep 16, 2024
1 parent 0b60013 commit a5a6ae0
Showing 1 changed file with 91 additions and 6 deletions.
97 changes: 91 additions & 6 deletions onnxruntime/contrib_ops/cpu/bert/attention_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,21 @@ using onnxruntime::concurrency::ThreadPool;
namespace onnxruntime {
namespace contrib {

namespace {
template <typename T>
struct EigenType;

template <>
struct EigenType<float> {
using Type = float;
};

template <>
struct EigenType<MLFloat16> {
using Type = Eigen::half;
};
}

Check warning on line 26 in onnxruntime/contrib_ops/cpu/bert/attention_utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Anonymous namespace should be terminated with "// namespace" [readability/namespace] [5] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_utils.cc:26: Anonymous namespace should be terminated with "// namespace" [readability/namespace] [5]

// Reshape Q/K/V from BxSxD to BxSxNxH
inline Status Reshape_BSD_to_BSNH(Tensor* qkv,
int batch_size,
Expand Down Expand Up @@ -48,13 +63,43 @@ Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V dat
constexpr size_t element_size = sizeof(T);
ProcessBroadcastSpanFuncs add_funcs{
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.ScalarInput0<T>() + per_iter_bh.EigenInput1<T>().array();
// per_iter_bh.OutputEigen<T>() = per_iter_bh.ScalarInput0<T>() + per_iter_bh.EigenInput1<T>().array();
auto num_elements = per_iter_bh.NumOutputElements();

const auto* input_1 = reinterpret_cast<const typename EigenType<T>::Type*>(per_iter_bh.EigenInput1<T>().data());
ConstEigenVectorArrayMap<typename EigenType<T>::Type> input_1_vec_map(input_1, num_elements);

auto* output = reinterpret_cast<typename EigenType<T>::Type*>(per_iter_bh.OutputEigen<T>().data());
EigenVectorArrayMap<typename EigenType<T>::Type> output_vec_map(output, num_elements);

output_vec_map = input_1_vec_map + static_cast<typename EigenType<T>::Type>(per_iter_bh.ScalarInput0<T>());
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>().array() + per_iter_bh.ScalarInput1<T>();
// per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>().array() + per_iter_bh.ScalarInput1<T>();
auto num_elements = per_iter_bh.NumOutputElements();

const auto* input_0 = reinterpret_cast<const typename EigenType<T>::Type*>(per_iter_bh.EigenInput0<T>().data());
ConstEigenVectorArrayMap<typename EigenType<T>::Type> input_0_vec_map(input_0, num_elements);

auto* output = reinterpret_cast<typename EigenType<T>::Type*>(per_iter_bh.OutputEigen<T>().data());
EigenVectorArrayMap<typename EigenType<T>::Type> output_vec_map(output, num_elements);

output_vec_map = input_0_vec_map + static_cast<typename EigenType<T>::Type>(per_iter_bh.ScalarInput1<T>());
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>() + per_iter_bh.EigenInput1<T>();
// per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>() + per_iter_bh.EigenInput1<T>();
auto num_elements = per_iter_bh.NumOutputElements();

const auto* input_0 = reinterpret_cast<const typename EigenType<T>::Type*>(per_iter_bh.EigenInput0<T>().data());
ConstEigenVectorArrayMap<typename EigenType<T>::Type> input_0_vec_map(input_0, num_elements);

const auto* input_1 = reinterpret_cast<const typename EigenType<T>::Type*>(per_iter_bh.EigenInput1<T>().data());
ConstEigenVectorArrayMap<typename EigenType<T>::Type> input_1_vec_map(input_1, num_elements);

auto* output = reinterpret_cast<typename EigenType<T>::Type*>(per_iter_bh.OutputEigen<T>().data());
EigenVectorArrayMap<typename EigenType<T>::Type> output_vec_map(output, num_elements);

output_vec_map = input_0_vec_map + input_1_vec_map;
}}; // For element-wise add

// Allocate space for output of Q(BS, D) + bias(D)
Expand Down Expand Up @@ -114,6 +159,7 @@ Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V dat
return Status::OK();
}


// Add bias + reshape for each of Q/K/V
// This is used in decoder_with_past when the sequence length is 1
template <typename T>
Expand All @@ -129,16 +175,47 @@ Status AddBiasReshape(const Tensor* qkv, // Input: Q/K/V data - query is
OpKernelContext* context) {
// Note: the comments below will refer to Q's dimensions for simplicity
auto element_type = DataTypeImpl::GetType<T>();
using eigen_type = typename EigenType<T>::Type;
constexpr size_t element_size = sizeof(T);
ProcessBroadcastSpanFuncs add_funcs{
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.ScalarInput0<T>() + per_iter_bh.EigenInput1<T>().array();
//per_iter_bh.OutputEigen<T>() = per_iter_bh.ScalarInput0<T>() + per_iter_bh.EigenInput1<T>().array();

Check warning on line 182 in onnxruntime/contrib_ops/cpu/bert/attention_utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Should have a space between // and comment [whitespace/comments] [4] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_utils.cc:182: Should have a space between // and comment [whitespace/comments] [4]
auto num_elements = per_iter_bh.NumOutputElements();

const auto* input_1 = reinterpret_cast<const eigen_type*>(per_iter_bh.EigenInput1<T>().data());
ConstEigenVectorArrayMap<eigen_type> input_1_vec_map(input_1, num_elements);

auto* output = reinterpret_cast<eigen_type*>(per_iter_bh.OutputEigen<T>().data());
EigenVectorArrayMap<eigen_type> output_vec_map(output, num_elements);

output_vec_map = input_1_vec_map + static_cast<eigen_type>(per_iter_bh.ScalarInput0<T>());
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>().array() + per_iter_bh.ScalarInput1<T>();
// per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>().array() + per_iter_bh.ScalarInput1<T>();
auto num_elements = per_iter_bh.NumOutputElements();

const auto* input_0 = reinterpret_cast<const eigen_type*>(per_iter_bh.EigenInput0<T>().data());
ConstEigenVectorArrayMap<eigen_type> input_0_vec_map(input_0, num_elements);

auto* output = reinterpret_cast<eigen_type*>(per_iter_bh.OutputEigen<T>().data());
EigenVectorArrayMap<eigen_type> output_vec_map(output, num_elements);

output_vec_map = input_0_vec_map + static_cast<eigen_type>(per_iter_bh.ScalarInput1<T>());
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>() + per_iter_bh.EigenInput1<T>();
// per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>() + per_iter_bh.EigenInput1<T>();
auto num_elements = per_iter_bh.NumOutputElements();

const auto* input_0 = reinterpret_cast<const eigen_type*>(per_iter_bh.EigenInput0<T>().data());
ConstEigenVectorArrayMap<eigen_type> input_0_vec_map(input_0, num_elements);

const auto* input_1 = reinterpret_cast<const eigen_type*>(per_iter_bh.EigenInput1<T>().data());
ConstEigenVectorArrayMap<eigen_type> input_1_vec_map(input_1, num_elements);

auto* output = reinterpret_cast<eigen_type*>(per_iter_bh.OutputEigen<T>().data());
EigenVectorArrayMap<eigen_type> output_vec_map(output, num_elements);

output_vec_map = input_0_vec_map + input_1_vec_map;
}}; // For element-wise add

// Get Q's bias from combined bias
Expand Down Expand Up @@ -219,6 +296,10 @@ template Status MaybeTransposeToBNSHAndAddBias<float>(OpKernelContext* context,
int batch_size, int num_heads, int sequence_length, int head_size,
const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out);

template Status MaybeTransposeToBNSHAndAddBias<MLFloat16>(OpKernelContext* context, AllocatorPtr allocator,
int batch_size, int num_heads, int sequence_length, int head_size,

Check warning on line 300 in onnxruntime/contrib_ops/cpu/bert/attention_utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_utils.cc:300: Lines should be <= 120 characters long [whitespace/line_length] [2]
const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out);

Check warning on line 301 in onnxruntime/contrib_ops/cpu/bert/attention_utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cpu/bert/attention_utils.cc:301: Lines should be <= 120 characters long [whitespace/line_length] [2]

template <typename T>
Status MaybeTransposeToBNSH(AllocatorPtr allocator,
int batch_size, int num_heads, int sequence_length, int head_size,
Expand All @@ -242,5 +323,9 @@ template Status MaybeTransposeToBNSH<float>(AllocatorPtr allocator,
int batch_size, int num_heads, int sequence_length, int head_size,
const Tensor* in, OrtValue& out);

template Status MaybeTransposeToBNSH<MLFloat16>(AllocatorPtr allocator,
int batch_size, int num_heads, int sequence_length, int head_size,
const Tensor* in, OrtValue& out);

} // namespace contrib
} // namespace onnxruntime

0 comments on commit a5a6ae0

Please sign in to comment.