Skip to content
Open
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
2a2ce5e
Trial
hariharans29 Jan 23, 2026
e7d9b70
Trial 2
hariharans29 Jan 24, 2026
b117de5
Trial 3
hariharans29 Jan 24, 2026
1fb9ceb
Trial 4
hariharans29 Jan 24, 2026
2519272
Trial 4 fix
hariharans29 Jan 24, 2026
29e3740
Trial 5
hariharans29 Jan 24, 2026
bc6528a
Update onnxruntime/contrib_ops/cpu/attnlstm/bahdanau_attention.cc
hariharans29 Jan 24, 2026
b9b9773
Update onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.h
hariharans29 Jan 24, 2026
007f3be
Update onnxruntime/contrib_ops/cpu/bert/attention.cc
hariharans29 Jan 24, 2026
c13c293
Update onnxruntime/contrib_ops/cpu/bert/attention_base.h
hariharans29 Jan 24, 2026
6ebb592
Update onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
hariharans29 Jan 24, 2026
615b31e
Update onnxruntime/contrib_ops/cpu/word_conv_embedding.h
hariharans29 Jan 24, 2026
f17ef4d
Update onnxruntime/contrib_ops/cpu/transformers/generation_device_hel…
hariharans29 Jan 24, 2026
8c15dac
Update onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h
hariharans29 Jan 24, 2026
e0240c2
Update onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h
hariharans29 Jan 24, 2026
990f185
Update onnxruntime/contrib_ops/cpu/word_conv_embedding.h
hariharans29 Jan 24, 2026
d03c491
Update onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc
hariharans29 Jan 24, 2026
a475a66
Copilot comments
hariharans29 Jan 24, 2026
c2d0d2a
Update onnxruntime/contrib_ops/cpu/bert/attention.cc
hariharans29 Jan 24, 2026
14994d9
Silence some build failures
hariharans29 Jan 24, 2026
93bd45a
Build failures
hariharans29 Jan 24, 2026
840766d
Fix builds
hariharans29 Jan 24, 2026
e0affbb
Fix builds
hariharans29 Jan 24, 2026
9329b0d
Fix builds
hariharans29 Jan 24, 2026
bb4e57b
Update onnxruntime/contrib_ops/cpu/cdist.h
hariharans29 Jan 24, 2026
f32c0a9
Update onnxruntime/contrib_ops/cpu/cdist.h
hariharans29 Jan 24, 2026
fda3cb0
Fix builds
hariharans29 Jan 24, 2026
9579426
Merge branch 'hari/kleidiai_opt_out' of https://github.com/microsoft/…
hariharans29 Jan 24, 2026
e04d660
Fix builds
hariharans29 Jan 24, 2026
ba87f69
Fix builds
hariharans29 Jan 24, 2026
fca4503
Fix CUDA builds
hariharans29 Jan 24, 2026
7bd8b6f
Fix cuda builds
hariharans29 Jan 25, 2026
d37b146
Fix builds
hariharans29 Jan 25, 2026
b68029c
Fix builds
hariharans29 Jan 25, 2026
e398270
Fix builds
hariharans29 Jan 25, 2026
1a8b523
Fix builds
hariharans29 Jan 25, 2026
ed1dc8a
Fix builds
hariharans29 Jan 25, 2026
03a3eb2
Update onnxruntime/contrib_ops/cpu/nchwc_ops.h
hariharans29 Jan 25, 2026
b325dc9
Update onnxruntime/contrib_ops/cpu/transformers/generation_device_hel…
hariharans29 Jan 25, 2026
57d20a0
Update onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h
hariharans29 Jan 25, 2026
0a98bb4
Update onnxruntime/contrib_ops/cpu/transformers/generation_device_hel…
hariharans29 Jan 25, 2026
1071268
Fix builds
hariharans29 Jan 25, 2026
961c6bb
Merge branch 'hari/kleidiai_opt_out' of https://github.com/microsoft/…
hariharans29 Jan 25, 2026
243bb02
Update onnxruntime/contrib_ops/cpu/transformers/generation_device_hel…
hariharans29 Jan 25, 2026
d9e6e6d
Update onnxruntime/contrib_ops/cpu/transformers/generation_device_hel…
hariharans29 Jan 25, 2026
d3a862e
Fix builds
hariharans29 Jan 25, 2026
8002f01
Merge branch 'hari/kleidiai_opt_out' of https://github.com/microsoft/…
hariharans29 Jan 25, 2026
14277c1
Copilot comments
hariharans29 Jan 25, 2026
db7d4a8
Fix builds
hariharans29 Jan 25, 2026
5c9f593
Fix some TODOs
hariharans29 Jan 26, 2026
1dc6525
Merge remote-tracking branch 'origin' into hari/kleidiai_opt_out
hariharans29 Jan 26, 2026
3bcfaa1
Format changes
hariharans29 Jan 26, 2026
c6cc01b
Moref fixes
hariharans29 Jan 26, 2026
e4438e0
More fixes
hariharans29 Jan 26, 2026
69c0931
More fixes
hariharans29 Jan 26, 2026
1c0606e
Fix training builds
hariharans29 Jan 26, 2026
3b0b3ea
More fixes
hariharans29 Jan 27, 2026
485ce6e
Plumb through kernel selector logic to MLAS from QNBitGemm + Copliot …
hariharans29 Jan 31, 2026
fa823e8
More QNBitGemm fixes
hariharans29 Jan 31, 2026
3d2dc08
More fixes
hariharans29 Jan 31, 2026
f75d9f6
More fixes
hariharans29 Jan 31, 2026
66cbce8
Test fixes
hariharans29 Jan 31, 2026
035bca2
Fix
hariharans29 Jan 31, 2026
bc43948
Fix
hariharans29 Jan 31, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,12 @@ static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas
// - "1": Use LUT based GEMM when available.
static const char* const kOrtSessionOptionsMlasLutGemm = "mlas.use_lut_gemm";

