diff --git a/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp b/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp index fdada83cc6582..87184bf8bb3cf 100644 --- a/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp +++ b/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp @@ -6,12 +6,20 @@ #include "kai_ukernel_interface.h" #include "mlasi.h" +#include "kleidiai/mlasi_kleidiai.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h" + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod = {kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, @@ -64,6 +72,56 @@ const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm}; +const kai_matmul_clamp_f32_f32_f32p_ukernel sgemm_gemv_sme = + {kai_get_m_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, + kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, + kai_get_nr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, + kai_get_kr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, + kai_get_sr_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, + kai_get_lhs_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, + kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, + kai_get_dst_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, + kai_get_dst_size_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla, + kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x8vl_sme_mla}; + +const kai_matmul_clamp_f32_f32_f32p_ukernel sgemm_gemv_sme2 = + {kai_get_m_step_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla, + kai_get_n_step_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla, + kai_get_nr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla, + kai_get_kr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla, + kai_get_sr_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla, + kai_get_lhs_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla, + kai_get_rhs_packed_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla, + kai_get_dst_offset_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla, + kai_get_dst_size_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla, + kai_run_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla}; + +const kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm_sme = + {kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa, + kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa, + kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa, + kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa, + kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa, + kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa, + kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa, + kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa, + kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa, + kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa, + kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa}; + +const kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm_sme2 = + {kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, + kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa}; + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemmUKernel() { if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()) { return kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm; @@ -79,3 +137,19 @@ const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemvUKernel() { return kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod; } } + +const kai_matmul_clamp_f32_f32p_f32p_ukernel& GetKleidiAISGemmUKernel() { + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()) { + return sgemm_gemm_sme2; + } else { + return sgemm_gemm_sme; + } +} + +const kai_matmul_clamp_f32_f32_f32p_ukernel& GetKleidiAISGemvUKernel() { + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()) { + return sgemm_gemv_sme2; + } else { + return sgemm_gemv_sme; + } +} diff --git a/onnxruntime/core/mlas/lib/kai_ukernel_interface.h b/onnxruntime/core/mlas/lib/kai_ukernel_interface.h index 1a6f111d1c794..e69c72329d64b 100644 --- a/onnxruntime/core/mlas/lib/kai_ukernel_interface.h +++ b/onnxruntime/core/mlas/lib/kai_ukernel_interface.h @@ -8,5 +8,12 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p_f32p_interface.h" + +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p_interface.h" + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemmUKernel(); const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemvUKernel(); + +const kai_matmul_clamp_f32_f32p_f32p_ukernel& GetKleidiAISGemmUKernel(); +const kai_matmul_clamp_f32_f32_f32p_ukernel& GetKleidiAISGemvUKernel(); diff --git a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h index 216eb35a9b6cc..ca81b9fa426ee 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h +++ b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h @@ -77,6 +77,19 @@ MlasGemmPackB( void* PackedB ); +bool +MLASCALL +MlasGemvBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SGEMM_DATA_PARAMS* Data, + size_t BatchSize + ); + + bool MLASCALL MlasGemmBatch( diff --git a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp index 848372d71e314..250b5d076475d 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp @@ -12,7 +12,9 @@ #include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h" +#include "mlas.h" #include "mlasi_kleidiai.h" +#include "kai_ukernel_interface.h" // Thread-local reusable buffers to reduce allocation overhead across tiles. @@ -21,9 +23,214 @@ struct KaiTlsBuffers { std::vector bias_zero; std::vector rhs_packed; std::vector lhs_packed; + std::vector gemv_lhs_row_tmp; }; static thread_local KaiTlsBuffers g_kai_tls; +const kai_matmul_clamp_f32_f32p_f32p_ukernel& sgemm_gemm = GetKleidiAISGemmUKernel(); +const kai_matmul_clamp_f32_f32_f32p_ukernel& sgemm_gemv = GetKleidiAISGemvUKernel(); + + +// Helpers for GEMV +/*++ +Routine Description: + Apply alpha/beta scaling to a 1-D vector with arbitrary destination stride. + +Arguments: + src - Pointer to the temporary A*B results (length L). + num_elements - Number of elements. + alpha - Scale for the computed product (A*B). + beta - Scale for the existing C values. + dst - Pointer to the destination in C. + dst_stride - Stride, in elements, between successive outputs in C. + allow_memcpy - If true, allows memcpy path when alpha==1, beta==0, and dst_stride==1. + +Notes: + Uses a memcpy path when alpha==1, beta==0, allow_memcpy is true, and dst_stride==1. +--*/ +static inline void ApplyAlphaBetaStrided(const float* src, size_t num_elements, float alpha, float beta, float* dst, size_t dst_stride, bool allow_memcpy) { + if (alpha == 1.0f && beta == 0.0f && allow_memcpy && dst_stride == 1) { + std::memcpy(dst, src, num_elements * sizeof(float)); + return; + } + for (size_t i = 0; i < num_elements; ++i) { + const float ab = src[i]; + float& d = dst[i * dst_stride]; + const float c_orig = d; + if (alpha == 1.0f && beta == 0.0f) { + d = ab; + } else if (alpha == 1.0f) { + d = ab + beta * c_orig; + } else if (beta == 0.0f) { + d = alpha * ab; + } else { + d = alpha * ab + beta * c_orig; + } + } +} + +/*++ +Routine Description: + Apply alpha/beta scaling to a 2-D tile (rows x cols). + +Arguments: + src - Pointer to the temporary A*B results (row-major, rows x cols). + rows - Number of rows in the tile. + cols - Number of columns in the tile. + alpha - Scale for the computed product (A*B). + beta - Scale for the existing C values. + dst - Pointer to the destination tile in C (row-major with leading dimension ldc). + ldc - Leading dimension of C (in elements). + +Notes: + Uses a memcpy path when alpha==1, beta==0, ldc==cols, and rows/cols are non-zero. + Otherwise applies per-row scaling via ApplyAlphaBetaStrided. +--*/ +static inline void ApplyAlphaBeta2D(const float* src, size_t rows, size_t cols, + float alpha, float beta, + float* dst, size_t ldc) { + if (alpha == 1.0f && beta == 0.0f && ldc == cols && rows != 0 && cols != 0) { + std::memcpy(dst, src, rows * cols * sizeof(float)); + return; + } + for (size_t i = 0; i < rows; ++i) { + const float* src_row = src + i * cols; + float* dst_row = dst + i * ldc; + ApplyAlphaBetaStrided(src_row, cols, alpha, beta, dst_row, 1, /*allow_memcpy*/ (ldc == cols)); + } +} + +/*++ +Routine Description: + Execute GEMV using the SME/SME2 1xN microkernel for degenerate GEMM shapes: + - M == 1 (row-vector times matrix) + - N == 1 (matrix times column-vector) + +N == 1 mapping (y = A(MxK) * b(Kx1)): + The 1xN microkernel computes a single LHS row against multiple RHS columns. + To reuse it for N == 1, we present A as the "RHS" by transpose-packing it + so that each of A's M rows becomes a "column" for the kernel: + - rhsBase := A, rhsShape := M, ldl := lda, tb := CblasTrans + - lhsBase := B (the vector b), length K + The kernel expects the LHS vector to be a contiguous K-length row: + - If TransB == CblasNoTrans, b is stored as a Kx1 column with stride ldb. + We gather it into a thread-local contiguous buffer when ldb != 1. + - If TransB == CblasTrans, b is a 1xK row and is already contiguous. + +Unsupported: + When N == 1 and Data->BIsPacked is true (except M == N == 1), this path is + disabled because we need to pack A (as RHS) and pass B as an unpacked vector. + +Post-processing: + The kernel produces M outputs into a temporary buffer. We apply alpha/beta + and write to C using ldc as the destination stride. + +Return Value: + true - A GEMV path was executed (M == 1 or N == 1). + false - Fall back to the general GEMM path. +--*/ + +bool +MLASCALL +ArmKleidiAI::MlasGemvBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SGEMM_DATA_PARAMS* Data, + size_t BatchSize +) { + // Only two paths: M-path (M == 1, also covers M == N == 1) or N-path (N == 1). + if (M != 1 && N != 1) { + return false; + } + + const bool m_path = (M == 1); + + // We cannot support cases where N == 1 and B is already packed. + // When both are 1, we route through the M-path, so this naturally doesn't trigger. + if (!m_path && Data->BIsPacked) { + return false; + } + + // Decide RHS and transposition once based on the path + CBLAS_TRANSPOSE tb = m_path ? TransB : CblasTrans; + size_t rhs_shape = m_path ? N : M; + + for (size_t b = 0; b < BatchSize; ++b) { + + size_t rhs_ld = m_path ? Data[b].ldb : Data[b].lda; + // LHS is the vector row we feed to the GEMV microkernel + // - M-path: LHS is A, stride = lda + // - N-path: LHS is B, stride = ldb + size_t lhs_ld = m_path ? Data[b].lda : Data[b].ldb; + + const float* rhs_base = m_path ? static_cast(Data[b].B) + : static_cast(Data[b].A); + const float* lhs_base = m_path ? static_cast(Data[b].A) + : static_cast(Data[b].B); + + // Prepare packed RHS if needed + const void* rhs_packed_ptr = nullptr; + + // The if branch can only be taken in cases where we are dealing with M == 1 + // We previously reject any prepacked B where N == 1 + // In cases where N == 1 we Pack A Matrix as the RHS using tb = CBlasTrans + // After which the rhs_packed_ptr points to Packed A not B + // rhs_packed_ptr = Data[b].B only when M == 1 + if (Data[b].BIsPacked) { + rhs_packed_ptr = Data[b].B; + } else { + const size_t rhs_size = ArmKleidiAI::MlasGemmPackBSize(TransA, tb, rhs_shape, K); + if (rhs_size == 0) { + return false; + } + g_kai_tls.rhs_packed.resize(rhs_size); + + ArmKleidiAI::MlasGemmPackB( + TransA, tb, rhs_shape, K, + rhs_base, + rhs_ld, + g_kai_tls.rhs_packed.data()); + rhs_packed_ptr = g_kai_tls.rhs_packed.data(); + } + // Ensure LHS is a contiguous K-length row for the GEMV microkernel. + // Compute once whether we need to gather based on which side is LHS. + const bool needs_gather = m_path ? (TransA == CblasTrans) : (TransB == CblasNoTrans); + if (needs_gather) { + g_kai_tls.gemv_lhs_row_tmp.resize(K); + for (size_t k = 0; k < K; ++k) { + g_kai_tls.gemv_lhs_row_tmp[k] = lhs_base[k * lhs_ld]; + } + lhs_base = g_kai_tls.gemv_lhs_row_tmp.data(); + } + + // Temporary buffer for output row + g_kai_tls.output_tile.resize(rhs_shape); + + // Run specialized 1xN-by-K kernel + sgemm_gemv.run_matmul( + 1, // Value of 1 for M == 1 and this value represents N when N == 1 case + rhs_shape, // Value of N for M == 1 and this value is M when N == 1 + K, // K + lhs_base, // lhs + K * sizeof(float), // lhs stride (bytes) + rhs_packed_ptr, // packed rhs + g_kai_tls.output_tile.data(), // output + rhs_shape * sizeof(float), // dst row stride (bytes) + sizeof(float), // dst col stride (bytes) + -std::numeric_limits::max(), + std::numeric_limits::max() + ); + // Apply alpha/beta to destination C row + bool allowMemCopy = m_path ? (Data[b].ldc == N) : (Data[b].ldc == 1); + size_t destStride = m_path ? 1 : Data[b].ldc; + ApplyAlphaBetaStrided(g_kai_tls.output_tile.data(), rhs_shape, Data[b].alpha, Data[b].beta, Data[b].C, destStride, allowMemCopy); + } + return true; +} + size_t MLASCALL ArmKleidiAI::MlasGemmPackBSize( @@ -125,12 +332,10 @@ Return Value: } if (TransA == CblasNoTrans) { - const size_t nr = UseSME2 ? kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() - : kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); - const size_t kr = UseSME2 ? kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() - : kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); - const size_t sr = UseSME2 ? kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() - : kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); + + const size_t nr = sgemm_gemm.get_nr(); + const size_t kr = sgemm_gemm.get_kr(); + const size_t sr = sgemm_gemm.get_sr(); // Ensure size and zero the used span. g_kai_tls.bias_zero.resize(N, 0.0f); @@ -174,7 +379,7 @@ ArmKleidiAI::MlasGemmBatch( Routine Description: - This routine performs a batched matrix multiplication (GEMM) operation using KleidiAI kernels. + This routine performs a batched matrix multiplication (GEMM or GemV) operation using KleidiAI kernels. It handles both packed and unpacked inputs and manages tiling and kernel selection depending on SME2 availability. If packing is needed, it prepares the required buffers and invokes the appropriate left-hand side (LHS) and right-hand side (RHS) pack functions. @@ -226,23 +431,27 @@ Return Value: return true; } - const size_t mr = UseSME2 ? kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() - : kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); - const size_t kr = UseSME2 ? kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() - : kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); - const size_t sr = UseSME2 ? kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() - : kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); + // Attempt GEMV (M==1 or N==1) + if (M == 1 || N == 1) + { + // TODO: Investigate passing threadpool and multithreading of gemv op + if (ArmKleidiAI::MlasGemvBatch(TransA, TransB, M, N, K, Data, BatchSize)) { + return true; + } + } - size_t m_step = UseSME2 ? kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() - : kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); - size_t n_step = UseSME2 ? kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() - : kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); + size_t m_step = sgemm_gemm.get_m_step(); + size_t n_step = sgemm_gemm.get_n_step(); if ((M < m_step || N < n_step) && !Data->BIsPacked) { // Fallback to MLAS return false; } + const size_t mr = sgemm_gemm.get_mr(); + const size_t kr = sgemm_gemm.get_kr(); + const size_t sr = sgemm_gemm.get_sr(); + size_t LhsPackedStride = 0; std::byte* LhsPackedData = nullptr; @@ -335,9 +544,7 @@ Return Value: ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2]; // Get rhs tile, B - const size_t rhs_packed_offset = - UseSME2 ? kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(NIdx * n_step, K) - : kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(NIdx * n_step, K); + const size_t rhs_packed_offset = sgemm_gemm.get_rhs_packed_offset(NIdx * n_step, K); const std::byte* B_base = Data[0].BIsPacked ? reinterpret_cast(Data[BIdx].B) @@ -345,9 +552,7 @@ Return Value: auto BTile = reinterpret_cast(B_base + rhs_packed_offset); // Get lhs tile, A - const size_t lhs_packed_offset = - UseSME2 ? kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(MIdx * m_step, K) - : kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(MIdx * m_step, K); + const size_t lhs_packed_offset = sgemm_gemm.get_lhs_packed_offset(MIdx * m_step, K); const std::byte* A_base = LhsPackedData + LhsPackedStride * BIdx; auto ATile = reinterpret_cast(A_base + lhs_packed_offset); @@ -368,31 +573,15 @@ Return Value: g_kai_tls.output_tile.resize(tile_elems); float* temp_tile = g_kai_tls.output_tile.data(); - std::fill_n(temp_tile, tile_elems, 0.0f); - - if (UseSME2) { - KLEIDIAI_KERNEL_LOG("kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa" - << " M=" << TileSizeM << " << N=" << TileSizeN << " K=" << K); - kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa( - TileSizeM, - TileSizeN, - K, - ATile, BTile, temp_tile, - TileSizeN * sizeof(float), sizeof(float), - -std::numeric_limits::max(), std::numeric_limits::max() - ); - } else { - KLEIDIAI_KERNEL_LOG("kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa" - << " M=" << TileSizeM << " N=" << TileSizeN << " K=" << K); - kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa( - TileSizeM, - TileSizeN, - K, - ATile, BTile, temp_tile, - TileSizeN * sizeof(float), sizeof(float), - -std::numeric_limits::max(), std::numeric_limits::max() - ); - } + + sgemm_gemm.run_matmul( + TileSizeM, + TileSizeN, + K, + ATile, BTile, temp_tile, + TileSizeN * sizeof(float), sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); // Final output tile pointer float* dst_tile = reinterpret_cast(CTile); @@ -417,25 +606,7 @@ Return Value: float beta = Data[BIdx].beta; size_t ldc = Data[BIdx].ldc; - for (size_t i = 0; i < TileSizeM; ++i) { - for (size_t j = 0; j < TileSizeN; ++j) { - const size_t temp_idx = i * TileSizeN + j; - const size_t dst_idx = i * ldc + j; - - float ab = temp_tile[temp_idx]; - float c_orig = dst_tile[dst_idx]; - - if (alpha == 1.0f && beta == 0.0f) { - dst_tile[dst_idx] = ab; - } else if (alpha == 1.0f) { - dst_tile[dst_idx] = ab + beta * c_orig; - } else if (beta == 0.0f) { - dst_tile[dst_idx] = alpha * ab; - } else { - dst_tile[dst_idx] = alpha * ab + beta * c_orig; - } - } - } + ApplyAlphaBeta2D(temp_tile, TileSizeM, TileSizeN, alpha, beta, dst_tile, ldc); return; }); return true; diff --git a/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h b/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h index c832ca69dbb31..cc40ce9217489 100644 --- a/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h +++ b/onnxruntime/test/mlas/unittest/test_fgemm_fixture.h @@ -71,6 +71,19 @@ class FgemmShortExecuteTest : public MlasTestFixture 1 to validate LHS gather handling. + test_registered += RegisterTestTransposeABProduct(1, 1, 16, 1, 1.0f, 0.0f); + test_registered += RegisterTestTransposeABProduct(1, 1, 31, 1, 1.0f, 0.0f); + test_registered += RegisterTestTransposeABProduct(1, 1, 64, 1, 1.0f, 0.0f); + test_registered += RegisterTestTransposeABProduct(1, 1, 16, 3, 1.0f, 0.0f); + return test_registered; }