diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index d7dcde945e6d7..59abea26e4f60 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -327,6 +327,12 @@ function (setup_arm_neon_nchwc) ${MLAS_SRC_DIR}/sconv_nchwc_kernel_neon.h ${MLAS_SRC_DIR}/sconv_nchwc_kernel_neon.cpp ${MLAS_SRC_DIR}/spool_nchwc_kernel_neon.cpp + # Hand written AArch64 micro-kernel for NCHW convolution. Using a + # separate assembly file allows tighter control over register allocation + # and avoids the overhead of C++/intrinsics based code generation. + ${MLAS_SRC_DIR}/aarch64/SconvKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeon.S ) list(APPEND mlas_private_compile_definitions MLAS_USE_ARM_NEON_NCHWC) set(mlas_private_compile_definitions ${mlas_private_compile_definitions} PARENT_SCOPE) diff --git a/onnxruntime/core/mlas/lib/aarch64/SconvDepthwiseKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SconvDepthwiseKernelNeon.S new file mode 100644 index 0000000000000..6521c50eb40a2 --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/SconvDepthwiseKernelNeon.S @@ -0,0 +1,238 @@ +/*++ +SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +SPDX-License-Identifier: MIT + +Module Name: + + SconvDepthwiseFloatKernelNeon.S + +Abstract: + + Optimized AArch64 assembly implementation of the depthwise convolution + micro-kernel used by the NCHWc single precision path. + + This kernel performs the following optimisations: + * Produce a fast path for interior output positions where all input + accesses are guaranteed to be in-bounds and can be loaded with a pair + of 128-bit loads. + * When an output position touches padding, only the affected 4-wide + lanes are checked individually and loaded; others are zeroed. This + mirrors the behavior of the C++ helper LoadInputVectorWithBounds. + * Keep the multiply/accumulate operations tightly scheduled to hide the + load latency. + + The kernel computes a single output position for a 16 channel block and is + repeatedly invoked by the high level dispatch code. + +--*/ + +#include "asmmacro.h" + + .text + +// Offsets for stack based parameters. AArch64 passes the first eight +// arguments in registers (x0-x7). The remaining parameters are read from the +// stack directly. The layout is defined by the C compiler so use constant +// offsets here. + + .equ .Ldw_InputBase, 0 + .equ .Ldw_InputWidth, 8 + .equ .Ldw_DilatedInputWidth, 16 + .equ .Ldw_OutputCountLeftPad, 24 + .equ .Ldw_OutputCount, 32 + .equ .Ldw_OutputCountRightPad, 40 + .equ .Ldw_Bias, 48 + .equ .Ldw_Flags, 56 + +// Prototype +// +// void +// MlasConvDepthwiseFloatKernelNeonAsm( +// const float* Input, // x0 +// const float* Filter, // x1 +// float* Output, // x2 +// size_t StrideWidth, // x3 (bytes) +// size_t DilationWidth, // x4 (bytes) +// size_t InputStride, // x5 (unused) +// size_t KernelHeight, // x6 +// size_t KernelWidth, // x7 +// const float* InputBase, // [sp + 0] +// size_t InputWidth, // [sp + 8] (bytes) +// size_t DilatedInputWidth, // [sp + 16] (bytes) +// size_t OutputCountLeftPad, // [sp + 24] +// size_t OutputCount, // [sp + 32] +// size_t OutputCountRightPad, // [sp + 40] +// const float* Bias, // [sp + 48] +// unsigned KernelFlags); // [sp + 56] +// + + FUNCTION_ENTRY MlasConvDepthwiseFloatKernelNeonAsm + + // Load the stack parameters used in the hot loops. + ldr x8, [sp,#.Ldw_InputBase] // base of valid input row + ldr x9, [sp,#.Ldw_InputWidth] // row width in bytes + ldr x10,[sp,#.Ldw_DilatedInputWidth] // stride between rows + ldr x11,[sp,#.Ldw_OutputCountLeftPad] + ldr x12,[sp,#.Ldw_OutputCount] + ldr x13,[sp,#.Ldw_OutputCountRightPad] + ldr x14,[sp,#.Ldw_Bias] + ldr w15,[sp,#.Ldw_Flags] + + // Preserve callee-saved registers used by this routine. + stp x29,x30,[sp,#-16]! + stp x27,x28,[sp,#-16]! + stp x25,x26,[sp,#-16]! + stp x23,x24,[sp,#-16]! + stp x21,x22,[sp,#-16]! + stp x19,x20,[sp,#-16]! + stp d12,d13,[sp,#-16]! + stp d14,d15,[sp,#-16]! + + // Compute total number of output elements to produce. + add x16,x11,x12 + add x16,x16,x13 + + // Load bias vectors when required; otherwise all zeros are used to + // initialize the accumulators. + eor v20.16b, v20.16b, v20.16b + eor v21.16b, v21.16b, v21.16b + eor v22.16b, v22.16b, v22.16b + eor v23.16b, v23.16b, v23.16b + tbz w15,#1,1f // no bias addition + ldp q20,q21,[x14],#32 + ldp q22,q23,[x14] +1: + // Constant zero used by ReLU handling. + eor v24.16b, v24.16b, v24.16b + + mov x17,#0 // output index + +// --------------------------------------------------------------------------- +// Loop over output elements. Each iteration computes one output position for +// 16 channels. +// --------------------------------------------------------------------------- +.Ldw_OutputLoop: + // Start accumulators from bias or zeros. + mov v0.16b, v20.16b + mov v1.16b, v21.16b + mov v2.16b, v22.16b + mov v3.16b, v23.16b + + mov x20,x1 // reset filter pointer + mov x21,#0 // kh = 0 + + // Base pointer for this output index across the input. + madd x19,x17,x3,x0 // Input + out_idx*StrideWidth + +.Ldw_HeightLoop: + // Compute [row_start,row_end) for the current kernel row. + madd x22,x21,x10,x8 // row_start + add x27,x22,x9 // row_end + sub x29,x27,#64 // row_end - 64 (fast path) + add x28,x27,#-16 // row_end - 16 + + // Base address for the first kw element on this row. + madd x26,x21,x10,x19 // input for this row + + mov x25,x7 // kw remaining + +.Ldw_WidthLoop: + // Fast path: the 16-lane load fits completely within the row. + cmp x26,x22 + b.lo .Ldw_SlowPath + cmp x26,x29 + bhi .Ldw_SlowPath + // Load 16 input values for the current position. + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x26] + b .Ldw_DoFma + +.Ldw_SlowPath: + // Zero registers and conditionally load each 4-wide vector when it is + // entirely within bounds. This matches the behavior of the C++ + // helper LoadInputVectorWithBounds. + eor v16.16b, v16.16b, v16.16b + eor v17.16b, v17.16b, v17.16b + eor v18.16b, v18.16b, v18.16b + eor v19.16b, v19.16b, v19.16b + + mov x23,x26 + cmp x23,x22 + b.lt 2f + cmp x23,x28 + b.hi 2f + ldr q16,[x23] +2: + add x23,x26,#16 + cmp x23,x22 + b.lt 3f + cmp x23,x28 + b.hi 3f + ldr q17,[x23] +3: + add x23,x26,#32 + cmp x23,x22 + b.lt 4f + cmp x23,x28 + b.hi 4f + ldr q18,[x23] +4: + add x23,x26,#48 + cmp x23,x22 + b.lt 5f + cmp x23,x28 + b.hi 5f + ldr q19,[x23] +5: + +.Ldw_DoFma: + // Load filter block and update accumulators. + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x20], #64 + fmla v0.4s, v16.4s, v12.4s + fmla v1.4s, v17.4s, v13.4s + fmla v2.4s, v18.4s, v14.4s + fmla v3.4s, v19.4s, v15.4s + + add x26,x26,x4 // advance to next kw + subs x25,x25,#1 + b.ne .Ldw_WidthLoop + + add x21,x21,#1 + cmp x21,x6 + blt .Ldw_HeightLoop + + // Compute destination pointer for this output element. + add x23,x2,x17,lsl #6 // 16 floats per output + + // Accumulate existing output when requested. + tbz w15,#0,6f + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x23] + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fadd v2.4s, v2.4s, v18.4s + fadd v3.4s, v3.4s, v19.4s +6: + // Optional ReLU activation. + tbz w15,#2,8f + fmax v0.4s, v0.4s, v24.4s + fmax v1.4s, v1.4s, v24.4s + fmax v2.4s, v2.4s, v24.4s + fmax v3.4s, v3.4s, v24.4s +8: + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x23] + + add x17,x17,#1 + cmp x17,x16 + blt .Ldw_OutputLoop + + ldp d14,d15,[sp],#16 + ldp d12,d13,[sp],#16 + ldp x19,x20,[sp],#16 + ldp x21,x22,[sp],#16 + ldp x23,x24,[sp],#16 + ldp x25,x26,[sp],#16 + ldp x27,x28,[sp],#16 + ldp x29,x30,[sp],#16 + + ret + + .end diff --git a/onnxruntime/core/mlas/lib/aarch64/SconvKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SconvKernelNeon.S new file mode 100644 index 0000000000000..643f537834663 --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/SconvKernelNeon.S @@ -0,0 +1,836 @@ +/*++ +SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +SPDX-License-Identifier: MIT + +Module Name: + SconvKernelNeon.S + +Abstract: + Hand written AArch64 vectorised kernel used by the convolution path + (NCHW activations and NCHWc weights). The kernel computes one output + row and processes up to four 16-wide filter blocks (FilterCount <= 4). + The hot interior loop avoids all bounds checks and the two-output path + amortizes filter loads when possible. +--*/ + +#include "asmmacro.h" + + // Stack layout for parameters passed on the stack. The first eight + // integer parameters follow the AArch64 calling convention and are in + // x0-x7. Remaining parameters are spilled by the C++ caller. + + .equ .LFrame_SavedRegs, (5*16) + .equ .LFrame_x19_x20, 0 + .equ .LFrame_x21_x22, 16 + .equ .LFrame_x23_x24, 32 + .equ .LFrame_x25_x26, 48 + .equ .LFrame_x27_x28, 64 + + .equ .KO_OutputStride, (0 + .LFrame_SavedRegs) + .equ .KO_KernelHeight, (8 + .LFrame_SavedRegs) + .equ .KO_KernelWidth, (16 + .LFrame_SavedRegs) + .equ .KO_InputBase, (24 + .LFrame_SavedRegs) + .equ .KO_InputWidth, (32 + .LFrame_SavedRegs) + .equ .KO_DilatedInputWidth, (40 + .LFrame_SavedRegs) + .equ .KO_OutputCountLeftPad, (48 + .LFrame_SavedRegs) + .equ .KO_OutputCount, (56 + .LFrame_SavedRegs) + .equ .KO_OutputCountRightPad, (64 + .LFrame_SavedRegs) + .equ .KO_Bias, (72 + .LFrame_SavedRegs) + .equ .KO_Flags, (80 + .LFrame_SavedRegs) + + .text + + // --------------------------------------------------------------------- + // Helper macros. + // --------------------------------------------------------------------- + + .macro CLEAR_ACCUM N + eor v0.16b,v0.16b,v0.16b + eor v1.16b,v1.16b,v1.16b + eor v2.16b,v2.16b,v2.16b + eor v3.16b,v3.16b,v3.16b +.if \N >= 2 + eor v4.16b,v4.16b,v4.16b + eor v5.16b,v5.16b,v5.16b + eor v6.16b,v6.16b,v6.16b + eor v7.16b,v7.16b,v7.16b +.endif +.if \N >= 3 + eor v8.16b,v8.16b,v8.16b + eor v9.16b,v9.16b,v9.16b + eor v10.16b,v10.16b,v10.16b + eor v11.16b,v11.16b,v11.16b +.endif +.if \N >= 4 + eor v12.16b,v12.16b,v12.16b + eor v13.16b,v13.16b,v13.16b + eor v14.16b,v14.16b,v14.16b + eor v15.16b,v15.16b,v15.16b +.endif + .endm + + // Post processing for a single output element working on N filter + // blocks. x27 points at output base for set0, x8 holds OutputStride + // and x17 contains the Bias pointer for set0. The zero vector is in + // v31 and KernelFlags in w6. + .macro POSTPROCESS_STORE N + // accumulate from existing output if requested + tbz w6,#0,1f + ldr q16,[x27] + ldr q17,[x27,#16] + ldr q18,[x27,#32] + ldr q19,[x27,#48] + fadd v0.4s,v0.4s,v16.4s + fadd v1.4s,v1.4s,v17.4s + fadd v2.4s,v2.4s,v18.4s + fadd v3.4s,v3.4s,v19.4s +.if \N >= 2 + add x24,x27,x8 + ldr q16,[x24] + ldr q17,[x24,#16] + ldr q18,[x24,#32] + ldr q19,[x24,#48] + fadd v4.4s,v4.4s,v16.4s + fadd v5.4s,v5.4s,v17.4s + fadd v6.4s,v6.4s,v18.4s + fadd v7.4s,v7.4s,v19.4s +.endif +.if \N >= 3 + add x24,x27,x8,lsl #1 + ldr q16,[x24] + ldr q17,[x24,#16] + ldr q18,[x24,#32] + ldr q19,[x24,#48] + fadd v8.4s,v8.4s,v16.4s + fadd v9.4s,v9.4s,v17.4s + fadd v10.4s,v10.4s,v18.4s + fadd v11.4s,v11.4s,v19.4s +.endif +.if \N >= 4 + add x24,x27,x8 + add x24,x24,x8,lsl #1 + ldr q16,[x24] + ldr q17,[x24,#16] + ldr q18,[x24,#32] + ldr q19,[x24,#48] + fadd v12.4s,v12.4s,v16.4s + fadd v13.4s,v13.4s,v17.4s + fadd v14.4s,v14.4s,v18.4s + fadd v15.4s,v15.4s,v19.4s +.endif +1: + // add bias + tbz w6,#1,2f + ldr q16,[x17] + ldr q17,[x17,#16] + ldr q18,[x17,#32] + ldr q19,[x17,#48] + fadd v0.4s,v0.4s,v16.4s + fadd v1.4s,v1.4s,v17.4s + fadd v2.4s,v2.4s,v18.4s + fadd v3.4s,v3.4s,v19.4s +.if \N >= 2 + add x24,x17,#64 + ldr q16,[x24] + ldr q17,[x24,#16] + ldr q18,[x24,#32] + ldr q19,[x24,#48] + fadd v4.4s,v4.4s,v16.4s + fadd v5.4s,v5.4s,v17.4s + fadd v6.4s,v6.4s,v18.4s + fadd v7.4s,v7.4s,v19.4s +.endif +.if \N >= 3 + add x24,x17,#128 + ldr q16,[x24] + ldr q17,[x24,#16] + ldr q18,[x24,#32] + ldr q19,[x24,#48] + fadd v8.4s,v8.4s,v16.4s + fadd v9.4s,v9.4s,v17.4s + fadd v10.4s,v10.4s,v18.4s + fadd v11.4s,v11.4s,v19.4s +.endif +.if \N >= 4 + add x24,x17,#192 + ldr q16,[x24] + ldr q17,[x24,#16] + ldr q18,[x24,#32] + ldr q19,[x24,#48] + fadd v12.4s,v12.4s,v16.4s + fadd v13.4s,v13.4s,v17.4s + fadd v14.4s,v14.4s,v18.4s + fadd v15.4s,v15.4s,v19.4s +.endif +2: + // optional ReLU + tbz w6,#2,3f + fmax v0.4s,v0.4s,v31.4s + fmax v1.4s,v1.4s,v31.4s + fmax v2.4s,v2.4s,v31.4s + fmax v3.4s,v3.4s,v31.4s +.if \N >= 2 + fmax v4.4s,v4.4s,v31.4s + fmax v5.4s,v5.4s,v31.4s + fmax v6.4s,v6.4s,v31.4s + fmax v7.4s,v7.4s,v31.4s +.endif +.if \N >= 3 + fmax v8.4s,v8.4s,v31.4s + fmax v9.4s,v9.4s,v31.4s + fmax v10.4s,v10.4s,v31.4s + fmax v11.4s,v11.4s,v31.4s +.endif +.if \N >= 4 + fmax v12.4s,v12.4s,v31.4s + fmax v13.4s,v13.4s,v31.4s + fmax v14.4s,v14.4s,v31.4s + fmax v15.4s,v15.4s,v31.4s +.endif +3: + // store results + str q0,[x27] + str q1,[x27,#16] + str q2,[x27,#32] + str q3,[x27,#48] +.if \N >= 2 + add x24,x27,x8 + str q4,[x24] + str q5,[x24,#16] + str q6,[x24,#32] + str q7,[x24,#48] +.endif +.if \N >= 3 + add x24,x27,x8,lsl #1 + str q8,[x24] + str q9,[x24,#16] + str q10,[x24,#32] + str q11,[x24,#48] +.endif +.if \N >= 4 + add x24,x27,x8 + add x24,x24,x8,lsl #1 + str q12,[x24] + str q13,[x24,#16] + str q14,[x24,#32] + str q15,[x24,#48] +.endif + .endm + + // Two-output helpers -------------------------------------------------- + .macro CLEAR_ACCUM2 N + eor v0.16b,v0.16b,v0.16b + eor v1.16b,v1.16b,v1.16b + eor v2.16b,v2.16b,v2.16b + eor v3.16b,v3.16b,v3.16b + eor v4.16b,v4.16b,v4.16b + eor v5.16b,v5.16b,v5.16b + eor v6.16b,v6.16b,v6.16b + eor v7.16b,v7.16b,v7.16b +.if \N >= 2 + eor v8.16b,v8.16b,v8.16b + eor v9.16b,v9.16b,v9.16b + eor v10.16b,v10.16b,v10.16b + eor v11.16b,v11.16b,v11.16b + eor v12.16b,v12.16b,v12.16b + eor v13.16b,v13.16b,v13.16b + eor v14.16b,v14.16b,v14.16b + eor v15.16b,v15.16b,v15.16b +.endif +.if \N >= 3 + eor v16.16b,v16.16b,v16.16b + eor v17.16b,v17.16b,v17.16b + eor v18.16b,v18.16b,v18.16b + eor v19.16b,v19.16b,v19.16b + eor v20.16b,v20.16b,v20.16b + eor v21.16b,v21.16b,v21.16b + eor v22.16b,v22.16b,v22.16b + eor v23.16b,v23.16b,v23.16b +.endif + .endm + + // Post process helpers for the two-output interior kernel. These are + // written out-of-line as they are sizable and used in several places. + .macro POSTPROCESS_STORE2_1 + add x24,x27,#64 + // accumulate + tbz w6,#0,1f + ldr q26,[x27] + ldr q27,[x27,#16] + ldr q28,[x27,#32] + ldr q29,[x27,#48] + fadd v0.4s,v0.4s,v26.4s + fadd v1.4s,v1.4s,v27.4s + fadd v2.4s,v2.4s,v28.4s + fadd v3.4s,v3.4s,v29.4s + ldr q26,[x24] + ldr q27,[x24,#16] + ldr q28,[x24,#32] + ldr q29,[x24,#48] + fadd v4.4s,v4.4s,v26.4s + fadd v5.4s,v5.4s,v27.4s + fadd v6.4s,v6.4s,v28.4s + fadd v7.4s,v7.4s,v29.4s +1: + // bias + tbz w6,#1,2f + ldr q26,[x17] + ldr q27,[x17,#16] + ldr q28,[x17,#32] + ldr q29,[x17,#48] + fadd v0.4s,v0.4s,v26.4s + fadd v1.4s,v1.4s,v27.4s + fadd v2.4s,v2.4s,v28.4s + fadd v3.4s,v3.4s,v29.4s + fadd v4.4s,v4.4s,v26.4s + fadd v5.4s,v5.4s,v27.4s + fadd v6.4s,v6.4s,v28.4s + fadd v7.4s,v7.4s,v29.4s +2: + tbz w6,#2,3f + fmax v0.4s,v0.4s,v31.4s + fmax v1.4s,v1.4s,v31.4s + fmax v2.4s,v2.4s,v31.4s + fmax v3.4s,v3.4s,v31.4s + fmax v4.4s,v4.4s,v31.4s + fmax v5.4s,v5.4s,v31.4s + fmax v6.4s,v6.4s,v31.4s + fmax v7.4s,v7.4s,v31.4s +3: + str q0,[x27] + str q1,[x27,#16] + str q2,[x27,#32] + str q3,[x27,#48] + str q4,[x24] + str q5,[x24,#16] + str q6,[x24,#32] + str q7,[x24,#48] + .endm + + .macro POSTPROCESS_STORE2_2 + POSTPROCESS_STORE2_1 + add x24,x27,x8 + add x24,x24,#64 + add x27,x27,x8 + tbz w6,#0,1f + ldr q26,[x27] + ldr q27,[x27,#16] + ldr q28,[x27,#32] + ldr q29,[x27,#48] + fadd v8.4s,v8.4s,v26.4s + fadd v9.4s,v9.4s,v27.4s + fadd v10.4s,v10.4s,v28.4s + fadd v11.4s,v11.4s,v29.4s + ldr q26,[x24] + ldr q27,[x24,#16] + ldr q28,[x24,#32] + ldr q29,[x24,#48] + fadd v12.4s,v12.4s,v26.4s + fadd v13.4s,v13.4s,v27.4s + fadd v14.4s,v14.4s,v28.4s + fadd v15.4s,v15.4s,v29.4s +1: + tbz w6,#1,2f + add x23,x17,#64 + ldr q26,[x23] + ldr q27,[x23,#16] + ldr q28,[x23,#32] + ldr q29,[x23,#48] + fadd v8.4s,v8.4s,v26.4s + fadd v9.4s,v9.4s,v27.4s + fadd v10.4s,v10.4s,v28.4s + fadd v11.4s,v11.4s,v29.4s + fadd v12.4s,v12.4s,v26.4s + fadd v13.4s,v13.4s,v27.4s + fadd v14.4s,v14.4s,v28.4s + fadd v15.4s,v15.4s,v29.4s +2: + tbz w6,#2,3f + fmax v8.4s,v8.4s,v31.4s + fmax v9.4s,v9.4s,v31.4s + fmax v10.4s,v10.4s,v31.4s + fmax v11.4s,v11.4s,v31.4s + fmax v12.4s,v12.4s,v31.4s + fmax v13.4s,v13.4s,v31.4s + fmax v14.4s,v14.4s,v31.4s + fmax v15.4s,v15.4s,v31.4s +3: + str q8,[x27] + str q9,[x27,#16] + str q10,[x27,#32] + str q11,[x27,#48] + str q12,[x24] + str q13,[x24,#16] + str q14,[x24,#32] + str q15,[x24,#48] + sub x27,x27,x8 + .endm + + .macro POSTPROCESS_STORE2_3 + POSTPROCESS_STORE2_1 + add x24,x27,x8 + add x24,x24,#64 + add x27,x27,x8 + tbz w6,#0,1f + ldr q26,[x27] + ldr q27,[x27,#16] + ldr q28,[x27,#32] + ldr q29,[x27,#48] + fadd v8.4s,v8.4s,v26.4s + fadd v9.4s,v9.4s,v27.4s + fadd v10.4s,v10.4s,v28.4s + fadd v11.4s,v11.4s,v29.4s + ldr q26,[x24] + ldr q27,[x24,#16] + ldr q28,[x24,#32] + ldr q29,[x24,#48] + fadd v12.4s,v12.4s,v26.4s + fadd v13.4s,v13.4s,v27.4s + fadd v14.4s,v14.4s,v28.4s + fadd v15.4s,v15.4s,v29.4s +1: + tbz w6,#1,2f + add x23,x17,#64 + ldr q26,[x23] + ldr q27,[x23,#16] + ldr q28,[x23,#32] + ldr q29,[x23,#48] + fadd v8.4s,v8.4s,v26.4s + fadd v9.4s,v9.4s,v27.4s + fadd v10.4s,v10.4s,v28.4s + fadd v11.4s,v11.4s,v29.4s + fadd v12.4s,v12.4s,v26.4s + fadd v13.4s,v13.4s,v27.4s + fadd v14.4s,v14.4s,v28.4s + fadd v15.4s,v15.4s,v29.4s +2: + tbz w6,#2,3f + fmax v8.4s,v8.4s,v31.4s + fmax v9.4s,v9.4s,v31.4s + fmax v10.4s,v10.4s,v31.4s + fmax v11.4s,v11.4s,v31.4s + fmax v12.4s,v12.4s,v31.4s + fmax v13.4s,v13.4s,v31.4s + fmax v14.4s,v14.4s,v31.4s + fmax v15.4s,v15.4s,v31.4s +3: + str q8,[x27] + str q9,[x27,#16] + str q10,[x27,#32] + str q11,[x27,#48] + str q12,[x24] + str q13,[x24,#16] + str q14,[x24,#32] + str q15,[x24,#48] + add x27,x27,x8 + add x27,x27,x8 // set2 base + add x24,x27,#64 + tbz w6,#0,4f + ldr q26,[x27] + ldr q27,[x27,#16] + ldr q28,[x27,#32] + ldr q29,[x27,#48] + fadd v16.4s,v16.4s,v26.4s + fadd v17.4s,v17.4s,v27.4s + fadd v18.4s,v18.4s,v28.4s + fadd v19.4s,v19.4s,v29.4s + ldr q26,[x24] + ldr q27,[x24,#16] + ldr q28,[x24,#32] + ldr q29,[x24,#48] + fadd v20.4s,v20.4s,v26.4s + fadd v21.4s,v21.4s,v27.4s + fadd v22.4s,v22.4s,v28.4s + fadd v23.4s,v23.4s,v29.4s +4: + tbz w6,#1,5f + add x23,x17,#128 + ldr q26,[x23] + ldr q27,[x23,#16] + ldr q28,[x23,#32] + ldr q29,[x23,#48] + fadd v16.4s,v16.4s,v26.4s + fadd v17.4s,v17.4s,v27.4s + fadd v18.4s,v18.4s,v28.4s + fadd v19.4s,v19.4s,v29.4s + fadd v20.4s,v20.4s,v26.4s + fadd v21.4s,v21.4s,v27.4s + fadd v22.4s,v22.4s,v28.4s + fadd v23.4s,v23.4s,v29.4s +5: + tbz w6,#2,6f + fmax v16.4s,v16.4s,v31.4s + fmax v17.4s,v17.4s,v31.4s + fmax v18.4s,v18.4s,v31.4s + fmax v19.4s,v19.4s,v31.4s + fmax v20.4s,v20.4s,v31.4s + fmax v21.4s,v21.4s,v31.4s + fmax v22.4s,v22.4s,v31.4s + fmax v23.4s,v23.4s,v31.4s +6: + str q16,[x27] + str q17,[x27,#16] + str q18,[x27,#32] + str q19,[x27,#48] + str q20,[x24] + str q21,[x24,#16] + str q22,[x24,#32] + str q23,[x24,#48] + sub x27,x27,x8 + sub x27,x27,x8 + .endm + + // 1-output compute with padding checks --------------------------------- + .macro CONV_PAD N + CLEAR_ACCUM \N + mov x28,x1 +.if \N >= 2 + add x19,x1,x7 +.endif +.if \N >= 3 + add x20,x19,x7 +.endif +.if \N >= 4 + add x26,x20,x7 +.endif + mov x21,x0 + ldr x9,[sp,#.KO_KernelHeight] + cbz x9,5f + ldr x10,[sp,#.KO_KernelWidth] + ldr x22,[sp,#.KO_InputBase] + ldr x12,[sp,#.KO_InputWidth] + add x23,x22,x12 +1: mov x25,x21 + mov x11,x10 +2: cmp x25,x22 + blo 3f + cmp x25,x23 + bhs 3f + ld1r {v24.4s},[x25] + b 4f +3: eor v24.16b,v24.16b,v24.16b +4: // Small lookahead to keep the miss machinery busy without polluting + // cache for tiny kernels. + prfm pldl1keep,[x28,#192] + ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x28],#64 + fmla v0.4s,v26.4s,v24.4s + fmla v1.4s,v27.4s,v24.4s + fmla v2.4s,v28.4s,v24.4s + fmla v3.4s,v29.4s,v24.4s +.if \N >= 2 + ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x19],#64 + fmla v4.4s,v26.4s,v24.4s + fmla v5.4s,v27.4s,v24.4s + fmla v6.4s,v28.4s,v24.4s + fmla v7.4s,v29.4s,v24.4s +.endif +.if \N >= 3 + ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x20],#64 + fmla v8.4s,v26.4s,v24.4s + fmla v9.4s,v27.4s,v24.4s + fmla v10.4s,v28.4s,v24.4s + fmla v11.4s,v29.4s,v24.4s +.endif +.if \N >= 4 + ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x26],#64 + fmla v12.4s,v26.4s,v24.4s + fmla v13.4s,v27.4s,v24.4s + fmla v14.4s,v28.4s,v24.4s + fmla v15.4s,v29.4s,v24.4s +.endif + add x25,x25,x4 + subs x11,x11,#1 + b.ne 2b + ldr x13,[sp,#.KO_DilatedInputWidth] + add x21,x21,x13 + add x22,x22,x13 + add x23,x23,x13 + subs x9,x9,#1 + b.ne 1b +5: POSTPROCESS_STORE \N + .endm + + // 1-output compute without padding checks ------------------------------ + .macro CONV_NOPAD N + CLEAR_ACCUM \N + mov x28,x1 +.if \N >= 2 + add x19,x1,x7 +.endif +.if \N >= 3 + add x20,x19,x7 +.endif +.if \N >= 4 + add x26,x20,x7 +.endif + mov x21,x0 + ldr x9,[sp,#.KO_KernelHeight] + cbz x9,3f + ldr x10,[sp,#.KO_KernelWidth] +1: mov x25,x21 + mov x11,x10 +2: ld1r {v24.4s},[x25] + prfm pldl1keep,[x28,#192] + ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x28],#64 + fmla v0.4s,v26.4s,v24.4s + fmla v1.4s,v27.4s,v24.4s + fmla v2.4s,v28.4s,v24.4s + fmla v3.4s,v29.4s,v24.4s +.if \N >= 2 + ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x19],#64 + fmla v4.4s,v26.4s,v24.4s + fmla v5.4s,v27.4s,v24.4s + fmla v6.4s,v28.4s,v24.4s + fmla v7.4s,v29.4s,v24.4s +.endif +.if \N >= 3 + ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x20],#64 + fmla v8.4s,v26.4s,v24.4s + fmla v9.4s,v27.4s,v24.4s + fmla v10.4s,v28.4s,v24.4s + fmla v11.4s,v29.4s,v24.4s +.endif +.if \N >= 4 + ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x26],#64 + fmla v12.4s,v26.4s,v24.4s + fmla v13.4s,v27.4s,v24.4s + fmla v14.4s,v28.4s,v24.4s + fmla v15.4s,v29.4s,v24.4s +.endif + add x25,x25,x4 + subs x11,x11,#1 + b.ne 2b + ldr x13,[sp,#.KO_DilatedInputWidth] + add x21,x21,x13 + subs x9,x9,#1 + b.ne 1b +3: POSTPROCESS_STORE \N + .endm + + // Two-output compute core when no padding is required. x25/x26 are + // the two input pointers, x27 the output base. x28, x19 and x20 are + // filter pointers for each set. v24/v25 hold broadcast inputs. + .macro CONV2_NOPAD N + CLEAR_ACCUM2 \N + mov x28,x1 +.if \N >= 2 + add x19,x1,x7 +.endif +.if \N >= 3 + add x20,x19,x7 +.endif + ldr x11,[sp,#.KO_KernelWidth] + ldr x9,[sp,#.KO_KernelHeight] +1: mov x10,x11 +2: prfm pldl1keep,[x28,#192] + ld1r {v24.4s},[x25] + ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x28],#64 + fmla v0.4s,v26.4s,v24.4s + fmla v1.4s,v27.4s,v24.4s + fmla v2.4s,v28.4s,v24.4s + fmla v3.4s,v29.4s,v24.4s + ld1r {v25.4s},[x26] + fmla v4.4s,v26.4s,v25.4s + fmla v5.4s,v27.4s,v25.4s + fmla v6.4s,v28.4s,v25.4s + fmla v7.4s,v29.4s,v25.4s +.if \N >= 2 + ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x19],#64 + fmla v8.4s,v26.4s,v24.4s + fmla v9.4s,v27.4s,v24.4s + fmla v10.4s,v28.4s,v24.4s + fmla v11.4s,v29.4s,v24.4s + fmla v12.4s,v26.4s,v25.4s + fmla v13.4s,v27.4s,v25.4s + fmla v14.4s,v28.4s,v25.4s + fmla v15.4s,v29.4s,v25.4s +.endif +.if \N >= 3 + ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x20],#64 + fmla v16.4s,v26.4s,v24.4s + fmla v17.4s,v27.4s,v24.4s + fmla v18.4s,v28.4s,v24.4s + fmla v19.4s,v29.4s,v24.4s + fmla v20.4s,v26.4s,v25.4s + fmla v21.4s,v27.4s,v25.4s + fmla v22.4s,v28.4s,v25.4s + fmla v23.4s,v29.4s,v25.4s +.endif + add x25,x25,x4 + add x26,x26,x4 + subs x10,x10,#1 + b.ne 2b + ldr x12,[sp,#.KO_DilatedInputWidth] + add x25,x25,x12 + add x26,x26,x12 + subs x9,x9,#1 + b.ne 1b + .endm + +// ----------------------------------------------------------------------------- +// void MlasConvNchwFloatKernelNeonAsm(...) +// Implements the convolution micro-kernel for the NCHW input format. +// ----------------------------------------------------------------------------- + + FUNCTION_ENTRY MlasConvNchwFloatKernelNeonAsm + + // prologue ------------------------------------------------------------- + stp x19,x20,[sp,#-.LFrame_SavedRegs]! + stp x21,x22,[sp,#.LFrame_x21_x22] + stp x23,x24,[sp,#.LFrame_x23_x24] + stp x25,x26,[sp,#.LFrame_x25_x26] + stp x27,x28,[sp,#.LFrame_x27_x28] + + // frequently used parameters + ldr x8,[sp,#.KO_OutputStride] + ldr x14,[sp,#.KO_OutputCountLeftPad] + ldr x15,[sp,#.KO_OutputCount] + ldr x16,[sp,#.KO_OutputCountRightPad] + ldr x17,[sp,#.KO_Bias] + ldr w6,[sp,#.KO_Flags] // kernel flags + + eor v31.16b,v31.16b,v31.16b + + // left padded region -------------------------------------------------- + cbz x14,.LInterior + mov x20,#0 +.LLeftLoop: + madd x21,x20,x3,x0 // input pointer for this output + lsl x22,x20,#6 // (index*64) + add x27,x2,x22 + mov x0,x21 + cmp x5,#1 + b.eq .LLeftFC1 + cmp x5,#2 + b.eq .LLeftFC2 + cmp x5,#3 + b.eq .LLeftFC3 +.LLeftFC4: + CONV_PAD 4 + b .LLeftDoneOne +.LLeftFC3: + CONV_PAD 3 + b .LLeftDoneOne +.LLeftFC2: + CONV_PAD 2 + b .LLeftDoneOne +.LLeftFC1: + CONV_PAD 1 +.LLeftDoneOne: + add x20,x20,#1 + cmp x20,x14 + blo .LLeftLoop + + // interior region ----------------------------------------------------- +.LInterior: + cbz x15,.LRight + + mul x24,x14,x3 // input byte offset after left pad + add x21,x0,x24 // base input for interior + lsl x23,x14,#6 // output offset (bytes) + + mov x20,#0 +.LIntLoop: + sub x12,x15,x20 + cmp x12,#2 + b.lt .LIntProcessOne + // two-output path + madd x25,x20,x3,x21 // input for first output + add x26,x25,x3 // input for second output + add x22,x14,x20 // output index + lsl x22,x22,#6 // output offset (bytes) + add x27,x2,x22 + cmp x5,#1 + b.eq .LInt2FC1 + cmp x5,#2 + b.eq .LInt2FC2 + cmp x5,#3 + b.eq .LInt2FC3 + // FilterCount==4 falls back to single-output implementation due to + // register pressure. + b .LIntProcessOne +.LInt2FC3: + CONV2_NOPAD 3 + POSTPROCESS_STORE2_3 + b .LIntNext2 +.LInt2FC2: + CONV2_NOPAD 2 + POSTPROCESS_STORE2_2 + b .LIntNext2 +.LInt2FC1: + CONV2_NOPAD 1 + POSTPROCESS_STORE2_1 +.LIntNext2: + add x20,x20,#2 + cmp x20,x15 + blo .LIntLoop + b .LRight + +.LIntProcessOne: + madd x25,x20,x3,x21 + add x22,x14,x20 // output index + lsl x22,x22,#6 // output offset (bytes) + add x27,x2,x22 + mov x0,x25 + cmp x5,#1 + b.eq .LIntFC1 + cmp x5,#2 + b.eq .LIntFC2 + cmp x5,#3 + b.eq .LIntFC3 + CONV_NOPAD 4 + b .LIntNext1 +.LIntFC3: + CONV_NOPAD 3 + b .LIntNext1 +.LIntFC2: + CONV_NOPAD 2 + b .LIntNext1 +.LIntFC1: + CONV_NOPAD 1 +.LIntNext1: + add x20,x20,#1 + cmp x20,x15 + blo .LIntLoop + + // right padded region ------------------------------------------------- +.LRight: + cbz x16,.LDone + mov x20,#0 +.LRightLoop: + add x24,x20,x14 + add x24,x24,x15 + madd x21,x24,x3,x0 + lsl x22,x24,#6 + add x27,x2,x22 + mov x0,x21 + cmp x5,#1 + b.eq .LRFC1 + cmp x5,#2 + b.eq .LRFC2 + cmp x5,#3 + b.eq .LRFC3 + CONV_PAD 4 + b .LRightNext +.LRFC3: + CONV_PAD 3 + b .LRightNext +.LRFC2: + CONV_PAD 2 + b .LRightNext +.LRFC1: + CONV_PAD 1 +.LRightNext: + add x20,x20,#1 + cmp x20,x16 + blo .LRightLoop + +.LDone: + ldp x27,x28,[sp,#.LFrame_x27_x28] + ldp x25,x26,[sp,#.LFrame_x25_x26] + ldp x23,x24,[sp,#.LFrame_x23_x24] + ldp x21,x22,[sp,#.LFrame_x21_x22] + ldp x19,x20,[sp],#.LFrame_SavedRegs + ret + + .end diff --git a/onnxruntime/core/mlas/lib/aarch64/SconvPointwiseKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SconvPointwiseKernelNeon.S new file mode 100644 index 0000000000000..8caae4fc080ac --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/SconvPointwiseKernelNeon.S @@ -0,0 +1,441 @@ +/*++ +SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +SPDX-License-Identifier: MIT + +Module Name: + + SconvPointwiseKernelNeon.S + +Abstract: + + A hand written AArch64 vectorised micro-kernel for pointwise (1x1) convolution + operating on tensors formatted in the NCHWc layout. The kernel computes + up to four output positions in parallel which allows the filter weights to + be re-used across several outputs, greatly reducing memory bandwidth. + +--*/ + +#include "asmmacro.h" + + .text + +// Stack layout for arguments passed on the stack. The first eight arguments +// are in x0-x7, the remaining four are placed on the stack by the caller. + + .equ .LPW_OutputStride, 0 + .equ .LPW_OutputCount, 8 + .equ .LPW_Bias, 16 + .equ .LPW_Flags, 24 + +// Kernel flag bits. Keep these in sync with sconv_nchwc_kernel_neon.h. + .equ .LPWFlag_Accumulate, 1 + .equ .LPWFlag_Bias, 2 + .equ .LPWFlag_Relu, 4 + +// Size in bytes of one NCHWc block (16 FP32 values). + .equ .LPW_BlockBytes, 64 + +//------------------------------------------------------------------------- +// Helper macros +//------------------------------------------------------------------------- + +// Compute four outputs for a single input channel block. The accumulators +// for the four outputs are held in v16-v31. + .macro CPK4_FmlaStep + ldp q0,q1,[x24],#32 + ldp q2,q3,[x24],#32 + ld1r {v4.4s},[x25],#4 + ld1r {v5.4s},[x26],#4 + ld1r {v6.4s},[x27],#4 + ld1r {v7.4s},[x28],#4 + fmla v16.4s,v0.4s,v4.4s + fmla v17.4s,v1.4s,v4.4s + fmla v18.4s,v2.4s,v4.4s + fmla v19.4s,v3.4s,v4.4s + fmla v20.4s,v0.4s,v5.4s + fmla v21.4s,v1.4s,v5.4s + fmla v22.4s,v2.4s,v5.4s + fmla v23.4s,v3.4s,v5.4s + fmla v24.4s,v0.4s,v6.4s + fmla v25.4s,v1.4s,v6.4s + fmla v26.4s,v2.4s,v6.4s + fmla v27.4s,v3.4s,v6.4s + fmla v28.4s,v0.4s,v7.4s + fmla v29.4s,v1.4s,v7.4s + fmla v30.4s,v2.4s,v7.4s + fmla v31.4s,v3.4s,v7.4s + .endm + +// Accumulate helper for the single output path. The input values for the +// output are loaded to v0-v3 and each lane is multiplied with a block of +// filter coefficients. + .macro CPK1_FmlaWithLane lane, AReg + ldp q4,q5,[x24],#32 + ldp q6,q7,[x24],#32 + fmla v16.4s,v4.4s,\AReg\().s[\lane] + fmla v17.4s,v5.4s,\AReg\().s[\lane] + fmla v18.4s,v6.4s,\AReg\().s[\lane] + fmla v19.4s,v7.4s,\AReg\().s[\lane] + .endm + +// Compute a single output position. Results are returned in v16-v19. + .macro CPK_ComputeOneOutput + mov x22,#0 + eor v16.16b,v16.16b,v16.16b + eor v17.16b,v17.16b,v17.16b + eor v18.16b,v18.16b,v18.16b + eor v19.16b,v19.16b,v19.16b +.Lpw_ic_loop1: + madd x23,x22,x6,x15 + ldp q0,q1,[x23] + ldp q2,q3,[x23,#32] + add x24,x17,x22,lsl #10 + CPK1_FmlaWithLane 0, v0 + CPK1_FmlaWithLane 1, v0 + CPK1_FmlaWithLane 2, v0 + CPK1_FmlaWithLane 3, v0 + CPK1_FmlaWithLane 0, v1 + CPK1_FmlaWithLane 1, v1 + CPK1_FmlaWithLane 2, v1 + CPK1_FmlaWithLane 3, v1 + CPK1_FmlaWithLane 0, v2 + CPK1_FmlaWithLane 1, v2 + CPK1_FmlaWithLane 2, v2 + CPK1_FmlaWithLane 3, v2 + CPK1_FmlaWithLane 0, v3 + CPK1_FmlaWithLane 1, v3 + CPK1_FmlaWithLane 2, v3 + CPK1_FmlaWithLane 3, v3 + add x22,x22,#1 + cmp x22,x4 + blt .Lpw_ic_loop1 + .endm + +//------------------------------------------------------------------------- +// Entry point +//------------------------------------------------------------------------- + + FUNCTION_ENTRY MlasConvPointwiseFloatKernelNeonAsm + + // Load the arguments passed on the stack. + ldr x8,[sp,#.LPW_OutputStride] + ldr x9,[sp,#.LPW_OutputCount] + ldr x10,[sp,#.LPW_Bias] + ldr w11,[sp,#.LPW_Flags] + + // Preserve callee saved registers that are used by this routine. + stp x22,x23,[sp,#-16]! + stp x24,x25,[sp,#-16]! + stp x26,x27,[sp,#-16]! + stp x28,x19,[sp,#-16]! + + mov x14,#0 // current filter set + cbz x5,.Lpw_exit // nothing to do + +.Lpw_filter_loop: + // Compute the base pointers for this filter block. + madd x16,x14,x8,x2 // output pointer for this filter + madd x17,x14,x7,x1 // filter pointer for this filter + add x18,x10,x14,lsl #6 // bias pointer (if used) + + mov x15,x0 // input base for this iteration + mov x20,x16 // running output pointer + mov x12,x9 // output count this iteration + + lsr x13,x12,#2 // number of groups of four outputs + cbz x13,.Lpw_process_remainder + +// ------------------------------------------------------------------ +// Main loop processing 4 outputs at a time. +// ------------------------------------------------------------------ +.Lpw_groups: + // Clear accumulators for 4 outputs (16 vectors total). + eor v16.16b,v16.16b,v16.16b + eor v17.16b,v17.16b,v17.16b + eor v18.16b,v18.16b,v18.16b + eor v19.16b,v19.16b,v19.16b + eor v20.16b,v20.16b,v20.16b + eor v21.16b,v21.16b,v21.16b + eor v22.16b,v22.16b,v22.16b + eor v23.16b,v23.16b,v23.16b + eor v24.16b,v24.16b,v24.16b + eor v25.16b,v25.16b,v25.16b + eor v26.16b,v26.16b,v26.16b + eor v27.16b,v27.16b,v27.16b + eor v28.16b,v28.16b,v28.16b + eor v29.16b,v29.16b,v29.16b + eor v30.16b,v30.16b,v30.16b + eor v31.16b,v31.16b,v31.16b + + mov x22,#0 // current input channel block +.Lpw_ic_loop4: + madd x23,x22,x6,x15 // input for this block + mov x25,x23 // four rows starting positions + add x26,x23,x3 + add x27,x26,x3 + add x28,x27,x3 + add x24,x17,x22,lsl #10 // filter for this block + + // The block size is 16 so unroll 16 steps. + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + + add x22,x22,#1 + cmp x22,x4 + blt .Lpw_ic_loop4 + + // ----------------------------------------------------------------- + // Store the four outputs computed above. There are several cases to + // handle based on accumulation, bias and ReLU flags. + // ----------------------------------------------------------------- + + // Test if the kernel should accumulate into the existing output. + tbz w11,#0,.Lpw_store_nacc + + // Accumulation path. Load bias once as it is re-used for all four + // stores when present. + tbz w11,#1,1f + ldp q4,q5,[x18] + ldp q6,q7,[x18,#32] +1: + // ---- output 0 ---- + ldp q0,q1,[x20] + ldp q2,q3,[x20,#32] + tbz w11,#1,2f + fadd v0.4s,v0.4s,v4.4s + fadd v1.4s,v1.4s,v5.4s + fadd v2.4s,v2.4s,v6.4s + fadd v3.4s,v3.4s,v7.4s +2: + fadd v16.4s,v16.4s,v0.4s + fadd v17.4s,v17.4s,v1.4s + fadd v18.4s,v18.4s,v2.4s + fadd v19.4s,v19.4s,v3.4s + tbz w11,#2,3f + eor v0.16b,v0.16b,v0.16b + fmax v16.4s,v16.4s,v0.4s + fmax v17.4s,v17.4s,v0.4s + fmax v18.4s,v18.4s,v0.4s + fmax v19.4s,v19.4s,v0.4s +3: + stp q16,q17,[x20] + stp q18,q19,[x20,#32] + + // ---- output 1 ---- + add x22,x20,#.LPW_BlockBytes + ldp q0,q1,[x22] + ldp q2,q3,[x22,#32] + tbz w11,#1,4f + fadd v0.4s,v0.4s,v4.4s + fadd v1.4s,v1.4s,v5.4s + fadd v2.4s,v2.4s,v6.4s + fadd v3.4s,v3.4s,v7.4s +4: + fadd v20.4s,v20.4s,v0.4s + fadd v21.4s,v21.4s,v1.4s + fadd v22.4s,v22.4s,v2.4s + fadd v23.4s,v23.4s,v3.4s + tbz w11,#2,5f + eor v0.16b,v0.16b,v0.16b + fmax v20.4s,v20.4s,v0.4s + fmax v21.4s,v21.4s,v0.4s + fmax v22.4s,v22.4s,v0.4s + fmax v23.4s,v23.4s,v0.4s +5: + stp q20,q21,[x22] + stp q22,q23,[x22,#32] + + // ---- output 2 ---- + add x22,x22,#.LPW_BlockBytes + ldp q0,q1,[x22] + ldp q2,q3,[x22,#32] + tbz w11,#1,6f + fadd v0.4s,v0.4s,v4.4s + fadd v1.4s,v1.4s,v5.4s + fadd v2.4s,v2.4s,v6.4s + fadd v3.4s,v3.4s,v7.4s +6: + fadd v24.4s,v24.4s,v0.4s + fadd v25.4s,v25.4s,v1.4s + fadd v26.4s,v26.4s,v2.4s + fadd v27.4s,v27.4s,v3.4s + tbz w11,#2,7f + eor v0.16b,v0.16b,v0.16b + fmax v24.4s,v24.4s,v0.4s + fmax v25.4s,v25.4s,v0.4s + fmax v26.4s,v26.4s,v0.4s + fmax v27.4s,v27.4s,v0.4s +7: + stp q24,q25,[x22] + stp q26,q27,[x22,#32] + + // ---- output 3 ---- + add x22,x22,#.LPW_BlockBytes + ldp q0,q1,[x22] + ldp q2,q3,[x22,#32] + tbz w11,#1,8f + fadd v0.4s,v0.4s,v4.4s + fadd v1.4s,v1.4s,v5.4s + fadd v2.4s,v2.4s,v6.4s + fadd v3.4s,v3.4s,v7.4s +8: + fadd v28.4s,v28.4s,v0.4s + fadd v29.4s,v29.4s,v1.4s + fadd v30.4s,v30.4s,v2.4s + fadd v31.4s,v31.4s,v3.4s + tbz w11,#2,9f + eor v0.16b,v0.16b,v0.16b + fmax v28.4s,v28.4s,v0.4s + fmax v29.4s,v29.4s,v0.4s + fmax v30.4s,v30.4s,v0.4s + fmax v31.4s,v31.4s,v0.4s +9: + stp q28,q29,[x22] + stp q30,q31,[x22,#32] + b .Lpw_advance_group + +// Non-accumulating path: add bias directly to the results if requested +.Lpw_store_nacc: + tbz w11,#1,10f + ldp q4,q5,[x18] + ldp q6,q7,[x18,#32] + fadd v16.4s,v16.4s,v4.4s + fadd v17.4s,v17.4s,v5.4s + fadd v18.4s,v18.4s,v6.4s + fadd v19.4s,v19.4s,v7.4s + fadd v20.4s,v20.4s,v4.4s + fadd v21.4s,v21.4s,v5.4s + fadd v22.4s,v22.4s,v6.4s + fadd v23.4s,v23.4s,v7.4s + fadd v24.4s,v24.4s,v4.4s + fadd v25.4s,v25.4s,v5.4s + fadd v26.4s,v26.4s,v6.4s + fadd v27.4s,v27.4s,v7.4s + fadd v28.4s,v28.4s,v4.4s + fadd v29.4s,v29.4s,v5.4s + fadd v30.4s,v30.4s,v6.4s + fadd v31.4s,v31.4s,v7.4s +10: + tbz w11,#2,11f + eor v0.16b,v0.16b,v0.16b + fmax v16.4s,v16.4s,v0.4s + fmax v17.4s,v17.4s,v0.4s + fmax v18.4s,v18.4s,v0.4s + fmax v19.4s,v19.4s,v0.4s + fmax v20.4s,v20.4s,v0.4s + fmax v21.4s,v21.4s,v0.4s + fmax v22.4s,v22.4s,v0.4s + fmax v23.4s,v23.4s,v0.4s + fmax v24.4s,v24.4s,v0.4s + fmax v25.4s,v25.4s,v0.4s + fmax v26.4s,v26.4s,v0.4s + fmax v27.4s,v27.4s,v0.4s + fmax v28.4s,v28.4s,v0.4s + fmax v29.4s,v29.4s,v0.4s + fmax v30.4s,v30.4s,v0.4s + fmax v31.4s,v31.4s,v0.4s +11: + stp q16,q17,[x20] + stp q18,q19,[x20,#32] + add x22,x20,#.LPW_BlockBytes + stp q20,q21,[x22] + stp q22,q23,[x22,#32] + add x22,x22,#.LPW_BlockBytes + stp q24,q25,[x22] + stp q26,q27,[x22,#32] + add x22,x22,#.LPW_BlockBytes + stp q28,q29,[x22] + stp q30,q31,[x22,#32] + +.Lpw_advance_group: + add x15,x15,x3,lsl #2 + add x20,x20,#(.LPW_BlockBytes*4) + subs x13,x13,#1 + b.ne .Lpw_groups + +// ------------------------------------------------------------------ +// Handle the leftover (0..3) output positions. +// ------------------------------------------------------------------ +.Lpw_process_remainder: + ands x12,x12,#3 + beq .Lpw_after_filter +.Lpw_left_loop: + CPK_ComputeOneOutput + + // Accumulate? + tbz w11,#0,.Lpw_left_noacc + ldp q0,q1,[x20] + ldp q2,q3,[x20,#32] + tbz w11,#1,12f + ldp q4,q5,[x18] + ldp q6,q7,[x18,#32] + fadd v0.4s,v0.4s,v4.4s + fadd v1.4s,v1.4s,v5.4s + fadd v2.4s,v2.4s,v6.4s + fadd v3.4s,v3.4s,v7.4s +12: + fadd v16.4s,v16.4s,v0.4s + fadd v17.4s,v17.4s,v1.4s + fadd v18.4s,v18.4s,v2.4s + fadd v19.4s,v19.4s,v3.4s + tbz w11,#2,13f + eor v0.16b,v0.16b,v0.16b + fmax v16.4s,v16.4s,v0.4s + fmax v17.4s,v17.4s,v0.4s + fmax v18.4s,v18.4s,v0.4s + fmax v19.4s,v19.4s,v0.4s +13: + stp q16,q17,[x20] + stp q18,q19,[x20,#32] + b 14f + +.Lpw_left_noacc: + tbz w11,#1,15f + ldp q4,q5,[x18] + ldp q6,q7,[x18,#32] + fadd v16.4s,v16.4s,v4.4s + fadd v17.4s,v17.4s,v5.4s + fadd v18.4s,v18.4s,v6.4s + fadd v19.4s,v19.4s,v7.4s +15: + tbz w11,#2,16f + eor v0.16b,v0.16b,v0.16b + fmax v16.4s,v16.4s,v0.4s + fmax v17.4s,v17.4s,v0.4s + fmax v18.4s,v18.4s,v0.4s + fmax v19.4s,v19.4s,v0.4s +16: + stp q16,q17,[x20] + stp q18,q19,[x20,#32] +14: + add x15,x15,x3 + add x20,x20,#.LPW_BlockBytes + subs x12,x12,#1 + b.ne .Lpw_left_loop + +.Lpw_after_filter: + add x14,x14,#1 + cmp x14,x5 + blt .Lpw_filter_loop +.Lpw_exit: + ldp x28,x19,[sp],#16 + ldp x26,x27,[sp],#16 + ldp x24,x25,[sp],#16 + ldp x22,x23,[sp],#16 + ret + + .end diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index e75ca3dc90e60..242cf88fa2d71 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -963,9 +963,14 @@ extern "C" { MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelAdd; #endif #if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC) - MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelNeon; + // AArch64 assembly micro-kernel for direct NCHW convolution + MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelNeonAsm; MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelNeon; + // AArch64 assembly micro-kernel for depthwise NCHWc convolution + MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelNeonAsm; MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelNeon; + // AArch64 assembly micro-kernel for pointwise NCHWc convolution + MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelNeonAsm; MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelNeon; #if defined(__aarch64__) && defined(__linux__) MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseBf16KernelNeon; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index b913b1c3b8c26..88fb59ac47093 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -571,7 +571,15 @@ Return Value: this->EltwiseDispatch = &MlasEltwiseDispatchNeon; #if defined(MLAS_USE_ARM_NEON_NCHWC) - this->ConvNchwFloatKernel = MlasConvNchwFloatKernelNeon; + // Prefer the hand written micro-kernel for the NCHW convolution path. It + // offers a tighter schedule and a specialised two-output inner loop that + // reduces pressure on the memory system compared + this->ConvNchwFloatKernel = MlasConvNchwFloatKernelNeonAsm; + // Prefer the hand written AArch64 micro-kernel for pointwise convolution + // as it computes multiple output positions at once and significantly + // reduces memory traffic. The AArch64 assembly kernel is picked up by + // heuristics in platform.cpp to avoid regressions on small convolutions. + // So here we set the default to the intrinsics version this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelNeon; this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelNeon; this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelNeon; diff --git a/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp index 745258080810a..09ee155d36d3e 100644 --- a/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp @@ -182,52 +182,6 @@ void } } -void - MLASCALL - MlasConvNchwFloatKernelNeon( - const float* Input, - const float* Filter, - float* Output, - size_t StrideWidth, - size_t DilationWidth, - size_t FilterCount, - size_t InputStride, - size_t FilterStride, - size_t OutputStride, - size_t KernelHeight, - size_t KernelWidth, - const float* InputBase, - size_t InputWidth, - size_t DilatedInputWidth, - size_t OutputCountLeftPad, - size_t OutputCount, - size_t OutputCountRightPad, - const float* Bias, - unsigned KernelFlags - ) -{ - MlasConvFloatKernelNeonImpl( - Input, - Filter, - Output, - StrideWidth, - DilationWidth, - FilterCount, - InputStride, - FilterStride, - OutputStride, - KernelHeight, - KernelWidth, - InputBase, - InputWidth, - DilatedInputWidth, - OutputCountLeftPad, - OutputCount, - OutputCountRightPad, - Bias, - KernelFlags - ); -} // // Implementation of MlasConvNchwcFloatKernelNeon diff --git a/onnxruntime/core/mlas/lib/snchwc.cpp b/onnxruntime/core/mlas/lib/snchwc.cpp index 505246841087c..0c4d95400a8d1 100644 --- a/onnxruntime/core/mlas/lib/snchwc.cpp +++ b/onnxruntime/core/mlas/lib/snchwc.cpp @@ -882,6 +882,9 @@ struct MLAS_NCHWC_CONV_POINTWISE_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM #if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvPointwiseFloatKernel; +#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC) + MLAS_CONV_POINTWISE_FLOAT_KERNEL* const KernelFast = MlasConvPointwiseFloatKernelNeonAsm; +#endif #if defined(__aarch64__) && defined(__linux__) if (WorkBlock->UseBf16) { Kernel = GetMlasPlatform().ConvPointwiseBf16Kernel; @@ -937,7 +940,14 @@ struct MLAS_NCHWC_CONV_POINTWISE_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM // Invoke the convolution kernel. // - Kernel(input, filter, output, StrideWidthBytes, InputChannelBatch / + MLAS_CONV_POINTWISE_FLOAT_KERNEL* KernelToUse = Kernel; +#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC) + if (!WorkBlock->UseBf16 && OutputThisIteration >= 4 && + StrideHeight == 1 && StrideWidth == 1) { + KernelToUse = KernelFast; + } +#endif + KernelToUse(input, filter, output, StrideWidthBytes, InputChannelBatch / BlockSize, FilterCount, InputStrideBytes, FilterStrideBytes, OutputStrideBytes, OutputThisIteration, Bias, KernelFlags); @@ -1024,6 +1034,9 @@ struct MLAS_NCHWC_CONV_DEPTHWISE_ALGORITHM : MLAS_NCHWC_CONV_ALGORITHM #if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvDepthwiseFloatKernel; +#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC) + MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* const KernelFast = MlasConvDepthwiseFloatKernelNeonAsm; +#endif #else MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* Kernel = MlasConvDepthwiseFloatKernel; #endif @@ -1047,7 +1060,13 @@ struct MLAS_NCHWC_CONV_DEPTHWISE_ALGORITHM : MLAS_NCHWC_CONV_ALGORITHM // Invoke the convolution kernel. // - Kernel(Input + BlockSize * (ih * InputWidth - PaddingLeftX), filter, + MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* KernelToUse = Kernel; +#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC) + if (OutputWidth >= 4) { + KernelToUse = KernelFast; + } +#endif + KernelToUse(Input + BlockSize * (ih * InputWidth - PaddingLeftX), filter, Output, StrideWidthBytes, DilationWidthBytes, InputStrideBytes, EffectiveKernelHeight, KernelWidth, Input + BlockSize * (ih * InputWidth), InputWidthBytes, DilatedInputWidthBytes, OutputCountLeftPadX,