// Use KleidiAI kernels in MLAS if available.
// Option values:
// - "0": Use KleidiAI kernels when available. [DEFAULT]
// - "1": Disable KleidiAI kernels even if available.
static const char* const kOrtSessionOptionsMlasDisableKleidiai = "mlas.disable_kleidiai";

// When converting DQ + MatMul -> MatMulNBits, the accuracy level of the MatMulNBits is controlled by this option.
// Refer to MatMulNBits op schema for more details.
// If not provided, default is 4.
Expand Down
10 changes: 6 additions & 4 deletions onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ template <typename T>
AttentionWrapper<T>::AttentionWrapper(AllocatorPtr alloc, const logging::Logger& logger,
int batch_size, int attn_context_depth, int attn_layer_depth,
int inner_cell_hidden_size, bool has_attn_layer,
const IAttentionMechanism<T>& attention_mechanism, concurrency::ThreadPool* threadpool)
const IAttentionMechanism<T>& attention_mechanism, concurrency::ThreadPool* threadpool,
const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config)
: allocator_(alloc),
logger_(logger),
batch_size_(batch_size),
Expand All @@ -28,7 +29,8 @@ AttentionWrapper<T>::AttentionWrapper(AllocatorPtr alloc, const logging::Logger&
inner_cell_hidden_size_(inner_cell_hidden_size),
has_attn_layer_(has_attn_layer),
attention_mechanism_(attention_mechanism),
ttp_(threadpool) {
ttp_(threadpool),
mlas_backend_kernel_selector_config_(mlas_backend_kernel_selector_config) {
auto mem_max_steps = attention_mechanism_.GetMaxMemorySteps();
prev_alignments_ = Allocate(allocator_, batch_size_ * mem_max_steps, prev_alignments_ptr_, true);
alignments_ = Allocate(allocator_, batch_size_ * mem_max_steps, alignments_ptr_, true);
Expand All @@ -45,7 +47,7 @@ void AttentionWrapper<T>::ProcessOutput(const gsl::span<const T>& rnn_cell_outpu
batch_size_, attn_layer_depth_, inner_cell_hidden_size_, T{1.0},
rnn_cell_output.data(), inner_cell_hidden_size_,
attn_layer_cell_weights_.data(), attn_layer_depth_, T{0.0},
attn_states_.data(), attn_layer_depth_, ttp_);
attn_states_.data(), attn_layer_depth_, ttp_, mlas_backend_kernel_selector_config_);
}

// Get the context which is calculated within attention mechanism.
Expand All @@ -62,7 +64,7 @@ void AttentionWrapper<T>::ProcessOutput(const gsl::span<const T>& rnn_cell_outpu
batch_size_, attn_layer_depth_, attn_context_depth_, T{1.0},
attn_context_.data(), attn_context_depth_,
attn_layer_attn_weights_.data(), attn_layer_depth_, T{1.0},
attn_states_.data(), attn_layer_depth_, ttp_);
attn_states_.data(), attn_layer_depth_, ttp_, mlas_backend_kernel_selector_config_);
}
}

Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/contrib_ops/cpu/attnlstm/attention_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "core/common/logging/logging.h"
#include "core/framework/allocator.h"
#include "core/platform/threadpool.h"
#include "core/mlas/inc/mlas.h"

namespace onnxruntime {
namespace contrib {
Expand All @@ -23,7 +24,8 @@ class AttentionWrapper {
int attn_layer_depth,
int inner_cell_hidden_size,
bool has_attn_layer,
const IAttentionMechanism<T>& attention_mechanism, concurrency::ThreadPool* threadpool);
const IAttentionMechanism<T>& attention_mechanism, concurrency::ThreadPool* threadpool,
const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config);

virtual ~AttentionWrapper() = default;

Expand Down Expand Up @@ -71,6 +73,8 @@ class AttentionWrapper {

const IAttentionMechanism<T>& attention_mechanism_;
concurrency::ThreadPool* ttp_;

const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config_;
};

} // namespace contrib
Expand Down
11 changes: 6 additions & 5 deletions onnxruntime/contrib_ops/cpu/attnlstm/bahdanau_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ namespace contrib {
template <typename T>
BahdanauAttention<T>::BahdanauAttention(AllocatorPtr allocator, const logging::Logger& logger,
int batch_size, int max_memory_step, int memory_depth,
int query_depth, int attn_depth, bool normalize, concurrency::ThreadPool* threadpool)
: allocator_(allocator), logger_(logger), batch_size_(batch_size), max_memory_steps_(max_memory_step), memory_depth_(memory_depth), query_depth_(query_depth), attn_depth_(attn_depth), normalize_(normalize), ttp_(threadpool) {
int query_depth, int attn_depth, bool normalize, concurrency::ThreadPool* threadpool,
const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config)
: allocator_(allocator), logger_(logger), batch_size_(batch_size), max_memory_steps_(max_memory_step), memory_depth_(memory_depth), query_depth_(query_depth), attn_depth_(attn_depth), normalize_(normalize), ttp_(threadpool), mlas_backend_kernel_selector_config_(mlas_backend_kernel_selector_config) {
values_ = Allocate(allocator_, batch_size_ * max_memory_steps_ * memory_depth_, values_ptr_, true);
keys_ = Allocate(allocator_, batch_size_ * max_memory_steps_ * attn_depth_, keys_ptr_, true);
processed_query_ = Allocate(allocator_, batch_size_ * attn_depth_, processed_query_ptr_, true);
Expand Down Expand Up @@ -80,7 +81,7 @@ void BahdanauAttention<T>::PrepareMemory(
batch_size_ * max_memory_steps_, attn_depth_, memory_depth_, T{1.0},
memory.data(), memory_depth_,
memory_layer_weights_.data(), attn_depth_, T{0.0},
keys_.data(), attn_depth_, ttp_);
keys_.data(), attn_depth_, ttp_, mlas_backend_kernel_selector_config_);
}

template <typename T>
Expand Down Expand Up @@ -123,7 +124,7 @@ void BahdanauAttention<T>::Compute(
batch_size_, attn_depth_, query_depth_, T{1.0},
queries.data(), query_depth_,
query_layer_weights_.data(), attn_depth_, T{0.0},
processed_query_.data(), attn_depth_, ttp_);
processed_query_.data(), attn_depth_, ttp_, mlas_backend_kernel_selector_config_);

std::fill(aligns.begin(), aligns.end(), T{});

Expand Down Expand Up @@ -154,7 +155,7 @@ void BahdanauAttention<T>::Compute(
1, memory_depth_, max_memory_steps_, T{1.0},
alignments, max_memory_steps_,
values.data(), memory_depth_, T{0.0},
outspan.data(), memory_depth_, ttp_);
outspan.data(), memory_depth_, ttp_, mlas_backend_kernel_selector_config_);
}
}

Expand Down
7 changes: 6 additions & 1 deletion onnxruntime/contrib_ops/cpu/attnlstm/bahdanau_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#include "attention_mechanism.h"

#include "core/mlas/inc/mlas.h"

namespace onnxruntime {
namespace contrib {

Expand All @@ -23,7 +25,8 @@ class BahdanauAttention : public IAttentionMechanism<T> {
int memory_depth,
int query_depth,
int attn_depth,
bool normalize, concurrency::ThreadPool* threadpool);
bool normalize, concurrency::ThreadPool* threadpool,
const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config);

void SetWeights(
const gsl::span<const T>& attn_weights,
Expand Down Expand Up @@ -78,6 +81,8 @@ class BahdanauAttention : public IAttentionMechanism<T> {

bool normalize_;
concurrency::ThreadPool* ttp_;

const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config_;
};

} // namespace contrib
Expand Down
18 changes: 9 additions & 9 deletions onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
memory_depth,
query_depth,
am_attn_size,
false, thread_pool);
false, thread_pool, &mlas_backend_kernel_selector_config_);

fam.SetWeights(
FirstHalfSpan(am_v_weights.DataAsSpan<T>()),
Expand All @@ -264,7 +264,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
attn_layer_depth,
hidden_size_,
has_attention_layer,
fam, thread_pool);
fam, thread_pool, &mlas_backend_kernel_selector_config_);
faw.SetWeights(FirstHalfSpan(attn_layer_weights_span));

UniDirectionalAttnLstm<T> fw(
Expand All @@ -275,7 +275,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
activation_funcs_.Entries()[0],
activation_funcs_.Entries()[1],
activation_funcs_.Entries()[2],
clip_, thread_pool);
clip_, thread_pool, &mlas_backend_kernel_selector_config_);

BahdanauAttention<T> bam(
alloc,
Expand All @@ -285,7 +285,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
memory_depth,
query_depth,
am_attn_size,
false, thread_pool);
false, thread_pool, &mlas_backend_kernel_selector_config_);
bam.SetWeights(
SecondHalfSpan(am_v_weights.DataAsSpan<T>()),
SecondHalfSpan(am_query_layer_weights.DataAsSpan<T>()),
Expand All @@ -300,7 +300,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
attn_layer_depth,
hidden_size_,
has_attention_layer,
bam, thread_pool);
bam, thread_pool, &mlas_backend_kernel_selector_config_);
baw.SetWeights(SecondHalfSpan(attn_layer_weights_span));

UniDirectionalAttnLstm<T> bw(
Expand All @@ -311,7 +311,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
activation_funcs_.Entries()[3],
activation_funcs_.Entries()[4],
activation_funcs_.Entries()[5],
clip_, thread_pool);
clip_, thread_pool, &mlas_backend_kernel_selector_config_);

fw.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1);
bw.Compute(input, sequence_lens_span, num_directions_, input_weights_2, hidden_weights_2, output_2, hidden_output_2, last_cell_2);
Expand All @@ -325,7 +325,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
memory_depth,
query_depth,
am_attn_size,
false, thread_pool);
false, thread_pool, &mlas_backend_kernel_selector_config_);

fam.SetWeights(
am_v_weights.DataAsSpan<T>(),
Expand All @@ -341,7 +341,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
attn_layer_depth,
hidden_size_,
has_attention_layer,
fam, thread_pool);
fam, thread_pool, &mlas_backend_kernel_selector_config_);

faw.SetWeights(attn_layer_weights_span);

Expand All @@ -353,7 +353,7 @@ Status DeepCpuAttnLstmOp::ComputeImpl(OpKernelContext& context) const {
activation_funcs_.Entries()[0],
activation_funcs_.Entries()[1],
activation_funcs_.Entries()[2],
clip_, thread_pool);
clip_, thread_pool, &mlas_backend_kernel_selector_config_);

fw.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1, hidden_output_1, last_cell_1);
}
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/contrib_ops/cpu/attnlstm/deep_cpu_attn_lstm.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include "core/common/narrow.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cpu/rnn/rnn_helpers.h"
#include "core/mlas/inc/mlas.h"
#include "core/session/onnxruntime_session_options_config_keys.h"

namespace onnxruntime {
namespace contrib {
Expand Down Expand Up @@ -58,6 +60,9 @@ class DeepCpuAttnLstmOp final : public OpKernel {
activation_funcs_ = ActivationFuncs(activation_func_names,
activation_func_alphas,
activation_func_betas);

mlas_backend_kernel_selector_config_.use_kleidiai =
info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasDisableKleidiai) != "1";
}

Status Compute(OpKernelContext* context) const override;
Expand Down Expand Up @@ -92,6 +97,8 @@ class DeepCpuAttnLstmOp final : public OpKernel {
bool input_forget_ = false;

ActivationFuncs activation_funcs_;

MLAS_BACKEND_KERNEL_SELECTOR_CONFIG mlas_backend_kernel_selector_config_;
};

} // namespace contrib
Expand Down
12 changes: 7 additions & 5 deletions onnxruntime/contrib_ops/cpu/attnlstm/uni_dir_attn_lstm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ UniDirectionalAttnLstm<T>::UniDirectionalAttnLstm(AllocatorPtr allocator,
const ActivationFuncs::Entry& activation_func_g,
const ActivationFuncs::Entry& activation_func_h,
const float clip,
onnxruntime::concurrency::ThreadPool* ttp)
onnxruntime::concurrency::ThreadPool* ttp,
const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config)
: allocator_(allocator),
logger_(logger),
seq_length_(seq_length),
Expand All @@ -64,7 +65,8 @@ UniDirectionalAttnLstm<T>::UniDirectionalAttnLstm(AllocatorPtr allocator,
use_bias_(!bias.empty()),
use_peepholes_(!peephole_weights.empty()),
attention_wrapper_(attention_wrapper),
ttp_(ttp) {
ttp_(ttp),
mlas_backend_kernel_selector_config_(mlas_backend_kernel_selector_config) {
activation_f_ = {deepcpu::ActivationFuncByName(activation_func_f.name),
activation_func_f.alpha,
activation_func_f.beta};
Expand Down Expand Up @@ -260,7 +262,7 @@ void UniDirectionalAttnLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
input_weights.begin(), input_weights.end(), // W[iofc]^T
input_size_ + attention_size_, T{0.0},
output_iofc_.begin(), output_iofc_.end(),
hidden_size_x4, ttp_);
hidden_size_x4, ttp_, mlas_backend_kernel_selector_config_);

DumpMatrix("Xt*(W[iofc]^T)", output_iofc_.data(), total_rows, hidden_size_x4);

Expand Down Expand Up @@ -298,7 +300,7 @@ void UniDirectionalAttnLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
input_weights.begin() + input_size_, input_weights.end(), // WA[iofc]
input_size_ + attention_size_, T{1.0},
step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
hidden_size_x4, ttp_);
hidden_size_x4, ttp_, mlas_backend_kernel_selector_config_);

// calculate Xt*(W[iofc]^T) + Ht-1*R[iofc]
ComputeGemm(batch_size_, hidden_size_x4, hidden_size_, T{1.0},
Expand All @@ -307,7 +309,7 @@ void UniDirectionalAttnLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
recurrent_weights.begin(), recurrent_weights.end(), // R[iofc]
hidden_size_, T{1.0},
step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
hidden_size_x4, ttp_);
hidden_size_x4, ttp_, mlas_backend_kernel_selector_config_);

span_T_iter batched_output, batched_output_end;
if (output_sequence) {
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/contrib_ops/cpu/attnlstm/uni_dir_attn_lstm.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ class UniDirectionalAttnLstm {
const ActivationFuncs::Entry& activation_func_g,
const ActivationFuncs::Entry& activation_func_h,
const float clip,
onnxruntime::concurrency::ThreadPool* ttp);
onnxruntime::concurrency::ThreadPool* ttp,
const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config);

void Compute(const gsl::span<const T>& inputs,
const gsl::span<const int>& sequence_lengths,
Expand Down Expand Up @@ -152,6 +153,8 @@ class UniDirectionalAttnLstm {
AttentionWrapper<T>& attention_wrapper_;

onnxruntime::concurrency::ThreadPool* ttp_;

const MLAS_BACKEND_KERNEL_SELECTOR_CONFIG* mlas_backend_kernel_selector_config_;
};

} // namespace detail
Expand Down
37 changes: 19 additions & 18 deletions onnxruntime/contrib_ops/cpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ bool Attention<T>::IsPackWeightsSuccessful(int qkv_index,
const T* weights_data,
size_t weight_matrix_col_size,
/*out*/ PrePackedWeights* prepacked_weights) {
size_t packb_size = MlasGemmPackBSize(CblasNoTrans, CblasNoTrans, head_size, input_hidden_size);
size_t packb_size = MlasGemmPackBSize(CblasNoTrans, CblasNoTrans, head_size, input_hidden_size, &mlas_backend_kernel_selector_config_);
if (packb_size == 0) {
return false;
}
Expand All @@ -87,7 +87,7 @@ bool Attention<T>::IsPackWeightsSuccessful(int qkv_index,
memset(packed_weights_data, 0, packed_weights_data_size);

for (size_t i = 0; i < loop_len; i++) {
MlasGemmPackB(CblasNoTrans, CblasNoTrans, head_size, input_hidden_size, weights_data, weight_matrix_col_size, packed_weights_data);
MlasGemmPackB(CblasNoTrans, CblasNoTrans, head_size, input_hidden_size, weights_data, weight_matrix_col_size, packed_weights_data, &mlas_backend_kernel_selector_config_);
packed_weights_data += packb_size;
weights_data += head_size;
}
Expand Down Expand Up @@ -310,24 +310,25 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
1.0f, // beta
qkv_dest + qkv_offset, // C
head_size, // ldc
nullptr); // use single-thread
nullptr, // use single-thread
&mlas_backend_kernel_selector_config_); // BackendKernelSelectorConfig
} else {
math::GemmEx<float, ThreadPool>(
CblasNoTrans, // TransA = no
CblasNoTrans, // TransB = no
sequence_length, // M = S
head_size, // N = H
input_hidden_size, // K = D
1.0f, // alpha
input_data + input_offset, // A
input_hidden_size, // lda = D
weights_data + weights_offset, // B
qkv_hidden_size, // ldb = D + D + D_v
1.0f, // beta
qkv_dest + qkv_offset, // C
head_size, // ldc
nullptr // use single-thread
);
CblasNoTrans, // TransA = no
CblasNoTrans, // TransB = no
sequence_length, // M = S
head_size, // N = H
input_hidden_size, // K = D
1.0f, // alpha
input_data + input_offset, // A
input_hidden_size, // lda = D
weights_data + weights_offset, // B
qkv_hidden_size, // ldb = D + D + D_v
1.0f, // beta
qkv_dest + qkv_offset, // C
head_size, // ldc
nullptr, // use single-thread
&mlas_backend_kernel_selector_config_); // BackendKernelSelectorConfig
}
}
});
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "core/framework/op_kernel.h"
#include "contrib_ops/cpu/bert/attention_common.h"
#include "contrib_ops/cpu/bert/attention_parameters.h"
#include "core/mlas/inc/mlas.h"

namespace onnxruntime {
namespace contrib {
Expand All @@ -32,6 +33,8 @@ class AttentionBase {
int& past_sequence_length) const;

protected:
MLAS_BACKEND_KERNEL_SELECTOR_CONFIG mlas_backend_kernel_selector_config_;

AttentionBase(const OpKernelInfo& info, bool require_same_hidden_size) {
int64_t num_heads = 0;
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
Expand Down
Loading
Loading