From d216d024452e3e5ca8899a6b2e55d0048de24a3c Mon Sep 17 00:00:00 2001 From: milpuz01 Date: Wed, 21 Jan 2026 13:19:09 +0000 Subject: [PATCH 01/23] mlas/arm64: add NEON conv asm kernels and tune NCHWC kernel selection Signed-off-by: Milos Puzovic --- cmake/onnxruntime_mlas.cmake | 6 + .../lib/aarch64/SconvDepthwiseKernelNeon.S | 238 +++++ .../core/mlas/lib/aarch64/SconvKernelNeon.S | 835 ++++++++++++++++++ .../lib/aarch64/SconvPointwiseKernelNeon.S | 441 +++++++++ onnxruntime/core/mlas/lib/mlasi.h | 6 + onnxruntime/core/mlas/lib/platform.cpp | 14 +- .../core/mlas/lib/sbconv_kernel_neon.cpp | 2 +- onnxruntime/core/mlas/lib/snchwc.cpp | 23 +- 8 files changed, 1561 insertions(+), 4 deletions(-) create mode 100644 onnxruntime/core/mlas/lib/aarch64/SconvDepthwiseKernelNeon.S create mode 100644 onnxruntime/core/mlas/lib/aarch64/SconvKernelNeon.S create mode 100644 onnxruntime/core/mlas/lib/aarch64/SconvPointwiseKernelNeon.S diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index d7dcde945e6d7..be8b8eb90c031 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -327,6 +327,10 @@ function (setup_arm_neon_nchwc) ${MLAS_SRC_DIR}/sconv_nchwc_kernel_neon.h ${MLAS_SRC_DIR}/sconv_nchwc_kernel_neon.cpp ${MLAS_SRC_DIR}/spool_nchwc_kernel_neon.cpp + # Hand written AArch64 micro-kernel for NCHW convolution. Using a + # separate assembly file allows tighter control over register allocation + # and avoids the overhead of C++/intrinsics based code generation. + ${MLAS_SRC_DIR}/aarch64/SconvKernelNeon.S ) list(APPEND mlas_private_compile_definitions MLAS_USE_ARM_NEON_NCHWC) set(mlas_private_compile_definitions ${mlas_private_compile_definitions} PARENT_SCOPE) @@ -460,6 +464,8 @@ else() ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSdot.S ${MLAS_SRC_DIR}/aarch64/SgemmKernelNeon.S ${MLAS_SRC_DIR}/aarch64/SgemvKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeon.S ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelNeon.S ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdot.S ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdotLd64.S diff --git a/onnxruntime/core/mlas/lib/aarch64/SconvDepthwiseKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SconvDepthwiseKernelNeon.S new file mode 100644 index 0000000000000..91fc44decfb10 --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/SconvDepthwiseKernelNeon.S @@ -0,0 +1,238 @@ +/*++ +SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +SPDX-License-Identifier: MIT + +Module Name: + + SconvDepthwiseFloatKernelNeon.S + +Abstract: + + Optimised AArch64 assembly implementation of the depthwise convolution + micro-kernel used by the NCHWc single precision path. + + This kernel performs the following optimisations: + * Produce a fast path for interior output positions where all input + accesses are guaranteed to be in-bounds and can be loaded with a pair + of 128-bit loads. + * When an output position touches padding, only the affected 4-wide + lanes are checked individually and loaded; others are zeroed. This + mirrors the behaviour of the C++ helper LoadInputVectorWithBounds. + * Keep the multiply/accumulate operations tightly scheduled to hide the + load latency. + + The kernel computes a single output position for a 16 channel block and is + repeatedly invoked by the high level dispatch code. + +--*/ + +#include "asmmacro.h" + + .text + +// Offsets for stack based parameters. AArch64 passes the first eight +// arguments in registers (x0-x7). The remaining parameters are read from the +// stack directly. The layout is defined by the C compiler so use constant +// offsets here. + + .equ .Ldw_InputBase, 0 + .equ .Ldw_InputWidth, 8 + .equ .Ldw_DilatedInputWidth, 16 + .equ .Ldw_OutputCountLeftPad, 24 + .equ .Ldw_OutputCount, 32 + .equ .Ldw_OutputCountRightPad, 40 + .equ .Ldw_Bias, 48 + .equ .Ldw_Flags, 56 + +// Prototype +// +// void +// MlasConvDepthwiseFloatKernelNeonAsm( +// const float* Input, // x0 +// const float* Filter, // x1 +// float* Output, // x2 +// size_t StrideWidth, // x3 (bytes) +// size_t DilationWidth, // x4 (bytes) +// size_t InputStride, // x5 (unused) +// size_t KernelHeight, // x6 +// size_t KernelWidth, // x7 +// const float* InputBase, // [sp + 0] +// size_t InputWidth, // [sp + 8] (bytes) +// size_t DilatedInputWidth, // [sp + 16] (bytes) +// size_t OutputCountLeftPad, // [sp + 24] +// size_t OutputCount, // [sp + 32] +// size_t OutputCountRightPad, // [sp + 40] +// const float* Bias, // [sp + 48] +// unsigned KernelFlags); // [sp + 56] +// + + FUNCTION_ENTRY MlasConvDepthwiseFloatKernelNeonAsm + + // Load the stack parameters used in the hot loops. + ldr x8, [sp,#.Ldw_InputBase] // base of valid input row + ldr x9, [sp,#.Ldw_InputWidth] // row width in bytes + ldr x10,[sp,#.Ldw_DilatedInputWidth] // stride between rows + ldr x11,[sp,#.Ldw_OutputCountLeftPad] + ldr x12,[sp,#.Ldw_OutputCount] + ldr x13,[sp,#.Ldw_OutputCountRightPad] + ldr x14,[sp,#.Ldw_Bias] + ldr w15,[sp,#.Ldw_Flags] + + // Preserve callee-saved registers used by this routine. + stp x29,x30,[sp,#-16]! + stp x27,x28,[sp,#-16]! + stp x25,x26,[sp,#-16]! + stp x23,x24,[sp,#-16]! + stp x21,x22,[sp,#-16]! + stp x19,x20,[sp,#-16]! + stp d12,d13,[sp,#-16]! + stp d14,d15,[sp,#-16]! + + // Compute total number of output elements to produce. + add x16,x11,x12 + add x16,x16,x13 + + // Load bias vectors when required; otherwise all zeros are used to + // initialise the accumulators. + eor v20.16b, v20.16b, v20.16b + eor v21.16b, v21.16b, v21.16b + eor v22.16b, v22.16b, v22.16b + eor v23.16b, v23.16b, v23.16b + tbz w15,#1,1f // no bias addition + ldp q20,q21,[x14],#32 + ldp q22,q23,[x14] +1: + // Constant zero used by ReLU handling. + eor v24.16b, v24.16b, v24.16b + + mov x17,#0 // output index + +// --------------------------------------------------------------------------- +// Loop over output elements. Each iteration computes one output position for +// 16 channels. +// --------------------------------------------------------------------------- +.Ldw_OutputLoop: + // Start accumulators from bias or zeros. + mov v0.16b, v20.16b + mov v1.16b, v21.16b + mov v2.16b, v22.16b + mov v3.16b, v23.16b + + mov x20,x1 // reset filter pointer + mov x21,#0 // kh = 0 + + // Base pointer for this output index across the input. + madd x19,x17,x3,x0 // Input + out_idx*StrideWidth + +.Ldw_HeightLoop: + // Compute [row_start,row_end) for the current kernel row. + madd x22,x21,x10,x8 // row_start + add x27,x22,x9 // row_end + sub x29,x27,#64 // row_end - 64 (fast path) + add x28,x27,#-16 // row_end - 16 + + // Base address for the first kw element on this row. + madd x26,x21,x10,x19 // input for this row + + mov x25,x7 // kw remaining + +.Ldw_WidthLoop: + // Fast path: the 16-lane load fits completely within the row. + cmp x26,x22 + b.lo .Ldw_SlowPath + cmp x26,x29 + bhi .Ldw_SlowPath + // Load 16 input values for the current position. + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x26] + b .Ldw_DoFma + +.Ldw_SlowPath: + // Zero registers and conditionally load each 4-wide vector when it is + // entirely within bounds. This matches the behaviour of the C++ + // helper LoadInputVectorWithBounds. + eor v16.16b, v16.16b, v16.16b + eor v17.16b, v17.16b, v17.16b + eor v18.16b, v18.16b, v18.16b + eor v19.16b, v19.16b, v19.16b + + mov x23,x26 + cmp x23,x22 + b.lt 2f + cmp x23,x28 + b.hi 2f + ldr q16,[x23] +2: + add x23,x26,#16 + cmp x23,x22 + b.lt 3f + cmp x23,x28 + b.hi 3f + ldr q17,[x23] +3: + add x23,x26,#32 + cmp x23,x22 + b.lt 4f + cmp x23,x28 + b.hi 4f + ldr q18,[x23] +4: + add x23,x26,#48 + cmp x23,x22 + b.lt 5f + cmp x23,x28 + b.hi 5f + ldr q19,[x23] +5: + +.Ldw_DoFma: + // Load filter block and update accumulators. + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x20], #64 + fmla v0.4s, v16.4s, v12.4s + fmla v1.4s, v17.4s, v13.4s + fmla v2.4s, v18.4s, v14.4s + fmla v3.4s, v19.4s, v15.4s + + add x26,x26,x4 // advance to next kw + subs x25,x25,#1 + b.ne .Ldw_WidthLoop + + add x21,x21,#1 + cmp x21,x6 + blt .Ldw_HeightLoop + + // Compute destination pointer for this output element. + add x23,x2,x17,lsl #6 // 16 floats per output + + // Accumulate existing output when requested. + tbz w15,#0,6f + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x23] + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v17.4s + fadd v2.4s, v2.4s, v18.4s + fadd v3.4s, v3.4s, v19.4s +6: + // Optional ReLU activation. + tbz w15,#2,8f + fmax v0.4s, v0.4s, v24.4s + fmax v1.4s, v1.4s, v24.4s + fmax v2.4s, v2.4s, v24.4s + fmax v3.4s, v3.4s, v24.4s +8: + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x23] + + add x17,x17,#1 + cmp x17,x16 + blt .Ldw_OutputLoop + + ldp d14,d15,[sp],#16 + ldp d12,d13,[sp],#16 + ldp x19,x20,[sp],#16 + ldp x21,x22,[sp],#16 + ldp x23,x24,[sp],#16 + ldp x25,x26,[sp],#16 + ldp x27,x28,[sp],#16 + ldp x29,x30,[sp],#16 + + ret + + .end diff --git a/onnxruntime/core/mlas/lib/aarch64/SconvKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SconvKernelNeon.S new file mode 100644 index 0000000000000..63a0e430a1705 --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/SconvKernelNeon.S @@ -0,0 +1,835 @@ +/*++ +SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +SPDX-License-Identifier: MIT + +Module Name: + SconvKernelNeon.S + +Abstract: + Hand written AArch64 vectorised kernel used by the NCHW convolution path. The + kernel computes one output row and processes up to four 16-wide filter + blocks (FilterCount <= 4). The hot interior loop avoids all bounds checks + and the two-output path amortises filter loads when possible. +--*/ + +#include "asmmacro.h" + + // Stack layout for parameters passed on the stack. The first eight + // integer parameters follow the AArch64 calling convention and are in + // x0-x7. Remaining parameters are spilled by the C++ caller. + + .equ .LFrame_SavedRegs, (5*16) + .equ .LFrame_x19_x20, 0 + .equ .LFrame_x21_x22, 16 + .equ .LFrame_x23_x24, 32 + .equ .LFrame_x25_x26, 48 + .equ .LFrame_x27_x28, 64 + + .equ .KO_OutputStride, (0 + .LFrame_SavedRegs) + .equ .KO_KernelHeight, (8 + .LFrame_SavedRegs) + .equ .KO_KernelWidth, (16 + .LFrame_SavedRegs) + .equ .KO_InputBase, (24 + .LFrame_SavedRegs) + .equ .KO_InputWidth, (32 + .LFrame_SavedRegs) + .equ .KO_DilatedInputWidth, (40 + .LFrame_SavedRegs) + .equ .KO_OutputCountLeftPad, (48 + .LFrame_SavedRegs) + .equ .KO_OutputCount, (56 + .LFrame_SavedRegs) + .equ .KO_OutputCountRightPad, (64 + .LFrame_SavedRegs) + .equ .KO_Bias, (72 + .LFrame_SavedRegs) + .equ .KO_Flags, (80 + .LFrame_SavedRegs) + + .text + + // --------------------------------------------------------------------- + // Helper macros. + // --------------------------------------------------------------------- + + .macro CLEAR_ACCUM N + eor v0.16b,v0.16b,v0.16b + eor v1.16b,v1.16b,v1.16b + eor v2.16b,v2.16b,v2.16b + eor v3.16b,v3.16b,v3.16b +.if \N >= 2 + eor v4.16b,v4.16b,v4.16b + eor v5.16b,v5.16b,v5.16b + eor v6.16b,v6.16b,v6.16b + eor v7.16b,v7.16b,v7.16b +.endif +.if \N >= 3 + eor v8.16b,v8.16b,v8.16b + eor v9.16b,v9.16b,v9.16b + eor v10.16b,v10.16b,v10.16b + eor v11.16b,v11.16b,v11.16b +.endif +.if \N >= 4 + eor v12.16b,v12.16b,v12.16b + eor v13.16b,v13.16b,v13.16b + eor v14.16b,v14.16b,v14.16b + eor v15.16b,v15.16b,v15.16b +.endif + .endm + + // Post processing for a single output element working on N filter + // blocks. x27 points at output base for set0, x8 holds OutputStride + // and x17 contains the Bias pointer for set0. The zero vector is in + // v31 and KernelFlags in w6. + .macro POSTPROCESS_STORE N + // accumulate from existing output if requested + tbz w6,#0,1f + ldr q16,[x27] + ldr q17,[x27,#16] + ldr q18,[x27,#32] + ldr q19,[x27,#48] + fadd v0.4s,v0.4s,v16.4s + fadd v1.4s,v1.4s,v17.4s + fadd v2.4s,v2.4s,v18.4s + fadd v3.4s,v3.4s,v19.4s +.if \N >= 2 + add x24,x27,x8 + ldr q16,[x24] + ldr q17,[x24,#16] + ldr q18,[x24,#32] + ldr q19,[x24,#48] + fadd v4.4s,v4.4s,v16.4s + fadd v5.4s,v5.4s,v17.4s + fadd v6.4s,v6.4s,v18.4s + fadd v7.4s,v7.4s,v19.4s +.endif +.if \N >= 3 + add x24,x27,x8,lsl #1 + ldr q16,[x24] + ldr q17,[x24,#16] + ldr q18,[x24,#32] + ldr q19,[x24,#48] + fadd v8.4s,v8.4s,v16.4s + fadd v9.4s,v9.4s,v17.4s + fadd v10.4s,v10.4s,v18.4s + fadd v11.4s,v11.4s,v19.4s +.endif +.if \N >= 4 + add x24,x27,x8 + add x24,x24,x8,lsl #1 + ldr q16,[x24] + ldr q17,[x24,#16] + ldr q18,[x24,#32] + ldr q19,[x24,#48] + fadd v12.4s,v12.4s,v16.4s + fadd v13.4s,v13.4s,v17.4s + fadd v14.4s,v14.4s,v18.4s + fadd v15.4s,v15.4s,v19.4s +.endif +1: + // add bias + tbz w6,#1,2f + ldr q16,[x17] + ldr q17,[x17,#16] + ldr q18,[x17,#32] + ldr q19,[x17,#48] + fadd v0.4s,v0.4s,v16.4s + fadd v1.4s,v1.4s,v17.4s + fadd v2.4s,v2.4s,v18.4s + fadd v3.4s,v3.4s,v19.4s +.if \N >= 2 + add x24,x17,#64 + ldr q16,[x24] + ldr q17,[x24,#16] + ldr q18,[x24,#32] + ldr q19,[x24,#48] + fadd v4.4s,v4.4s,v16.4s + fadd v5.4s,v5.4s,v17.4s + fadd v6.4s,v6.4s,v18.4s + fadd v7.4s,v7.4s,v19.4s +.endif +.if \N >= 3 + add x24,x17,#128 + ldr q16,[x24] + ldr q17,[x24,#16] + ldr q18,[x24,#32] + ldr q19,[x24,#48] + fadd v8.4s,v8.4s,v16.4s + fadd v9.4s,v9.4s,v17.4s + fadd v10.4s,v10.4s,v18.4s + fadd v11.4s,v11.4s,v19.4s +.endif +.if \N >= 4 + add x24,x17,#192 + ldr q16,[x24] + ldr q17,[x24,#16] + ldr q18,[x24,#32] + ldr q19,[x24,#48] + fadd v12.4s,v12.4s,v16.4s + fadd v13.4s,v13.4s,v17.4s + fadd v14.4s,v14.4s,v18.4s + fadd v15.4s,v15.4s,v19.4s +.endif +2: + // optional ReLU + tbz w6,#2,3f + fmax v0.4s,v0.4s,v31.4s + fmax v1.4s,v1.4s,v31.4s + fmax v2.4s,v2.4s,v31.4s + fmax v3.4s,v3.4s,v31.4s +.if \N >= 2 + fmax v4.4s,v4.4s,v31.4s + fmax v5.4s,v5.4s,v31.4s + fmax v6.4s,v6.4s,v31.4s + fmax v7.4s,v7.4s,v31.4s +.endif +.if \N >= 3 + fmax v8.4s,v8.4s,v31.4s + fmax v9.4s,v9.4s,v31.4s + fmax v10.4s,v10.4s,v31.4s + fmax v11.4s,v11.4s,v31.4s +.endif +.if \N >= 4 + fmax v12.4s,v12.4s,v31.4s + fmax v13.4s,v13.4s,v31.4s + fmax v14.4s,v14.4s,v31.4s + fmax v15.4s,v15.4s,v31.4s +.endif +3: + // store results + str q0,[x27] + str q1,[x27,#16] + str q2,[x27,#32] + str q3,[x27,#48] +.if \N >= 2 + add x24,x27,x8 + str q4,[x24] + str q5,[x24,#16] + str q6,[x24,#32] + str q7,[x24,#48] +.endif +.if \N >= 3 + add x24,x27,x8,lsl #1 + str q8,[x24] + str q9,[x24,#16] + str q10,[x24,#32] + str q11,[x24,#48] +.endif +.if \N >= 4 + add x24,x27,x8 + add x24,x24,x8,lsl #1 + str q12,[x24] + str q13,[x24,#16] + str q14,[x24,#32] + str q15,[x24,#48] +.endif + .endm + + // Two-output helpers -------------------------------------------------- + .macro CLEAR_ACCUM2 N + eor v0.16b,v0.16b,v0.16b + eor v1.16b,v1.16b,v1.16b + eor v2.16b,v2.16b,v2.16b + eor v3.16b,v3.16b,v3.16b + eor v4.16b,v4.16b,v4.16b + eor v5.16b,v5.16b,v5.16b + eor v6.16b,v6.16b,v6.16b + eor v7.16b,v7.16b,v7.16b +.if \N >= 2 + eor v8.16b,v8.16b,v8.16b + eor v9.16b,v9.16b,v9.16b + eor v10.16b,v10.16b,v10.16b + eor v11.16b,v11.16b,v11.16b + eor v12.16b,v12.16b,v12.16b + eor v13.16b,v13.16b,v13.16b + eor v14.16b,v14.16b,v14.16b + eor v15.16b,v15.16b,v15.16b +.endif +.if \N >= 3 + eor v16.16b,v16.16b,v16.16b + eor v17.16b,v17.16b,v17.16b + eor v18.16b,v18.16b,v18.16b + eor v19.16b,v19.16b,v19.16b + eor v20.16b,v20.16b,v20.16b + eor v21.16b,v21.16b,v21.16b + eor v22.16b,v22.16b,v22.16b + eor v23.16b,v23.16b,v23.16b +.endif + .endm + + // Post process helpers for the two-output interior kernel. These are + // written out-of-line as they are sizable and used in several places. + .macro POSTPROCESS_STORE2_1 + add x24,x27,#64 + // accumulate + tbz w6,#0,1f + ldr q26,[x27] + ldr q27,[x27,#16] + ldr q28,[x27,#32] + ldr q29,[x27,#48] + fadd v0.4s,v0.4s,v26.4s + fadd v1.4s,v1.4s,v27.4s + fadd v2.4s,v2.4s,v28.4s + fadd v3.4s,v3.4s,v29.4s + ldr q26,[x24] + ldr q27,[x24,#16] + ldr q28,[x24,#32] + ldr q29,[x24,#48] + fadd v4.4s,v4.4s,v26.4s + fadd v5.4s,v5.4s,v27.4s + fadd v6.4s,v6.4s,v28.4s + fadd v7.4s,v7.4s,v29.4s +1: + // bias + tbz w6,#1,2f + ldr q26,[x17] + ldr q27,[x17,#16] + ldr q28,[x17,#32] + ldr q29,[x17,#48] + fadd v0.4s,v0.4s,v26.4s + fadd v1.4s,v1.4s,v27.4s + fadd v2.4s,v2.4s,v28.4s + fadd v3.4s,v3.4s,v29.4s + fadd v4.4s,v4.4s,v26.4s + fadd v5.4s,v5.4s,v27.4s + fadd v6.4s,v6.4s,v28.4s + fadd v7.4s,v7.4s,v29.4s +2: + tbz w6,#2,3f + fmax v0.4s,v0.4s,v31.4s + fmax v1.4s,v1.4s,v31.4s + fmax v2.4s,v2.4s,v31.4s + fmax v3.4s,v3.4s,v31.4s + fmax v4.4s,v4.4s,v31.4s + fmax v5.4s,v5.4s,v31.4s + fmax v6.4s,v6.4s,v31.4s + fmax v7.4s,v7.4s,v31.4s +3: + str q0,[x27] + str q1,[x27,#16] + str q2,[x27,#32] + str q3,[x27,#48] + str q4,[x24] + str q5,[x24,#16] + str q6,[x24,#32] + str q7,[x24,#48] + .endm + + .macro POSTPROCESS_STORE2_2 + POSTPROCESS_STORE2_1 + add x24,x27,x8 + add x24,x24,#64 + add x27,x27,x8 + tbz w6,#0,1f + ldr q26,[x27] + ldr q27,[x27,#16] + ldr q28,[x27,#32] + ldr q29,[x27,#48] + fadd v8.4s,v8.4s,v26.4s + fadd v9.4s,v9.4s,v27.4s + fadd v10.4s,v10.4s,v28.4s + fadd v11.4s,v11.4s,v29.4s + ldr q26,[x24] + ldr q27,[x24,#16] + ldr q28,[x24,#32] + ldr q29,[x24,#48] + fadd v12.4s,v12.4s,v26.4s + fadd v13.4s,v13.4s,v27.4s + fadd v14.4s,v14.4s,v28.4s + fadd v15.4s,v15.4s,v29.4s +1: + tbz w6,#1,2f + add x23,x17,#64 + ldr q26,[x23] + ldr q27,[x23,#16] + ldr q28,[x23,#32] + ldr q29,[x23,#48] + fadd v8.4s,v8.4s,v26.4s + fadd v9.4s,v9.4s,v27.4s + fadd v10.4s,v10.4s,v28.4s + fadd v11.4s,v11.4s,v29.4s + fadd v12.4s,v12.4s,v26.4s + fadd v13.4s,v13.4s,v27.4s + fadd v14.4s,v14.4s,v28.4s + fadd v15.4s,v15.4s,v29.4s +2: + tbz w6,#2,3f + fmax v8.4s,v8.4s,v31.4s + fmax v9.4s,v9.4s,v31.4s + fmax v10.4s,v10.4s,v31.4s + fmax v11.4s,v11.4s,v31.4s + fmax v12.4s,v12.4s,v31.4s + fmax v13.4s,v13.4s,v31.4s + fmax v14.4s,v14.4s,v31.4s + fmax v15.4s,v15.4s,v31.4s +3: + str q8,[x27] + str q9,[x27,#16] + str q10,[x27,#32] + str q11,[x27,#48] + str q12,[x24] + str q13,[x24,#16] + str q14,[x24,#32] + str q15,[x24,#48] + sub x27,x27,x8 + .endm + + .macro POSTPROCESS_STORE2_3 + POSTPROCESS_STORE2_1 + add x24,x27,x8 + add x24,x24,#64 + add x27,x27,x8 + tbz w6,#0,1f + ldr q26,[x27] + ldr q27,[x27,#16] + ldr q28,[x27,#32] + ldr q29,[x27,#48] + fadd v8.4s,v8.4s,v26.4s + fadd v9.4s,v9.4s,v27.4s + fadd v10.4s,v10.4s,v28.4s + fadd v11.4s,v11.4s,v29.4s + ldr q26,[x24] + ldr q27,[x24,#16] + ldr q28,[x24,#32] + ldr q29,[x24,#48] + fadd v12.4s,v12.4s,v26.4s + fadd v13.4s,v13.4s,v27.4s + fadd v14.4s,v14.4s,v28.4s + fadd v15.4s,v15.4s,v29.4s +1: + tbz w6,#1,2f + add x23,x17,#64 + ldr q26,[x23] + ldr q27,[x23,#16] + ldr q28,[x23,#32] + ldr q29,[x23,#48] + fadd v8.4s,v8.4s,v26.4s + fadd v9.4s,v9.4s,v27.4s + fadd v10.4s,v10.4s,v28.4s + fadd v11.4s,v11.4s,v29.4s + fadd v12.4s,v12.4s,v26.4s + fadd v13.4s,v13.4s,v27.4s + fadd v14.4s,v14.4s,v28.4s + fadd v15.4s,v15.4s,v29.4s +2: + tbz w6,#2,3f + fmax v8.4s,v8.4s,v31.4s + fmax v9.4s,v9.4s,v31.4s + fmax v10.4s,v10.4s,v31.4s + fmax v11.4s,v11.4s,v31.4s + fmax v12.4s,v12.4s,v31.4s + fmax v13.4s,v13.4s,v31.4s + fmax v14.4s,v14.4s,v31.4s + fmax v15.4s,v15.4s,v31.4s +3: + str q8,[x27] + str q9,[x27,#16] + str q10,[x27,#32] + str q11,[x27,#48] + str q12,[x24] + str q13,[x24,#16] + str q14,[x24,#32] + str q15,[x24,#48] + add x27,x27,x8 + add x27,x27,x8 // set2 base + add x24,x27,#64 + tbz w6,#0,4f + ldr q26,[x27] + ldr q27,[x27,#16] + ldr q28,[x27,#32] + ldr q29,[x27,#48] + fadd v16.4s,v16.4s,v26.4s + fadd v17.4s,v17.4s,v27.4s + fadd v18.4s,v18.4s,v28.4s + fadd v19.4s,v19.4s,v29.4s + ldr q26,[x24] + ldr q27,[x24,#16] + ldr q28,[x24,#32] + ldr q29,[x24,#48] + fadd v20.4s,v20.4s,v26.4s + fadd v21.4s,v21.4s,v27.4s + fadd v22.4s,v22.4s,v28.4s + fadd v23.4s,v23.4s,v29.4s +4: + tbz w6,#1,5f + add x23,x17,#128 + ldr q26,[x23] + ldr q27,[x23,#16] + ldr q28,[x23,#32] + ldr q29,[x23,#48] + fadd v16.4s,v16.4s,v26.4s + fadd v17.4s,v17.4s,v27.4s + fadd v18.4s,v18.4s,v28.4s + fadd v19.4s,v19.4s,v29.4s + fadd v20.4s,v20.4s,v26.4s + fadd v21.4s,v21.4s,v27.4s + fadd v22.4s,v22.4s,v28.4s + fadd v23.4s,v23.4s,v29.4s +5: + tbz w6,#2,6f + fmax v16.4s,v16.4s,v31.4s + fmax v17.4s,v17.4s,v31.4s + fmax v18.4s,v18.4s,v31.4s + fmax v19.4s,v19.4s,v31.4s + fmax v20.4s,v20.4s,v31.4s + fmax v21.4s,v21.4s,v31.4s + fmax v22.4s,v22.4s,v31.4s + fmax v23.4s,v23.4s,v31.4s +6: + str q16,[x27] + str q17,[x27,#16] + str q18,[x27,#32] + str q19,[x27,#48] + str q20,[x24] + str q21,[x24,#16] + str q22,[x24,#32] + str q23,[x24,#48] + sub x27,x27,x8 + sub x27,x27,x8 + .endm + + // 1-output compute with padding checks --------------------------------- + .macro CONV_PAD N + CLEAR_ACCUM \N + mov x28,x1 +.if \N >= 2 + add x19,x1,x7 +.endif +.if \N >= 3 + add x20,x19,x7 +.endif +.if \N >= 4 + add x26,x20,x7 +.endif + mov x21,x0 + ldr x9,[sp,#.KO_KernelHeight] + cbz x9,5f + ldr x10,[sp,#.KO_KernelWidth] + ldr x22,[sp,#.KO_InputBase] + ldr x12,[sp,#.KO_InputWidth] + add x23,x22,x12 +1: mov x25,x21 + mov x11,x10 +2: cmp x25,x22 + blo 3f + cmp x25,x23 + bhs 3f + ld1r {v24.4s},[x25] + b 4f +3: eor v24.16b,v24.16b,v24.16b +4: // Small lookahead to keep the miss machinery busy without polluting + // cache for tiny kernels. + prfm pldl1keep,[x28,#192] + ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x28],#64 + fmla v0.4s,v26.4s,v24.4s + fmla v1.4s,v27.4s,v24.4s + fmla v2.4s,v28.4s,v24.4s + fmla v3.4s,v29.4s,v24.4s +.if \N >= 2 + ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x19],#64 + fmla v4.4s,v26.4s,v24.4s + fmla v5.4s,v27.4s,v24.4s + fmla v6.4s,v28.4s,v24.4s + fmla v7.4s,v29.4s,v24.4s +.endif +.if \N >= 3 + ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x20],#64 + fmla v8.4s,v26.4s,v24.4s + fmla v9.4s,v27.4s,v24.4s + fmla v10.4s,v28.4s,v24.4s + fmla v11.4s,v29.4s,v24.4s +.endif +.if \N >= 4 + ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x26],#64 + fmla v12.4s,v26.4s,v24.4s + fmla v13.4s,v27.4s,v24.4s + fmla v14.4s,v28.4s,v24.4s + fmla v15.4s,v29.4s,v24.4s +.endif + add x25,x25,x4 + subs x11,x11,#1 + b.ne 2b + ldr x13,[sp,#.KO_DilatedInputWidth] + add x21,x21,x13 + add x22,x22,x13 + add x23,x23,x13 + subs x9,x9,#1 + b.ne 1b +5: POSTPROCESS_STORE \N + .endm + + // 1-output compute without padding checks ------------------------------ + .macro CONV_NOPAD N + CLEAR_ACCUM \N + mov x28,x1 +.if \N >= 2 + add x19,x1,x7 +.endif +.if \N >= 3 + add x20,x19,x7 +.endif +.if \N >= 4 + add x26,x20,x7 +.endif + mov x21,x0 + ldr x9,[sp,#.KO_KernelHeight] + cbz x9,3f + ldr x10,[sp,#.KO_KernelWidth] +1: mov x25,x21 + mov x11,x10 +2: ld1r {v24.4s},[x25] + prfm pldl1keep,[x28,#192] + ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x28],#64 + fmla v0.4s,v26.4s,v24.4s + fmla v1.4s,v27.4s,v24.4s + fmla v2.4s,v28.4s,v24.4s + fmla v3.4s,v29.4s,v24.4s +.if \N >= 2 + ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x19],#64 + fmla v4.4s,v26.4s,v24.4s + fmla v5.4s,v27.4s,v24.4s + fmla v6.4s,v28.4s,v24.4s + fmla v7.4s,v29.4s,v24.4s +.endif +.if \N >= 3 + ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x20],#64 + fmla v8.4s,v26.4s,v24.4s + fmla v9.4s,v27.4s,v24.4s + fmla v10.4s,v28.4s,v24.4s + fmla v11.4s,v29.4s,v24.4s +.endif +.if \N >= 4 + ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x26],#64 + fmla v12.4s,v26.4s,v24.4s + fmla v13.4s,v27.4s,v24.4s + fmla v14.4s,v28.4s,v24.4s + fmla v15.4s,v29.4s,v24.4s +.endif + add x25,x25,x4 + subs x11,x11,#1 + b.ne 2b + ldr x13,[sp,#.KO_DilatedInputWidth] + add x21,x21,x13 + subs x9,x9,#1 + b.ne 1b +3: POSTPROCESS_STORE \N + .endm + + // Two-output compute core when no padding is required. x25/x26 are + // the two input pointers, x27 the output base. x28, x19 and x20 are + // filter pointers for each set. v24/v25 hold broadcast inputs. + .macro CONV2_NOPAD N + CLEAR_ACCUM2 \N + mov x28,x1 +.if \N >= 2 + add x19,x1,x7 +.endif +.if \N >= 3 + add x20,x19,x7 +.endif + ldr x11,[sp,#.KO_KernelWidth] + ldr x9,[sp,#.KO_KernelHeight] +1: mov x10,x11 +2: prfm pldl1keep,[x28,#192] + ld1r {v24.4s},[x25] + ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x28],#64 + fmla v0.4s,v26.4s,v24.4s + fmla v1.4s,v27.4s,v24.4s + fmla v2.4s,v28.4s,v24.4s + fmla v3.4s,v29.4s,v24.4s + ld1r {v25.4s},[x26] + fmla v4.4s,v26.4s,v25.4s + fmla v5.4s,v27.4s,v25.4s + fmla v6.4s,v28.4s,v25.4s + fmla v7.4s,v29.4s,v25.4s +.if \N >= 2 + ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x19],#64 + fmla v8.4s,v26.4s,v24.4s + fmla v9.4s,v27.4s,v24.4s + fmla v10.4s,v28.4s,v24.4s + fmla v11.4s,v29.4s,v24.4s + fmla v12.4s,v26.4s,v25.4s + fmla v13.4s,v27.4s,v25.4s + fmla v14.4s,v28.4s,v25.4s + fmla v15.4s,v29.4s,v25.4s +.endif +.if \N >= 3 + ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x20],#64 + fmla v16.4s,v26.4s,v24.4s + fmla v17.4s,v27.4s,v24.4s + fmla v18.4s,v28.4s,v24.4s + fmla v19.4s,v29.4s,v24.4s + fmla v20.4s,v26.4s,v25.4s + fmla v21.4s,v27.4s,v25.4s + fmla v22.4s,v28.4s,v25.4s + fmla v23.4s,v29.4s,v25.4s +.endif + add x25,x25,x4 + add x26,x26,x4 + subs x10,x10,#1 + b.ne 2b + ldr x12,[sp,#.KO_DilatedInputWidth] + add x25,x25,x12 + add x26,x26,x12 + subs x9,x9,#1 + b.ne 1b + .endm + +// ----------------------------------------------------------------------------- +// void MlasConvNchwFloatKernelNeonAsm(...) +// Implements the convolution micro-kernel for the NCHW input format. +// ----------------------------------------------------------------------------- + + FUNCTION_ENTRY MlasConvNchwFloatKernelNeonAsm + + // prologue ------------------------------------------------------------- + stp x19,x20,[sp,#-.LFrame_SavedRegs]! + stp x21,x22,[sp,#.LFrame_x21_x22] + stp x23,x24,[sp,#.LFrame_x23_x24] + stp x25,x26,[sp,#.LFrame_x25_x26] + stp x27,x28,[sp,#.LFrame_x27_x28] + + // frequently used parameters + ldr x8,[sp,#.KO_OutputStride] + ldr x14,[sp,#.KO_OutputCountLeftPad] + ldr x15,[sp,#.KO_OutputCount] + ldr x16,[sp,#.KO_OutputCountRightPad] + ldr x17,[sp,#.KO_Bias] + ldr w6,[sp,#.KO_Flags] // kernel flags + + eor v31.16b,v31.16b,v31.16b + + // left padded region -------------------------------------------------- + cbz x14,.LInterior + mov x20,#0 +.LLeftLoop: + madd x21,x20,x3,x0 // input pointer for this output + lsl x22,x20,#6 // (index*64) + add x27,x2,x22 + mov x0,x21 + cmp x5,#1 + b.eq .LLeftFC1 + cmp x5,#2 + b.eq .LLeftFC2 + cmp x5,#3 + b.eq .LLeftFC3 +.LLeftFC4: + CONV_PAD 4 + b .LLeftDoneOne +.LLeftFC3: + CONV_PAD 3 + b .LLeftDoneOne +.LLeftFC2: + CONV_PAD 2 + b .LLeftDoneOne +.LLeftFC1: + CONV_PAD 1 +.LLeftDoneOne: + add x20,x20,#1 + cmp x20,x14 + blo .LLeftLoop + + // interior region ----------------------------------------------------- +.LInterior: + cbz x15,.LRight + + mul x24,x14,x3 // input byte offset after left pad + add x21,x0,x24 // base input for interior + lsl x23,x14,#6 // output offset (bytes) + + mov x20,#0 +.LIntLoop: + sub x12,x15,x20 + cmp x12,#2 + b.lt .LIntProcessOne + // two-output path + madd x25,x20,x3,x21 // input for first output + add x26,x25,x3 // input for second output + add x22,x14,x20 // output index + lsl x22,x22,#6 // output offset (bytes) + add x27,x2,x22 + cmp x5,#1 + b.eq .LInt2FC1 + cmp x5,#2 + b.eq .LInt2FC2 + cmp x5,#3 + b.eq .LInt2FC3 + // FilterCount==4 falls back to single-output implementation due to + // register pressure. + b .LIntProcessOne +.LInt2FC3: + CONV2_NOPAD 3 + POSTPROCESS_STORE2_3 + b .LIntNext2 +.LInt2FC2: + CONV2_NOPAD 2 + POSTPROCESS_STORE2_2 + b .LIntNext2 +.LInt2FC1: + CONV2_NOPAD 1 + POSTPROCESS_STORE2_1 +.LIntNext2: + add x20,x20,#2 + cmp x20,x15 + blo .LIntLoop + b .LRight + +.LIntProcessOne: + madd x25,x20,x3,x21 + add x22,x14,x20 // output index + lsl x22,x22,#6 // output offset (bytes) + add x27,x2,x22 + mov x0,x25 + cmp x5,#1 + b.eq .LIntFC1 + cmp x5,#2 + b.eq .LIntFC2 + cmp x5,#3 + b.eq .LIntFC3 + CONV_NOPAD 4 + b .LIntNext1 +.LIntFC3: + CONV_NOPAD 3 + b .LIntNext1 +.LIntFC2: + CONV_NOPAD 2 + b .LIntNext1 +.LIntFC1: + CONV_NOPAD 1 +.LIntNext1: + add x20,x20,#1 + cmp x20,x15 + blo .LIntLoop + + // right padded region ------------------------------------------------- +.LRight: + cbz x16,.LDone + mov x20,#0 +.LRightLoop: + add x24,x20,x14 + add x24,x24,x15 + madd x21,x24,x3,x0 + lsl x22,x24,#6 + add x27,x2,x22 + mov x0,x21 + cmp x5,#1 + b.eq .LRFC1 + cmp x5,#2 + b.eq .LRFC2 + cmp x5,#3 + b.eq .LRFC3 + CONV_PAD 4 + b .LRightNext +.LRFC3: + CONV_PAD 3 + b .LRightNext +.LRFC2: + CONV_PAD 2 + b .LRightNext +.LRFC1: + CONV_PAD 1 +.LRightNext: + add x20,x20,#1 + cmp x20,x16 + blo .LRightLoop + +.LDone: + ldp x27,x28,[sp,#.LFrame_x27_x28] + ldp x25,x26,[sp,#.LFrame_x25_x26] + ldp x23,x24,[sp,#.LFrame_x23_x24] + ldp x21,x22,[sp,#.LFrame_x21_x22] + ldp x19,x20,[sp],#.LFrame_SavedRegs + ret + + .end diff --git a/onnxruntime/core/mlas/lib/aarch64/SconvPointwiseKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SconvPointwiseKernelNeon.S new file mode 100644 index 0000000000000..e3fd621b4422b --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/SconvPointwiseKernelNeon.S @@ -0,0 +1,441 @@ +/*++ +SPDX-FileCopyrightText: Copyright 2026 Arm Limited and/or its affiliates +SPDX-License-Identifier: MIT + +Module Name: + + SconvPointwiseKernelNeon.S + +Abstract: + + A hand written AArch64 vectorised micro-kernel for pointwise (1x1) convolution + operating on tensors formatted in the NCHWc layout. The kernel computes + up to four output positions in parallel which allows the filter weights to + be re-used across several outputs, greatly reducing memory bandwidth. + +--*/ + +#include "asmmacro.h" + + .text + +// Stack layout for arguments passed on the stack. The first eight arguments +// are in x0-x7, the remaining four are placed on the stack by the caller. + + .equ .LPW_OutputStride, 0 + .equ .LPW_OutputCount, 8 + .equ .LPW_Bias, 16 + .equ .LPW_Flags, 24 + +// Kernel flag bits. Keep these in sync with sconv.h. + .equ .LPWFlag_Accumulate, 1 + .equ .LPWFlag_Bias, 2 + .equ .LPWFlag_Relu, 4 + +// Size in bytes of one NCHWc block (16 FP32 values). + .equ .LPW_BlockBytes, 64 + +//------------------------------------------------------------------------- +// Helper macros +//------------------------------------------------------------------------- + +// Compute four outputs for a single input channel block. The accumulators +// for the four outputs are held in v16-v31. + .macro CPK4_FmlaStep + ldp q0,q1,[x24],#32 + ldp q2,q3,[x24],#32 + ld1r {v4.4s},[x25],#4 + ld1r {v5.4s},[x26],#4 + ld1r {v6.4s},[x27],#4 + ld1r {v7.4s},[x28],#4 + fmla v16.4s,v0.4s,v4.4s + fmla v17.4s,v1.4s,v4.4s + fmla v18.4s,v2.4s,v4.4s + fmla v19.4s,v3.4s,v4.4s + fmla v20.4s,v0.4s,v5.4s + fmla v21.4s,v1.4s,v5.4s + fmla v22.4s,v2.4s,v5.4s + fmla v23.4s,v3.4s,v5.4s + fmla v24.4s,v0.4s,v6.4s + fmla v25.4s,v1.4s,v6.4s + fmla v26.4s,v2.4s,v6.4s + fmla v27.4s,v3.4s,v6.4s + fmla v28.4s,v0.4s,v7.4s + fmla v29.4s,v1.4s,v7.4s + fmla v30.4s,v2.4s,v7.4s + fmla v31.4s,v3.4s,v7.4s + .endm + +// Accumulate helper for the single output path. The input values for the +// output are loaded to v0-v3 and each lane is multiplied with a block of +// filter coefficients. + .macro CPK1_FmlaWithLane lane, AReg + ldp q4,q5,[x24],#32 + ldp q6,q7,[x24],#32 + fmla v16.4s,v4.4s,\AReg\().s[\lane] + fmla v17.4s,v5.4s,\AReg\().s[\lane] + fmla v18.4s,v6.4s,\AReg\().s[\lane] + fmla v19.4s,v7.4s,\AReg\().s[\lane] + .endm + +// Compute a single output position. Results are returned in v16-v19. + .macro CPK_ComputeOneOutput + mov x22,#0 + eor v16.16b,v16.16b,v16.16b + eor v17.16b,v17.16b,v17.16b + eor v18.16b,v18.16b,v18.16b + eor v19.16b,v19.16b,v19.16b +.Lpw_ic_loop1: + madd x23,x22,x6,x15 + ldp q0,q1,[x23] + ldp q2,q3,[x23,#32] + add x24,x17,x22,lsl #10 + CPK1_FmlaWithLane 0, v0 + CPK1_FmlaWithLane 1, v0 + CPK1_FmlaWithLane 2, v0 + CPK1_FmlaWithLane 3, v0 + CPK1_FmlaWithLane 0, v1 + CPK1_FmlaWithLane 1, v1 + CPK1_FmlaWithLane 2, v1 + CPK1_FmlaWithLane 3, v1 + CPK1_FmlaWithLane 0, v2 + CPK1_FmlaWithLane 1, v2 + CPK1_FmlaWithLane 2, v2 + CPK1_FmlaWithLane 3, v2 + CPK1_FmlaWithLane 0, v3 + CPK1_FmlaWithLane 1, v3 + CPK1_FmlaWithLane 2, v3 + CPK1_FmlaWithLane 3, v3 + add x22,x22,#1 + cmp x22,x4 + blt .Lpw_ic_loop1 + .endm + +//------------------------------------------------------------------------- +// Entry point +//------------------------------------------------------------------------- + + FUNCTION_ENTRY MlasConvPointwiseFloatKernelNeonAsm + + // Load the arguments passed on the stack. + ldr x8,[sp,#.LPW_OutputStride] + ldr x9,[sp,#.LPW_OutputCount] + ldr x10,[sp,#.LPW_Bias] + ldr w11,[sp,#.LPW_Flags] + + // Preserve callee saved registers that are used by this routine. + stp x22,x23,[sp,#-16]! + stp x24,x25,[sp,#-16]! + stp x26,x27,[sp,#-16]! + stp x28,x19,[sp,#-16]! + + mov x14,#0 // current filter set + cbz x5,.Lpw_exit // nothing to do + +.Lpw_filter_loop: + // Compute the base pointers for this filter block. + madd x16,x14,x8,x2 // output pointer for this filter + madd x17,x14,x7,x1 // filter pointer for this filter + add x18,x10,x14,lsl #6 // bias pointer (if used) + + mov x15,x0 // input base for this iteration + mov x20,x16 // running output pointer + mov x12,x9 // output count this iteration + + lsr x13,x12,#2 // number of groups of four outputs + cbz x13,.Lpw_process_remainder + +// ------------------------------------------------------------------ +// Main loop processing 4 outputs at a time. +// ------------------------------------------------------------------ +.Lpw_groups: + // Clear accumulators for 4 outputs (16 vectors total). + eor v16.16b,v16.16b,v16.16b + eor v17.16b,v17.16b,v17.16b + eor v18.16b,v18.16b,v18.16b + eor v19.16b,v19.16b,v19.16b + eor v20.16b,v20.16b,v20.16b + eor v21.16b,v21.16b,v21.16b + eor v22.16b,v22.16b,v22.16b + eor v23.16b,v23.16b,v23.16b + eor v24.16b,v24.16b,v24.16b + eor v25.16b,v25.16b,v25.16b + eor v26.16b,v26.16b,v26.16b + eor v27.16b,v27.16b,v27.16b + eor v28.16b,v28.16b,v28.16b + eor v29.16b,v29.16b,v29.16b + eor v30.16b,v30.16b,v30.16b + eor v31.16b,v31.16b,v31.16b + + mov x22,#0 // current input channel block +.Lpw_ic_loop4: + madd x23,x22,x6,x15 // input for this block + mov x25,x23 // four rows starting positions + add x26,x23,x3 + add x27,x26,x3 + add x28,x27,x3 + add x24,x17,x22,lsl #10 // filter for this block + + // The block size is 16 so unroll 16 steps. + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + CPK4_FmlaStep + + add x22,x22,#1 + cmp x22,x4 + blt .Lpw_ic_loop4 + + // ----------------------------------------------------------------- + // Store the four outputs computed above. There are several cases to + // handle based on accumulation, bias and ReLU flags. + // ----------------------------------------------------------------- + + // Test if the kernel should accumulate into the existing output. + tbz w11,#0,.Lpw_store_nacc + + // Accumulation path. Load bias once as it is re-used for all four + // stores when present. + tbz w11,#1,1f + ldp q4,q5,[x18] + ldp q6,q7,[x18,#32] +1: + // ---- output 0 ---- + ldp q0,q1,[x20] + ldp q2,q3,[x20,#32] + tbz w11,#1,2f + fadd v0.4s,v0.4s,v4.4s + fadd v1.4s,v1.4s,v5.4s + fadd v2.4s,v2.4s,v6.4s + fadd v3.4s,v3.4s,v7.4s +2: + fadd v16.4s,v16.4s,v0.4s + fadd v17.4s,v17.4s,v1.4s + fadd v18.4s,v18.4s,v2.4s + fadd v19.4s,v19.4s,v3.4s + tbz w11,#2,3f + eor v0.16b,v0.16b,v0.16b + fmax v16.4s,v16.4s,v0.4s + fmax v17.4s,v17.4s,v0.4s + fmax v18.4s,v18.4s,v0.4s + fmax v19.4s,v19.4s,v0.4s +3: + stp q16,q17,[x20] + stp q18,q19,[x20,#32] + + // ---- output 1 ---- + add x22,x20,#.LPW_BlockBytes + ldp q0,q1,[x22] + ldp q2,q3,[x22,#32] + tbz w11,#1,4f + fadd v0.4s,v0.4s,v4.4s + fadd v1.4s,v1.4s,v5.4s + fadd v2.4s,v2.4s,v6.4s + fadd v3.4s,v3.4s,v7.4s +4: + fadd v20.4s,v20.4s,v0.4s + fadd v21.4s,v21.4s,v1.4s + fadd v22.4s,v22.4s,v2.4s + fadd v23.4s,v23.4s,v3.4s + tbz w11,#2,5f + eor v0.16b,v0.16b,v0.16b + fmax v20.4s,v20.4s,v0.4s + fmax v21.4s,v21.4s,v0.4s + fmax v22.4s,v22.4s,v0.4s + fmax v23.4s,v23.4s,v0.4s +5: + stp q20,q21,[x22] + stp q22,q23,[x22,#32] + + // ---- output 2 ---- + add x22,x22,#.LPW_BlockBytes + ldp q0,q1,[x22] + ldp q2,q3,[x22,#32] + tbz w11,#1,6f + fadd v0.4s,v0.4s,v4.4s + fadd v1.4s,v1.4s,v5.4s + fadd v2.4s,v2.4s,v6.4s + fadd v3.4s,v3.4s,v7.4s +6: + fadd v24.4s,v24.4s,v0.4s + fadd v25.4s,v25.4s,v1.4s + fadd v26.4s,v26.4s,v2.4s + fadd v27.4s,v27.4s,v3.4s + tbz w11,#2,7f + eor v0.16b,v0.16b,v0.16b + fmax v24.4s,v24.4s,v0.4s + fmax v25.4s,v25.4s,v0.4s + fmax v26.4s,v26.4s,v0.4s + fmax v27.4s,v27.4s,v0.4s +7: + stp q24,q25,[x22] + stp q26,q27,[x22,#32] + + // ---- output 3 ---- + add x22,x22,#.LPW_BlockBytes + ldp q0,q1,[x22] + ldp q2,q3,[x22,#32] + tbz w11,#1,8f + fadd v0.4s,v0.4s,v4.4s + fadd v1.4s,v1.4s,v5.4s + fadd v2.4s,v2.4s,v6.4s + fadd v3.4s,v3.4s,v7.4s +8: + fadd v28.4s,v28.4s,v0.4s + fadd v29.4s,v29.4s,v1.4s + fadd v30.4s,v30.4s,v2.4s + fadd v31.4s,v31.4s,v3.4s + tbz w11,#2,9f + eor v0.16b,v0.16b,v0.16b + fmax v28.4s,v28.4s,v0.4s + fmax v29.4s,v29.4s,v0.4s + fmax v30.4s,v30.4s,v0.4s + fmax v31.4s,v31.4s,v0.4s +9: + stp q28,q29,[x22] + stp q30,q31,[x22,#32] + b .Lpw_advance_group + +// Non-accumulating path: add bias directly to the results if requested +.Lpw_store_nacc: + tbz w11,#1,10f + ldp q4,q5,[x18] + ldp q6,q7,[x18,#32] + fadd v16.4s,v16.4s,v4.4s + fadd v17.4s,v17.4s,v5.4s + fadd v18.4s,v18.4s,v6.4s + fadd v19.4s,v19.4s,v7.4s + fadd v20.4s,v20.4s,v4.4s + fadd v21.4s,v21.4s,v5.4s + fadd v22.4s,v22.4s,v6.4s + fadd v23.4s,v23.4s,v7.4s + fadd v24.4s,v24.4s,v4.4s + fadd v25.4s,v25.4s,v5.4s + fadd v26.4s,v26.4s,v6.4s + fadd v27.4s,v27.4s,v7.4s + fadd v28.4s,v28.4s,v4.4s + fadd v29.4s,v29.4s,v5.4s + fadd v30.4s,v30.4s,v6.4s + fadd v31.4s,v31.4s,v7.4s +10: + tbz w11,#2,11f + eor v0.16b,v0.16b,v0.16b + fmax v16.4s,v16.4s,v0.4s + fmax v17.4s,v17.4s,v0.4s + fmax v18.4s,v18.4s,v0.4s + fmax v19.4s,v19.4s,v0.4s + fmax v20.4s,v20.4s,v0.4s + fmax v21.4s,v21.4s,v0.4s + fmax v22.4s,v22.4s,v0.4s + fmax v23.4s,v23.4s,v0.4s + fmax v24.4s,v24.4s,v0.4s + fmax v25.4s,v25.4s,v0.4s + fmax v26.4s,v26.4s,v0.4s + fmax v27.4s,v27.4s,v0.4s + fmax v28.4s,v28.4s,v0.4s + fmax v29.4s,v29.4s,v0.4s + fmax v30.4s,v30.4s,v0.4s + fmax v31.4s,v31.4s,v0.4s +11: + stp q16,q17,[x20] + stp q18,q19,[x20,#32] + add x22,x20,#.LPW_BlockBytes + stp q20,q21,[x22] + stp q22,q23,[x22,#32] + add x22,x22,#.LPW_BlockBytes + stp q24,q25,[x22] + stp q26,q27,[x22,#32] + add x22,x22,#.LPW_BlockBytes + stp q28,q29,[x22] + stp q30,q31,[x22,#32] + +.Lpw_advance_group: + add x15,x15,x3,lsl #2 + add x20,x20,#(.LPW_BlockBytes*4) + subs x13,x13,#1 + b.ne .Lpw_groups + +// ------------------------------------------------------------------ +// Handle the leftover (0..3) output positions. +// ------------------------------------------------------------------ +.Lpw_process_remainder: + ands x12,x12,#3 + beq .Lpw_after_filter +.Lpw_left_loop: + CPK_ComputeOneOutput + + // Accumulate? + tbz w11,#0,.Lpw_left_noacc + ldp q0,q1,[x20] + ldp q2,q3,[x20,#32] + tbz w11,#1,12f + ldp q4,q5,[x18] + ldp q6,q7,[x18,#32] + fadd v0.4s,v0.4s,v4.4s + fadd v1.4s,v1.4s,v5.4s + fadd v2.4s,v2.4s,v6.4s + fadd v3.4s,v3.4s,v7.4s +12: + fadd v16.4s,v16.4s,v0.4s + fadd v17.4s,v17.4s,v1.4s + fadd v18.4s,v18.4s,v2.4s + fadd v19.4s,v19.4s,v3.4s + tbz w11,#2,13f + eor v0.16b,v0.16b,v0.16b + fmax v16.4s,v16.4s,v0.4s + fmax v17.4s,v17.4s,v0.4s + fmax v18.4s,v18.4s,v0.4s + fmax v19.4s,v19.4s,v0.4s +13: + stp q16,q17,[x20] + stp q18,q19,[x20,#32] + b 14f + +.Lpw_left_noacc: + tbz w11,#1,15f + ldp q4,q5,[x18] + ldp q6,q7,[x18,#32] + fadd v16.4s,v16.4s,v4.4s + fadd v17.4s,v17.4s,v5.4s + fadd v18.4s,v18.4s,v6.4s + fadd v19.4s,v19.4s,v7.4s +15: + tbz w11,#2,16f + eor v0.16b,v0.16b,v0.16b + fmax v16.4s,v16.4s,v0.4s + fmax v17.4s,v17.4s,v0.4s + fmax v18.4s,v18.4s,v0.4s + fmax v19.4s,v19.4s,v0.4s +16: + stp q16,q17,[x20] + stp q18,q19,[x20,#32] +14: + add x15,x15,x3 + add x20,x20,#.LPW_BlockBytes + subs x12,x12,#1 + b.ne .Lpw_left_loop + +.Lpw_after_filter: + add x14,x14,#1 + cmp x14,x5 + blt .Lpw_filter_loop +.Lpw_exit: + ldp x28,x19,[sp],#16 + ldp x26,x27,[sp],#16 + ldp x24,x25,[sp],#16 + ldp x22,x23,[sp],#16 + ret + + .end diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index e75ca3dc90e60..16e830f892863 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -963,9 +963,15 @@ extern "C" { MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelAdd; #endif #if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC) + // AArch64 assembly micro-kernel for direct NCHW convolution + MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelNeonAsm; MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelNeon; MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelNeon; + // AArch64 assembly micro-kernel for depthwise NCHW convolution + MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelNeonAsm; MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelNeon; + // AArch64 assembly micro-kernel for pointwise NCHW convolution + MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelNeonAsm; MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelNeon; #if defined(__aarch64__) && defined(__linux__) MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseBf16KernelNeon; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index b913b1c3b8c26..dcb5c45817961 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -571,7 +571,19 @@ Return Value: this->EltwiseDispatch = &MlasEltwiseDispatchNeon; #if defined(MLAS_USE_ARM_NEON_NCHWC) - this->ConvNchwFloatKernel = MlasConvNchwFloatKernelNeon; + // Prefer the hand written micro-kernel for the NCHW convolution path. It + // offers a tighter schedule and a specialised two-output inner loop that + // reduces pressure on the memory system compared + this->ConvNchwFloatKernel = MlasConvNchwFloatKernelNeonAsm; +#if defined(_WIN32) + this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelNeon; +#else + this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelNeonAsm; +#endif + // Prefer the hand written AArch64 micro-kernel for pointwise convolution + // as it computes multiple output positions at once and significantly + // reduces memory traffic + this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelNeonAsm; this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelNeon; this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelNeon; this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelNeon; diff --git a/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp index 1a9949983c3ee..f41b380b2a071 100644 --- a/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp @@ -17,7 +17,7 @@ Module Name: #if defined(MLAS_USE_ARM_NEON_NCHWC) && defined(__linux__) #include "mlasi.h" -#include "sconv.h" +#include "sconv_nchwc_kernel_neon.h" constexpr size_t BlockSize = MLAS_PLATFORM::MLAS_NEON_NCHWC_BLOCK_SIZE; diff --git a/onnxruntime/core/mlas/lib/snchwc.cpp b/onnxruntime/core/mlas/lib/snchwc.cpp index 505246841087c..0c4d95400a8d1 100644 --- a/onnxruntime/core/mlas/lib/snchwc.cpp +++ b/onnxruntime/core/mlas/lib/snchwc.cpp @@ -882,6 +882,9 @@ struct MLAS_NCHWC_CONV_POINTWISE_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM #if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvPointwiseFloatKernel; +#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC) + MLAS_CONV_POINTWISE_FLOAT_KERNEL* const KernelFast = MlasConvPointwiseFloatKernelNeonAsm; +#endif #if defined(__aarch64__) && defined(__linux__) if (WorkBlock->UseBf16) { Kernel = GetMlasPlatform().ConvPointwiseBf16Kernel; @@ -937,7 +940,14 @@ struct MLAS_NCHWC_CONV_POINTWISE_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM // Invoke the convolution kernel. // - Kernel(input, filter, output, StrideWidthBytes, InputChannelBatch / + MLAS_CONV_POINTWISE_FLOAT_KERNEL* KernelToUse = Kernel; +#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC) + if (!WorkBlock->UseBf16 && OutputThisIteration >= 4 && + StrideHeight == 1 && StrideWidth == 1) { + KernelToUse = KernelFast; + } +#endif + KernelToUse(input, filter, output, StrideWidthBytes, InputChannelBatch / BlockSize, FilterCount, InputStrideBytes, FilterStrideBytes, OutputStrideBytes, OutputThisIteration, Bias, KernelFlags); @@ -1024,6 +1034,9 @@ struct MLAS_NCHWC_CONV_DEPTHWISE_ALGORITHM : MLAS_NCHWC_CONV_ALGORITHM #if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvDepthwiseFloatKernel; +#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC) + MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* const KernelFast = MlasConvDepthwiseFloatKernelNeonAsm; +#endif #else MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* Kernel = MlasConvDepthwiseFloatKernel; #endif @@ -1047,7 +1060,13 @@ struct MLAS_NCHWC_CONV_DEPTHWISE_ALGORITHM : MLAS_NCHWC_CONV_ALGORITHM // Invoke the convolution kernel. // - Kernel(Input + BlockSize * (ih * InputWidth - PaddingLeftX), filter, + MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* KernelToUse = Kernel; +#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC) + if (OutputWidth >= 4) { + KernelToUse = KernelFast; + } +#endif + KernelToUse(Input + BlockSize * (ih * InputWidth - PaddingLeftX), filter, Output, StrideWidthBytes, DilationWidthBytes, InputStrideBytes, EffectiveKernelHeight, KernelWidth, Input + BlockSize * (ih * InputWidth), InputWidthBytes, DilatedInputWidthBytes, OutputCountLeftPadX, From 6c980e30d58d7f60ef394f5e56ec1b15d4aa414d Mon Sep 17 00:00:00 2001 From: Milos Puzovic Date: Mon, 26 Jan 2026 16:27:20 +0000 Subject: [PATCH 02/23] Address comments from the reviewers Signed-off-by: Milos Puzovic --- cmake/onnxruntime_mlas.cmake | 4 +- .../lib/aarch64/SconvDepthwiseKernelNeon.S | 8 ++-- .../core/mlas/lib/aarch64/SconvKernelNeon.S | 9 ++-- .../lib/aarch64/SconvPointwiseKernelNeon.S | 2 +- onnxruntime/core/mlas/lib/mlasi.h | 5 +- onnxruntime/core/mlas/lib/platform.cpp | 10 ++-- .../core/mlas/lib/sconv_nchwc_kernel_neon.cpp | 46 ------------------- 7 files changed, 17 insertions(+), 67 deletions(-) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index be8b8eb90c031..59abea26e4f60 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -331,6 +331,8 @@ function (setup_arm_neon_nchwc) # separate assembly file allows tighter control over register allocation # and avoids the overhead of C++/intrinsics based code generation. ${MLAS_SRC_DIR}/aarch64/SconvKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeon.S ) list(APPEND mlas_private_compile_definitions MLAS_USE_ARM_NEON_NCHWC) set(mlas_private_compile_definitions ${mlas_private_compile_definitions} PARENT_SCOPE) @@ -464,8 +466,6 @@ else() ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSdot.S ${MLAS_SRC_DIR}/aarch64/SgemmKernelNeon.S ${MLAS_SRC_DIR}/aarch64/SgemvKernelNeon.S - ${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeon.S - ${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeon.S ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelNeon.S ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdot.S ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdotLd64.S diff --git a/onnxruntime/core/mlas/lib/aarch64/SconvDepthwiseKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SconvDepthwiseKernelNeon.S index 91fc44decfb10..6521c50eb40a2 100644 --- a/onnxruntime/core/mlas/lib/aarch64/SconvDepthwiseKernelNeon.S +++ b/onnxruntime/core/mlas/lib/aarch64/SconvDepthwiseKernelNeon.S @@ -8,7 +8,7 @@ Module Name: Abstract: - Optimised AArch64 assembly implementation of the depthwise convolution + Optimized AArch64 assembly implementation of the depthwise convolution micro-kernel used by the NCHWc single precision path. This kernel performs the following optimisations: @@ -17,7 +17,7 @@ Abstract: of 128-bit loads. * When an output position touches padding, only the affected 4-wide lanes are checked individually and loaded; others are zeroed. This - mirrors the behaviour of the C++ helper LoadInputVectorWithBounds. + mirrors the behavior of the C++ helper LoadInputVectorWithBounds. * Keep the multiply/accumulate operations tightly scheduled to hide the load latency. @@ -93,7 +93,7 @@ Abstract: add x16,x16,x13 // Load bias vectors when required; otherwise all zeros are used to - // initialise the accumulators. + // initialize the accumulators. eor v20.16b, v20.16b, v20.16b eor v21.16b, v21.16b, v21.16b eor v22.16b, v22.16b, v22.16b @@ -148,7 +148,7 @@ Abstract: .Ldw_SlowPath: // Zero registers and conditionally load each 4-wide vector when it is - // entirely within bounds. This matches the behaviour of the C++ + // entirely within bounds. This matches the behavior of the C++ // helper LoadInputVectorWithBounds. eor v16.16b, v16.16b, v16.16b eor v17.16b, v17.16b, v17.16b diff --git a/onnxruntime/core/mlas/lib/aarch64/SconvKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SconvKernelNeon.S index 63a0e430a1705..643f537834663 100644 --- a/onnxruntime/core/mlas/lib/aarch64/SconvKernelNeon.S +++ b/onnxruntime/core/mlas/lib/aarch64/SconvKernelNeon.S @@ -6,10 +6,11 @@ Module Name: SconvKernelNeon.S Abstract: - Hand written AArch64 vectorised kernel used by the NCHW convolution path. The - kernel computes one output row and processes up to four 16-wide filter - blocks (FilterCount <= 4). The hot interior loop avoids all bounds checks - and the two-output path amortises filter loads when possible. + Hand written AArch64 vectorised kernel used by the convolution path + (NCHW activations and NCHWc weights). The kernel computes one output + row and processes up to four 16-wide filter blocks (FilterCount <= 4). + The hot interior loop avoids all bounds checks and the two-output path + amortizes filter loads when possible. --*/ #include "asmmacro.h" diff --git a/onnxruntime/core/mlas/lib/aarch64/SconvPointwiseKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SconvPointwiseKernelNeon.S index e3fd621b4422b..8caae4fc080ac 100644 --- a/onnxruntime/core/mlas/lib/aarch64/SconvPointwiseKernelNeon.S +++ b/onnxruntime/core/mlas/lib/aarch64/SconvPointwiseKernelNeon.S @@ -27,7 +27,7 @@ Abstract: .equ .LPW_Bias, 16 .equ .LPW_Flags, 24 -// Kernel flag bits. Keep these in sync with sconv.h. +// Kernel flag bits. Keep these in sync with sconv_nchwc_kernel_neon.h. .equ .LPWFlag_Accumulate, 1 .equ .LPWFlag_Bias, 2 .equ .LPWFlag_Relu, 4 diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 16e830f892863..242cf88fa2d71 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -965,12 +965,11 @@ extern "C" { #if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC) // AArch64 assembly micro-kernel for direct NCHW convolution MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelNeonAsm; - MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelNeon; MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelNeon; - // AArch64 assembly micro-kernel for depthwise NCHW convolution + // AArch64 assembly micro-kernel for depthwise NCHWc convolution MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelNeonAsm; MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelNeon; - // AArch64 assembly micro-kernel for pointwise NCHW convolution + // AArch64 assembly micro-kernel for pointwise NCHWc convolution MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelNeonAsm; MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelNeon; #if defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index dcb5c45817961..88fb59ac47093 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -575,15 +575,11 @@ Return Value: // offers a tighter schedule and a specialised two-output inner loop that // reduces pressure on the memory system compared this->ConvNchwFloatKernel = MlasConvNchwFloatKernelNeonAsm; -#if defined(_WIN32) - this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelNeon; -#else - this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelNeonAsm; -#endif // Prefer the hand written AArch64 micro-kernel for pointwise convolution // as it computes multiple output positions at once and significantly - // reduces memory traffic - this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelNeonAsm; + // reduces memory traffic. The AArch64 assembly kernel is picked up by + // heuristics in platform.cpp to avoid regressions on small convolutions. + // So here we set the default to the intrinsics version this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelNeon; this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelNeon; this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelNeon; diff --git a/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp index 745258080810a..09ee155d36d3e 100644 --- a/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp @@ -182,52 +182,6 @@ void } } -void - MLASCALL - MlasConvNchwFloatKernelNeon( - const float* Input, - const float* Filter, - float* Output, - size_t StrideWidth, - size_t DilationWidth, - size_t FilterCount, - size_t InputStride, - size_t FilterStride, - size_t OutputStride, - size_t KernelHeight, - size_t KernelWidth, - const float* InputBase, - size_t InputWidth, - size_t DilatedInputWidth, - size_t OutputCountLeftPad, - size_t OutputCount, - size_t OutputCountRightPad, - const float* Bias, - unsigned KernelFlags - ) -{ - MlasConvFloatKernelNeonImpl( - Input, - Filter, - Output, - StrideWidth, - DilationWidth, - FilterCount, - InputStride, - FilterStride, - OutputStride, - KernelHeight, - KernelWidth, - InputBase, - InputWidth, - DilatedInputWidth, - OutputCountLeftPad, - OutputCount, - OutputCountRightPad, - Bias, - KernelFlags - ); -} // // Implementation of MlasConvNchwcFloatKernelNeon From fe780769679efd069bff3cb248ea22dca12e8284 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Mon, 26 Jan 2026 09:41:47 -0800 Subject: [PATCH 03/23] Apply absl cuda warning patch to othe OS (#27126) Fix #27125 It does fix the build issue on Linux, but I am not entirely sure whether this is the optimal fix. --- cmake/external/abseil-cpp.cmake | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/cmake/external/abseil-cpp.cmake b/cmake/external/abseil-cpp.cmake index 3f7ff2c26ff81..6c5464851937c 100644 --- a/cmake/external/abseil-cpp.cmake +++ b/cmake/external/abseil-cpp.cmake @@ -20,9 +20,13 @@ else() endif() endif() -if(Patch_FOUND AND WIN32) - set(ABSL_PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/abseil/absl_windows.patch && - ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/abseil/absl_cuda_warnings.patch) +if(Patch_FOUND) + if (WIN32) + set(ABSL_PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/abseil/absl_windows.patch && + ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/abseil/absl_cuda_warnings.patch) + else() + set(ABSL_PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/abseil/absl_cuda_warnings.patch) + endif() else() set(ABSL_PATCH_COMMAND "") endif() From c9a8a38b2ad46eaf78586661bc53cd43c8af4291 Mon Sep 17 00:00:00 2001 From: Jambay Kinley Date: Tue, 27 Jan 2026 02:03:57 +0600 Subject: [PATCH 04/23] Deprecate transformers model examples (#27156) ### Description Models with corresponding Olive recipes are deprecated. ### Motivation and Context Olive and Olive-recipes is the entry point for model optimization. We want onnxruntime to be only for runtime. So, deprecating examples that are already present in olive recipes. --- .../tools/transformers/models/gpt2/convert_to_onnx.py | 8 ++++++++ .../python/tools/transformers/models/llama/README.md | 2 ++ .../tools/transformers/models/llama/convert_to_onnx.py | 7 +++++++ .../python/tools/transformers/models/phi2/README.md | 2 ++ .../tools/transformers/models/phi2/convert_to_onnx.py | 7 +++++++ .../tools/transformers/models/stable_diffusion/README.md | 2 ++ .../models/stable_diffusion/optimize_pipeline.py | 7 +++++++ .../python/tools/transformers/models/whisper/README.md | 2 ++ .../tools/transformers/models/whisper/convert_to_onnx.py | 7 +++++++ 9 files changed, 44 insertions(+) diff --git a/onnxruntime/python/tools/transformers/models/gpt2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/gpt2/convert_to_onnx.py index a4015f50fdc13..841421a353b07 100644 --- a/onnxruntime/python/tools/transformers/models/gpt2/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/gpt2/convert_to_onnx.py @@ -21,6 +21,7 @@ import os import shutil import sys +import warnings from pathlib import Path import numpy @@ -243,6 +244,13 @@ def get_latency_name(batch_size, sequence_length, past_sequence_length): def main(argv=None, experiment_name: str = "", run_id: str = "0", csv_filename: str = "gpt2_parity_results.csv"): + warnings.warn( + "This example is deprecated. Use the Olive recipe instead: " + "https://github.com/microsoft/olive-recipes/tree/main", + DeprecationWarning, + stacklevel=2, + ) + result = {} if version.parse(transformers_version) < version.parse( "3.1.0" diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md index cd8a8756d681e..eccfb46582fbc 100644 --- a/onnxruntime/python/tools/transformers/models/llama/README.md +++ b/onnxruntime/python/tools/transformers/models/llama/README.md @@ -1,3 +1,5 @@ +> **Deprecated:** This example is deprecated. Use the Olive recipes instead: https://github.com/microsoft/olive-recipes/tree/main + # Contents - [LLaMA-2](#llama-2) - [Prerequisites](#prerequisites) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 6411dca00b5de..2cb6a733c5bc7 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -12,6 +12,7 @@ import subprocess import sys import tempfile +import warnings from itertools import chain import onnx @@ -801,6 +802,12 @@ def get_args(): def main(): + warnings.warn( + "This example is deprecated. Use the Olive recipe instead: " + "https://github.com/microsoft/olive-recipes/tree/main", + DeprecationWarning, + stacklevel=2, + ) if version.parse(torch.__version__) < version.parse("2.2.0"): logger.error(f"Detected PyTorch version {torch.__version__}. Please upgrade and use v2.2.0 or newer.") return diff --git a/onnxruntime/python/tools/transformers/models/phi2/README.md b/onnxruntime/python/tools/transformers/models/phi2/README.md index da62bba0f02fb..eab31680e64c7 100644 --- a/onnxruntime/python/tools/transformers/models/phi2/README.md +++ b/onnxruntime/python/tools/transformers/models/phi2/README.md @@ -1,3 +1,5 @@ +> **Deprecated:** This example is deprecated. Use the Olive recipes instead: https://github.com/microsoft/olive-recipes/tree/main + # Phi2 Optimizations ## Prerequisites A Linux machine for [TorchDynamo-based ONNX Exporter](https://pytorch.org/docs/stable/onnx.html#torchdynamo-based-onnx-exporter)\ diff --git a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py index dd0accc5dd9e8..ebdb5e32b7184 100644 --- a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py @@ -7,6 +7,7 @@ import argparse import logging import os +import warnings from pathlib import Path import onnx @@ -375,6 +376,12 @@ def parse_arguments(): def main(): + warnings.warn( + "This example is deprecated. Use the Olive recipe instead: " + "https://github.com/microsoft/olive-recipes/tree/main", + DeprecationWarning, + stacklevel=2, + ) args = parse_arguments() device = torch.device("cuda", args.device_id) if torch.cuda.is_available() else torch.device("cpu") diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index 12e6df53de577..4afede881fb93 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -1,3 +1,5 @@ +> **Deprecated:** This example is deprecated. Use the Olive recipes instead: https://github.com/microsoft/olive-recipes/tree/main + # Stable Diffusion GPU Optimization ONNX Runtime uses the following optimizations to speed up Stable Diffusion in CUDA: diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index eb4d7242f72fc..33397cf75e127 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -20,6 +20,7 @@ import os import shutil import tempfile +import warnings from pathlib import Path import coloredlogs @@ -569,6 +570,12 @@ def parse_arguments(argv: list[str] | None = None): def main(argv: list[str] | None = None): + warnings.warn( + "This example is deprecated. Use the Olive recipe instead: " + "https://github.com/microsoft/olive-recipes/tree/main", + DeprecationWarning, + stacklevel=2, + ) args = parse_arguments(argv) logger.info("Arguments: %s", str(args)) diff --git a/onnxruntime/python/tools/transformers/models/whisper/README.md b/onnxruntime/python/tools/transformers/models/whisper/README.md index 9056ac07cc286..44a041d789b5d 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/README.md +++ b/onnxruntime/python/tools/transformers/models/whisper/README.md @@ -1,3 +1,5 @@ +> **Deprecated:** This example is deprecated. Use the Olive recipes instead: https://github.com/microsoft/olive-recipes/tree/main + # Whisper ## Prerequisites diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index 79b508047da55..93b509eec6982 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -7,6 +7,7 @@ import argparse import logging import os +import warnings import onnx import torch @@ -493,6 +494,12 @@ def export_onnx_models( def main(argv=None): + warnings.warn( + "This example is deprecated. Use the Olive recipe instead: " + "https://github.com/microsoft/olive-recipes/tree/main", + DeprecationWarning, + stacklevel=2, + ) args = parse_arguments(argv) setup_logger(args.verbose) From e24d130943ac59bdc5cdf810326f63d4c7fcb9c2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 26 Jan 2026 13:20:01 -0800 Subject: [PATCH 05/23] Bump lodash from 4.17.21 to 4.17.23 in /js/react_native/e2e (#27134) Bumps [lodash](https://github.com/lodash/lodash) from 4.17.21 to 4.17.23.
Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=lodash&package-manager=npm_and_yarn&previous-version=4.17.21&new-version=4.17.23)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- js/react_native/e2e/package-lock.json | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/js/react_native/e2e/package-lock.json b/js/react_native/e2e/package-lock.json index 931b86b3b08b1..8e3f605deef7d 100644 --- a/js/react_native/e2e/package-lock.json +++ b/js/react_native/e2e/package-lock.json @@ -9266,7 +9266,9 @@ } }, "node_modules/lodash": { - "version": "4.17.21", + "version": "4.17.23", + "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.23.tgz", + "integrity": "sha512-LgVTMpQtIopCi79SJeDiP0TfWi5CNEc/L/aRdTh3yIvmZXTnheWpKjSZhnvMl8iXbC1tFg9gdHHDMLoV7CnG+w==", "dev": true, "license": "MIT" }, From 66454136abeac6e6134361e8d79d1e1667837a7f Mon Sep 17 00:00:00 2001 From: xhcao Date: Tue, 27 Jan 2026 10:13:16 +0800 Subject: [PATCH 06/23] webgpu: optimize Gemm and MatMul using subgroup feature (#26433) ### Description ### Motivation and Context --- .../core/providers/webgpu/math/gemm.cc | 5 + .../core/providers/webgpu/math/gemm_packed.cc | 4 +- .../core/providers/webgpu/math/gemm_utils.cc | 65 ++++--- .../core/providers/webgpu/math/gemm_utils.h | 4 +- .../core/providers/webgpu/math/matmul.cc | 5 + .../providers/webgpu/math/matmul_packed.cc | 6 +- .../webgpu/vendor/intel/math/gemm.cc | 121 ++++++++++++ .../providers/webgpu/vendor/intel/math/gemm.h | 64 ++++++ .../webgpu/vendor/intel/math/gemm_subgroup.cc | 183 ++++++++++++++++++ .../webgpu/vendor/intel/math/gemm_subgroup.h | 31 +++ .../webgpu/vendor/intel/math/matmul.cc | 135 +++++++++++++ .../webgpu/vendor/intel/math/matmul.h | 48 +++++ 12 files changed, 631 insertions(+), 40 deletions(-) create mode 100644 onnxruntime/core/providers/webgpu/vendor/intel/math/gemm.cc create mode 100644 onnxruntime/core/providers/webgpu/vendor/intel/math/gemm.h create mode 100644 onnxruntime/core/providers/webgpu/vendor/intel/math/gemm_subgroup.cc create mode 100644 onnxruntime/core/providers/webgpu/vendor/intel/math/gemm_subgroup.h create mode 100644 onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.cc create mode 100644 onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.h diff --git a/onnxruntime/core/providers/webgpu/math/gemm.cc b/onnxruntime/core/providers/webgpu/math/gemm.cc index 4fb512001381a..b722430049877 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm.cc @@ -3,6 +3,7 @@ #include "core/providers/webgpu/math/gemm.h" #include "core/providers/webgpu/math/gemm_packed.h" +#include "core/providers/webgpu/vendor/intel/math/gemm.h" #include @@ -147,6 +148,10 @@ Status Gemm::ComputeInternal(ComputeContext& context) const { return context.RunProgram(program); } + if (intel::CanApplyGemmIntel(context, M, N, K, transA_, transB_)) { + return intel::ApplyGemmIntel(A, B, C, transA_, transB_, alpha_, beta_, context); + } + return ApplyGemmPacked(A, B, C, transA_, transB_, alpha_, beta_, context); } diff --git a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc index 1a0ad7a843ec4..023a671420d89 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_packed.cc @@ -31,7 +31,7 @@ Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); const auto& b = shader.AddInput("b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - MatMulReadFnSource(shader, a, b, nullptr, transA_, transB_, is_vec4_); + MatMulReadFnSource(shader, a, b, nullptr, transA_, transB_); } if (is_vec4_) { ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, elements_per_thread, WorkgroupSizeX(), WorkgroupSizeY(), data_type, nullptr, transA_, transB_, alpha_, need_handle_matmul_, output_components_, /*tile_inner*/ 32, need_split_k, split_dim_inner_)); @@ -45,7 +45,7 @@ Status GemmProgram::GenerateShaderCode(ShaderHelper& shader) const { } const ProgramVariableDataType output_var_type = this->Outputs()[0].var_type; - MatMulWriteFnSource(shader, output, c, /* is_gemm = */ true, c_components_, output_components_, c_is_scalar_, /*activation_snippet*/ "", /*is_channels_last*/ false, need_split_k, output_var_type); + MatMulWriteFnSource(shader, output, c, /* is_gemm = */ true, c_components_, c_is_scalar_, /*activation_snippet*/ "", /*is_channels_last*/ false, need_split_k, output_var_type); return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index ba7e9290f8455..0228fb25d1d26 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -49,7 +49,7 @@ void HandleMaybeBiasForMatMul(ShaderHelper& shader, shader.AdditionalImplementation() << " value = value + output_value_t(" << (is_channels_last ? bias->GetByOffset("colIn") : bias->GetByOffset("row")) << ");\n"; } shader.AdditionalImplementation() << " " << activation_snippet << "\n" - << output.SetByIndices("coords", "value") << "\n"; + << " " << output.SetByIndices("coords", "value") << "\n"; } void HandleMatMulWithSplitK( @@ -127,60 +127,61 @@ void MatMulReadFnSource(ShaderHelper& shader, const ShaderVariableHelper& b, const ShaderIndicesHelper* batch_dims, bool transA, - bool transB, - bool is_vec4) { - int components = is_vec4 ? 4 : 1; + bool transB) { + const int a_components = a.NumComponents(); const std::string data_type = "output_element_t"; - const std::string type_string = MakeScalarOrVectorType(components, data_type); + std::string type_string = MakeScalarOrVectorType(a_components, data_type); shader.AdditionalImplementation() << "fn mm_readA(batch: i32, row: i32, colIn: i32 " << (batch_dims ? ", batch_indices: batch_dims_indices_t" : "") - << ") -> " << type_string << " {\n " - << " var value = " << type_string << "(0);\n" - << " let col = colIn * " << components << ";\n"; + << ") -> " << type_string << " {\n" + << " var value = " << type_string << "(0);\n" + << " let col = colIn * " << a_components << ";\n"; if (transA) { - shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_a_outer)) {\n"; + shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_a_outer)) {\n"; } else { - shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_inner)) {\n"; + shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_inner)) {\n"; } - shader.AdditionalImplementation() << " var a_indices: a_indices_t;\n"; + shader.AdditionalImplementation() << " var a_indices: a_indices_t;\n"; if (batch_dims) { - shader.AdditionalImplementation() << ConvertOutputBatchIndicesToInputBatchIndices("a", a, a.Rank() - 2, batch_dims ? batch_dims->Rank() : 0, " batch_indices ") << "\n"; + shader.AdditionalImplementation() << ConvertOutputBatchIndicesToInputBatchIndices("a", a, a.Rank() - 2, batch_dims ? batch_dims->Rank() : 0, " batch_indices "); } - shader.AdditionalImplementation() << a.IndicesSet("a_indices", a.Rank() - 2, "u32(row)") << "\n" - << a.IndicesSet("a_indices", a.Rank() - 1, "u32(colIn)") << "\n" - << " value = " << a.GetByIndices("a_indices") << ";\n" - << " }\n" - << " return value;\n" + shader.AdditionalImplementation() << " " << a.IndicesSet("a_indices", a.Rank() - 2, "u32(row)") << "\n" + << " " << a.IndicesSet("a_indices", a.Rank() - 1, "u32(colIn)") << "\n" + << " value = " << a.GetByIndices("a_indices") << ";\n" + << " }\n" + << " return value;\n" << "}\n\n"; // Add the mm_readB function + const int b_components = b.NumComponents(); + type_string = MakeScalarOrVectorType(b_components, data_type); shader.AdditionalImplementation() << "fn mm_readB(batch: i32, row: i32, colIn: i32 " << (batch_dims ? ", batch_indices: batch_dims_indices_t" : "") - << ") -> " << type_string << " {\n " - << " var value = " << type_string << "(0);\n" - << " let col = colIn * " << components << ";\n"; + << ") -> " << type_string << " {\n" + << " var value = " << type_string << "(0);\n" + << " let col = colIn * " << b_components << ";\n"; if (transB) { - shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_b_outer) && col < i32(uniforms.dim_inner)) {\n"; + shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_b_outer) && col < i32(uniforms.dim_inner)) {\n"; } else { - shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_b_outer)) {\n"; + shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_b_outer)) {\n"; } - shader.AdditionalImplementation() << " var b_indices: b_indices_t;\n" + shader.AdditionalImplementation() << " var b_indices: b_indices_t;\n" << ConvertOutputBatchIndicesToInputBatchIndices("b", b, b.Rank() - 2, batch_dims ? batch_dims->Rank() : 0, "batch_indices") - << b.IndicesSet("b_indices", b.Rank() - 2, "u32(row)") << "\n" - << b.IndicesSet("b_indices", b.Rank() - 1, "u32(colIn)") << "\n" - << " value = " << b.GetByIndices("b_indices") << ";\n" - << " }\n" - << " return value;\n" + << " " << b.IndicesSet("b_indices", b.Rank() - 2, "u32(row)") << "\n" + << " " << b.IndicesSet("b_indices", b.Rank() - 1, "u32(colIn)") << "\n" + << " value = " << b.GetByIndices("b_indices") << ";\n" + << " }\n" + << " return value;\n" << "}\n\n"; } @@ -189,19 +190,19 @@ void MatMulWriteFnSource(ShaderHelper& shader, const ShaderVariableHelper* bias, bool is_gemm, int c_components, - int output_components, bool c_is_scalar, std::string activation_snippet, bool is_channels_last, bool use_split_k, ProgramVariableDataType output_variable_type) { + const int output_components = output.NumComponents(); shader.AdditionalImplementation() - << "fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: output_value_t) { \n"; + << "fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: output_value_t) {\n"; shader.AdditionalImplementation() << " let col = colIn * " << output_components << ";\n"; - shader.AdditionalImplementation() << "if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_b_outer)) { \n" - << " var value = valueIn; \n"; + shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_b_outer)) {\n" + << " var value = valueIn;\n"; if (use_split_k) { // Set output when MatMul is performed with Split-K. diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.h b/onnxruntime/core/providers/webgpu/math/gemm_utils.h index e001544f9e50d..49c3fbb8640a5 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.h +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.h @@ -13,15 +13,13 @@ void MatMulReadFnSource(ShaderHelper& shader, const ShaderVariableHelper& b, const ShaderIndicesHelper* batch_dims, bool transA, - bool transB, - bool is_vec4); + bool transB); void MatMulWriteFnSource(ShaderHelper& shader, const ShaderVariableHelper& output, const ShaderVariableHelper* bias, bool is_gemm, int c_components, - int output_components, bool c_is_scalar, std::string activation_snippet = "", bool is_channels_last = false, diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index ba32365bf9d88..b9afbc9bfecab 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -8,6 +8,7 @@ #include "core/providers/webgpu/webgpu_supported_types.h" #include "core/providers/webgpu/nn/fuse_utils.h" #include "core/providers/webgpu/data_transfer.h" +#include "core/providers/webgpu/vendor/intel/math/matmul.h" #include "core/providers/webgpu/webgpu_utils.h" namespace onnxruntime { @@ -163,6 +164,10 @@ Status MatMul::ComputeInternal(ComputeContext& context) const { inputs.push_back(bias); } + if (intel::CanApplyMatMulIntel(context, helper.M(), helper.N(), helper.K())) { + return intel::ApplyMatMulIntel(context, Activation(), inputs, output_tensor); + } + return ComputeMatMul(&context, Activation(), inputs, output_tensor, false); } diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc index e97e0fd6f1058..fb137f4755ed9 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc @@ -33,8 +33,8 @@ Status MatMulProgram::GenerateShaderCode(ShaderHelper& shader) const { std::string apply_activation = GetActivationSnippet(activation_, "output_value_t", "output_element_t"); ProgramVariableDataType output_var_type = this->Outputs()[0].var_type; // declare the read and write functions - MatMulReadFnSource(shader, a, b, &batch_dims, /*transA = */ false, /*transB = */ false, is_vec4_); - MatMulWriteFnSource(shader, output, bias, /* is_gemm = */ false, 1, is_vec4_ ? 4 : 1, false, apply_activation, is_channels_last_, need_split_k, output_var_type); + MatMulReadFnSource(shader, a, b, &batch_dims, /*transA = */ false, /*transB = */ false); + MatMulWriteFnSource(shader, output, bias, /* is_gemm = */ false, 1, false, apply_activation, is_channels_last_, need_split_k, output_var_type); std::string data_type = "a_element_t"; // generate the main function if (is_vec4_) { @@ -65,7 +65,7 @@ Status MatMulFillBiasOrZeroBeforeSplitKProgram::GenerateShaderCode(ShaderHelper& // `use_split_k` is true only when we do the actual MatMul with Split-K. const uint32_t bias_components = output_components_; MatMulWriteFnSource( - shader, output, bias, is_gemm_, bias_components, output_components_, bias_is_scalar_, + shader, output, bias, is_gemm_, bias_components, bias_is_scalar_, /*activation_snippet*/ "", /*is_channels_last*/ true, /*use_split_k*/ false); shader.MainFunctionBody() << " let output_components = " << output_components_ << ";\n"; diff --git a/onnxruntime/core/providers/webgpu/vendor/intel/math/gemm.cc b/onnxruntime/core/providers/webgpu/vendor/intel/math/gemm.cc new file mode 100644 index 0000000000000..699487b4c2270 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/vendor/intel/math/gemm.cc @@ -0,0 +1,121 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/vendor/intel/math/gemm.h" +#include "core/providers/webgpu/vendor/intel/math/gemm_subgroup.h" +#include "core/providers/webgpu/math/gemm_utils.h" + +namespace onnxruntime { +namespace webgpu { +namespace intel { + +Status GemmSubgroupProgram::GenerateShaderCode(ShaderHelper& shader) const { + const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | + ShaderUsage::UseValueTypeAlias | + ShaderUsage::UseElementTypeAlias); + + if (need_handle_matmul_) { + const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | + ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& b = shader.AddInput("b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | + ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + + MatMulReadFnSource(shader, a, b, nullptr, transA_, transB_); + } + + ORT_RETURN_IF_ERROR(MakeMatMulSubgroupSource(shader, elements_per_thread_, nullptr, is_vec4_, transA_, transB_, + alpha_, need_handle_matmul_)); + const ShaderVariableHelper* c = nullptr; + if (need_handle_bias_) { + c = &shader.AddInput("c", ShaderUsage::UseUniform); + } + MatMulWriteFnSource(shader, output, c, true, c_components_, c_is_scalar_); + + return Status::OK(); +} + +bool CanApplyGemmIntel(const ComputeContext& context, int64_t M, int64_t N, int64_t K, bool transA, bool transB) { + return CanApplySubgroup(context, M, N, K, transA, transB); +} + +Status ApplyGemmIntel(const Tensor* a, + const Tensor* b, + const Tensor* c, + bool transA, + bool transB, + float alpha, + float beta, + ComputeContext& context) { + const auto& a_shape = a->Shape(); + const auto& b_shape = b->Shape(); + + uint32_t M = onnxruntime::narrow(transA ? a_shape[1] : a_shape[0]); + uint32_t K = onnxruntime::narrow(transA ? a_shape[0] : a_shape[1]); + uint32_t N = onnxruntime::narrow(transB ? b_shape[0] : b_shape[1]); + + std::vector output_dims{M, N}; + auto* y = context.Output(0, output_dims); + int64_t output_size = y->Shape().Size(); + + if (output_size == 0) { + return Status::OK(); + } + + // WebGPU doesn't support binding a zero-sized buffer, so we need to check if A or B is empty. + bool need_handle_matmul = a_shape.Size() > 0 && b_shape.Size() > 0; + bool need_handle_bias = c && beta; + + const bool is_vec4 = b_shape[1] % 4 == 0; + // Components for A, B + int a_components = 1; + int b_components = is_vec4 ? 4 : 1; + // Components for Y + int output_components = (is_vec4 && N % 4 == 0) ? 4 : 1; + // Components for C. + int c_components = 1; + + bool c_is_scalar = false; + if (need_handle_bias) { + const auto& c_shape = c->Shape(); + int64_t c_last_dim = c_shape[c_shape.NumDimensions() - 1]; + // `C` in GEMM might be broadcast to the output, and broadcasting requires the components to be consistent. + // So we use vec4 for C when its last dimension is N, and the output is also a vec4. + c_components = (c_last_dim == N && output_components == 4) ? 4 : 1; + c_is_scalar = c_shape.Size() == 1; + } + + InlinedVector elements_per_thread = InlinedVector({4, intel::ElementsPerThreadY(is_vec4, M), 1}); + const uint32_t dispatch_x = narrow((N + kSubgroupLogicalWorkGroupSizeX * elements_per_thread[0] - 1) / + (kSubgroupLogicalWorkGroupSizeX * elements_per_thread[0])); + const uint32_t dispatch_y = narrow((M + kSubgroupLogicalWorkGroupSizeY * elements_per_thread[1] - 1) / + (kSubgroupLogicalWorkGroupSizeY * elements_per_thread[1])); + + GemmSubgroupProgram program{transA, transB, alpha, need_handle_bias, need_handle_matmul, c_components, c_is_scalar, + is_vec4, elements_per_thread}; + + if (need_handle_matmul) { + program.AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_components}, + {b, ProgramTensorMetadataDependency::TypeAndRank, b_components}}); + } + + if (need_handle_bias) { + program.AddInput({c, ProgramTensorMetadataDependency::TypeAndRank, c_components}); + } + + program.CacheHint(alpha, transA, transB, c_is_scalar, absl::StrJoin(elements_per_thread, "-")) + .AddOutputs({{y, ProgramTensorMetadataDependency::TypeAndRank, output_components}}) + .SetDispatchGroupSize(dispatch_x, dispatch_y, 1) + .SetWorkgroupSize(kSubgroupLogicalWorkGroupSizeX * kSubgroupLogicalWorkGroupSizeY, 1, 1) + .AddUniformVariables({{alpha}, + {beta}, + {M}, /* dim_a_outer */ + {N}, /* dim_b_outer */ + {K}} /*dim_inner */ + ); + + return context.RunProgram(program); +} + +} // namespace intel +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/vendor/intel/math/gemm.h b/onnxruntime/core/providers/webgpu/vendor/intel/math/gemm.h new file mode 100644 index 0000000000000..1e6ac6a7e7514 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/vendor/intel/math/gemm.h @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { +namespace intel { + +class GemmSubgroupProgram final : public Program { + public: + GemmSubgroupProgram(bool transA, bool transB, float alpha, bool need_handle_bias, bool need_handle_matmul, + int c_components, bool c_is_scalar, bool is_vec4, + const gsl::span& elements_per_thread) + : Program{"GemmSubgroup"}, + transA_{transA}, + transB_{transB}, + alpha_{alpha}, + need_handle_bias_{need_handle_bias}, + need_handle_matmul_{need_handle_matmul}, + c_components_(c_components), + c_is_scalar_(c_is_scalar), + is_vec4_(is_vec4), + elements_per_thread_(elements_per_thread.begin(), elements_per_thread.end()) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"alpha", ProgramUniformVariableDataType::Float32}, + {"beta", ProgramUniformVariableDataType::Float32}, + {"dim_a_outer", ProgramUniformVariableDataType::Uint32}, + {"dim_b_outer", ProgramUniformVariableDataType::Uint32}, + {"dim_inner", ProgramUniformVariableDataType::Uint32}); + + private: + bool transA_; + bool transB_; + float alpha_; + bool need_handle_bias_; + bool need_handle_matmul_; + int c_components_; + bool c_is_scalar_ = false; + bool is_vec4_ = false; + const InlinedVector elements_per_thread_; +}; + +bool CanApplyGemmIntel(const ComputeContext& context, int64_t M, int64_t N, int64_t K, bool transA, bool transB); + +Status ApplyGemmIntel(const Tensor* a, + const Tensor* b, + const Tensor* c, + bool transA, + bool transB, + float alpha, + float beta, + ComputeContext& context); + +} // namespace intel +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/vendor/intel/math/gemm_subgroup.cc b/onnxruntime/core/providers/webgpu/vendor/intel/math/gemm_subgroup.cc new file mode 100644 index 0000000000000..a6baf8dfb0239 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/vendor/intel/math/gemm_subgroup.cc @@ -0,0 +1,183 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/webgpu_utils.h" +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/string_macros.h" +#include "core/providers/webgpu/vendor/intel/math/gemm_subgroup.h" + +namespace onnxruntime { +namespace webgpu { +namespace intel { + +namespace { + +std::string LoadAStr(const ShaderIndicesHelper* batch_dims, int64_t elements_per_thread_y) { + SS(load_a_ss, 128); + for (int64_t i = 0; i < elements_per_thread_y; i++) { + load_a_ss << " a_val_" << i << " = " << std::string("mm_readA(batch, globalRowStart + ") + << i << std::string(", aCol") + (batch_dims ? ", batchIndices" : "") + ");\n"; + } + return SS_GET(load_a_ss); +} + +// Load one tile of B into local memory. +std::string LoadBStr(const ShaderIndicesHelper* batch_dims, int64_t tile_b_outer, bool is_vec4) { + SS(load_b_ss, 256); + load_b_ss << " let loadRowsPerThread = " << kSubgroupLogicalWorkGroupSizeX / kSubgroupLogicalWorkGroupSizeY << ";\n" + << " for (var innerRow = 0; innerRow < loadRowsPerThread; innerRow++) {\n" + << " let inputRow = loadRowsPerThread * localRow + innerRow;\n" + << " let inputCol = tileCol;\n"; + if (is_vec4) { + load_b_ss << " mm_Bsub[inputRow][inputCol] = mm_readB(batch, kStart + inputRow, globalColStart" + << (batch_dims ? ", batchIndices" : "") << ");\n"; + } else { + for (int j = 0; j < tile_b_outer; j += kSubgroupLogicalWorkGroupSizeX) { + load_b_ss << " mm_Bsub[inputRow][inputCol + " << j << "] = mm_readB(batch, kStart + inputRow, globalColStart + " + << j << (batch_dims ? ", batchIndices" : "") << ");\n"; + } + } + load_b_ss << " }\n" + << " workgroupBarrier();\n"; + + return SS_GET(load_b_ss); +} + +std::string LoadBCacheStr(bool is_vec4, uint32_t offset) { + SS(b_cache_ss, 256); + if (is_vec4) { + b_cache_ss << "BCache = mm_Bsub[" << offset << "][tileCol];\n"; + } else { + b_cache_ss << "BCache = vec4(mm_Bsub[" << offset << "][tileCol], " + << "mm_Bsub[" << offset << "][tileCol + " << kSubgroupLogicalWorkGroupSizeX << "], " + << "mm_Bsub[" << offset << "][tileCol + " << 2 * kSubgroupLogicalWorkGroupSizeX << "], " + << "mm_Bsub[" << offset << "][tileCol + " << 3 * kSubgroupLogicalWorkGroupSizeX << "]);\n"; + } + return SS_GET(b_cache_ss); +} + +std::string CalculateAccStr(const ShaderIndicesHelper* batch_dims, int64_t elements_per_thread_y, bool is_vec4) { + SS(cal_acc_ss, 1024); + + // key: simd size; value: the offset row of mm_Bsub. + std::map> simd_map = { + {32, {0}}, + {16, {0, 16}}, + {8, {0, 8, 16, 24}}}; + for (const auto& [simd, offsets] : simd_map) { + cal_acc_ss << " if (sg_size == " << simd << ") {\n"; + for (uint32_t offset : offsets) { + cal_acc_ss << LoadAStr(batch_dims, elements_per_thread_y) + << " aCol += " << simd << ";\n"; + for (uint32_t sg_idx = 0; sg_idx < simd; sg_idx++) { + cal_acc_ss << " " << LoadBCacheStr(is_vec4, sg_idx + offset); + for (uint32_t i = 0; i < elements_per_thread_y; i++) { + cal_acc_ss << " acc_" << i << " += subgroupBroadcast(a_val_" << i << ", " << sg_idx << ") * BCache;\n"; + } + } + } + cal_acc_ss << " }\n"; + } + + return SS_GET(cal_acc_ss); +} + +} // namespace + +bool CanApplySubgroup(const ComputeContext& context, int64_t M, int64_t N, int64_t K, bool transA, bool transB) { + if (context.AdapterInfo().vendor == std::string_view{"intel"}) { + bool use_subgroup = context.HasFeature(wgpu::FeatureName::Subgroups) && + M >= 64 && N >= 512 && K >= 32 && !transA && !transB; + return use_subgroup; + } + + return false; +} + +int64_t ElementsPerThreadY(bool is_vec4, uint32_t M) { + return is_vec4 ? (M <= 8 ? 1 : (M <= 16 ? 2 : (M <= 32 ? 4 : 8))) : 4; +} + +Status MakeMatMulSubgroupSource(ShaderHelper& shader, + const InlinedVector& elements_per_thread, + const ShaderIndicesHelper* batch_dims, + bool is_vec4, + bool transpose_a, + bool transpose_b, + float alpha, + bool need_handle_matmul) { + ORT_UNUSED_PARAMETER(transpose_a); + ORT_UNUSED_PARAMETER(transpose_b); + + // elements per thread + const auto elements_per_thread_x = elements_per_thread[0]; + const auto elements_per_thread_y = elements_per_thread[1]; + + const auto tile_a_outer = kSubgroupLogicalWorkGroupSizeY * elements_per_thread_y; + const auto tile_b_outer = kSubgroupLogicalWorkGroupSizeX * elements_per_thread_x; + + shader.AdditionalImplementation() + << "var mm_Bsub: array, 32>;\n"; + + shader.MainFunctionBody() + << " let workgroupIdXStride = (uniforms.dim_b_outer - 1) / " << tile_b_outer << " + 1;\n" + << " let workgroupIdYStride = (uniforms.dim_a_outer - 1) / " << tile_a_outer << " + 1;\n" + << " let batch = i32(workgroup_idx / (workgroupIdXStride * workgroupIdYStride));\n" + << " let workgroupIdXY = workgroup_idx % (workgroupIdXStride * workgroupIdYStride);\n" + << " let workgroupIdX = workgroupIdXY % workgroupIdXStride;\n" + << " let workgroupIdY = workgroupIdXY / workgroupIdXStride;\n" + << " let tileRow = i32(local_id.x / " << kSubgroupLogicalWorkGroupSizeX << ") * " << elements_per_thread_y << ";\n" + << " let tileCol = i32(local_id.x % " << kSubgroupLogicalWorkGroupSizeX << ");\n" + << " let localRow = i32(local_id.x / " << kSubgroupLogicalWorkGroupSizeX << ");\n" + << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "") + << " let globalRowStart = i32(workgroupIdY) * " << tile_a_outer << " + tileRow;\n" + << " let globalColStart = i32(workgroupIdX) * " << (is_vec4 ? tile_b_outer / elements_per_thread_x : tile_b_outer) << " + tileCol;\n" + << " let numTiles = (uniforms.dim_inner - 1) / 32 + 1;\n" + << " var kStart = 0;\n" + << " var aCol = 0;\n" + << " var BCache: vec4;\n"; + + for (uint32_t i = 0; i < elements_per_thread_y; i++) { + shader.MainFunctionBody() << " var acc_" << i << " = vec4(0);\n" + << " var a_val_" << i << " = a_value_t(0);\n"; + } + + if (need_handle_matmul) { + shader.MainFunctionBody() << " for (var t = 0; t < i32(numTiles); t++) {\n" + << LoadBStr(batch_dims, tile_b_outer, is_vec4) + << " aCol = kStart + tileCol % i32(sg_size);\n" + << CalculateAccStr(batch_dims, elements_per_thread_y, is_vec4) + << " kStart = kStart + 32;\n" + << " workgroupBarrier();\n" + << " }\n"; // main for loop + + // Calculate alpha * acc + if (alpha != 1.0f) { + for (uint32_t i = 0; i < elements_per_thread_y; i++) { + shader.MainFunctionBody() << " acc_" << i << " *= output_element_t(uniforms.alpha);\n"; + } + } + } + + // Write the results to the output buffer + if (is_vec4) { + for (uint32_t i = 0; i < elements_per_thread_y; i++) { + shader.MainFunctionBody() << " mm_write(batch, globalRowStart + " << i + << ", globalColStart, acc_" << i << ");\n"; + } + } else { + for (uint32_t i = 0; i < elements_per_thread_y; i++) { + for (uint32_t j = 0; j < elements_per_thread_x; j++) { + shader.MainFunctionBody() << " " + << "mm_write(batch, globalRowStart + " << i << ", globalColStart + " + << j * kSubgroupLogicalWorkGroupSizeX << ", acc_" << i << "[" << j << "]);\n"; + } + } + } + + return Status::OK(); +} + +} // namespace intel +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/vendor/intel/math/gemm_subgroup.h b/onnxruntime/core/providers/webgpu/vendor/intel/math/gemm_subgroup.h new file mode 100644 index 0000000000000..89dca023d3e1b --- /dev/null +++ b/onnxruntime/core/providers/webgpu/vendor/intel/math/gemm_subgroup.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace webgpu { +namespace intel { + +const uint32_t kSubgroupLogicalWorkGroupSizeX = 32; +const uint32_t kSubgroupLogicalWorkGroupSizeY = 8; +const uint32_t kSubgroupLogicalWorkGroupSizeZ = 1; + +bool CanApplySubgroup(const ComputeContext& context, int64_t M, int64_t N, int64_t K, bool transA = false, bool transB = false); + +int64_t ElementsPerThreadY(bool is_vec4, uint32_t M); + +Status MakeMatMulSubgroupSource(ShaderHelper& shader, + const InlinedVector& elements_per_thread, + const ShaderIndicesHelper* batch_dims, + bool is_vec4, + bool transpose_a = false, + bool transpose_b = false, + float alpha = 1.0f, + bool need_handle_matmul = true); + +} // namespace intel +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.cc b/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.cc new file mode 100644 index 0000000000000..20874522daa20 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.cc @@ -0,0 +1,135 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/webgpu/webgpu_utils.h" +#include "core/providers/webgpu/math/matmul_utils.h" +#include "core/providers/webgpu/vendor/intel/math/gemm_subgroup.h" +#include "core/providers/webgpu/math/gemm_utils.h" +#include "core/providers/webgpu/vendor/intel/math/matmul.h" + +namespace onnxruntime { +namespace webgpu { +namespace intel { + +Status MatMulSubgroupProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | + ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& b = shader.AddInput("b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | + ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | + ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& batch_dims = shader.AddIndices("batch_dims", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + + const ShaderVariableHelper* bias = nullptr; + if (has_bias_) { + bias = &shader.AddInput("bias", ShaderUsage::UseUniform); + } + std::string apply_activation = GetActivationSnippet(activation_, "output_value_t", "output_element_t"); + // declare the read and write functions + MatMulReadFnSource(shader, a, b, &batch_dims, /*transA = */ false, /*transB = */ false); + MatMulWriteFnSource(shader, output, bias, /* is_gemm = */ false, 1, + false, apply_activation, /*is_channels_last = */ false); + // generate the main function + ORT_RETURN_IF_ERROR(MakeMatMulSubgroupSource(shader, elements_per_thread_, &batch_dims, is_vec4_)); + return Status::OK(); +} + +bool CanApplyMatMulIntel(const ComputeContext& context, int64_t M, int64_t N, int64_t K) { + return CanApplySubgroup(context, M, N, K); +} + +Status ApplyMatMulIntel(ComputeContext& context, + const Activation& activation, + std::vector& inputs, + Tensor* output) { + const auto* a = inputs[0]; + const auto* b = inputs[1]; + bool has_bias = inputs.size() > 2; + TensorShape a_shape = a->Shape(); + TensorShape b_shape = b->Shape(); + + MatMulComputeHelper helper; + ORT_THROW_IF_ERROR(helper.Compute(a_shape, b_shape)); + int64_t batchA = a_shape.SizeToDimension(a_shape.NumDimensions() - 2); + int64_t batchB = b_shape.SizeToDimension(b_shape.NumDimensions() - 2); + + TensorShape output_shape = helper.OutputShape(); + + const int64_t dim_output_outer = output_shape[output_shape.NumDimensions() - 2]; + // check if A is batch of vector (bach is not 1, M is 1) and B is a matrix (batch is 1) + if (batchA != 1 && dim_output_outer == 1 && batchB == 1) { + // optimization for batched vector matrix multiplication + // dimensions of A: [1,`batchA`,K] + TensorShapeVector dims_a = {1, batchA, helper.K()}; + // dimensions of B: [1,K,N] + TensorShapeVector dims_b = {1, helper.K(), helper.N()}; + + a_shape = TensorShape(dims_a); + b_shape = TensorShape(dims_b); + output_shape = {1, batchA, helper.N()}; + } + + // helpful dimension variables + TensorShape outer_dims_a = a_shape.NumDimensions() > 2 + ? a_shape.Slice(0, a_shape.NumDimensions() - 2) + : TensorShape({}); + + TensorShape outer_dims_b = b_shape.NumDimensions() > 2 + ? b_shape.Slice(0, b_shape.NumDimensions() - 2) + : TensorShape({}); + + TensorShape outer_dims = output_shape.NumDimensions() > 2 + ? output_shape.Slice(0, output_shape.NumDimensions() - 2) + : TensorShape({}); + + const int64_t batch_size = outer_dims.Size(); + + // Get dimensions for matrix multiplication from TensorShape + const uint32_t dim_a_outer = narrow(a_shape[a_shape.NumDimensions() - 2]); // left matrix second dimension + const uint32_t dim_inner = narrow(a_shape[a_shape.NumDimensions() - 1]); // left matrix first dimension + const uint32_t dim_b_outer = narrow(b_shape[b_shape.NumDimensions() - 1]); // right matrix first dimension + + // Always access A with 1-component when using subgroup. + const bool is_vec4 = dim_b_outer % 4 == 0; + InlinedVector elements_per_thread = InlinedVector({4, intel::ElementsPerThreadY(is_vec4, dim_a_outer), 1}); + + const uint32_t dispatch_x = narrow((dim_b_outer + kSubgroupLogicalWorkGroupSizeX * elements_per_thread[0] - 1) / + (kSubgroupLogicalWorkGroupSizeX * elements_per_thread[0])); + const uint32_t dispatch_y = narrow((dim_a_outer + kSubgroupLogicalWorkGroupSizeY * elements_per_thread[1] - 1) / + (kSubgroupLogicalWorkGroupSizeY * elements_per_thread[1])); + const uint32_t dispatch_z = narrow((static_cast(batch_size) + + kSubgroupLogicalWorkGroupSizeZ * elements_per_thread[2] - 1) / + (kSubgroupLogicalWorkGroupSizeZ * elements_per_thread[2])); + + const int components = is_vec4 ? 4 : 1; + const int a_components = 1; + const int b_components = components; + const TensorShape a_shape_temp = CreateMatMulIntermediateShape(outer_dims_a, dim_a_outer, dim_inner, a_components); + const TensorShape b_shape_temp = CreateMatMulIntermediateShape(outer_dims_b, dim_inner, dim_b_outer, b_components); + const TensorShape output_shape_temp = TensorShape({batch_size, dim_a_outer, dim_b_outer / components}); + + MatMulSubgroupProgram program{activation, has_bias, is_vec4, elements_per_thread}; + program + .CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-")) + .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, a_components}, + {b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, b_components}}) + .AddOutputs({{output, ProgramTensorMetadataDependency::Rank, output_shape_temp, components}}) + .AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}}) + .AddIndices(outer_dims) + .SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z) + .SetWorkgroupSize(kSubgroupLogicalWorkGroupSizeX * kSubgroupLogicalWorkGroupSizeY, 1, 1); + + if (has_bias) { + auto bias_components = 1; + const auto* bias = inputs[2]; + TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components); + program.AddInput({bias, ProgramTensorMetadataDependency::Rank, reduced_bias_shape, bias_components}); + } + + return context.RunProgram(program); +} + +} // namespace intel +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.h b/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.h new file mode 100644 index 0000000000000..2a8333e3e912b --- /dev/null +++ b/onnxruntime/core/providers/webgpu/vendor/intel/math/matmul.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/nn/fuse_utils.h" + +namespace onnxruntime { +namespace webgpu { +namespace intel { + +class MatMulSubgroupProgram final : public Program { + public: + MatMulSubgroupProgram(const Activation& activation, + bool bias, + bool is_vec4, + const gsl::span& elements_per_thread) + : Program{"MatMulSubgroup"}, + activation_(activation), + has_bias_{bias}, + is_vec4_{is_vec4}, + elements_per_thread_(elements_per_thread.begin(), elements_per_thread.end()) {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32}, + {"dim_b_outer", ProgramUniformVariableDataType::Uint32}, + {"dim_inner", ProgramUniformVariableDataType::Uint32}); + + private: + const Activation activation_; + const bool has_bias_; + const bool is_vec4_; + const InlinedVector elements_per_thread_; +}; + +bool CanApplyMatMulIntel(const ComputeContext& context, int64_t M, int64_t N, int64_t K); + +Status ApplyMatMulIntel(ComputeContext& context, + const Activation& activation, + std::vector& inputs, + Tensor* output); + +} // namespace intel +} // namespace webgpu +} // namespace onnxruntime From 21cfa8d392419afda55f9d63b5c717498328bf58 Mon Sep 17 00:00:00 2001 From: quic-calvnguy Date: Mon, 26 Jan 2026 18:29:39 -0800 Subject: [PATCH 07/23] [QNN-EP] Implement file mapped weights feature (#26952) Description Enables the file mapping of weights as well as the overall context bin. This feature is currently only enabled for ARM64 WIN devices Motivation and Context Currently, when reading the context bin, ORT allocates a large buffer on the heap. Assuming the same model is used, each ORT session will allocate a buffer for the context bin. This is incredibly wasteful when large models are used. Instead, WIN file mapping can be leveraged to map the context bin, then every time a context needs to be created with the context bin, the pointer to the context bin can be retrieved and used instead of some pre-allocated buffer, thus making QNN EP more memory-efficient. In the case of multiple ORT sessions, the context bin will only be loaded once for all sessions, increasing memory efficiency and overall initialization performance. This is very useful regarding the use of LLMs going forward. --------- Co-authored-by: quic_calvnguy --- .../qnn/builder/onnx_ctx_model_helper.cc | 14 + .../qnn/builder/qnn_backend_manager.cc | 396 ++++++++++++++++-- .../qnn/builder/qnn_backend_manager.h | 59 +++ .../core/providers/qnn/builder/qnn_def.h | 6 + .../qnn/builder/qnn_file_mapping_interface.h | 25 ++ .../qnn/builder/qnn_windows_file_mapper.cc | 113 +++++ .../qnn/builder/qnn_windows_file_mapper.h | 42 ++ .../providers/qnn/qnn_execution_provider.cc | 37 +- .../providers/qnn/qnn_execution_provider.h | 4 +- .../core/providers/qnn/rpcmem_library.cc | 2 + .../core/providers/qnn/rpcmem_library.h | 15 + onnxruntime/test/perftest/ort_test_session.cc | 6 +- .../test/providers/qnn/qnn_ep_context_test.cc | 108 +++++ .../test/providers/qnn/qnn_test_utils.cc | 26 +- 14 files changed, 801 insertions(+), 52 deletions(-) create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_file_mapping_interface.h create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_windows_file_mapper.cc create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_windows_file_mapper.h diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index d468894080b3d..0e49c0f897bea 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -94,6 +94,7 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, const std::string& context_binary = node_helper.Get(EP_CACHE_CONTEXT, ""); return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast(context_binary.c_str()), static_cast(context_binary.length()), + "", main_context_node.Name(), qnn_models, max_spill_fill_size); @@ -127,6 +128,18 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "The file path in ep_cache_context does not exist or is not accessible."); } + std::string context_binary_path_str = context_binary_path.string(); +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + if (qnn_backend_manager->FileMappingIsEnabled()) { + return qnn_backend_manager->LoadCachedQnnContextFromBuffer(nullptr, + 0, + context_binary_path_str, + main_context_node.Name(), + qnn_models, + max_spill_fill_size); + } +#endif + size_t buffer_size{0}; std::ifstream cache_file(context_binary_path.string().c_str(), std::ifstream::binary); ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to open cache file."); @@ -144,6 +157,7 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, cache_file.close(); return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(), static_cast(buffer_size), + context_binary_path_str, main_context_node.Name(), qnn_models, max_spill_fill_size); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index a39d0d71216b0..9fc1cd7f42939 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -29,6 +29,10 @@ #include "core/providers/qnn/builder/qnn_configs_helper.h" #include "core/providers/qnn/builder/qnn_utils.h" +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE +#include "core/providers/qnn/builder/qnn_windows_file_mapper.h" +#endif + // Flag to determine if Backend should do node validation for each opNode added #define DO_GRAPH_NODE_VALIDATIONS 1 @@ -786,22 +790,148 @@ Status SetQnnContextConfig(ContextPriority context_priority, QnnContext_Config_t return Status::OK(); } +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE +// Callback required for allocating file mapping resources +static Qnn_ErrorHandle_t MapDmaDataCallback(Qnn_ContextBinaryDataRequest_t request, + Qnn_ContextBinaryDmaDataResponse_t* response, void* notify_param) { + if (notify_param == nullptr) { + LOGS_DEFAULT(ERROR) << "MapDmaDataCallback: notify_param is null"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + auto callback_info = reinterpret_cast(notify_param); + + if (callback_info->backend_manager == nullptr) { + LOGS_DEFAULT(ERROR) << "MapDmaDataCallback: QnnBackendManager is null"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + + return callback_info->backend_manager->MapDmaData(request, response, + callback_info->mapped_file_ptr, + callback_info->file_size); +} + +Qnn_ErrorHandle_t QnnBackendManager::MapDmaData(Qnn_ContextBinaryDataRequest_t request, + Qnn_ContextBinaryDmaDataResponse_t* response, + void* const mapped_base_ptr, + const size_t file_size) { + if (!file_mapped_weights_enabled_) { + LOGS(*logger_, WARNING) << "Attempting to map DMA data but file mapping has been disabled, " + << "possibly due to an error in a previous request."; + return QNN_CONTEXT_ERROR_ABORTED; + } + + if (mapped_base_ptr == nullptr) { + LOGS(*logger_, ERROR) << "Attempting to map DMA data for null memory mapped base pointer"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + + LOGS(*logger_, INFO) << "Mapping DMA data for request: memory mapped base pointer(" + << mapped_base_ptr << "), offset(" << request.offset + << "), size(" << request.size << "), total file size(" + << file_size << ") isBackendMappingNeeded(" + << request.isBackendMappingNeeded << ")"; + + auto size = request.size; + if (size == 0 || !request.isBackendMappingNeeded) { + LOGS(*logger_, ERROR) << "Mapping request size must be > 0 with backend mapping required"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + + // offset & size are type uint64_t + // Should never be an issue, but if this occurs then there is something inherently wrong with QNN + if ((UINT64_MAX - request.offset) < size) { + LOGS(*logger_, ERROR) << "Critical error in QNN: mapping request offset + size will overflow 64 bits"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + + // file_size will be promoted to 64 bits on 32-bit systems + if ((request.offset + size) > file_size) { + LOGS(*logger_, ERROR) << "Requested offset and size includes memory outside of mapped file"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + + void* unaligned_data_ptr = static_cast(mapped_base_ptr) + request.offset; + rpcmem_library_->Api().register_buf(unaligned_data_ptr, size, NULL, + rpcmem::RPCMEM_ATTR_IMPORT_BUFFER | rpcmem::RPCMEM_ATTR_READ_ONLY); + + auto fd = rpcmem_library_->Api().to_fd(unaligned_data_ptr); + if (fd == -1) { + LOGS(*logger_, ERROR) << "Failed to register DMA data mapping to RPCMEM"; + return QNN_COMMON_ERROR_SYSTEM; + } + + LOGS(*logger_, INFO) << "Created DMA data mapping with address: " << unaligned_data_ptr; + + response->dmaBuffer.fd = fd; + response->dmaBuffer.data = unaligned_data_ptr; + response->dataStartOffset = 0; + response->alignedSize = size; + + return QNN_SUCCESS; +} + +// Callback required for releasing file mapping resources +static Qnn_ErrorHandle_t ReleaseDmaDataCallback(Qnn_ContextBinaryDmaDataMem_t data_mem, void* notify_param) { + if (notify_param == nullptr) { + LOGS_DEFAULT(ERROR) << "ReleaseDmaDataCallback: notify_param is null"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + + auto callback_info = reinterpret_cast(notify_param); + + if (callback_info->backend_manager == nullptr) { + LOGS_DEFAULT(ERROR) << "ReleaseDmaDataCallback: QnnBackendManager is null"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + + return callback_info->backend_manager->ReleaseDmaData(data_mem, callback_info->mapped_file_ptr); +} + +// Use LOGS_DEFAULT here as this function will be called during destruction of QnnBackendManager +// At time of destruction, usage of logger_ will not be available and will result in a seg fault +Qnn_ErrorHandle_t QnnBackendManager::ReleaseDmaData(Qnn_ContextBinaryDmaDataMem_t data_mem, + void* mapped_base_ptr) { + if (mapped_base_ptr == nullptr) { + LOGS_DEFAULT(ERROR) << "Attempting to release DMA data for null memory mapped pointer"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + + LOGS_DEFAULT(INFO) << "Releasing DMA data mapping for memory mapped pointer(" + << mapped_base_ptr << "), address(" << data_mem.dmaBuffer.data + << "), size: (" << data_mem.memSize << ")"; + + if (data_mem.dmaBuffer.data == nullptr || data_mem.memSize == 0) { + LOGS_DEFAULT(ERROR) << "Mapping release request address must not be null and size must be > 0"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + + // Deregister file mapped data from NPU regardless of file_mapped_weights_enabled_ + // as there may be file mapped data registered to the NPU prior to any mapping error + void* unaligned_data_ptr = data_mem.dmaBuffer.data; + rpcmem_library_->Api().register_buf(unaligned_data_ptr, data_mem.memSize, -1, + rpcmem::RPCMEM_ATTR_IMPORT_BUFFER | rpcmem::RPCMEM_ATTR_READ_ONLY); + + auto fd = rpcmem_library_->Api().to_fd(unaligned_data_ptr); + if (fd != -1) { + LOGS_DEFAULT(ERROR) << "Failed to deregister buffer from RPCMEM: " << unaligned_data_ptr; + return QNN_CONTEXT_ERROR_MEM_ALLOC; + } + return QNN_SUCCESS; +} +#endif // QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + // callback required to add context handles to class list // when using contextCreateFromBinaryListAsync() -void ContextCreateAsyncCallback(Qnn_ContextHandle_t context, - Qnn_GraphHandle_t graph, - const char* graphName, - QnnContext_createFromBinaryAsyncNotifyType_t notifyType, - void* notifyParam, - Qnn_ErrorHandle_t status) { +static void ContextCreateAsyncCallback(Qnn_ContextHandle_t context, + Qnn_GraphHandle_t /* graph */, + const char* /* graph_name */, + QnnContext_createFromBinaryAsyncNotifyType_t /* notify_type */, + void* notify_param, + Qnn_ErrorHandle_t /* status */) { auto qnn_backend_manager = SharedContext::GetInstance().GetSharedQnnBackendManager(); if (context) { - qnn_backend_manager->ProcessContextFromBinListAsync(context, notifyParam); - } - - if (nullptr == graphName || graph || notifyType || status) { - // Avoid compilation unused var warning error + qnn_backend_manager->ProcessContextFromBinListAsync(context, notify_param); } } @@ -825,6 +955,41 @@ void QnnBackendManager::ProcessContextFromBinListAsync(Qnn_ContextHandle_t conte } } +Status QnnBackendManager::GetFileSizeIfValid(const std::string& filepath, + size_t& file_size) { + std::error_code ec; + ORT_RETURN_IF(!std::filesystem::exists(filepath, ec), "Context binary does not exist: ", filepath); + ORT_RETURN_IF(ec, "Failed to read file: ", filepath, + ", error: ", ec.message()); + + auto size = std::filesystem::file_size(filepath, ec); + ORT_RETURN_IF(ec, "Failed to retrieve size of file: ", filepath, + ", error: ", ec.message()); + + ORT_RETURN_IF(size == 0, "File is empty: ", filepath); + ORT_RETURN_IF(size > SIZE_MAX, "File (", filepath, ") file size (", size, + " bytes) exceeds maximum value of size_t for this platform (", SIZE_MAX, " bytes)."); + + file_size = static_cast(size); + return Status::OK(); +} + +Status QnnBackendManager::ReadContextBinIfValid(const std::string& context_bin_filepath, + std::vector& buffer) { + size_t buffer_size; + ORT_RETURN_IF_ERROR(GetFileSizeIfValid(context_bin_filepath, buffer_size)); + + buffer.resize(buffer_size); + + std::ifstream cache_file(context_bin_filepath.c_str(), std::ifstream::binary); + ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to read context binary from: ", context_bin_filepath); + + const auto& read_result = cache_file.read(buffer.data(), buffer_size); + ORT_RETURN_IF(!read_result, "Failed to read contents from cached context file."); + + return Status::OK(); +} + Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unordered_map>>& context_bin_map) { #if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 26) QnnContext_Config_t context_config_resource_sharing = QNN_CONTEXT_CONFIG_INIT; @@ -861,10 +1026,27 @@ Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unord #endif nullptr}; +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + if (file_mapped_weights_enabled_ && file_mapper_) { + // Retry logic -- if context creation failed with file mapped weights, then retry with feature disabled + auto res = CreateContextFromListAsyncWithCallback(configs, context_bin_map); + if (!res.IsOK()) { + LOGS(*logger_, WARNING) << res.ErrorMessage() << ". Retrying with feature disabled."; + } else { + return Status::OK(); + } + } +#endif + return CreateContextFromListAsync(configs, context_bin_map); +} + +Status QnnBackendManager::CreateContextFromListAsync(const QnnContext_Config_t** configs, + std::unordered_map>>& context_bin_map) { std::vector context_params_list; std::vector context_paramsv1_list; std::vector context_params_ptr_list; - std::vector> buffer_list; + std::vector> buffer_list; context_params_list.reserve(context_bin_map.size()); context_params_ptr_list.reserve(context_bin_map.size() + 1); @@ -872,22 +1054,14 @@ Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unord for (auto& it : context_bin_map) { auto context_bin_filepath = it.first; - std::ifstream cache_file(context_bin_filepath.c_str(), std::ifstream::binary); - ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to retrieve context binary from: ", context_bin_filepath); - - cache_file.seekg(0, cache_file.end); - size_t buffer_size = static_cast(cache_file.tellg()); - ORT_RETURN_IF(0 == buffer_size, "Empty cache file encountered."); + std::vector buffer; + ORT_RETURN_IF_ERROR(ReadContextBinIfValid(context_bin_filepath, buffer)); - cache_file.seekg(0, cache_file.beg); - std::unique_ptr buffer = std::make_unique(buffer_size); - ORT_RETURN_IF(nullptr == buffer, "Failed to allocate memory for cache file."); - const auto& read_result = cache_file.read(buffer.get(), buffer_size); - ORT_RETURN_IF(!read_result, "Failed to read contents from cached context file."); + size_t buffer_size = buffer.size(); + buffer_list.push_back(std::move(buffer)); - cache_file.close(); QnnContext_ParamsV1_t context_params_v1 = {nullptr, - buffer.get(), + buffer_list.back().data(), buffer_size, nullptr, ContextCreateAsyncCallback, @@ -896,7 +1070,6 @@ Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unord QnnContext_Params_t context_params = {QnnContext_ParamsVersion_t::QNN_CONTEXT_PARAMS_VERSION_1, {context_params_v1}}; - buffer_list.push_back(std::move(buffer)); context_params_list.push_back(std::move(context_params)); context_paramsv1_list.push_back(std::move(context_params_v1)); context_params_ptr_list.push_back(&context_params_list.back()); @@ -908,15 +1081,76 @@ Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unord configs, nullptr); - context_params_ptr_list.clear(); - context_paramsv1_list.clear(); - context_params_list.clear(); - buffer_list.clear(); - ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to create context. Error: ", QnnErrorHandleToString(result), ", Code:", result); return Status::OK(); } +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE +Status QnnBackendManager::CreateContextFromListAsyncWithCallback(const QnnContext_Config_t** configs, + std::unordered_map>>& context_bin_map) { + std::vector context_params_list; + std::vector context_paramsv2_list; + std::vector context_callbacks_list; + std::vector context_params_ptr_list; + + context_params_list.reserve(context_bin_map.size()); + context_paramsv2_list.reserve(context_bin_map.size()); + context_callbacks_list.reserve(context_bin_map.size()); + context_params_ptr_list.reserve(context_bin_map.size() + 1); + + for (auto& it : context_bin_map) { + auto context_bin_filepath = it.first; + + size_t buffer_size; + ORT_RETURN_IF_ERROR(GetFileSizeIfValid(context_bin_filepath, buffer_size)); + + void* buffer; + ORT_RETURN_IF_ERROR(file_mapper_->GetContextBinMappedMemoryPtr(context_bin_filepath, &buffer)); + + auto notify_param_ptr = std::make_unique(buffer, buffer_size, this); + + Qnn_ContextBinaryCallback_t context_file_map_callbacks; + context_file_map_callbacks.type = QNN_CONTEXT_CALLBACK_DMA_BUFFER; + context_file_map_callbacks.dmaBufferCallback.version = QNN_CONTEXT_CALLBACK_DMA_BUFFER_VERSION_1; + context_file_map_callbacks.dmaBufferCallback.v1.dataProvide = MapDmaDataCallback; + context_file_map_callbacks.dmaBufferCallback.v1.dataRelease = ReleaseDmaDataCallback; + context_file_map_callbacks.dmaBufferCallback.v1.notifyParam = reinterpret_cast(notify_param_ptr.get()); + + file_mapping_notify_params_.push_back(std::move(notify_param_ptr)); + context_callbacks_list.push_back(std::move(context_file_map_callbacks)); + + // Callbacks require QnnContext_ParamsV2_t which is new to QNN API 2.32 + QnnContext_ParamsV2_t context_params_v2 = {nullptr, + buffer, + buffer_size, + nullptr, + ContextCreateAsyncCallback, + it.second.get(), + &context_callbacks_list.back()}; + + QnnContext_Params_t context_params = {QnnContext_ParamsVersion_t::QNN_CONTEXT_PARAMS_VERSION_2, + {}}; + + context_paramsv2_list.push_back(std::move(context_params_v2)); + + context_params.v2 = &context_paramsv2_list.back(); + context_params_list.push_back(std::move(context_params)); + context_params_ptr_list.push_back(&(context_params_list.back())); + } + context_params_ptr_list.push_back(nullptr); + auto result = qnn_interface_.contextCreateFromBinaryListAsync(backend_handle_, + device_handle_, + context_params_ptr_list.data(), + configs, + nullptr); + + ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to create context with file mapping enabled. Error: ", + QnnErrorHandleToString(result), ", Code:", result); + return Status::OK(); +} +#endif // QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + Status QnnBackendManager::SetContextPriority(ContextPriority context_priority) { QnnContext_Config_t context_priority_config = QNN_CONTEXT_CONFIG_INIT; ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority, context_priority_config)); @@ -1114,6 +1348,7 @@ Status QnnBackendManager::GetMaxSpillFillBufferSize(unsigned char* buffer, } Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, + const std::string& context_bin_filepath, std::string node_name, QnnModelLookupTable& qnn_models, int64_t max_spill_fill_size) { @@ -1122,6 +1357,24 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t nullptr == qnn_sys_interface_.systemContextFree; ORT_RETURN_IF(result, "Failed to get valid function pointer."); + void* bin_buffer = nullptr; +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + if (file_mapped_weights_enabled_) { + ORT_RETURN_IF(!file_mapper_, "Attemping to use File Mapping feature but file_mapper_ is uninitialized"); + + ORT_RETURN_IF_ERROR(GetFileSizeIfValid(context_bin_filepath, buffer_length)); + + ORT_RETURN_IF(buffer_length == 0, "Context bin has a size of 0 bytes: ", context_bin_filepath); + ORT_RETURN_IF_ERROR(file_mapper_->GetContextBinMappedMemoryPtr(context_bin_filepath, &bin_buffer)); + + } else { + ORT_RETURN_IF(buffer == nullptr, "Attempting to load QNN context from buffer but buffer is null"); + bin_buffer = static_cast(buffer); + } +#else + bin_buffer = static_cast(buffer); +#endif + QnnSystemContext_Handle_t sys_ctx_handle = nullptr; auto rt = qnn_sys_interface_.systemContextCreate(&sys_ctx_handle); ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create system handle."); @@ -1129,7 +1382,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t const QnnSystemContext_BinaryInfo_t* binary_info = nullptr; Qnn_ContextBinarySize_t binary_info_size{0}; rt = qnn_sys_interface_.systemContextGetBinaryInfo(sys_ctx_handle, - static_cast(buffer), + bin_buffer, buffer_length, &binary_info, &binary_info_size); @@ -1204,6 +1457,26 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary, "Invalid function pointer for contextCreateFromBinary."); +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + Qnn_ContextBinaryCallback_t callbacks; + if (file_mapped_weights_enabled_ && file_mapper_) { + ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinaryWithCallback, + "Invalid function pointer for contextCreateFromBinaryWithCallback."); + + auto notify_param_ptr = std::make_unique(bin_buffer, buffer_length, this); + + callbacks.type = QNN_CONTEXT_CALLBACK_DMA_BUFFER; + callbacks.dmaBufferCallback.version = QNN_CONTEXT_CALLBACK_DMA_BUFFER_VERSION_1; + callbacks.dmaBufferCallback.v1.dataProvide = MapDmaDataCallback; + callbacks.dmaBufferCallback.v1.dataRelease = ReleaseDmaDataCallback; + callbacks.dmaBufferCallback.v1.notifyParam = reinterpret_cast(notify_param_ptr.get()); + + file_mapping_notify_params_.push_back(std::move(notify_param_ptr)); + } +#else + ORT_UNUSED_PARAMETER(context_bin_filepath); +#endif + qnn::profile::ProfilingInfo profiling_info; #ifdef QNN_SYSTEM_PROFILE_API_ENABLED if (ProfilingEnabled()) { @@ -1211,13 +1484,41 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t } #endif - rt = qnn_interface_.contextCreateFromBinary(backend_handle_, - device_handle_, - context_configs, - static_cast(buffer), - buffer_length, - &context, - profile_backend_handle_); +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + std::vector backup_buffer; + if (file_mapped_weights_enabled_ && file_mapper_) { + rt = qnn_interface_.contextCreateFromBinaryWithCallback(backend_handle_, + device_handle_, + context_configs, + &callbacks, + bin_buffer, + buffer_length, + &context, + profile_backend_handle_, + NULL); + + if (rt != QNN_SUCCESS) { + LOGS(*logger_, WARNING) << "Failed to create context with file mapping enabled. Error: " + << QnnErrorHandleToString(rt) << ", Code : " << rt + << ". Retrying with feature disabled."; + + // Read context bin from file since file mapping has failed + ORT_RETURN_IF_ERROR(ReadContextBinIfValid(context_bin_filepath, backup_buffer)); + + bin_buffer = static_cast(backup_buffer.data()); + } + } +#endif // QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + + if (!file_mapped_weights_enabled_ || rt != QNN_SUCCESS) { + rt = qnn_interface_.contextCreateFromBinary(backend_handle_, + device_handle_, + context_configs, + bin_buffer, + buffer_length, + &context, + profile_backend_handle_); + } #ifdef QNN_SYSTEM_PROFILE_API_ENABLED if (ProfilingEnabled()) { @@ -1265,6 +1566,8 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool need_load_system_lib, bool share_ep_contexts, bool enable_vtcm_backup_buffer_sharing, + bool enable_file_mapped_weights, + std::shared_ptr rpcmem_library, std::unordered_map>>& context_bin_map) { std::lock_guard lock(logger_recursive_mutex_); if (backend_setup_completed_) { @@ -1304,6 +1607,20 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, } else { status = LoadQnnSerializerBackend(); } + +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + // Backend is determined after LoadBackend() or LoadQnnSerializerBackend() + if (enable_file_mapped_weights && !file_mapper_ && GetQnnBackendType() == QnnBackendType::HTP) { + ORT_RETURN_IF(!rpcmem_library, "RPCMem Library is required for file mapping but is uninitialized."); + rpcmem_library_ = rpcmem_library; + file_mapped_weights_enabled_ = true; + file_mapper_ = std::make_unique(logger); + } +#else + ORT_UNUSED_PARAMETER(enable_file_mapped_weights); + ORT_UNUSED_PARAMETER(rpcmem_library); +#endif + if (status.IsOK()) { LOGS(logger, VERBOSE) << "LoadBackend succeed."; } @@ -1545,7 +1862,6 @@ void QnnBackendManager::ReleaseResources() { } backend_setup_completed_ = false; - return; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index aff55321d18fd..9b573531f7c3d 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -25,6 +25,7 @@ #include "System/QnnSystemInterface.h" #include "core/providers/qnn/ort_api.h" +#include "core/providers/qnn/rpcmem_library.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_context_mem_handle_manager.h" #include "core/providers/qnn/builder/qnn_def.h" @@ -32,6 +33,10 @@ #include "core/providers/qnn/builder/qnn_profile_serializer.h" #include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE +#include "core/providers/qnn/builder/qnn_file_mapping_interface.h" +#endif + namespace onnxruntime { namespace qnn { @@ -154,6 +159,7 @@ class QnnBackendManager : public std::enable_shared_from_this std::unique_ptr GetContextBinaryBuffer(uint64_t& written_buffer_size); Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, + const std::string& context_bin_filepath, std::string node_name, std::unordered_map>& qnn_models, int64_t max_spill_fill_size); @@ -163,6 +169,8 @@ class QnnBackendManager : public std::enable_shared_from_this Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context, bool need_load_system_lib, bool share_ep_contexts, bool enable_vtcm_backup_buffer_sharing, + bool enable_file_mapped_weights, + std::shared_ptr rpcmem_library, std::unordered_map>>& context_bin_map); Status CreateHtpPowerCfgId(uint32_t deviceId, uint32_t coreId, uint32_t& htp_power_config_id); @@ -248,9 +256,34 @@ class QnnBackendManager : public std::enable_shared_from_this bool ProfilingEnabled() { return profiling_enabled_; } #endif + bool FileMappingIsEnabled() { + return file_mapped_weights_enabled_; + } + +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + Qnn_ErrorHandle_t MapDmaData(Qnn_ContextBinaryDataRequest_t request, + Qnn_ContextBinaryDmaDataResponse_t* response, + void* const mapped_base_ptr, + const size_t file_size); + + Qnn_ErrorHandle_t ReleaseDmaData(Qnn_ContextBinaryDmaDataMem_t data_mem, void* mapped_base_ptr); +#endif + QnnLog_Level_t MapOrtSeverityToQNNLogLevel(logging::Severity ort_log_level); static logging::Severity MapQNNLogLevelToOrtSeverity(QnnLog_Level_t qnn_log_level); +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + typedef struct FileMappingCallbackInfo { + void* const mapped_file_ptr; + const size_t file_size; + QnnBackendManager* const backend_manager; + + FileMappingCallbackInfo(void* ptr, size_t size, QnnBackendManager* manager) + : mapped_file_ptr(ptr), file_size(size), backend_manager(manager) {} + + } FileMappingCallbackInfo_t; +#endif + private: Status LoadBackend(); @@ -268,9 +301,24 @@ class QnnBackendManager : public std::enable_shared_from_this Status CreateContext(bool enable_htp_weight_sharing); + Status GetFileSizeIfValid(const std::string& filepath, size_t& file_size); + + Status ReadContextBinIfValid(const std::string& context_bin_filepath, + std::vector& buffer); + Status CreateContextVtcmBackupBufferSharingEnabled(std::unordered_map>>& context_bin_map); + Status CreateContextFromListAsync(const QnnContext_Config_t** configs, + std::unordered_map>>& context_bin_map); + +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + Status CreateContextFromListAsyncWithCallback(const QnnContext_Config_t** configs, + std::unordered_map>>& context_bin_map); +#endif + Status ReleaseContext(); // Sets the ORT logger and creates a corresponding QNN logger with the same log level. @@ -455,6 +503,15 @@ class QnnBackendManager : public std::enable_shared_from_this bool context_created_ = false; bool backend_setup_completed_ = false; bool vtcm_backup_buffer_sharing_enabled_ = false; + bool file_mapped_weights_enabled_ = false; + +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + std::unique_ptr file_mapper_ = nullptr; + // Notify params for file mapping must persist throughout lifetime of + // QnnBackendManager for release of DMA data callback on destruction + std::vector> file_mapping_notify_params_; +#endif + // NPU backend requires quantized model QnnBackendType qnn_backend_type_ = QnnBackendType::CPU; Qnn_ProfileHandle_t profile_backend_handle_ = nullptr; @@ -473,6 +530,8 @@ class QnnBackendManager : public std::enable_shared_from_this // Mapping of thread id to on-run-start/end power configs std::mutex per_thread_power_configs_mutex_; std::unordered_map per_thread_power_configs_; + + std::shared_ptr rpcmem_library_ = nullptr; }; } // namespace qnn diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.h b/onnxruntime/core/providers/qnn/builder/qnn_def.h index 625166f62d166..847de084c49f6 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.h @@ -19,6 +19,12 @@ namespace qnn { #define QNN_SYSTEM_PROFILE_API_ENABLED #endif +#if defined(_WIN32) && (defined(__aarch64__) || defined(_M_ARM64)) +#if QNN_API_VERSION_MAJOR > 2 || ((QNN_API_VERSION_MAJOR) == 2 && (QNN_API_VERSION_MINOR >= 32)) +#define QNN_FILE_MAPPED_WEIGHTS_AVAILABLE +#endif +#endif + // QNN only support subset of POSIX of dlopen/dlsym/dladdr/dlerror/dlclose // except the following flags for dlopen, others should be done only // when we really need them diff --git a/onnxruntime/core/providers/qnn/builder/qnn_file_mapping_interface.h b/onnxruntime/core/providers/qnn/builder/qnn_file_mapping_interface.h new file mode 100644 index 0000000000000..f99cc7b1ee5dd --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_file_mapping_interface.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include + +#include "core/providers/qnn/ort_api.h" +#include "core/providers/qnn/builder/qnn_def.h" + +namespace onnxruntime { +namespace qnn { + +class FileMappingInterface { + public: + virtual ~FileMappingInterface() = default; + + virtual Status GetContextBinMappedMemoryPtr(const std::string& bin_filepath, + void** mapped_data_ptr) = 0; +}; + +} // namespace qnn +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/qnn/builder/qnn_windows_file_mapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_windows_file_mapper.cc new file mode 100644 index 0000000000000..71f562d59d847 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_windows_file_mapper.cc @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/qnn_windows_file_mapper.h" +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + +#include + +#include + +#include + +#include "core/providers/qnn/ort_api.h" + +namespace onnxruntime { +namespace qnn { + +WindowsFileMapper::WindowsFileMapper(const logging::Logger& logger) + : logger_(&logger) { +} + +WindowsFileMapper::~WindowsFileMapper() { +} + +static void UnmapFile(void* addr) noexcept { + bool successful = UnmapViewOfFile(addr); + if (!successful) { + const auto error_code = GetLastError(); + LOGS_DEFAULT(ERROR) << "Failed to unmap view of file with ptr: " << addr + << ", Error code: " << error_code << ", \"" + << std::system_category().message(error_code) << "\""; + } +} + +Status WindowsFileMapper::GetContextBinMappedMemoryPtr(const std::string& bin_filepath, + void** mapped_data_ptr) { + LOGS(*logger_, INFO) << "Creating context bin file mapping for " + << bin_filepath; + + ORT_RETURN_IF(bin_filepath.empty(), "Context bin file path is empty"); + + std::lock_guard lock(map_mutex_); + auto map_it = mapped_memory_ptrs_.find(bin_filepath); + if (map_it != mapped_memory_ptrs_.end()) { + *mapped_data_ptr = map_it->second.get(); + LOGS(*logger_, INFO) << "Found existing mapview memory pointer (" << mapped_data_ptr + << ") for context bin file: " << bin_filepath; + return Status::OK(); + } + + std::wstring bin_filepath_wstr(bin_filepath.begin(), bin_filepath.end()); + wil::unique_hfile file_handle{CreateFile2(bin_filepath_wstr.c_str(), + GENERIC_READ, + FILE_SHARE_READ, + OPEN_EXISTING, + NULL)}; + if (file_handle.get() == INVALID_HANDLE_VALUE) { + const auto error_code = GetLastError(); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to create file handle for context bin", bin_filepath, + ". Error code: ", error_code, ", \"", + std::system_category().message(error_code), "\""); + } + + LOGS(*logger_, VERBOSE) << "Created file handle (" << file_handle.get() << ") for context bin: " + << bin_filepath; + + wil::unique_hfile file_mapping_handle{CreateFileMappingW(file_handle.get(), + nullptr, + PAGE_READONLY, + 0x00, + 0x00, + nullptr)}; + if (file_mapping_handle.get() == INVALID_HANDLE_VALUE) { + const auto error_code = GetLastError(); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to create file mapping handle for context bin", + bin_filepath, ". Error code: ", error_code, ", \"", + std::system_category().message(error_code), "\""); + } + + LOGS(*logger_, VERBOSE) << "Created file mapping with handle (" << file_mapping_handle.get() + << ") for context bin:" << bin_filepath; + + void* const mapped_base_ptr = MapViewOfFile(file_mapping_handle.get(), + FILE_MAP_READ, + 0, 0, 0); + + if (mapped_base_ptr == nullptr) { + const auto error_code = GetLastError(); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to retrieve mapview pointer for context bin", + bin_filepath, ". Error code: ", error_code, ", \"", + std::system_category().message(error_code), "\""); + } + + LOGS(*logger_, INFO) << "Created mapview pointer with address " << mapped_base_ptr + << " for context bin " << bin_filepath; + + onnxruntime::Env::MappedMemoryPtr mapped_memory_ptr{reinterpret_cast(mapped_base_ptr), + [mapped_base_ptr](void*) { + UnmapFile(mapped_base_ptr); + }}; + + *mapped_data_ptr = mapped_memory_ptr.get(); + mapped_memory_ptrs_.emplace(bin_filepath, std::move(mapped_memory_ptr)); + + return Status::OK(); +} +} // namespace qnn +} // namespace onnxruntime + +#endif // QNN_FILE_MAPPED_WEIGHTS_AVAILABLE diff --git a/onnxruntime/core/providers/qnn/builder/qnn_windows_file_mapper.h b/onnxruntime/core/providers/qnn/builder/qnn_windows_file_mapper.h new file mode 100644 index 0000000000000..742255b26f07d --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_windows_file_mapper.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/qnn/builder/qnn_file_mapping_interface.h" +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + +#include +#include +#include +#include + +#include + +#include "core/providers/qnn/ort_api.h" + +namespace onnxruntime { +namespace qnn { + +class WindowsFileMapper : public FileMappingInterface { + public: + explicit WindowsFileMapper(const logging::Logger& logger); + ~WindowsFileMapper() override; + + // Creates a file mapping of the context binary and returns the + // mapview pointer of the file mapping + Status GetContextBinMappedMemoryPtr(const std::string& bin_filepath, + void** mapped_data_ptr) override; + + private: + // A container of smart pointers of mapview memory pointers to mapped context bins + // key: filepath to context bin, value: smart pointer of mapview memory pointers + std::mutex map_mutex_; + std::unordered_map mapped_memory_ptrs_; + const logging::Logger* logger_; +}; + +} // namespace qnn +} // namespace onnxruntime + +#endif // QNN_FILE_MAPPED_WEIGHTS_AVAILABLE diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index a912bf24ac32a..a6f1d1c1681cf 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -475,6 +475,21 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio #endif } + static const std::string DISABLE_FILE_MAPPED_WEIGHTS = "disable_file_mapped_weights"; + auto disable_file_mapped_weights_pos = provider_options_map.find(DISABLE_FILE_MAPPED_WEIGHTS); + if (disable_file_mapped_weights_pos != provider_options_map.end()) { + if ("1" == disable_file_mapped_weights_pos->second) { + enable_file_mapped_weights_ = false; + } + LOGS_DEFAULT(VERBOSE) << "User specified disable_file_mapped_weights: " << enable_file_mapped_weights_; + } + +#ifndef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + enable_file_mapped_weights_ = false; + LOGS_DEFAULT(WARNING) << "File mapped weights feature is only available on Windows arm64 devices for QNN API versions >= 2.32. " + << "Feature will be disabled by default"; +#endif + static const std::string QNN_DEVICE_ID = "device_id"; auto dev_id_pos = provider_options_map.find(QNN_DEVICE_ID); if (dev_id_pos != provider_options_map.end()) { @@ -552,11 +567,24 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } static const std::string QNN_HTP_SHARED_MEMORY_ALLOCATOR_ENABLED = "enable_htp_shared_memory_allocator"; - if (ParseBoolOption(QNN_HTP_SHARED_MEMORY_ALLOCATOR_ENABLED, false, provider_options_map)) { + enable_htp_shared_mem_allocator_ = ParseBoolOption(QNN_HTP_SHARED_MEMORY_ALLOCATOR_ENABLED, false, provider_options_map); + if (enable_htp_shared_mem_allocator_) { // Initialize rpcmem_library_. // This is necessary for HtpSharedMemoryAllocator to function and also indicates that the allocator is available. rpcmem_library_ = std::make_shared(); - model_settings_.htp_shared_memory = true; + model_settings_.htp_shared_memory = enable_htp_shared_mem_allocator_; + } + + if (enable_file_mapped_weights_ && !rpcmem_library_) { + // Attempt to init rpcmem_library_ if needed. If this fails, then + // disable file mapped weights and proceed with normal operation + try { + rpcmem_library_ = std::make_shared(); + } catch (const std::exception& e) { + LOGS_DEFAULT(WARNING) << "Unable to load RPCMem library: " << e.what() + << " - Disabling file mapped weights."; + enable_file_mapped_weights_ = false; + } } dump_json_qnn_graph_ = ParseBoolOption("dump_json_qnn_graph", false, provider_options_map); @@ -947,7 +975,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer } std::unordered_map>> context_bin_map; - if (enable_vtcm_backup_buffer_sharing_) { + if (enable_vtcm_backup_buffer_sharing_ || enable_file_mapped_weights_) { std::unordered_set ep_ctx_nodes; GetMainEPCtxNodes(graph_viewer, ep_ctx_nodes, logger); @@ -960,7 +988,6 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer NodeAttrHelper node_helper(*ep_ctx_node); std::string context_bin_filepath(parent_path.string()); context_bin_filepath.append("/").append(node_helper.Get(qnn::EP_CACHE_CONTEXT, "")); - if (context_bin_map.find(context_bin_filepath) == context_bin_map.end()) { context_bin_map.emplace(context_bin_filepath, std::make_unique>()); // Push context bin filepath for lookup between sessions @@ -977,6 +1004,8 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer context_cache_enabled_ && enable_spill_fill_buffer_, share_ep_contexts_, enable_vtcm_backup_buffer_sharing_, + enable_file_mapped_weights_, + rpcmem_library_, context_bin_map); context_bin_map.clear(); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index dd301d7915935..f7022229f6c7b 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -80,7 +80,7 @@ class QNNExecutionProvider : public IExecutionProvider { qnn::ProfilingLevel GetProfilingLevelFromETWLevel(unsigned char level); - bool IsHtpSharedMemoryAllocatorAvailable() const { return rpcmem_library_ != nullptr; } + bool IsHtpSharedMemoryAllocatorAvailable() const { return enable_htp_shared_mem_allocator_ && rpcmem_library_ != nullptr; } private: // Will return true if any power config options need to be updated @@ -119,6 +119,8 @@ class QNNExecutionProvider : public IExecutionProvider { bool share_ep_contexts_ = false; bool stop_share_ep_contexts_ = false; bool enable_spill_fill_buffer_ = false; + bool enable_file_mapped_weights_ = true; + bool enable_htp_shared_mem_allocator_ = false; #if defined(_WIN32) onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_ = nullptr; #endif diff --git a/onnxruntime/core/providers/qnn/rpcmem_library.cc b/onnxruntime/core/providers/qnn/rpcmem_library.cc index 20918f8bc6de1..f89a15157ddf4 100644 --- a/onnxruntime/core/providers/qnn/rpcmem_library.cc +++ b/onnxruntime/core/providers/qnn/rpcmem_library.cc @@ -165,6 +165,8 @@ RpcMemApi CreateApi(void* library_handle) { ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(library_handle, "rpcmem_to_fd", (void**)&api.to_fd)); + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(library_handle, "remote_register_buf_attr2", (void**)&api.register_buf)); + return api; } diff --git a/onnxruntime/core/providers/qnn/rpcmem_library.h b/onnxruntime/core/providers/qnn/rpcmem_library.h index 2746e147373bb..0f4b5b5391f59 100644 --- a/onnxruntime/core/providers/qnn/rpcmem_library.h +++ b/onnxruntime/core/providers/qnn/rpcmem_library.h @@ -24,6 +24,9 @@ constexpr uint32_t RPCMEM_DEFAULT_FLAGS = 1; constexpr int RPCMEM_HEAP_ID_SYSTEM = 25; +constexpr int RPCMEM_ATTR_IMPORT_BUFFER = 256; +constexpr int RPCMEM_ATTR_READ_ONLY = 512; + /** * Allocate a zero-copy buffer for size upto 2 GB with the FastRPC framework. * Buffers larger than 2 GB must be allocated with rpcmem_alloc2 @@ -46,6 +49,17 @@ using FreeFnPtr = void (*)(void* po); */ using ToFdFnPtr = int (*)(void* po); +/** + * Registers and maps a CPU buffer to RPC memory space + * @param[in] buff Data pointer for a CPU-allocated buffer + * @param[in] size Size of the buffer in bytes + * @param[in] fd File descriptor for a CPU-allocated buffer + * Note: Can be NULL if N/A or -1 to signal deregistration + * @param[in] attr Specified attributes for the buffer + * @return Data pointer for an RPCMEM-allocated buffer + */ +using RegisterBufFnPtr = void (*)(void* buff, size_t size, int fd, int attr); + } // namespace rpcmem // RPCMEM API function pointers. @@ -53,6 +67,7 @@ struct RpcMemApi { rpcmem::AllocFnPtr alloc; rpcmem::FreeFnPtr free; rpcmem::ToFdFnPtr to_fd; + rpcmem::RegisterBufFnPtr register_buf; }; // Loads and provides access to the RPCMEM API functions from a dynamically loaded library. diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index aeabddec33e89..71f9050730c0b 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -258,7 +258,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device "qnn_saver_path", "htp_graph_finalization_optimization_mode", "qnn_context_priority", "htp_arch", "enable_htp_fp16_precision", "offload_graph_io_quantization", "enable_htp_spill_fill_buffer", "enable_htp_shared_memory_allocator", "dump_json_qnn_graph", - "json_qnn_graph_dir", "htp_bf16_enable"}); + "json_qnn_graph_dir", "disable_file_mapped_weights", "htp_bf16_enable", "enable_vtcm_backup_buffer_sharing"}); for (const auto& provider_option : provider_options) { const std::string& key = provider_option.first; const std::string& value = provider_option.second; @@ -322,7 +322,9 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device key == "offload_graph_io_quantization" || key == "enable_htp_spill_fill_buffer" || key == "enable_htp_shared_memory_allocator" || - key == "dump_json_qnn_graph") { + key == "dump_json_qnn_graph" || + key == "disable_file_mapped_weights" || + key == "enable_vtcm_backup_buffer_sharing") { std::set supported_options = {"0", "1"}; if (supported_options.find(value) == supported_options.end()) { std::ostringstream str_stream; diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 4cccccca97804..813abf74828a2 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -1901,6 +1901,114 @@ TEST_F(QnnHTPBackendTests, VTCMBackupBufferSharing) { std::remove(qnn_ctx_binary_file_name1.c_str()); } +TEST_F(QnnHTPBackendTests, FileMapping_Off) { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + provider_options["disable_file_mapped_weights"] = "1"; + + // Create QDQ models + std::vector onnx_model_paths{"./weight_share1.onnx", "./weight_share2.onnx"}; + // cleanup in case some failure test doesn't remove them + for (auto model_path : onnx_model_paths) { + std::remove(model_path.c_str()); + } + + std::vector ctx_model_paths; + for (auto model_path : onnx_model_paths) { + CreateQdqModel(model_path, DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(std::filesystem::exists(model_path.c_str())); + auto pos = model_path.find_last_of("."); + if (pos != std::string::npos) { + model_path = model_path.substr(0, pos) + "_ctx.onnx"; + } else { + model_path = model_path + "_ctx.onnx"; + } + ctx_model_paths.push_back(model_path); + } + for (auto ctx_model_path : ctx_model_paths) { + std::remove(ctx_model_path.c_str()); + } + + DumpModelWithSharedCtx(provider_options, onnx_model_paths[0], onnx_model_paths[1]); + + std::string qnn_ctx_binary_file_name1; + GetContextBinaryFileName(ctx_model_paths[0], qnn_ctx_binary_file_name1, + DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(!qnn_ctx_binary_file_name1.empty()); + + std::string qnn_ctx_binary_file_name2; + GetContextBinaryFileName(ctx_model_paths[1], qnn_ctx_binary_file_name2, + DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(!qnn_ctx_binary_file_name2.empty()); + // 2 *_ctx.onn point to same .bin file + EXPECT_TRUE(qnn_ctx_binary_file_name1 == qnn_ctx_binary_file_name2); + auto file_size_1 = std::filesystem::file_size(qnn_ctx_binary_file_name1); + EXPECT_TRUE(file_size_1 > 0); + + // only load and run the session on real device +#if defined(__aarch64__) || defined(_M_ARM64) + Ort::SessionOptions so1; + so1.SetLogId("so1"); + so1.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); + so1.AppendExecutionProvider("QNN", provider_options); + Ort::SessionOptions so2; + + // Test CreateFromBinaryListAsync path + provider_options["enable_vtcm_backup_buffer_sharing"] = "1"; + so2.SetLogId("so2"); + so2.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); + so2.AppendExecutionProvider("QNN", provider_options); + + EXPECT_TRUE(2 == ctx_model_paths.size()); +#ifdef _WIN32 + std::wstring ctx_model_file1(ctx_model_paths[0].begin(), ctx_model_paths[0].end()); + std::wstring ctx_model_file2(ctx_model_paths[1].begin(), ctx_model_paths[1].end()); +#else + std::string ctx_model_file1(ctx_model_paths[0].begin(), ctx_model_paths[0].end()); + std::string ctx_model_file2(ctx_model_paths[1].begin(), ctx_model_paths[1].end()); +#endif + Ort::Session session1(*ort_env, ctx_model_file1.c_str(), so1); + Ort::Session session2(*ort_env, ctx_model_file2.c_str(), so2); + + std::vector input_names; + std::vector output_names; + GetModelInputNames(ctx_model_paths[1], input_names, output_names, + DefaultLoggingManager().DefaultLogger()); + + // Run sessions + // prepare input + std::vector input_dim{2, 3}; + std::vector input_value(2 * 3, 0.0f); + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + std::vector ort_inputs; + std::vector input_names_c; + for (size_t i = 0; i < input_names.size(); ++i) { + auto input_tensor = Ort::Value::CreateTensor(info, input_value.data(), input_value.size(), + input_dim.data(), input_dim.size()); + ort_inputs.push_back(std::move(input_tensor)); + input_names_c.push_back(input_names[i].c_str()); + } + std::vector output_names_c; + for (size_t i = 0; i < output_names.size(); ++i) { + output_names_c.push_back(output_names[i].c_str()); + } + + auto ort_outputs1 = session1.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); + auto ort_outputs2 = session2.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); +#endif + + for (auto model_path : onnx_model_paths) { + std::remove(model_path.c_str()); + } + for (auto ctx_model_path : ctx_model_paths) { + std::remove(ctx_model_path.c_str()); + } + std::remove(qnn_ctx_binary_file_name1.c_str()); +} + // For Ort sessions to generate the context binary, with session option ep.share_ep_contexts enabled // Ort sessions will share the QnnBackendManager, so that all graphs from all models compile into the same Qnn context TEST_F(QnnHTPBackendTests, QnnContextGenWeightSharingSessionAPI) { diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index 9ad34788444db..a6d43a3d3a9d9 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -408,12 +408,28 @@ static BackendSupport GetHTPSupport(const onnxruntime::logging::Logger& logger) // Create QNN EP and call GetCapability(). MockKernelLookup kernel_lookup; onnxruntime::GraphViewer graph_viewer(graph); - std::unique_ptr qnn_ep = QnnExecutionProviderWithOptions( - {{"backend_type", "htp"}, {"offload_graph_io_quantization", "0"}}); - GraphOptimizerRegistry graph_optimizer_registry(nullptr, nullptr, nullptr); // as a placeholder to feed into GetCapability - qnn_ep->SetLogger(&logger); - auto result = qnn_ep->GetCapability(graph_viewer, kernel_lookup, graph_optimizer_registry, nullptr); + std::vector> result; + std::unique_ptr qnn_ep; + try { + qnn_ep = QnnExecutionProviderWithOptions( + {{"backend_type", "htp"}, {"offload_graph_io_quantization", "0"}, {"enable_htp_shared_memory_allocator", "1"}}); + GraphOptimizerRegistry graph_optimizer_registry(nullptr, nullptr, nullptr); // as a placeholder to feed into GetCapability + + qnn_ep->SetLogger(&logger); + result = qnn_ep->GetCapability(graph_viewer, kernel_lookup, graph_optimizer_registry, nullptr); + } catch (const std::exception& e) { + // handle exception that indicates that the libcdsprpc.so / dll can't be loaded + std::string_view error_message = e.what(); + std::string_view expected_error_message = "Failed to initialize RPCMEM dynamic library handle"; + + if (error_message.find(expected_error_message) != std::string_view::npos) { + return BackendSupport::UNSUPPORTED; + } + + // propagate other exceptions + throw; + } return result.empty() ? BackendSupport::UNSUPPORTED : BackendSupport::SUPPORTED; } From 2fde8b93a3c624c88d356108c531518bd79a9ad9 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Tue, 27 Jan 2026 20:53:43 -0800 Subject: [PATCH 08/23] Add LpNormalization-22 and update the implementation to respect ONNX spec (#27164) I missed the operator since it didn't have the corresponding tests at the time. With https://github.com/onnx/onnx/pull/7618, the disabled test should be able to pass. --- This pull request updates the ONNX Runtime CPU execution provider to add support for the `LpNormalization` operator for opset version 22, in addition to clarifying and correcting the registration for earlier versions. It also updates the backend test filters to reflect this new support. **ONNX Operator Kernel Registration:** * Added new kernel registrations for `LpNormalization` with opset version 22 for both `float` and `double` data types in `cpu_execution_provider.cc`. [[1]](diffhunk://#diff-054ffdd679ada14ebb4b1db27a60b2881e2db48f9dc3f0b948c784cdcdaf4908R1328-R1329) [[2]](diffhunk://#diff-054ffdd679ada14ebb4b1db27a60b2881e2db48f9dc3f0b948c784cdcdaf4908R3389-R3392) * Updated the registration for `LpNormalization` for opset versions 1 through 21 to use the correct versioned kernel macro, ensuring correct kernel selection and compatibility. [[1]](diffhunk://#diff-054ffdd679ada14ebb4b1db27a60b2881e2db48f9dc3f0b948c784cdcdaf4908L197-R198) [[2]](diffhunk://#diff-054ffdd679ada14ebb4b1db27a60b2881e2db48f9dc3f0b948c784cdcdaf4908L1731-R1735) **Test Filters Update:** * Updated `onnx_backend_test_series_filters.jsonc` to remove the exclusion of `test_l1normalization*`, `test_lpnormalization*`, and `test_l2normalization*` now that `LpNormalization` opset 22 is implemented, and added a TODO comment referencing ONNX 1.21 for a known zero-norm issue. [[1]](diffhunk://#diff-abc0f78c2314f9e7648c8081125d0ce9f33b12399520d92d811d73e3c795ed59R32-R33) [[2]](diffhunk://#diff-abc0f78c2314f9e7648c8081125d0ce9f33b12399520d92d811d73e3c795ed59L42) [[3]](diffhunk://#diff-abc0f78c2314f9e7648c8081125d0ce9f33b12399520d92d811d73e3c795ed59L70-L71) --- docs/OperatorKernels.md | 3 ++- .../providers/cpu/cpu_execution_provider.cc | 18 ++++++++++++------ onnxruntime/core/providers/cpu/nn/lp_norm.cc | 12 ++++++++++-- .../onnx_backend_test_series_filters.jsonc | 5 ++--- 4 files changed, 26 insertions(+), 12 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 7cc57a636362f..08840c623b709 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -240,7 +240,8 @@ Do not modify directly.* |||[13, 15]|**B** = tensor(bool)
**I** = tensor(int64)
**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**B** = tensor(bool)
**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[1, 10]|**B** = tensor(bool)
**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|LpNormalization|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float)| +|LpNormalization|*in* input:**T**
*out* output:**T**|22+|**T** = tensor(double), tensor(float)| +|||[1, 21]|**T** = tensor(double), tensor(float)| |LpPool|*in* X:**T**
*out* Y:**T**|22+|**T** = tensor(float)| |||[18, 21]|**T** = tensor(float)| |||[11, 17]|**T** = tensor(float)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 0ad8d1d4fef4d..f6484ab60f5da 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -194,8 +194,8 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, ConvTranspose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, Flatten); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 6, 21, InstanceNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, float, LpNormalization); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, LpNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 21, float, LpNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 21, double, LpNormalization); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 12, LRN); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 9, AveragePool); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 7, MaxPool); @@ -1325,6 +1325,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, Softsign); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, ThresholdedRelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, AveragePool); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, float, LpNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, double, LpNormalization); #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 22, MLFloat16, Conv); @@ -1728,10 +1730,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -3384,6 +3386,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/nn/lp_norm.cc b/onnxruntime/core/providers/cpu/nn/lp_norm.cc index 2286800c9638b..03f85e8ea5705 100644 --- a/onnxruntime/core/providers/cpu/nn/lp_norm.cc +++ b/onnxruntime/core/providers/cpu/nn/lp_norm.cc @@ -7,14 +7,22 @@ #include "core/providers/common.h" namespace onnxruntime { +#define REGISTER_LPNORMALISATION_VERSIONED_KERNEL(type, sinceVersion, endVersion) \ + ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \ + LpNormalization, sinceVersion, endVersion, type, \ + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + LpNorm); + #define REGISTER_LPNORMALISATION_KERNEL(type, sinceVersion) \ ONNX_CPU_OPERATOR_TYPED_KERNEL( \ LpNormalization, sinceVersion, type, \ KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), \ LpNorm); -REGISTER_LPNORMALISATION_KERNEL(float, 1) -REGISTER_LPNORMALISATION_KERNEL(double, 1) +REGISTER_LPNORMALISATION_VERSIONED_KERNEL(float, 1, 21) +REGISTER_LPNORMALISATION_VERSIONED_KERNEL(double, 1, 21) +REGISTER_LPNORMALISATION_KERNEL(float, 22) +REGISTER_LPNORMALISATION_KERNEL(double, 22) using InnerStride = Eigen::InnerStride; diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 54e24cd1e0a83..cd57bc82aabf4 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -29,6 +29,8 @@ // Tests that are failing temporarily and should be fixed "current_failing_tests": [ + // TODO(titaiwang): onnx 1.21 should fix lpnorm zero norm issue + "^test_l2normalization*", // LpNormalization(22) not implemented "^test_adagrad", "^test_adagrad_multiple", "^test_attention_4d_fp16*", // precision issue: 1 / 192 mismatched elements @@ -39,7 +41,6 @@ "^test_attention_4d_with_past_and_present_qk_matmul_bias_4d_mask_causal*", // location of infinities "^test_attention_4d_attn_mask_3d_causal_expanded*", // webgpu "^test_attention_4d_diff_heads_mask4d_padded_kv*", // Need nonpad_kv_seqlen - "^test_l2normalization*", // LpNormalization(22) not implemented // TODO: support the following tests in Attention-cuda "^test_attention_3d_gqa.*_cuda", // GQA not supported in Attention-cuda "^test_attention_4d_gqa.*_cuda", // GQA not supported in Attention-cuda @@ -67,8 +68,6 @@ "^test_attention_4d_attn_mask_4d_causal_cuda", "^test_attention_4d_causal_cuda", "^test_attention_4d_diff_heads_sizes_causal_cuda", - "^test_l1normalization*", // LpNormalization(22) not implemented - "^test_lpnormalization*", // LpNormalization(22) not implemented "^test_tensorscatter*", // TensorScatter(24) not implemented "^test_castlike_no_saturate_FLOAT_to_FLOAT8*", // ORT does not support ml_dtypes "^test_castlike_UINT4_to*", // ORT does not support ml_dtypes From 621e7899875070e830633240ad063ec2f9d94982 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Thu, 29 Jan 2026 03:29:56 +0800 Subject: [PATCH 09/23] [WebGPU EP] Reduce duplicated code in `MatMulReadFnSource()` (#27151) ### Description Previously in `MatMulReadFnSource()` we use duplicated code to read data from two inputs `a` and `b`. This patch implements another overload of `MatMulReadFnSource()` to only read data from one input to reduce duplicated code and get ready for further use. --- .../core/providers/webgpu/math/gemm_utils.cc | 80 ++++++++----------- 1 file changed, 34 insertions(+), 46 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index 0228fb25d1d26..7aa5b6cea4ee7 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -123,68 +123,56 @@ void InitializeLogicalWorkgroupIDAndGlobalID(ShaderHelper& shader) { } // namespace void MatMulReadFnSource(ShaderHelper& shader, - const ShaderVariableHelper& a, - const ShaderVariableHelper& b, + std::string_view function_name, + const ShaderVariableHelper& input, + const std::string& input_name, const ShaderIndicesHelper* batch_dims, - bool transA, - bool transB) { - const int a_components = a.NumComponents(); + std::string_view rows, + std::string_view components_per_row, + bool transpose) { + const int components = input.NumComponents(); const std::string data_type = "output_element_t"; - std::string type_string = MakeScalarOrVectorType(a_components, data_type); + const std::string type_string = MakeScalarOrVectorType(components, data_type); shader.AdditionalImplementation() - << "fn mm_readA(batch: i32, row: i32, colIn: i32 " + << "fn " << function_name << "(batch: i32, row: i32, colIn: i32 " << (batch_dims ? ", batch_indices: batch_dims_indices_t" : "") - << ") -> " << type_string << " {\n" - << " var value = " << type_string << "(0);\n" - << " let col = colIn * " << a_components << ";\n"; - if (transA) { - shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_a_outer)) {\n"; + << ") -> " << type_string << " {\n " + << " var value = " << type_string << "(0);\n" + << " let col = colIn * " << components << ";\n"; + if (transpose) { + shader.AdditionalImplementation() << " if(row < i32(" << components_per_row << ") && col < i32(" << rows << ")) {\n"; } else { - shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_inner)) {\n"; + shader.AdditionalImplementation() << " if(row < i32(" << rows << ") && col < i32(" << components_per_row << ")) {\n"; } - shader.AdditionalImplementation() << " var a_indices: a_indices_t;\n"; - if (batch_dims) { - shader.AdditionalImplementation() << ConvertOutputBatchIndicesToInputBatchIndices("a", a, a.Rank() - 2, batch_dims ? batch_dims->Rank() : 0, " batch_indices "); - } - shader.AdditionalImplementation() << " " << a.IndicesSet("a_indices", a.Rank() - 2, "u32(row)") << "\n" - << " " << a.IndicesSet("a_indices", a.Rank() - 1, "u32(colIn)") << "\n" - << " value = " << a.GetByIndices("a_indices") << ";\n" - << " }\n" - << " return value;\n" - << "}\n\n"; - - // Add the mm_readB function - const int b_components = b.NumComponents(); - type_string = MakeScalarOrVectorType(b_components, data_type); - shader.AdditionalImplementation() - << "fn mm_readB(batch: i32, row: i32, colIn: i32 " - << (batch_dims - ? ", batch_indices: batch_dims_indices_t" - : "") - << ") -> " << type_string << " {\n" - << " var value = " << type_string << "(0);\n" - << " let col = colIn * " << b_components << ";\n"; + const std::string input_indices = input_name + "_indices"; + shader.AdditionalImplementation() << " var " << input_indices << ": " << input_name << "_indices_t" << ";\n"; - if (transB) { - shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_b_outer) && col < i32(uniforms.dim_inner)) {\n"; - } else { - shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_b_outer)) {\n"; + if (batch_dims) { + shader.AdditionalImplementation() << ConvertOutputBatchIndicesToInputBatchIndices(input_name, input, input.Rank() - 2, batch_dims ? batch_dims->Rank() : 0, " batch_indices ") << "\n"; } - shader.AdditionalImplementation() << " var b_indices: b_indices_t;\n" - << ConvertOutputBatchIndicesToInputBatchIndices("b", b, b.Rank() - 2, batch_dims ? batch_dims->Rank() : 0, "batch_indices") - << " " << b.IndicesSet("b_indices", b.Rank() - 2, "u32(row)") << "\n" - << " " << b.IndicesSet("b_indices", b.Rank() - 1, "u32(colIn)") << "\n" - << " value = " << b.GetByIndices("b_indices") << ";\n" - << " }\n" - << " return value;\n" + shader.AdditionalImplementation() << input.IndicesSet(input_indices, input.Rank() - 2, "u32(row)") << "\n" + << input.IndicesSet(input_indices, input.Rank() - 1, "u32(colIn)") << "\n" + << " value = " << input.GetByIndices(input_indices) << ";\n" + << " }\n" + << " return value;\n" << "}\n\n"; } +void MatMulReadFnSource(ShaderHelper& shader, + const ShaderVariableHelper& a, + const ShaderVariableHelper& b, + const ShaderIndicesHelper* batch_dims, + bool transA, + bool transB) { + MatMulReadFnSource(shader, "mm_readA", a, "a", batch_dims, "uniforms.dim_a_outer", "uniforms.dim_inner", transA); + MatMulReadFnSource(shader, "mm_readB", b, "b", batch_dims, "uniforms.dim_inner", "uniforms.dim_b_outer", transB); +} + void MatMulWriteFnSource(ShaderHelper& shader, const ShaderVariableHelper& output, const ShaderVariableHelper* bias, From 6c37921fd28460ff802e5e66dc8316e187603ae9 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 28 Jan 2026 12:08:23 -0800 Subject: [PATCH 10/23] [MLAS] Fix Data Race in MlasLutGemm by Serializing LUT Generation (#27179) ## Problem Description The `MatMulNBitsLutGemm.Float32_2Bits_Asymmetric_Batch32_256x256` test was exhibiting flaky behavior (failure rate ~2-20%) with numerical mismatches. Investigation revealed a **race condition** in the [GenerateLUT](https://github.com/microsoft/onnxruntime/blob/38dfc91f38fe53da9eaf7e9fb9b158904eb3cd5b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp#L326) step within [MlasLutGemm](https://github.com/microsoft/onnxruntime/blob/38dfc91f38fe53da9eaf7e9fb9b158904eb3cd5b/onnxruntime/core/mlas/inc/mlas_qnbit.h#L328). When the batch size `M > 1`, [MlasLutGemm](https://github.com/microsoft/onnxruntime/blob/38dfc91f38fe53da9eaf7e9fb9b158904eb3cd5b/onnxruntime/core/mlas/inc/mlas_qnbit.h#L328) attempted to parallelize the LUT generation over the batch dimension using `MlasTrySimpleParallel`. However, the underlying [GenerateLUT](https://github.com/microsoft/onnxruntime/blob/38dfc91f38fe53da9eaf7e9fb9b158904eb3cd5b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp#L326) implementation (specifically shared usage of `lut_scales`/`lut_biases` or internal buffers) is not thread-safe for concurrent execution on the same destination buffers or related state. This led to corruption of the Look-Up Tables or scales, causing random output errors. ## Solution This PR modifies [onnxruntime/core/mlas/lib/qlutgemm.cpp](https://github.com/microsoft/onnxruntime/blob/38dfc91f38fe53da9eaf7e9fb9b158904eb3cd5b/onnxruntime/core/mlas/lib/qlutgemm.cpp) to **serialize the [GenerateLUT](file:///home/tlwu/onnxruntime/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp#324-355) loop**. Instead of using `MlasTrySimpleParallel`, we now use a simple `for` loop to process each row of the batch sequentially. **Performance Impact:** The [GenerateLUT](https://github.com/microsoft/onnxruntime/blob/38dfc91f38fe53da9eaf7e9fb9b158904eb3cd5b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp#L326) step is computationally lightweight compared to the subsequent [TMACComputeGemm](https://github.com/microsoft/onnxruntime/blob/38dfc91f38fe53da9eaf7e9fb9b158904eb3cd5b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp#L505) matrix multiplication. Serializing this setup step has negligible impact on overall inference latency (micro-benchmarks showed no measurable regression), but effectively eliminates the race condition. ## Verification * **Reproduction:** The issue was reliably reproduced by running `MatMulNBitsLutGemm.Float32_2Bits_Asymmetric_Batch32_256x256` in a loop (failing ~1 in 5 times). * **Verification:** After applying the fix, the same test passed **50/50 iterations** consistently. * **Regression Testing:** Standard `MatMulNBitsLutGemm` tests (including `BlkLen64` and `M=1` cases) continue to pass. --- onnxruntime/core/mlas/lib/qlutgemm.cpp | 48 +++++++++++++------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/onnxruntime/core/mlas/lib/qlutgemm.cpp b/onnxruntime/core/mlas/lib/qlutgemm.cpp index f029e539f02a1..cb099c2409a44 100644 --- a/onnxruntime/core/mlas/lib/qlutgemm.cpp +++ b/onnxruntime/core/mlas/lib/qlutgemm.cpp @@ -548,32 +548,32 @@ MlasLutGemm( // const int num_groups = static_cast(K / BlkLen); - // Parallelize over M (batch dimension) - // Each iteration processes one row of the activation matrix + // Iterate over M (batch dimension) + // Each iteration processes one row of the activation matrix. + // NOTE: This loop is intentionally serialized. Previous attempts to parallelize + // using MlasTrySimpleParallel caused flaky test failures (race conditions) + // when M > 1 (e.g., Batch32 case). Since GenerateLUT is lightweight, + // serial execution ensures correctness with negligible performance impact. // TODO(vraspar): Ideally we have to do block parallelism here - MlasTrySimpleParallel( - threadpool, - static_cast(M), - [&](ptrdiff_t ine11) { - const size_t row_offset = static_cast(ine11) * K; - const size_t lut_offset = static_cast(ine11) * K * 4; // 4 bytes per K element for 2-bit LUT - const size_t scale_bias_offset = static_cast(ine11) * lut_scales_size; - - // Call the dispatch function for this row - // ggml_tmac_mul_mat_task_init - Dispatch->GenerateLUT( - const_cast(a_float + row_offset), // Input activation for this row - qlut + lut_offset, // Output LUT for this row - lut_scales + scale_bias_offset, // Scales for this row - lut_biases + scale_bias_offset, // Biases for this row - M, - K, - N, - tmac_params.act_group_size - ); - } - ); + for (size_t ine11 = 0; ine11 < static_cast(M); ine11++) { + const size_t row_offset = ine11 * K; + const size_t lut_offset = ine11 * K * 4; // 4 bytes per K element for 2-bit LUT + const size_t scale_bias_offset = ine11 * lut_scales_size; + + // Call the dispatch function for this row + // ggml_tmac_mul_mat_task_init + Dispatch->GenerateLUT( + const_cast(a_float + row_offset), // Input activation for this row + qlut + lut_offset, // Output LUT for this row + lut_scales + scale_bias_offset, // Scales for this row + lut_biases + scale_bias_offset, // Biases for this row + M, + K, + N, + tmac_params.act_group_size + ); + } // all relevant LUT's have been generated // equivalent of lut_mul_mat's ggml_backend_tmac_mul_mat function ggml_barrier line From 74443ff158f0ab2ca5c7fb1a12b020ba1d78aee4 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 28 Jan 2026 12:43:39 -0800 Subject: [PATCH 11/23] remove coloredlogs (#27135) See related issues: https://github.com/microsoft/onnxruntime/issues/26889 --- dockerfiles/Dockerfile.source | 2 +- docs/python/requirements.txt | 1 - onnxruntime/python/tools/tensorrt/perf/benchmark.py | 10 +++++----- .../python/tools/tensorrt/perf/benchmark_wrapper.py | 1 - onnxruntime/python/tools/tensorrt/perf/perf_utils.py | 2 -- .../python/tools/tensorrt/perf/requirements.txt | 1 - .../python/tools/transformers/benchmark_helper.py | 9 ++++----- .../tools/transformers/convert_to_packing_mode.py | 9 ++++----- .../models/longformer/benchmark_longformer.py | 2 +- .../transformers/models/longformer/convert_to_onnx.py | 2 +- .../transformers/models/stable_diffusion/benchmark.py | 5 ++--- .../models/stable_diffusion/demo_txt2img.py | 5 +++-- .../models/stable_diffusion/demo_txt2img_xl.py | 5 +++-- .../models/stable_diffusion/optimize_pipeline.py | 3 +-- .../stable_diffusion/requirements/requirements.txt | 1 - .../Inference_GPT2_with_OnnxRuntime_on_CPU.ipynb | 4 ++-- .../notebooks/PyTorch_Bert-Squad_OnnxRuntime_GPU.ipynb | 10 +++++----- ...Tensorflow_Tf2onnx_Bert-Squad_OnnxRuntime_CPU.ipynb | 4 ++-- onnxruntime/python/tools/transformers/optimizer.py | 8 +++----- onnxruntime/python/tools/transformers/requirements.txt | 1 - onnxruntime/python/tools/transformers/run_benchmark.sh | 2 +- .../test/python/transformers/test_gpt2_benchmark.py | 3 +-- .../test/python/transformers/test_gpt2_to_onnx.py | 3 +-- requirements.txt | 1 - .../linux-gpu-tensorrt-daily-perf-pipeline.yml | 6 +++--- .../templates/py-package-smoking-test.yml | 2 +- tools/ci_build/github/windows/python/requirements.txt | 1 - .../requirements/transformers-test/requirements.txt | 1 - 28 files changed, 44 insertions(+), 60 deletions(-) diff --git a/dockerfiles/Dockerfile.source b/dockerfiles/Dockerfile.source index ea28e144ee95a..51291e59aa0d5 100644 --- a/dockerfiles/Dockerfile.source +++ b/dockerfiles/Dockerfile.source @@ -16,4 +16,4 @@ RUN cd /code && /bin/bash ./build.sh --allow_running_as_root --skip_submodule_sy FROM mcr.microsoft.com/azurelinux/base/python:3 COPY --from=0 /code/build/Linux/Release/dist /root COPY --from=0 /code/dockerfiles/LICENSE-IMAGE.txt /code/LICENSE-IMAGE.txt -RUN tdnf install -y ca-certificates python3-setuptools python3-wheel python3-pip python3-numpy python3-flatbuffers python3-packaging python3-protobuf python3-mpmath python3-sympy && python3 -m pip install coloredlogs humanfriendly && python3 -m pip install --no-index --find-links /root onnxruntime && rm -rf /root/*.whl +RUN tdnf install -y ca-certificates python3-setuptools python3-wheel python3-pip python3-numpy python3-flatbuffers python3-packaging python3-protobuf python3-mpmath python3-sympy && python3 -m pip install humanfriendly && python3 -m pip install --no-index --find-links /root onnxruntime && rm -rf /root/*.whl diff --git a/docs/python/requirements.txt b/docs/python/requirements.txt index 0be11c8760892..0b0e5d464b26e 100644 --- a/docs/python/requirements.txt +++ b/docs/python/requirements.txt @@ -11,7 +11,6 @@ furo pyquickhelper pandas pydot -coloredlogs flatbuffers numpy<2.0.0 packaging diff --git a/onnxruntime/python/tools/tensorrt/perf/benchmark.py b/onnxruntime/python/tools/tensorrt/perf/benchmark.py index 66ab0c44f8814..2017cf154f21e 100644 --- a/onnxruntime/python/tools/tensorrt/perf/benchmark.py +++ b/onnxruntime/python/tools/tensorrt/perf/benchmark.py @@ -12,7 +12,6 @@ import timeit from datetime import datetime -import coloredlogs import numpy as np from perf_utils import ( acl, @@ -2259,12 +2258,13 @@ def parse_arguments(): def setup_logger(verbose): if verbose: - coloredlogs.install( - level="DEBUG", - fmt="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s", + logging.basicConfig( + level=logging.DEBUG, + format="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s", + force=True, ) else: - coloredlogs.install(fmt="%(message)s") + logging.basicConfig(format="%(message)s", force=True) logging.getLogger("transformers").setLevel(logging.WARNING) diff --git a/onnxruntime/python/tools/tensorrt/perf/benchmark_wrapper.py b/onnxruntime/python/tools/tensorrt/perf/benchmark_wrapper.py index 204fe61396663..7bfe25b1549cf 100644 --- a/onnxruntime/python/tools/tensorrt/perf/benchmark_wrapper.py +++ b/onnxruntime/python/tools/tensorrt/perf/benchmark_wrapper.py @@ -11,7 +11,6 @@ import pprint import re -import coloredlogs # noqa: F401 from benchmark import * # noqa: F403 from perf_utils import * # noqa: F403 diff --git a/onnxruntime/python/tools/tensorrt/perf/perf_utils.py b/onnxruntime/python/tools/tensorrt/perf/perf_utils.py index 8d2f4b07b7984..4b83e1a8fc41f 100644 --- a/onnxruntime/python/tools/tensorrt/perf/perf_utils.py +++ b/onnxruntime/python/tools/tensorrt/perf/perf_utils.py @@ -5,8 +5,6 @@ import subprocess import sys -import coloredlogs # noqa: F401 - debug = False debug_verbose = False diff --git a/onnxruntime/python/tools/tensorrt/perf/requirements.txt b/onnxruntime/python/tools/tensorrt/perf/requirements.txt index 0afbf47e88307..2a4b319cfc57e 100644 --- a/onnxruntime/python/tools/tensorrt/perf/requirements.txt +++ b/onnxruntime/python/tools/tensorrt/perf/requirements.txt @@ -1,4 +1,3 @@ onnxconverter-common onnxmltools pandas -coloredlogs \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index 8055e5e4ae876..56b670e8f2306 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -18,7 +18,6 @@ from time import sleep from typing import Any -import coloredlogs import numpy import torch import transformers @@ -147,12 +146,12 @@ def create_onnxruntime_session( def setup_logger(verbose=True): if verbose: - coloredlogs.install( - level="DEBUG", - fmt="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s", + logging.basicConfig( + format="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s", + level=logging.DEBUG, ) else: - coloredlogs.install(fmt="%(message)s") + logging.basicConfig(format="%(message)s", level=logging.INFO) logging.getLogger("transformers").setLevel(logging.WARNING) diff --git a/onnxruntime/python/tools/transformers/convert_to_packing_mode.py b/onnxruntime/python/tools/transformers/convert_to_packing_mode.py index 9a6388b3f350d..d8177fcd3cb02 100644 --- a/onnxruntime/python/tools/transformers/convert_to_packing_mode.py +++ b/onnxruntime/python/tools/transformers/convert_to_packing_mode.py @@ -7,7 +7,6 @@ import logging import os -import coloredlogs from constants import ( AttentionInputIDs, AttentionOutputIDs, @@ -358,12 +357,12 @@ def _parse_arguments(): def _setup_logger(verbose): if verbose: - coloredlogs.install( - level="DEBUG", - fmt="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s", + logging.basicConfig( + format="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s", + level=logging.DEBUG, ) else: - coloredlogs.install(fmt="%(funcName)20s: %(message)s") + logging.basicConfig(format="%(funcName)20s: %(message)s", level=logging.INFO) def main(): diff --git a/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py b/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py index 21848deaf99fe..674dc831d70f9 100644 --- a/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py +++ b/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py @@ -11,7 +11,7 @@ # conda create -n gpu_env python=3.8 # conda activate gpu_env # pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 -# pip3 install onnx transformers onnxruntime-gpu numpy sympy coloredlogs psutil py3nvml +# pip3 install onnx transformers onnxruntime-gpu numpy sympy psutil py3nvml # python benchmark_longformer.py # # When there is no parameter, pre-defined tests will run on the longformer-base-4096 model. diff --git a/onnxruntime/python/tools/transformers/models/longformer/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/longformer/convert_to_onnx.py index b80feec892994..513a115352556 100644 --- a/onnxruntime/python/tools/transformers/models/longformer/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/longformer/convert_to_onnx.py @@ -18,7 +18,7 @@ # conda create -n longformer python=3.8 # conda activate longformer # python3 -m pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html -# python3 -m pip install coloredlogs flatbuffers numpy packaging sympy protobuf==3.20.1 onnx==1.12.0 transformers==4.18.0 +# python3 -m pip install flatbuffers numpy packaging sympy protobuf==3.20.1 onnx==1.12.0 transformers==4.18.0 # python3 -m pip install -i https://test.pypi.org/simple/ ort-nightly-gpu # cd ./torch_extensions # rm -rf build diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py index ed2e346972a6c..e90af970032e5 100755 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -5,14 +5,13 @@ import argparse import csv +import logging import os import statistics import sys import time from pathlib import Path -import coloredlogs - # import torch before onnxruntime so that onnxruntime uses the cuDNN in the torch package. import torch from benchmark_helper import measure_memory @@ -1332,7 +1331,7 @@ def main(): if version.parse(ort_version) < version.parse("1.16"): raise ValueError("CUDA graph requires ONNX Runtime 1.16 or later") - coloredlogs.install(fmt="%(funcName)20s: %(message)s") + logging.basicConfig(format="%(funcName)20s: %(message)s", level=logging.INFO, force=True) memory_monitor_type = "cuda" diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py index a3caba138f44a..d851e785e8d84 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py @@ -20,7 +20,8 @@ # limitations under the License. # -------------------------------------------------------------------------- -import coloredlogs +import logging + from cuda import cudart from demo_utils import ( add_controlnet_arguments, @@ -86,7 +87,7 @@ def run_inference(warmup=False): if __name__ == "__main__": - coloredlogs.install(fmt="%(funcName)20s: %(message)s") + logging.basicConfig(format="%(funcName)20s: %(message)s", level=logging.INFO) parser = arg_parser("Options for Stable Diffusion Demo") add_controlnet_arguments(parser) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py index c3e91a405b53f..739f3cb5025e7 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py @@ -20,7 +20,8 @@ # limitations under the License. # -------------------------------------------------------------------------- -import coloredlogs +import logging + from cuda import cudart from demo_utils import ( add_controlnet_arguments, @@ -252,7 +253,7 @@ def main(args): if __name__ == "__main__": - coloredlogs.install(fmt="%(funcName)20s: %(message)s") + logging.basicConfig(format="%(funcName)20s: %(message)s", level=logging.INFO) parser = arg_parser("Options for Stable Diffusion XL Demo") add_controlnet_arguments(parser) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index 33397cf75e127..25c034f7b70b5 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -23,7 +23,6 @@ import warnings from pathlib import Path -import coloredlogs import onnx from fusion_options import FusionOptions from onnx_model_clip import ClipOnnxModel @@ -587,5 +586,5 @@ def main(argv: list[str] | None = None): if __name__ == "__main__": - coloredlogs.install(fmt="%(funcName)20s: %(message)s") + logging.basicConfig(format="%(funcName)20s: %(message)s", level=logging.INFO) main() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/requirements.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/requirements.txt index 73929214b22ea..e7852f7478db8 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/requirements.txt @@ -4,7 +4,6 @@ transformers==4.50.0 numpy>=1.24.1 accelerate onnx==1.18.0 -coloredlogs packaging # Use newer version of protobuf might cause crash protobuf==4.25.8 diff --git a/onnxruntime/python/tools/transformers/notebooks/Inference_GPT2_with_OnnxRuntime_on_CPU.ipynb b/onnxruntime/python/tools/transformers/notebooks/Inference_GPT2_with_OnnxRuntime_on_CPU.ipynb index 5e81e754e1109..6603c9c387517 100644 --- a/onnxruntime/python/tools/transformers/notebooks/Inference_GPT2_with_OnnxRuntime_on_CPU.ipynb +++ b/onnxruntime/python/tools/transformers/notebooks/Inference_GPT2_with_OnnxRuntime_on_CPU.ipynb @@ -52,7 +52,7 @@ "else:\n", " !{sys.executable} -m pip install install torch --index-url https://download.pytorch.org/whl/cpu -q\n", "\n", - "!{sys.executable} -m pip install onnxruntime transformers==4.18 onnx psutil pandas py-cpuinfo py3nvml netron coloredlogs --no-warn-script-location -q" + "!{sys.executable} -m pip install onnxruntime transformers==4.18 onnx psutil pandas py-cpuinfo py3nvml netron --no-warn-script-location -q" ] }, { @@ -719,4 +719,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/notebooks/PyTorch_Bert-Squad_OnnxRuntime_GPU.ipynb b/onnxruntime/python/tools/transformers/notebooks/PyTorch_Bert-Squad_OnnxRuntime_GPU.ipynb index 7295ae1436c99..76458ca3220c9 100644 --- a/onnxruntime/python/tools/transformers/notebooks/PyTorch_Bert-Squad_OnnxRuntime_GPU.ipynb +++ b/onnxruntime/python/tools/transformers/notebooks/PyTorch_Bert-Squad_OnnxRuntime_GPU.ipynb @@ -59,7 +59,7 @@ "\n", "if sys.platform in ['linux', 'win32']: # Linux or Windows\n", " !{sys.executable} -m pip install torch --index-url https://download.pytorch.org/whl/cu118 -q\n", - " !{sys.executable} -m pip install onnxruntime-gpu onnx transformers psutil pandas py-cpuinfo py3nvml coloredlogs wget netron sympy protobuf==3.20.3 -q\n", + " !{sys.executable} -m pip install onnxruntime-gpu onnx transformers psutil pandas py-cpuinfo py3nvml wget netron sympy protobuf==3.20.3 -q\n", "else: # Mac\n", " print(\"CUDA is not available on MacOS\")" ] @@ -196,9 +196,9 @@ "Some weights of the model checkpoint at bert-large-uncased-whole-word-masking-finetuned-squad were not used when initializing BertForQuestionAnswering: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n", "- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 48/48 [00:02<00:00, 16.27it/s]\n", - "convert squad examples to features: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:03<00:00, 256.11it/s]\n", - "add example index and unique id: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00= 1.8 numpy >= 1.19.0 -coloredlogs psutil py-cpuinfo py3nvml diff --git a/onnxruntime/python/tools/transformers/run_benchmark.sh b/onnxruntime/python/tools/transformers/run_benchmark.sh index 25997f40d348f..c16d60d0d5046 100755 --- a/onnxruntime/python/tools/transformers/run_benchmark.sh +++ b/onnxruntime/python/tools/transformers/run_benchmark.sh @@ -95,7 +95,7 @@ if [ "$run_install" = true ] ; then else pip install onnxruntime-gpu fi - pip install --upgrade onnx coloredlogs packaging psutil py3nvml numpy transformers sympy + pip install --upgrade onnx packaging psutil py3nvml numpy transformers sympy fi if [ "$use_package" = true ] ; then diff --git a/onnxruntime/test/python/transformers/test_gpt2_benchmark.py b/onnxruntime/test/python/transformers/test_gpt2_benchmark.py index 2d9bc035fe4fd..40be872250f1a 100644 --- a/onnxruntime/test/python/transformers/test_gpt2_benchmark.py +++ b/onnxruntime/test/python/transformers/test_gpt2_benchmark.py @@ -9,7 +9,6 @@ import os import unittest -import coloredlogs import pytest from parity_utilities import find_transformers_source @@ -50,6 +49,6 @@ def test_gpt2_int8(self): if __name__ == "__main__": - coloredlogs.install(fmt="%(message)s") + logging.basicConfig(format="%(message)s") logging.getLogger("transformers").setLevel(logging.ERROR) unittest.main() diff --git a/onnxruntime/test/python/transformers/test_gpt2_to_onnx.py b/onnxruntime/test/python/transformers/test_gpt2_to_onnx.py index e179d3d087120..bda99abbb7287 100644 --- a/onnxruntime/test/python/transformers/test_gpt2_to_onnx.py +++ b/onnxruntime/test/python/transformers/test_gpt2_to_onnx.py @@ -7,7 +7,6 @@ import logging import unittest -import coloredlogs import pytest from parity_utilities import find_transformers_source @@ -58,6 +57,6 @@ def test_auto_mixed_precision(self): if __name__ == "__main__": - coloredlogs.install(fmt="%(message)s") + logging.basicConfig(format="%(message)s") logging.getLogger("transformers").setLevel(logging.ERROR) unittest.main() diff --git a/requirements.txt b/requirements.txt index 2fd9362c949dd..ff8cc04d6f219 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -coloredlogs flatbuffers numpy >= 1.21.6 packaging diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml index c00cbb06f26fd..4bfb9c630fede 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml @@ -142,7 +142,7 @@ jobs: workingDirectory: '$(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/' condition: always() - - script: 'python3 -m pip install pandas azure-kusto-data[pandas] azure-kusto-ingest[pandas] coloredlogs' + - script: 'python3 -m pip install pandas azure-kusto-data[pandas] azure-kusto-ingest[pandas]' displayName: 'Install dashboard dependencies' - script: | @@ -165,7 +165,7 @@ jobs: - ${{ if eq(parameters.PostToDashboard, true) }}: - - script: 'python3 -m pip install pandas azure-kusto-data[pandas] azure-kusto-ingest[pandas] coloredlogs' + - script: 'python3 -m pip install pandas azure-kusto-data[pandas] azure-kusto-ingest[pandas]' displayName: 'Install dashboard dependencies' - script: | @@ -191,4 +191,4 @@ jobs: pathtoPublish: '$(Build.SourcesDirectory)/Artifact' artifactName: 'result-$(Build.BuildNumber)' - - template: templates/clean-agent-build-directory-step.yml \ No newline at end of file + - template: templates/clean-agent-build-directory-step.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml b/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml index 9b15f389e5349..4ec074055fcc2 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml @@ -54,7 +54,7 @@ jobs: FILE_NAME="${files[0]}" FILE_NAME=$(basename $FILE_NAME) PYTHON_PACKAGE_NAME=$(echo "$FILE_NAME" | cut -f 1 -d '-') - python3 -m pip install coloredlogs flatbuffers numpy packaging protobuf sympy + python3 -m pip install flatbuffers numpy packaging protobuf sympy python3 -m pip install --no-index --find-links . $PYTHON_PACKAGE_NAME python3 -m pip show $PYTHON_PACKAGE_NAME python3 -c "import onnxruntime as ort; print(ort.__version__)" diff --git a/tools/ci_build/github/windows/python/requirements.txt b/tools/ci_build/github/windows/python/requirements.txt index 4e24bf7cbfa97..a86eef170bc25 100644 --- a/tools/ci_build/github/windows/python/requirements.txt +++ b/tools/ci_build/github/windows/python/requirements.txt @@ -14,5 +14,4 @@ jinja2 markupsafe semver packaging -coloredlogs onnx==1.20.1; python_version < "3.14" diff --git a/tools/ci_build/requirements/transformers-test/requirements.txt b/tools/ci_build/requirements/transformers-test/requirements.txt index e95509c7ddec3..1523b420bfdbd 100644 --- a/tools/ci_build/requirements/transformers-test/requirements.txt +++ b/tools/ci_build/requirements/transformers-test/requirements.txt @@ -6,7 +6,6 @@ numpy==2.2.6; python_version < "3.14" numpy==2.3.2; python_version >= "3.14" torch==2.8.0 torchvision==0.23.0 -coloredlogs==15.0 transformers==4.52.1 parameterized>=0.8.1 sentencepiece From e1ba098dcb70a5e5ed48ce3d2961547c9cc43801 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 28 Jan 2026 13:37:38 -0800 Subject: [PATCH 12/23] Bump tar and cmake-js in /js/node (#27193) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [tar](https://github.com/isaacs/node-tar) to 7.5.7 and updates ancestor dependency [cmake-js](https://github.com/cmake-js/cmake-js). These dependencies need to be updated together. Updates `tar` from 6.2.1 to 7.5.7
Changelog

Sourced from tar's changelog.

Changelog

7.5

  • Added zstd compression support.
  • Consistent TOCTOU behavior in sync t.list
  • Only read from ustar block if not specified in Pax
  • Fix sync tar.list when file size reduces while reading
  • Sanitize absolute linkpaths properly
  • Prevent writing hardlink entries to the archive ahead of their file target

7.4

  • Deprecate onentry in favor of onReadEntry for clarity.

7.3

  • Add onWriteEntry option

7.2

  • DRY the command definitions into a single makeCommand method, and update the type signatures to more appropriately infer the return type from the options and arguments provided.

7.1

  • Update minipass to v7.1.0
  • Update the type definitions of write() and end() methods on Unpack and Parser classes to be compatible with the NodeJS.WritableStream type in the latest versions of @types/node.

7.0

  • Drop support for node <18
  • Rewrite in TypeScript, provide ESM and CommonJS hybrid interface
  • Add tree-shake friendly exports, like import('tar/create') and import('tar/read-entry') to get individual functions or classes.
  • Add chmod option that defaults to false, and deprecate noChmod. That is, reverse the default option regarding explicitly setting file system modes to match tar entry settings.
  • Add processUmask option to avoid having to call process.umask() when chmod: true (or noChmod: false) is set.

... (truncated)

Commits
Maintainer changes

This version was pushed to npm by isaacs, a new releaser for tar since your current version.


Updates `cmake-js` from 7.2.1 to 8.0.0
Release notes

Sourced from cmake-js's releases.

v8.0.0

This is a small but breaking change.

This now requires nodejs 20 or later, due to increased requirements of updated dependencies

With the increased minimum, this now uses the builtin fetch which further reduces the install size!

Full Changelog: https://github.com/cmake-js/cmake-js/compare/v7.4.0...v8.0.0

Changelog

Sourced from cmake-js's changelog.

v8.0.0 - 27/01/26

  • feat: require nodejs 20 or later
  • feat: update deprecated dependencies

v7.4.0 - 14/11/25

v7.3.1 - 17/04/25

  • fix(windows): support windows arm64 (Thanks to @​jaycex)
  • fix(windows): support newer visual studio installations

v7.3.0 - 15/01/24

  • feat(windows): replace custom libnode.def generation with version from node-api-headers
  • fix: support for vs2015 with nodejs 18 and older (#317)
  • fix(windows): always remove Path if PATH is also defined (#319)
  • fix: Cmake arguments got converted to numbers (#314)
  • fix: update node-api-headers
  • chore: update dependencies
Commits

Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- js/node/package-lock.json | 1252 ++++++------------------------------- js/node/package.json | 2 +- 2 files changed, 208 insertions(+), 1046 deletions(-) diff --git a/js/node/package-lock.json b/js/node/package-lock.json index 6cf2c2963b769..7fb63332d7079 100644 --- a/js/node/package-lock.json +++ b/js/node/package-lock.json @@ -21,7 +21,7 @@ }, "devDependencies": { "@types/minimist": "^1.2.2", - "cmake-js": "^7.2.1", + "cmake-js": "^8.0.0", "jsonc": "^2.0.0", "minimist": "^1.2.8", "node-addon-api": "^6.0.0", @@ -37,6 +37,18 @@ "typedoc": "^0.25.7" } }, + "node_modules/@isaacs/fs-minipass": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/@isaacs/fs-minipass/-/fs-minipass-4.0.1.tgz", + "integrity": "sha512-wgm9Ehl2jpeqP3zw/7mo3kRHFp5MEDhqAdwy1fTGkHAwnkGOVsgpvQhL8B5n1qlb01jV3n/bI0ZfZp5lWA1k4w==", + "dev": true, + "dependencies": { + "minipass": "^7.0.4" + }, + "engines": { + "node": ">=18.0.0" + } + }, "node_modules/@protobufjs/aspromise": { "version": "1.1.2", "resolved": "https://registry.npmjs.org/@protobufjs/aspromise/-/aspromise-1.1.2.tgz", @@ -146,60 +158,19 @@ "url": "https://github.com/chalk/ansi-styles?sponsor=1" } }, - "node_modules/aproba": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/aproba/-/aproba-2.0.0.tgz", - "integrity": "sha512-lYe4Gx7QT+MKGbDsA+Z+he/Wtef0BiwDOlK/XkBrdfsh9J/jPPXbX0tE9x9cl27Tmu5gg3QUbUrQYa/y+KOHPQ==", - "dev": true - }, - "node_modules/are-we-there-yet": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/are-we-there-yet/-/are-we-there-yet-3.0.1.tgz", - "integrity": "sha512-QZW4EDmGwlYur0Yyf/b2uGucHQMa8aFUP7eu9ddR73vvhFyt4V0Vl3QHPcTNJ8l6qYOBdxgXdnBXQrHilfRQBg==", - "dev": true, - "dependencies": { - "delegates": "^1.0.0", - "readable-stream": "^3.6.0" - }, - "engines": { - "node": "^12.13.0 || ^14.15.0 || >=16.0.0" - } - }, - "node_modules/asynckit": { - "version": "0.4.0", - "resolved": "https://registry.npmjs.org/asynckit/-/asynckit-0.4.0.tgz", - "integrity": "sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==", - "dev": true - }, - "node_modules/axios": { - "version": "1.12.2", - "resolved": "https://registry.npmjs.org/axios/-/axios-1.12.2.tgz", - "integrity": "sha512-vMJzPewAlRyOgxV2dU0Cuz2O8zzzx9VYtbJOaBgXFeLc4IV/Eg50n4LowmehOOR61S8ZMpc2K5Sa7g6A4jfkUw==", - "dev": true, - "license": "MIT", - "dependencies": { - "follow-redirects": "^1.15.6", - "form-data": "^4.0.4", - "proxy-from-env": "^1.1.0" - } - }, "node_modules/boolean": { "version": "3.2.0", "resolved": "https://registry.npmjs.org/boolean/-/boolean-3.2.0.tgz", "integrity": "sha512-d0II/GO9uf9lfUHH2BQsjxzRJZBdsjgsBiW4BvhWk/3qoKwQFjIDVN19PfX8F2D/r9PCMTtLWjYVCFrpeYUzsw==", "deprecated": "Package no longer supported. Contact Support at https://www.npmjs.com/support for more info." }, - "node_modules/call-bind-apply-helpers": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/call-bind-apply-helpers/-/call-bind-apply-helpers-1.0.2.tgz", - "integrity": "sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==", + "node_modules/chownr": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/chownr/-/chownr-3.0.0.tgz", + "integrity": "sha512-+IxzY9BZOQd/XuYPRmrvEVjF/nqj5kgT4kEq7VofrDoM1MxoRjEWkrCC3EtLi59TVawxTAn+orJwFQcrqEN1+g==", "dev": true, - "dependencies": { - "es-errors": "^1.3.0", - "function-bind": "^1.1.2" - }, "engines": { - "node": ">= 0.4" + "node": ">=18" } }, "node_modules/cliui": { @@ -217,93 +188,26 @@ } }, "node_modules/cmake-js": { - "version": "7.2.1", - "resolved": "https://registry.npmjs.org/cmake-js/-/cmake-js-7.2.1.tgz", - "integrity": "sha512-AdPSz9cSIJWdKvm0aJgVu3X8i0U3mNTswJkSHzZISqmYVjZk7Td4oDFg0mCBA383wO+9pG5Ix7pEP1CZH9x2BA==", + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/cmake-js/-/cmake-js-8.0.0.tgz", + "integrity": "sha512-YbUP88RDwCvoQkZhRtGURYm9RIpWdtvZuhT87fKNoLjk8kIFIFeARpKfuZQGdwfH99GZpUmqSfcDrK62X7lTgg==", "dev": true, "dependencies": { - "axios": "^1.3.2", - "debug": "^4", - "fs-extra": "^10.1.0", - "lodash.isplainobject": "^4.0.6", - "memory-stream": "^1.0.0", - "node-api-headers": "^0.0.2", - "npmlog": "^6.0.2", - "rc": "^1.2.7", - "semver": "^7.3.8", - "tar": "^6.1.11", + "debug": "^4.4.3", + "fs-extra": "^11.3.3", + "node-api-headers": "^1.8.0", + "rc": "1.2.8", + "semver": "^7.7.3", + "tar": "^7.5.6", "url-join": "^4.0.1", - "which": "^2.0.2", - "yargs": "^17.6.0" + "which": "^6.0.0", + "yargs": "^17.7.2" }, "bin": { "cmake-js": "bin/cmake-js" }, "engines": { - "node": ">= 14.15.0" - } - }, - "node_modules/cmake-js/node_modules/chownr": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/chownr/-/chownr-2.0.0.tgz", - "integrity": "sha512-bIomtDF5KGpdogkLd9VspvFzk9KfpyyGlS8YFVZl7TGPBHL5snIOnxeshwVgPteQ9b4Eydl+pVbIyE1DcvCWgQ==", - "dev": true, - "engines": { - "node": ">=10" - } - }, - "node_modules/cmake-js/node_modules/minizlib": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/minizlib/-/minizlib-2.1.2.tgz", - "integrity": "sha512-bAxsR8BVfj60DWXHE3u30oHzfl4G7khkSuPW+qvpd7jFRHm7dLxOjUk1EHACJ/hxLY8phGJ0YhYHZo7jil7Qdg==", - "dev": true, - "dependencies": { - "minipass": "^3.0.0", - "yallist": "^4.0.0" - }, - "engines": { - "node": ">= 8" - } - }, - "node_modules/cmake-js/node_modules/minizlib/node_modules/minipass": { - "version": "3.3.6", - "resolved": "https://registry.npmjs.org/minipass/-/minipass-3.3.6.tgz", - "integrity": "sha512-DxiNidxSEK+tHG6zOIklvNOwm3hvCrbUrdtzY74U6HKTJxvIDfOUL5W5P2Ghd3DTkhhKPYGqeNUIh5qcM4YBfw==", - "dev": true, - "dependencies": { - "yallist": "^4.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/cmake-js/node_modules/mkdirp": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/mkdirp/-/mkdirp-1.0.4.tgz", - "integrity": "sha512-vVqVZQyf3WLx2Shd0qJ9xuvqgAyKPLAiqITEtqW0oIUjzo3PePDd6fW9iFz30ef7Ysp/oiWqbhszeGWW2T6Gzw==", - "dev": true, - "bin": { - "mkdirp": "bin/cmd.js" - }, - "engines": { - "node": ">=10" - } - }, - "node_modules/cmake-js/node_modules/tar": { - "version": "6.2.1", - "resolved": "https://registry.npmjs.org/tar/-/tar-6.2.1.tgz", - "integrity": "sha512-DZ4yORTwrbTj/7MZYq2w+/ZFdI6OZ/f9SFHR+71gIVUZhOQPHzVCLpvRnPgyaMpfWxxk/4ONva3GQSyNIKRv6A==", - "dev": true, - "dependencies": { - "chownr": "^2.0.0", - "fs-minipass": "^2.0.0", - "minipass": "^5.0.0", - "minizlib": "^2.1.1", - "mkdirp": "^1.0.3", - "yallist": "^4.0.0" - }, - "engines": { - "node": ">=10" + "node": "^20.17.0 || >=22.9.0" } }, "node_modules/color-convert": { @@ -324,40 +228,13 @@ "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", "dev": true }, - "node_modules/color-support": { - "version": "1.1.3", - "resolved": "https://registry.npmjs.org/color-support/-/color-support-1.1.3.tgz", - "integrity": "sha512-qiBjkpbMLO/HL68y+lh4q0/O1MZFj2RX6X/KmMa3+gJD3z+WwI1ZzDHysvqHGS3mP6mznPckpXmw1nI9cJjyRg==", - "dev": true, - "bin": { - "color-support": "bin.js" - } - }, - "node_modules/combined-stream": { - "version": "1.0.8", - "resolved": "https://registry.npmjs.org/combined-stream/-/combined-stream-1.0.8.tgz", - "integrity": "sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==", - "dev": true, - "dependencies": { - "delayed-stream": "~1.0.0" - }, - "engines": { - "node": ">= 0.8" - } - }, - "node_modules/console-control-strings": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/console-control-strings/-/console-control-strings-1.1.0.tgz", - "integrity": "sha512-ty/fTekppD2fIwRvnZAVdeOiGd1c7YXEixbgJTNzqcxJWKQnjJ/V1bNEEE6hygpM3WjwHFUVK6HTjWSzV4a8sQ==", - "dev": true - }, "node_modules/debug": { - "version": "4.3.4", - "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz", - "integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==", + "version": "4.4.3", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.4.3.tgz", + "integrity": "sha512-RGwwWnwQvkVfavKVt22FGLw+xYSdzARwm0ru6DhTVA3umU5hZc28V3kO4stgYryrTlLpuvgI9GiijltAjNbcqA==", "dev": true, "dependencies": { - "ms": "2.1.2" + "ms": "^2.1.3" }, "engines": { "node": ">=6.0" @@ -409,40 +286,11 @@ "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/delayed-stream": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/delayed-stream/-/delayed-stream-1.0.0.tgz", - "integrity": "sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==", - "dev": true, - "engines": { - "node": ">=0.4.0" - } - }, - "node_modules/delegates": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/delegates/-/delegates-1.0.0.tgz", - "integrity": "sha512-bd2L678uiWATM6m5Z1VzNCErI3jiGzt6HGY8OVICs40JQq/HALfbyNJmp0UDakEY4pMMaN0Ly5om/B1VI/+xfQ==", - "dev": true - }, "node_modules/detect-node": { "version": "2.1.0", "resolved": "https://registry.npmjs.org/detect-node/-/detect-node-2.1.0.tgz", "integrity": "sha512-T0NIuQpnTvFDATNuHN5roPwSBG83rFsuO+MXXH9/3N1eFbn4wcPjttvjMLEPWJ0RGUYgQE7cGgS3tNxbqCGM7g==" }, - "node_modules/dunder-proto": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/dunder-proto/-/dunder-proto-1.0.1.tgz", - "integrity": "sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==", - "dev": true, - "dependencies": { - "call-bind-apply-helpers": "^1.0.1", - "es-errors": "^1.3.0", - "gopd": "^1.2.0" - }, - "engines": { - "node": ">= 0.4" - } - }, "node_modules/emoji-regex": { "version": "8.0.0", "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", @@ -474,42 +322,15 @@ "node": ">= 0.4" } }, - "node_modules/es-object-atoms": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/es-object-atoms/-/es-object-atoms-1.1.1.tgz", - "integrity": "sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==", - "dev": true, - "dependencies": { - "es-errors": "^1.3.0" - }, - "engines": { - "node": ">= 0.4" - } - }, - "node_modules/es-set-tostringtag": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/es-set-tostringtag/-/es-set-tostringtag-2.1.0.tgz", - "integrity": "sha512-j6vWzfrGVfyXxge+O0x5sh6cvxAog0a/4Rdd2K36zCMV5eJ+/+tOAngRO8cODMNWbVRdVlmGZQL2YS3yR8bIUA==", - "dev": true, - "dependencies": { - "es-errors": "^1.3.0", - "get-intrinsic": "^1.2.6", - "has-tostringtag": "^1.0.2", - "hasown": "^2.0.2" - }, - "engines": { - "node": ">= 0.4" - } - }, "node_modules/es6-error": { "version": "4.1.1", "resolved": "https://registry.npmjs.org/es6-error/-/es6-error-4.1.1.tgz", "integrity": "sha512-Um/+FxMr9CISWh0bi5Zv0iOD+4cFh5qLeks1qhAopKVAJw3drgKbKySikp7wGhDL0HPeaja0P5ULZrxLkniUVg==" }, "node_modules/escalade": { - "version": "3.1.1", - "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.1.1.tgz", - "integrity": "sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw==", + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz", + "integrity": "sha512-WUj2qlxaQtO4g6Pq5c29GTcWGDyd8itL8zTlipgECz3JesAiiOKotd8JU6otB3PACgG6xkJUyVhboMS+bje/jA==", "dev": true, "engines": { "node": ">=6" @@ -532,46 +353,10 @@ "integrity": "sha512-W+KJc2dmILlPplD/H4K9l9LcAHAfPtP6BY84uVLXQ6Evcz9Lcg33Y2z1IVblT6xdY54PXYVHEv+0Wpq8Io6zkA==", "dev": true }, - "node_modules/follow-redirects": { - "version": "1.15.6", - "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz", - "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==", - "dev": true, - "funding": [ - { - "type": "individual", - "url": "https://github.com/sponsors/RubenVerborgh" - } - ], - "engines": { - "node": ">=4.0" - }, - "peerDependenciesMeta": { - "debug": { - "optional": true - } - } - }, - "node_modules/form-data": { - "version": "4.0.4", - "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.4.tgz", - "integrity": "sha512-KrGhL9Q4zjj0kiUt5OO4Mr/A/jlI2jDYs5eHBpYHPcBEVSiipAvn2Ko2HnPe20rmcuuvMHNdZFp+4IlGTMF0Ow==", - "dev": true, - "dependencies": { - "asynckit": "^0.4.0", - "combined-stream": "^1.0.8", - "es-set-tostringtag": "^2.1.0", - "hasown": "^2.0.2", - "mime-types": "^2.1.12" - }, - "engines": { - "node": ">= 6" - } - }, "node_modules/fs-extra": { - "version": "10.1.0", - "resolved": "https://registry.npmjs.org/fs-extra/-/fs-extra-10.1.0.tgz", - "integrity": "sha512-oRXApq54ETRj4eMiFzGnHWGy+zo5raudjuxN0b8H7s/RU2oW0Wvsx9O0ACRN/kRq9E8Vu/ReskGB5o3ji+FzHQ==", + "version": "11.3.3", + "resolved": "https://registry.npmjs.org/fs-extra/-/fs-extra-11.3.3.tgz", + "integrity": "sha512-VWSRii4t0AFm6ixFFmLLx1t7wS1gh+ckoa84aOeapGum0h+EZd1EhEumSB+ZdDLnEPuucsVB9oB7cxJHap6Afg==", "dev": true, "dependencies": { "graceful-fs": "^4.2.0", @@ -579,59 +364,7 @@ "universalify": "^2.0.0" }, "engines": { - "node": ">=12" - } - }, - "node_modules/fs-minipass": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/fs-minipass/-/fs-minipass-2.1.0.tgz", - "integrity": "sha512-V/JgOLFCS+R6Vcq0slCuaeWEdNC3ouDlJMNIsacH2VtALiu9mV4LPrHc5cDl8k5aw6J8jwgWWpiTo5RYhmIzvg==", - "dev": true, - "dependencies": { - "minipass": "^3.0.0" - }, - "engines": { - "node": ">= 8" - } - }, - "node_modules/fs-minipass/node_modules/minipass": { - "version": "3.3.6", - "resolved": "https://registry.npmjs.org/minipass/-/minipass-3.3.6.tgz", - "integrity": "sha512-DxiNidxSEK+tHG6zOIklvNOwm3hvCrbUrdtzY74U6HKTJxvIDfOUL5W5P2Ghd3DTkhhKPYGqeNUIh5qcM4YBfw==", - "dev": true, - "dependencies": { - "yallist": "^4.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/function-bind": { - "version": "1.1.2", - "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz", - "integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==", - "dev": true, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/gauge": { - "version": "4.0.4", - "resolved": "https://registry.npmjs.org/gauge/-/gauge-4.0.4.tgz", - "integrity": "sha512-f9m+BEN5jkg6a0fZjleidjN51VE1X+mPFQ2DJ0uv1V39oCLCbsGe6yjbBnp7eK7z/+GAon99a3nHuqbuuthyPg==", - "dev": true, - "dependencies": { - "aproba": "^1.0.3 || ^2.0.0", - "color-support": "^1.1.3", - "console-control-strings": "^1.1.0", - "has-unicode": "^2.0.1", - "signal-exit": "^3.0.7", - "string-width": "^4.2.3", - "strip-ansi": "^6.0.1", - "wide-align": "^1.1.5" - }, - "engines": { - "node": "^12.13.0 || ^14.15.0 || >=16.0.0" + "node": ">=14.14" } }, "node_modules/get-caller-file": { @@ -643,43 +376,6 @@ "node": "6.* || 8.* || >= 10.*" } }, - "node_modules/get-intrinsic": { - "version": "1.3.0", - "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.3.0.tgz", - "integrity": "sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==", - "dev": true, - "dependencies": { - "call-bind-apply-helpers": "^1.0.2", - "es-define-property": "^1.0.1", - "es-errors": "^1.3.0", - "es-object-atoms": "^1.1.1", - "function-bind": "^1.1.2", - "get-proto": "^1.0.1", - "gopd": "^1.2.0", - "has-symbols": "^1.1.0", - "hasown": "^2.0.2", - "math-intrinsics": "^1.1.0" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/get-proto": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/get-proto/-/get-proto-1.0.1.tgz", - "integrity": "sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==", - "dev": true, - "dependencies": { - "dunder-proto": "^1.0.1", - "es-object-atoms": "^1.0.0" - }, - "engines": { - "node": ">= 0.4" - } - }, "node_modules/global-agent": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/global-agent/-/global-agent-3.0.0.tgz", @@ -739,57 +435,6 @@ "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/has-symbols": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.1.0.tgz", - "integrity": "sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ==", - "dev": true, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/has-tostringtag": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/has-tostringtag/-/has-tostringtag-1.0.2.tgz", - "integrity": "sha512-NqADB8VjPFLM2V0VvHUewwwsw0ZWBaIdgo+ieHtK3hasLz4qeCRjYcqfB6AQrBggRKppKF8L52/VqdVsO47Dlw==", - "dev": true, - "dependencies": { - "has-symbols": "^1.0.3" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/has-unicode": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/has-unicode/-/has-unicode-2.0.1.tgz", - "integrity": "sha512-8Rf9Y83NBReMnx0gFzA8JImQACstCYWUplepDa9xprwwtmgEZUF0h/i5xSA625zB/I37EtrswSST6OXxwaaIJQ==", - "dev": true - }, - "node_modules/hasown": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz", - "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==", - "dev": true, - "dependencies": { - "function-bind": "^1.1.2" - }, - "engines": { - "node": ">= 0.4" - } - }, - "node_modules/inherits": { - "version": "2.0.4", - "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz", - "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==", - "dev": true - }, "node_modules/ini": { "version": "1.3.8", "resolved": "https://registry.npmjs.org/ini/-/ini-1.3.8.tgz", @@ -812,10 +457,13 @@ } }, "node_modules/isexe": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", - "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==", - "dev": true + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/isexe/-/isexe-3.1.1.tgz", + "integrity": "sha512-LpB/54B+/2J5hqQ7imZHfdU31OlgQqx7ZicVlkm9kzg9/w8GKLEcFfJl/t7DCEDueOyBAD6zCCwTO6Fzs0NoEQ==", + "dev": true, + "engines": { + "node": ">=16" + } }, "node_modules/json-parse-better-errors": { "version": "1.0.2", @@ -846,9 +494,9 @@ } }, "node_modules/jsonfile": { - "version": "6.1.0", - "resolved": "https://registry.npmjs.org/jsonfile/-/jsonfile-6.1.0.tgz", - "integrity": "sha512-5dgndWOriYSm5cnYaJNhalLNDKOqFwyDB/rr1E9ZsGciGvKPs8R2xYGCacuf3z6K1YKDz182fd+fY3cn3pMqXQ==", + "version": "6.2.0", + "resolved": "https://registry.npmjs.org/jsonfile/-/jsonfile-6.2.0.tgz", + "integrity": "sha512-FGuPw30AdOIUTRMC2OMRtQV+jkVj2cfPqSeWXv1NEAJ1qZ5zb1X6z1mFhbfOB/iy3ssJCD+3KuZ8r8C3uVFlAg==", "dev": true, "dependencies": { "universalify": "^2.0.0" @@ -857,29 +505,12 @@ "graceful-fs": "^4.1.6" } }, - "node_modules/lodash.isplainobject": { - "version": "4.0.6", - "resolved": "https://registry.npmjs.org/lodash.isplainobject/-/lodash.isplainobject-4.0.6.tgz", - "integrity": "sha512-oSXzaWypCMHkPC3NvBEaPHf0KsA5mvPrOPgQWDsbg8n7orZ290M0BmC/jgRZ4vcJ6DTAhjrsSYgdsW/F+MFOBA==", - "dev": true - }, "node_modules/long": { "version": "5.2.3", "resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz", "integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==", "dev": true }, - "node_modules/lru-cache": { - "version": "6.0.0", - "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz", - "integrity": "sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==", - "dependencies": { - "yallist": "^4.0.0" - }, - "engines": { - "node": ">=10" - } - }, "node_modules/matcher": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/matcher/-/matcher-3.0.0.tgz", @@ -891,45 +522,6 @@ "node": ">=10" } }, - "node_modules/math-intrinsics": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/math-intrinsics/-/math-intrinsics-1.1.0.tgz", - "integrity": "sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==", - "dev": true, - "engines": { - "node": ">= 0.4" - } - }, - "node_modules/memory-stream": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/memory-stream/-/memory-stream-1.0.0.tgz", - "integrity": "sha512-Wm13VcsPIMdG96dzILfij09PvuS3APtcKNh7M28FsCA/w6+1mjR7hhPmfFNoilX9xU7wTdhsH5lJAm6XNzdtww==", - "dev": true, - "dependencies": { - "readable-stream": "^3.4.0" - } - }, - "node_modules/mime-db": { - "version": "1.52.0", - "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.52.0.tgz", - "integrity": "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==", - "dev": true, - "engines": { - "node": ">= 0.6" - } - }, - "node_modules/mime-types": { - "version": "2.1.35", - "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.35.tgz", - "integrity": "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==", - "dev": true, - "dependencies": { - "mime-db": "1.52.0" - }, - "engines": { - "node": ">= 0.6" - } - }, "node_modules/minimist": { "version": "1.2.8", "resolved": "https://registry.npmjs.org/minimist/-/minimist-1.2.8.tgz", @@ -940,12 +532,24 @@ } }, "node_modules/minipass": { - "version": "5.0.0", - "resolved": "https://registry.npmjs.org/minipass/-/minipass-5.0.0.tgz", - "integrity": "sha512-3FnjYuehv9k6ovOEbyOswadCDPX1piCfhV8ncmYtHOjuPwylVWsghTLo7rabjC3Rx5xD4HDx8Wm1xnMF7S5qFQ==", + "version": "7.1.2", + "resolved": "https://registry.npmjs.org/minipass/-/minipass-7.1.2.tgz", + "integrity": "sha512-qOOzS1cBTWYF4BH8fVePDBOO9iptMnGUEZwNc/cMWnTV2nVLZ7VoNWEPHkYczZA0pdoA7dl6e7FL659nX9S2aw==", "dev": true, "engines": { - "node": ">=8" + "node": ">=16 || 14 >=14.17" + } + }, + "node_modules/minizlib": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/minizlib/-/minizlib-3.1.0.tgz", + "integrity": "sha512-KZxYo1BUkWD2TVFLr0MQoM8vUUigWD3LlD83a/75BqC+4qE0Hb1Vo5v1FgcfaNXvfXzr+5EhQ6ing/CaBijTlw==", + "dev": true, + "dependencies": { + "minipass": "^7.1.2" + }, + "engines": { + "node": ">= 18" } }, "node_modules/mkdirp": { @@ -961,9 +565,9 @@ } }, "node_modules/ms": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", - "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==", + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", "dev": true }, "node_modules/node-addon-api": { @@ -973,26 +577,11 @@ "dev": true }, "node_modules/node-api-headers": { - "version": "0.0.2", - "resolved": "https://registry.npmjs.org/node-api-headers/-/node-api-headers-0.0.2.tgz", - "integrity": "sha512-YsjmaKGPDkmhoNKIpkChtCsPVaRE0a274IdERKnuc/E8K1UJdBZ4/mvI006OijlQZHCfpRNOH3dfHQs92se8gg==", + "version": "1.8.0", + "resolved": "https://registry.npmjs.org/node-api-headers/-/node-api-headers-1.8.0.tgz", + "integrity": "sha512-jfnmiKWjRAGbdD1yQS28bknFM1tbHC1oucyuMPjmkEs+kpiu76aRs40WlTmBmyEgzDM76ge1DQ7XJ3R5deiVjQ==", "dev": true }, - "node_modules/npmlog": { - "version": "6.0.2", - "resolved": "https://registry.npmjs.org/npmlog/-/npmlog-6.0.2.tgz", - "integrity": "sha512-/vBvz5Jfr9dT/aFWd0FIRf+T/Q2WBsLENygUaFUqstqsycmZAP/t5BvFJTK0viFmSUxiUKTUplWy5vt+rvKIxg==", - "dev": true, - "dependencies": { - "are-we-there-yet": "^3.0.0", - "console-control-strings": "^1.1.0", - "gauge": "^4.0.3", - "set-blocking": "^2.0.0" - }, - "engines": { - "node": "^12.13.0 || ^14.15.0 || >=16.0.0" - } - }, "node_modules/object-keys": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/object-keys/-/object-keys-1.1.1.tgz", @@ -1042,12 +631,6 @@ "node": ">=12.0.0" } }, - "node_modules/proxy-from-env": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/proxy-from-env/-/proxy-from-env-1.1.0.tgz", - "integrity": "sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg==", - "dev": true - }, "node_modules/rc": { "version": "1.2.8", "resolved": "https://registry.npmjs.org/rc/-/rc-1.2.8.tgz", @@ -1072,20 +655,6 @@ "node": ">=0.10.0" } }, - "node_modules/readable-stream": { - "version": "3.6.1", - "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-3.6.1.tgz", - "integrity": "sha512-+rQmrWMYGA90yenhTYsLWAsLsqVC8osOw6PKE1HDYiO0gdPeKe/xDHNzIAIn4C91YQ6oenEhfYqqc1883qHbjQ==", - "dev": true, - "dependencies": { - "inherits": "^2.0.3", - "string_decoder": "^1.1.1", - "util-deprecate": "^1.0.1" - }, - "engines": { - "node": ">= 6" - } - }, "node_modules/require-directory": { "version": "2.1.1", "resolved": "https://registry.npmjs.org/require-directory/-/require-directory-2.1.1.tgz", @@ -1111,33 +680,10 @@ "node": ">=8.0" } }, - "node_modules/safe-buffer": { - "version": "5.2.1", - "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", - "integrity": "sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==", - "dev": true, - "funding": [ - { - "type": "github", - "url": "https://github.com/sponsors/feross" - }, - { - "type": "patreon", - "url": "https://www.patreon.com/feross" - }, - { - "type": "consulting", - "url": "https://feross.org/support" - } - ] - }, "node_modules/semver": { - "version": "7.5.4", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", - "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", - "dependencies": { - "lru-cache": "^6.0.0" - }, + "version": "7.7.3", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.3.tgz", + "integrity": "sha512-SdsKMrI9TdgjdweUSR9MweHA4EJ8YxHn8DFaDisvhVlUOe4BF1tLD7GAj0lIqWVl+dPb/rExr0Btby5loQm20Q==", "bin": { "semver": "bin/semver.js" }, @@ -1164,32 +710,11 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/set-blocking": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/set-blocking/-/set-blocking-2.0.0.tgz", - "integrity": "sha512-KiKBS8AnWGEyLzofFfmvKwpdPzqiy16LvQfK3yv/fVH7Bj13/wl3JSR1J+rfgRE9q7xUJK4qvgS8raSOeLUehw==", - "dev": true - }, - "node_modules/signal-exit": { - "version": "3.0.7", - "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-3.0.7.tgz", - "integrity": "sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==", - "dev": true - }, "node_modules/sprintf-js": { "version": "1.1.3", "resolved": "https://registry.npmjs.org/sprintf-js/-/sprintf-js-1.1.3.tgz", "integrity": "sha512-Oo+0REFV59/rz3gfJNKQiBlwfHaSESl1pcGyABQsnnIfWOFt6JNj5gCog2U6MLZ//IGYD+nA8nI+mTShREReaA==" }, - "node_modules/string_decoder": { - "version": "1.3.0", - "resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.3.0.tgz", - "integrity": "sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA==", - "dev": true, - "dependencies": { - "safe-buffer": "~5.2.0" - } - }, "node_modules/string-width": { "version": "4.2.3", "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz", @@ -1237,6 +762,22 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/tar": { + "version": "7.5.7", + "resolved": "https://registry.npmjs.org/tar/-/tar-7.5.7.tgz", + "integrity": "sha512-fov56fJiRuThVFXD6o6/Q354S7pnWMJIVlDBYijsTNx6jKSE4pvrDTs6lUnmGvNyfJwFQQwWy3owKz1ucIhveQ==", + "dev": true, + "dependencies": { + "@isaacs/fs-minipass": "^4.0.0", + "chownr": "^3.0.0", + "minipass": "^7.1.2", + "minizlib": "^3.1.0", + "yallist": "^5.0.0" + }, + "engines": { + "node": ">=18" + } + }, "node_modules/type-fest": { "version": "0.13.1", "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.13.1.tgz", @@ -1249,9 +790,9 @@ } }, "node_modules/universalify": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/universalify/-/universalify-2.0.0.tgz", - "integrity": "sha512-hAZsKq7Yy11Zu1DE0OzWjw7nnLZmJZYTDZZyEFHZdUhV8FkH5MCfoU1XMaxXovpyW5nq5scPqq0ZDP9Zyl04oQ==", + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/universalify/-/universalify-2.0.1.tgz", + "integrity": "sha512-gptHNQghINnc/vTGIk0SOFGFNXw7JVrlRUtConJRlvaw6DuX0wO5Jeko9sWrMBhh+PsYAZ7oXAiOnf/UKogyiw==", "dev": true, "engines": { "node": ">= 10.0.0" @@ -1263,34 +804,19 @@ "integrity": "sha512-jk1+QP6ZJqyOiuEI9AEWQfju/nB2Pw466kbA0LEZljHwKeMgd9WrAEgEGxjPDD2+TNbbb37rTyhEfrCXfuKXnA==", "dev": true }, - "node_modules/util-deprecate": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", - "integrity": "sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==", - "dev": true - }, "node_modules/which": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", - "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/which/-/which-6.0.0.tgz", + "integrity": "sha512-f+gEpIKMR9faW/JgAgPK1D7mekkFoqbmiwvNzuhsHetni20QSgzg9Vhn0g2JSJkkfehQnqdUAx7/e15qS1lPxg==", "dev": true, "dependencies": { - "isexe": "^2.0.0" + "isexe": "^3.1.1" }, "bin": { - "node-which": "bin/node-which" + "node-which": "bin/which.js" }, "engines": { - "node": ">= 8" - } - }, - "node_modules/wide-align": { - "version": "1.1.5", - "resolved": "https://registry.npmjs.org/wide-align/-/wide-align-1.1.5.tgz", - "integrity": "sha512-eDMORYaPNZ4sQIuuYPDHdQvf4gyCF9rEEV/yPxGfwPkRodwEgiMUUXTx/dex+Me0wxx53S+NgUHaP7y3MGlDmg==", - "dev": true, - "dependencies": { - "string-width": "^1.0.2 || 2 || 3 || 4" + "node": "^20.17.0 || >=22.9.0" } }, "node_modules/wrap-ansi": { @@ -1320,14 +846,18 @@ } }, "node_modules/yallist": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz", - "integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==" + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/yallist/-/yallist-5.0.0.tgz", + "integrity": "sha512-YgvUTfwqyc7UXVMrB+SImsVYSmTS8X/tSrtdNZMImM+n7+QTriRXyXim0mBrTXNeqzVF0KWGgHPeiyViFFrNDw==", + "dev": true, + "engines": { + "node": ">=18" + } }, "node_modules/yargs": { - "version": "17.7.1", - "resolved": "https://registry.npmjs.org/yargs/-/yargs-17.7.1.tgz", - "integrity": "sha512-cwiTb08Xuv5fqF4AovYacTFNxk62th7LKJ6BL9IGUpTJrWoU7/7WdQGTP2SjKf1dUNBGzDd28p/Yfs/GI6JrLw==", + "version": "17.7.2", + "resolved": "https://registry.npmjs.org/yargs/-/yargs-17.7.2.tgz", + "integrity": "sha512-7dSzzRQ++CKnNI/krKnYRV7JKKPUXMEh61soaHKg9mrWEhzFWhFnxPxGl+69cD1Ou63C13NUPCnmIcrvqCuM6w==", "dev": true, "dependencies": { "cliui": "^8.0.1", @@ -1342,7 +872,7 @@ "node": ">=12" } }, - "node_modules/yargs/node_modules/yargs-parser": { + "node_modules/yargs-parser": { "version": "21.1.1", "resolved": "https://registry.npmjs.org/yargs-parser/-/yargs-parser-21.1.1.tgz", "integrity": "sha512-tVpsJW7DdjecAiFpbIB1e3qxIQsE6NoPc5/eTdrbbIC4h0LVsWhnoa3g+m2HclBIujHzsxZ4VJVA+GUuc2/LBw==", @@ -1353,6 +883,15 @@ } }, "dependencies": { + "@isaacs/fs-minipass": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/@isaacs/fs-minipass/-/fs-minipass-4.0.1.tgz", + "integrity": "sha512-wgm9Ehl2jpeqP3zw/7mo3kRHFp5MEDhqAdwy1fTGkHAwnkGOVsgpvQhL8B5n1qlb01jV3n/bI0ZfZp5lWA1k4w==", + "dev": true, + "requires": { + "minipass": "^7.0.4" + } + }, "@protobufjs/aspromise": { "version": "1.1.2", "resolved": "https://registry.npmjs.org/@protobufjs/aspromise/-/aspromise-1.1.2.tgz", @@ -1449,53 +988,16 @@ "color-convert": "^2.0.1" } }, - "aproba": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/aproba/-/aproba-2.0.0.tgz", - "integrity": "sha512-lYe4Gx7QT+MKGbDsA+Z+he/Wtef0BiwDOlK/XkBrdfsh9J/jPPXbX0tE9x9cl27Tmu5gg3QUbUrQYa/y+KOHPQ==", - "dev": true - }, - "are-we-there-yet": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/are-we-there-yet/-/are-we-there-yet-3.0.1.tgz", - "integrity": "sha512-QZW4EDmGwlYur0Yyf/b2uGucHQMa8aFUP7eu9ddR73vvhFyt4V0Vl3QHPcTNJ8l6qYOBdxgXdnBXQrHilfRQBg==", - "dev": true, - "requires": { - "delegates": "^1.0.0", - "readable-stream": "^3.6.0" - } - }, - "asynckit": { - "version": "0.4.0", - "resolved": "https://registry.npmjs.org/asynckit/-/asynckit-0.4.0.tgz", - "integrity": "sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==", - "dev": true - }, - "axios": { - "version": "1.12.2", - "resolved": "https://registry.npmjs.org/axios/-/axios-1.12.2.tgz", - "integrity": "sha512-vMJzPewAlRyOgxV2dU0Cuz2O8zzzx9VYtbJOaBgXFeLc4IV/Eg50n4LowmehOOR61S8ZMpc2K5Sa7g6A4jfkUw==", - "dev": true, - "requires": { - "follow-redirects": "^1.15.6", - "form-data": "^4.0.4", - "proxy-from-env": "^1.1.0" - } - }, "boolean": { "version": "3.2.0", "resolved": "https://registry.npmjs.org/boolean/-/boolean-3.2.0.tgz", "integrity": "sha512-d0II/GO9uf9lfUHH2BQsjxzRJZBdsjgsBiW4BvhWk/3qoKwQFjIDVN19PfX8F2D/r9PCMTtLWjYVCFrpeYUzsw==" }, - "call-bind-apply-helpers": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/call-bind-apply-helpers/-/call-bind-apply-helpers-1.0.2.tgz", - "integrity": "sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==", - "dev": true, - "requires": { - "es-errors": "^1.3.0", - "function-bind": "^1.1.2" - } + "chownr": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/chownr/-/chownr-3.0.0.tgz", + "integrity": "sha512-+IxzY9BZOQd/XuYPRmrvEVjF/nqj5kgT4kEq7VofrDoM1MxoRjEWkrCC3EtLi59TVawxTAn+orJwFQcrqEN1+g==", + "dev": true }, "cliui": { "version": "8.0.1", @@ -1509,73 +1011,20 @@ } }, "cmake-js": { - "version": "7.2.1", - "resolved": "https://registry.npmjs.org/cmake-js/-/cmake-js-7.2.1.tgz", - "integrity": "sha512-AdPSz9cSIJWdKvm0aJgVu3X8i0U3mNTswJkSHzZISqmYVjZk7Td4oDFg0mCBA383wO+9pG5Ix7pEP1CZH9x2BA==", + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/cmake-js/-/cmake-js-8.0.0.tgz", + "integrity": "sha512-YbUP88RDwCvoQkZhRtGURYm9RIpWdtvZuhT87fKNoLjk8kIFIFeARpKfuZQGdwfH99GZpUmqSfcDrK62X7lTgg==", "dev": true, "requires": { - "axios": "^1.3.2", - "debug": "^4", - "fs-extra": "^10.1.0", - "lodash.isplainobject": "^4.0.6", - "memory-stream": "^1.0.0", - "node-api-headers": "^0.0.2", - "npmlog": "^6.0.2", - "rc": "^1.2.7", - "semver": "^7.3.8", - "tar": "^6.1.11", + "debug": "^4.4.3", + "fs-extra": "^11.3.3", + "node-api-headers": "^1.8.0", + "rc": "1.2.8", + "semver": "^7.7.3", + "tar": "^7.5.6", "url-join": "^4.0.1", - "which": "^2.0.2", - "yargs": "^17.6.0" - }, - "dependencies": { - "chownr": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/chownr/-/chownr-2.0.0.tgz", - "integrity": "sha512-bIomtDF5KGpdogkLd9VspvFzk9KfpyyGlS8YFVZl7TGPBHL5snIOnxeshwVgPteQ9b4Eydl+pVbIyE1DcvCWgQ==", - "dev": true - }, - "minizlib": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/minizlib/-/minizlib-2.1.2.tgz", - "integrity": "sha512-bAxsR8BVfj60DWXHE3u30oHzfl4G7khkSuPW+qvpd7jFRHm7dLxOjUk1EHACJ/hxLY8phGJ0YhYHZo7jil7Qdg==", - "dev": true, - "requires": { - "minipass": "^3.0.0", - "yallist": "^4.0.0" - }, - "dependencies": { - "minipass": { - "version": "3.3.6", - "resolved": "https://registry.npmjs.org/minipass/-/minipass-3.3.6.tgz", - "integrity": "sha512-DxiNidxSEK+tHG6zOIklvNOwm3hvCrbUrdtzY74U6HKTJxvIDfOUL5W5P2Ghd3DTkhhKPYGqeNUIh5qcM4YBfw==", - "dev": true, - "requires": { - "yallist": "^4.0.0" - } - } - } - }, - "mkdirp": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/mkdirp/-/mkdirp-1.0.4.tgz", - "integrity": "sha512-vVqVZQyf3WLx2Shd0qJ9xuvqgAyKPLAiqITEtqW0oIUjzo3PePDd6fW9iFz30ef7Ysp/oiWqbhszeGWW2T6Gzw==", - "dev": true - }, - "tar": { - "version": "6.2.1", - "resolved": "https://registry.npmjs.org/tar/-/tar-6.2.1.tgz", - "integrity": "sha512-DZ4yORTwrbTj/7MZYq2w+/ZFdI6OZ/f9SFHR+71gIVUZhOQPHzVCLpvRnPgyaMpfWxxk/4ONva3GQSyNIKRv6A==", - "dev": true, - "requires": { - "chownr": "^2.0.0", - "fs-minipass": "^2.0.0", - "minipass": "^5.0.0", - "minizlib": "^2.1.1", - "mkdirp": "^1.0.3", - "yallist": "^4.0.0" - } - } + "which": "^6.0.0", + "yargs": "^17.7.2" } }, "color-convert": { @@ -1593,34 +1042,13 @@ "integrity": "sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==", "dev": true }, - "color-support": { - "version": "1.1.3", - "resolved": "https://registry.npmjs.org/color-support/-/color-support-1.1.3.tgz", - "integrity": "sha512-qiBjkpbMLO/HL68y+lh4q0/O1MZFj2RX6X/KmMa3+gJD3z+WwI1ZzDHysvqHGS3mP6mznPckpXmw1nI9cJjyRg==", - "dev": true - }, - "combined-stream": { - "version": "1.0.8", - "resolved": "https://registry.npmjs.org/combined-stream/-/combined-stream-1.0.8.tgz", - "integrity": "sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==", - "dev": true, - "requires": { - "delayed-stream": "~1.0.0" - } - }, - "console-control-strings": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/console-control-strings/-/console-control-strings-1.1.0.tgz", - "integrity": "sha512-ty/fTekppD2fIwRvnZAVdeOiGd1c7YXEixbgJTNzqcxJWKQnjJ/V1bNEEE6hygpM3WjwHFUVK6HTjWSzV4a8sQ==", - "dev": true - }, "debug": { - "version": "4.3.4", - "resolved": "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz", - "integrity": "sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==", + "version": "4.4.3", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.4.3.tgz", + "integrity": "sha512-RGwwWnwQvkVfavKVt22FGLw+xYSdzARwm0ru6DhTVA3umU5hZc28V3kO4stgYryrTlLpuvgI9GiijltAjNbcqA==", "dev": true, "requires": { - "ms": "2.1.2" + "ms": "^2.1.3" } }, "deep-extend": { @@ -1649,34 +1077,11 @@ "object-keys": "^1.1.1" } }, - "delayed-stream": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/delayed-stream/-/delayed-stream-1.0.0.tgz", - "integrity": "sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==", - "dev": true - }, - "delegates": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/delegates/-/delegates-1.0.0.tgz", - "integrity": "sha512-bd2L678uiWATM6m5Z1VzNCErI3jiGzt6HGY8OVICs40JQq/HALfbyNJmp0UDakEY4pMMaN0Ly5om/B1VI/+xfQ==", - "dev": true - }, "detect-node": { "version": "2.1.0", "resolved": "https://registry.npmjs.org/detect-node/-/detect-node-2.1.0.tgz", "integrity": "sha512-T0NIuQpnTvFDATNuHN5roPwSBG83rFsuO+MXXH9/3N1eFbn4wcPjttvjMLEPWJ0RGUYgQE7cGgS3tNxbqCGM7g==" }, - "dunder-proto": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/dunder-proto/-/dunder-proto-1.0.1.tgz", - "integrity": "sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==", - "dev": true, - "requires": { - "call-bind-apply-helpers": "^1.0.1", - "es-errors": "^1.3.0", - "gopd": "^1.2.0" - } - }, "emoji-regex": { "version": "8.0.0", "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", @@ -1702,36 +1107,15 @@ "resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz", "integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==" }, - "es-object-atoms": { - "version": "1.1.1", - "resolved": "https://registry.npmjs.org/es-object-atoms/-/es-object-atoms-1.1.1.tgz", - "integrity": "sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==", - "dev": true, - "requires": { - "es-errors": "^1.3.0" - } - }, - "es-set-tostringtag": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/es-set-tostringtag/-/es-set-tostringtag-2.1.0.tgz", - "integrity": "sha512-j6vWzfrGVfyXxge+O0x5sh6cvxAog0a/4Rdd2K36zCMV5eJ+/+tOAngRO8cODMNWbVRdVlmGZQL2YS3yR8bIUA==", - "dev": true, - "requires": { - "es-errors": "^1.3.0", - "get-intrinsic": "^1.2.6", - "has-tostringtag": "^1.0.2", - "hasown": "^2.0.2" - } - }, "es6-error": { "version": "4.1.1", "resolved": "https://registry.npmjs.org/es6-error/-/es6-error-4.1.1.tgz", "integrity": "sha512-Um/+FxMr9CISWh0bi5Zv0iOD+4cFh5qLeks1qhAopKVAJw3drgKbKySikp7wGhDL0HPeaja0P5ULZrxLkniUVg==" }, "escalade": { - "version": "3.1.1", - "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.1.1.tgz", - "integrity": "sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw==", + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz", + "integrity": "sha512-WUj2qlxaQtO4g6Pq5c29GTcWGDyd8itL8zTlipgECz3JesAiiOKotd8JU6otB3PACgG6xkJUyVhboMS+bje/jA==", "dev": true }, "escape-string-regexp": { @@ -1745,29 +1129,10 @@ "integrity": "sha512-W+KJc2dmILlPplD/H4K9l9LcAHAfPtP6BY84uVLXQ6Evcz9Lcg33Y2z1IVblT6xdY54PXYVHEv+0Wpq8Io6zkA==", "dev": true }, - "follow-redirects": { - "version": "1.15.6", - "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz", - "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==", - "dev": true - }, - "form-data": { - "version": "4.0.4", - "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.4.tgz", - "integrity": "sha512-KrGhL9Q4zjj0kiUt5OO4Mr/A/jlI2jDYs5eHBpYHPcBEVSiipAvn2Ko2HnPe20rmcuuvMHNdZFp+4IlGTMF0Ow==", - "dev": true, - "requires": { - "asynckit": "^0.4.0", - "combined-stream": "^1.0.8", - "es-set-tostringtag": "^2.1.0", - "hasown": "^2.0.2", - "mime-types": "^2.1.12" - } - }, "fs-extra": { - "version": "10.1.0", - "resolved": "https://registry.npmjs.org/fs-extra/-/fs-extra-10.1.0.tgz", - "integrity": "sha512-oRXApq54ETRj4eMiFzGnHWGy+zo5raudjuxN0b8H7s/RU2oW0Wvsx9O0ACRN/kRq9E8Vu/ReskGB5o3ji+FzHQ==", + "version": "11.3.3", + "resolved": "https://registry.npmjs.org/fs-extra/-/fs-extra-11.3.3.tgz", + "integrity": "sha512-VWSRii4t0AFm6ixFFmLLx1t7wS1gh+ckoa84aOeapGum0h+EZd1EhEumSB+ZdDLnEPuucsVB9oB7cxJHap6Afg==", "dev": true, "requires": { "graceful-fs": "^4.2.0", @@ -1775,82 +1140,12 @@ "universalify": "^2.0.0" } }, - "fs-minipass": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/fs-minipass/-/fs-minipass-2.1.0.tgz", - "integrity": "sha512-V/JgOLFCS+R6Vcq0slCuaeWEdNC3ouDlJMNIsacH2VtALiu9mV4LPrHc5cDl8k5aw6J8jwgWWpiTo5RYhmIzvg==", - "dev": true, - "requires": { - "minipass": "^3.0.0" - }, - "dependencies": { - "minipass": { - "version": "3.3.6", - "resolved": "https://registry.npmjs.org/minipass/-/minipass-3.3.6.tgz", - "integrity": "sha512-DxiNidxSEK+tHG6zOIklvNOwm3hvCrbUrdtzY74U6HKTJxvIDfOUL5W5P2Ghd3DTkhhKPYGqeNUIh5qcM4YBfw==", - "dev": true, - "requires": { - "yallist": "^4.0.0" - } - } - } - }, - "function-bind": { - "version": "1.1.2", - "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz", - "integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==", - "dev": true - }, - "gauge": { - "version": "4.0.4", - "resolved": "https://registry.npmjs.org/gauge/-/gauge-4.0.4.tgz", - "integrity": "sha512-f9m+BEN5jkg6a0fZjleidjN51VE1X+mPFQ2DJ0uv1V39oCLCbsGe6yjbBnp7eK7z/+GAon99a3nHuqbuuthyPg==", - "dev": true, - "requires": { - "aproba": "^1.0.3 || ^2.0.0", - "color-support": "^1.1.3", - "console-control-strings": "^1.1.0", - "has-unicode": "^2.0.1", - "signal-exit": "^3.0.7", - "string-width": "^4.2.3", - "strip-ansi": "^6.0.1", - "wide-align": "^1.1.5" - } - }, "get-caller-file": { "version": "2.0.5", "resolved": "https://registry.npmjs.org/get-caller-file/-/get-caller-file-2.0.5.tgz", "integrity": "sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg==", "dev": true }, - "get-intrinsic": { - "version": "1.3.0", - "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.3.0.tgz", - "integrity": "sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==", - "dev": true, - "requires": { - "call-bind-apply-helpers": "^1.0.2", - "es-define-property": "^1.0.1", - "es-errors": "^1.3.0", - "es-object-atoms": "^1.1.1", - "function-bind": "^1.1.2", - "get-proto": "^1.0.1", - "gopd": "^1.2.0", - "has-symbols": "^1.1.0", - "hasown": "^2.0.2", - "math-intrinsics": "^1.1.0" - } - }, - "get-proto": { - "version": "1.0.1", - "resolved": "https://registry.npmjs.org/get-proto/-/get-proto-1.0.1.tgz", - "integrity": "sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==", - "dev": true, - "requires": { - "dunder-proto": "^1.0.1", - "es-object-atoms": "^1.0.0" - } - }, "global-agent": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/global-agent/-/global-agent-3.0.0.tgz", @@ -1892,42 +1187,6 @@ "es-define-property": "^1.0.0" } }, - "has-symbols": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.1.0.tgz", - "integrity": "sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ==", - "dev": true - }, - "has-tostringtag": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/has-tostringtag/-/has-tostringtag-1.0.2.tgz", - "integrity": "sha512-NqADB8VjPFLM2V0VvHUewwwsw0ZWBaIdgo+ieHtK3hasLz4qeCRjYcqfB6AQrBggRKppKF8L52/VqdVsO47Dlw==", - "dev": true, - "requires": { - "has-symbols": "^1.0.3" - } - }, - "has-unicode": { - "version": "2.0.1", - "resolved": "https://registry.npmjs.org/has-unicode/-/has-unicode-2.0.1.tgz", - "integrity": "sha512-8Rf9Y83NBReMnx0gFzA8JImQACstCYWUplepDa9xprwwtmgEZUF0h/i5xSA625zB/I37EtrswSST6OXxwaaIJQ==", - "dev": true - }, - "hasown": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz", - "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==", - "dev": true, - "requires": { - "function-bind": "^1.1.2" - } - }, - "inherits": { - "version": "2.0.4", - "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz", - "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==", - "dev": true - }, "ini": { "version": "1.3.8", "resolved": "https://registry.npmjs.org/ini/-/ini-1.3.8.tgz", @@ -1947,9 +1206,9 @@ "dev": true }, "isexe": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", - "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==", + "version": "3.1.1", + "resolved": "https://registry.npmjs.org/isexe/-/isexe-3.1.1.tgz", + "integrity": "sha512-LpB/54B+/2J5hqQ7imZHfdU31OlgQqx7ZicVlkm9kzg9/w8GKLEcFfJl/t7DCEDueOyBAD6zCCwTO6Fzs0NoEQ==", "dev": true }, "json-parse-better-errors": { @@ -1978,35 +1237,21 @@ } }, "jsonfile": { - "version": "6.1.0", - "resolved": "https://registry.npmjs.org/jsonfile/-/jsonfile-6.1.0.tgz", - "integrity": "sha512-5dgndWOriYSm5cnYaJNhalLNDKOqFwyDB/rr1E9ZsGciGvKPs8R2xYGCacuf3z6K1YKDz182fd+fY3cn3pMqXQ==", + "version": "6.2.0", + "resolved": "https://registry.npmjs.org/jsonfile/-/jsonfile-6.2.0.tgz", + "integrity": "sha512-FGuPw30AdOIUTRMC2OMRtQV+jkVj2cfPqSeWXv1NEAJ1qZ5zb1X6z1mFhbfOB/iy3ssJCD+3KuZ8r8C3uVFlAg==", "dev": true, "requires": { "graceful-fs": "^4.1.6", "universalify": "^2.0.0" } }, - "lodash.isplainobject": { - "version": "4.0.6", - "resolved": "https://registry.npmjs.org/lodash.isplainobject/-/lodash.isplainobject-4.0.6.tgz", - "integrity": "sha512-oSXzaWypCMHkPC3NvBEaPHf0KsA5mvPrOPgQWDsbg8n7orZ290M0BmC/jgRZ4vcJ6DTAhjrsSYgdsW/F+MFOBA==", - "dev": true - }, "long": { "version": "5.2.3", "resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz", "integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==", "dev": true }, - "lru-cache": { - "version": "6.0.0", - "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz", - "integrity": "sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==", - "requires": { - "yallist": "^4.0.0" - } - }, "matcher": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/matcher/-/matcher-3.0.0.tgz", @@ -2015,36 +1260,6 @@ "escape-string-regexp": "^4.0.0" } }, - "math-intrinsics": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/math-intrinsics/-/math-intrinsics-1.1.0.tgz", - "integrity": "sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==", - "dev": true - }, - "memory-stream": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/memory-stream/-/memory-stream-1.0.0.tgz", - "integrity": "sha512-Wm13VcsPIMdG96dzILfij09PvuS3APtcKNh7M28FsCA/w6+1mjR7hhPmfFNoilX9xU7wTdhsH5lJAm6XNzdtww==", - "dev": true, - "requires": { - "readable-stream": "^3.4.0" - } - }, - "mime-db": { - "version": "1.52.0", - "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.52.0.tgz", - "integrity": "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==", - "dev": true - }, - "mime-types": { - "version": "2.1.35", - "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.35.tgz", - "integrity": "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==", - "dev": true, - "requires": { - "mime-db": "1.52.0" - } - }, "minimist": { "version": "1.2.8", "resolved": "https://registry.npmjs.org/minimist/-/minimist-1.2.8.tgz", @@ -2052,11 +1267,20 @@ "dev": true }, "minipass": { - "version": "5.0.0", - "resolved": "https://registry.npmjs.org/minipass/-/minipass-5.0.0.tgz", - "integrity": "sha512-3FnjYuehv9k6ovOEbyOswadCDPX1piCfhV8ncmYtHOjuPwylVWsghTLo7rabjC3Rx5xD4HDx8Wm1xnMF7S5qFQ==", + "version": "7.1.2", + "resolved": "https://registry.npmjs.org/minipass/-/minipass-7.1.2.tgz", + "integrity": "sha512-qOOzS1cBTWYF4BH8fVePDBOO9iptMnGUEZwNc/cMWnTV2nVLZ7VoNWEPHkYczZA0pdoA7dl6e7FL659nX9S2aw==", "dev": true }, + "minizlib": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/minizlib/-/minizlib-3.1.0.tgz", + "integrity": "sha512-KZxYo1BUkWD2TVFLr0MQoM8vUUigWD3LlD83a/75BqC+4qE0Hb1Vo5v1FgcfaNXvfXzr+5EhQ6ing/CaBijTlw==", + "dev": true, + "requires": { + "minipass": "^7.1.2" + } + }, "mkdirp": { "version": "0.5.6", "resolved": "https://registry.npmjs.org/mkdirp/-/mkdirp-0.5.6.tgz", @@ -2067,9 +1291,9 @@ } }, "ms": { - "version": "2.1.2", - "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz", - "integrity": "sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==", + "version": "2.1.3", + "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", + "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==", "dev": true }, "node-addon-api": { @@ -2079,23 +1303,11 @@ "dev": true }, "node-api-headers": { - "version": "0.0.2", - "resolved": "https://registry.npmjs.org/node-api-headers/-/node-api-headers-0.0.2.tgz", - "integrity": "sha512-YsjmaKGPDkmhoNKIpkChtCsPVaRE0a274IdERKnuc/E8K1UJdBZ4/mvI006OijlQZHCfpRNOH3dfHQs92se8gg==", + "version": "1.8.0", + "resolved": "https://registry.npmjs.org/node-api-headers/-/node-api-headers-1.8.0.tgz", + "integrity": "sha512-jfnmiKWjRAGbdD1yQS28bknFM1tbHC1oucyuMPjmkEs+kpiu76aRs40WlTmBmyEgzDM76ge1DQ7XJ3R5deiVjQ==", "dev": true }, - "npmlog": { - "version": "6.0.2", - "resolved": "https://registry.npmjs.org/npmlog/-/npmlog-6.0.2.tgz", - "integrity": "sha512-/vBvz5Jfr9dT/aFWd0FIRf+T/Q2WBsLENygUaFUqstqsycmZAP/t5BvFJTK0viFmSUxiUKTUplWy5vt+rvKIxg==", - "dev": true, - "requires": { - "are-we-there-yet": "^3.0.0", - "console-control-strings": "^1.1.0", - "gauge": "^4.0.3", - "set-blocking": "^2.0.0" - } - }, "object-keys": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/object-keys/-/object-keys-1.1.1.tgz", @@ -2138,12 +1350,6 @@ "long": "^5.0.0" } }, - "proxy-from-env": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/proxy-from-env/-/proxy-from-env-1.1.0.tgz", - "integrity": "sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg==", - "dev": true - }, "rc": { "version": "1.2.8", "resolved": "https://registry.npmjs.org/rc/-/rc-1.2.8.tgz", @@ -2164,17 +1370,6 @@ } } }, - "readable-stream": { - "version": "3.6.1", - "resolved": "https://registry.npmjs.org/readable-stream/-/readable-stream-3.6.1.tgz", - "integrity": "sha512-+rQmrWMYGA90yenhTYsLWAsLsqVC8osOw6PKE1HDYiO0gdPeKe/xDHNzIAIn4C91YQ6oenEhfYqqc1883qHbjQ==", - "dev": true, - "requires": { - "inherits": "^2.0.3", - "string_decoder": "^1.1.1", - "util-deprecate": "^1.0.1" - } - }, "require-directory": { "version": "2.1.1", "resolved": "https://registry.npmjs.org/require-directory/-/require-directory-2.1.1.tgz", @@ -2194,19 +1389,10 @@ "sprintf-js": "^1.1.2" } }, - "safe-buffer": { - "version": "5.2.1", - "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz", - "integrity": "sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==", - "dev": true - }, "semver": { - "version": "7.5.4", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", - "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", - "requires": { - "lru-cache": "^6.0.0" - } + "version": "7.7.3", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.3.tgz", + "integrity": "sha512-SdsKMrI9TdgjdweUSR9MweHA4EJ8YxHn8DFaDisvhVlUOe4BF1tLD7GAj0lIqWVl+dPb/rExr0Btby5loQm20Q==" }, "semver-compare": { "version": "1.0.0", @@ -2221,32 +1407,11 @@ "type-fest": "^0.13.1" } }, - "set-blocking": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/set-blocking/-/set-blocking-2.0.0.tgz", - "integrity": "sha512-KiKBS8AnWGEyLzofFfmvKwpdPzqiy16LvQfK3yv/fVH7Bj13/wl3JSR1J+rfgRE9q7xUJK4qvgS8raSOeLUehw==", - "dev": true - }, - "signal-exit": { - "version": "3.0.7", - "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-3.0.7.tgz", - "integrity": "sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==", - "dev": true - }, "sprintf-js": { "version": "1.1.3", "resolved": "https://registry.npmjs.org/sprintf-js/-/sprintf-js-1.1.3.tgz", "integrity": "sha512-Oo+0REFV59/rz3gfJNKQiBlwfHaSESl1pcGyABQsnnIfWOFt6JNj5gCog2U6MLZ//IGYD+nA8nI+mTShREReaA==" }, - "string_decoder": { - "version": "1.3.0", - "resolved": "https://registry.npmjs.org/string_decoder/-/string_decoder-1.3.0.tgz", - "integrity": "sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA==", - "dev": true, - "requires": { - "safe-buffer": "~5.2.0" - } - }, "string-width": { "version": "4.2.3", "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz", @@ -2279,15 +1444,28 @@ "integrity": "sha512-6fPc+R4ihwqP6N/aIv2f1gMH8lOVtWQHoqC4yK6oSDVVocumAsfCqjkXnqiYMhmMwS/mEHLp7Vehlt3ql6lEig==", "dev": true }, + "tar": { + "version": "7.5.7", + "resolved": "https://registry.npmjs.org/tar/-/tar-7.5.7.tgz", + "integrity": "sha512-fov56fJiRuThVFXD6o6/Q354S7pnWMJIVlDBYijsTNx6jKSE4pvrDTs6lUnmGvNyfJwFQQwWy3owKz1ucIhveQ==", + "dev": true, + "requires": { + "@isaacs/fs-minipass": "^4.0.0", + "chownr": "^3.0.0", + "minipass": "^7.1.2", + "minizlib": "^3.1.0", + "yallist": "^5.0.0" + } + }, "type-fest": { "version": "0.13.1", "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-0.13.1.tgz", "integrity": "sha512-34R7HTnG0XIJcBSn5XhDd7nNFPRcXYRZrBB2O2jdKqYODldSzBAqzsWoZYYvduky73toYS/ESqxPvkDf/F0XMg==" }, "universalify": { - "version": "2.0.0", - "resolved": "https://registry.npmjs.org/universalify/-/universalify-2.0.0.tgz", - "integrity": "sha512-hAZsKq7Yy11Zu1DE0OzWjw7nnLZmJZYTDZZyEFHZdUhV8FkH5MCfoU1XMaxXovpyW5nq5scPqq0ZDP9Zyl04oQ==", + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/universalify/-/universalify-2.0.1.tgz", + "integrity": "sha512-gptHNQghINnc/vTGIk0SOFGFNXw7JVrlRUtConJRlvaw6DuX0wO5Jeko9sWrMBhh+PsYAZ7oXAiOnf/UKogyiw==", "dev": true }, "url-join": { @@ -2296,28 +1474,13 @@ "integrity": "sha512-jk1+QP6ZJqyOiuEI9AEWQfju/nB2Pw466kbA0LEZljHwKeMgd9WrAEgEGxjPDD2+TNbbb37rTyhEfrCXfuKXnA==", "dev": true }, - "util-deprecate": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", - "integrity": "sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==", - "dev": true - }, "which": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", - "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", - "dev": true, - "requires": { - "isexe": "^2.0.0" - } - }, - "wide-align": { - "version": "1.1.5", - "resolved": "https://registry.npmjs.org/wide-align/-/wide-align-1.1.5.tgz", - "integrity": "sha512-eDMORYaPNZ4sQIuuYPDHdQvf4gyCF9rEEV/yPxGfwPkRodwEgiMUUXTx/dex+Me0wxx53S+NgUHaP7y3MGlDmg==", + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/which/-/which-6.0.0.tgz", + "integrity": "sha512-f+gEpIKMR9faW/JgAgPK1D7mekkFoqbmiwvNzuhsHetni20QSgzg9Vhn0g2JSJkkfehQnqdUAx7/e15qS1lPxg==", "dev": true, "requires": { - "string-width": "^1.0.2 || 2 || 3 || 4" + "isexe": "^3.1.1" } }, "wrap-ansi": { @@ -2338,14 +1501,15 @@ "dev": true }, "yallist": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz", - "integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==" + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/yallist/-/yallist-5.0.0.tgz", + "integrity": "sha512-YgvUTfwqyc7UXVMrB+SImsVYSmTS8X/tSrtdNZMImM+n7+QTriRXyXim0mBrTXNeqzVF0KWGgHPeiyViFFrNDw==", + "dev": true }, "yargs": { - "version": "17.7.1", - "resolved": "https://registry.npmjs.org/yargs/-/yargs-17.7.1.tgz", - "integrity": "sha512-cwiTb08Xuv5fqF4AovYacTFNxk62th7LKJ6BL9IGUpTJrWoU7/7WdQGTP2SjKf1dUNBGzDd28p/Yfs/GI6JrLw==", + "version": "17.7.2", + "resolved": "https://registry.npmjs.org/yargs/-/yargs-17.7.2.tgz", + "integrity": "sha512-7dSzzRQ++CKnNI/krKnYRV7JKKPUXMEh61soaHKg9mrWEhzFWhFnxPxGl+69cD1Ou63C13NUPCnmIcrvqCuM6w==", "dev": true, "requires": { "cliui": "^8.0.1", @@ -2355,15 +1519,13 @@ "string-width": "^4.2.3", "y18n": "^5.0.5", "yargs-parser": "^21.1.1" - }, - "dependencies": { - "yargs-parser": { - "version": "21.1.1", - "resolved": "https://registry.npmjs.org/yargs-parser/-/yargs-parser-21.1.1.tgz", - "integrity": "sha512-tVpsJW7DdjecAiFpbIB1e3qxIQsE6NoPc5/eTdrbbIC4h0LVsWhnoa3g+m2HclBIujHzsxZ4VJVA+GUuc2/LBw==", - "dev": true - } } + }, + "yargs-parser": { + "version": "21.1.1", + "resolved": "https://registry.npmjs.org/yargs-parser/-/yargs-parser-21.1.1.tgz", + "integrity": "sha512-tVpsJW7DdjecAiFpbIB1e3qxIQsE6NoPc5/eTdrbbIC4h0LVsWhnoa3g+m2HclBIujHzsxZ4VJVA+GUuc2/LBw==", + "dev": true } } } diff --git a/js/node/package.json b/js/node/package.json index 67a371136946d..4d35ec8c424d5 100644 --- a/js/node/package.json +++ b/js/node/package.json @@ -37,7 +37,7 @@ ], "devDependencies": { "@types/minimist": "^1.2.2", - "cmake-js": "^7.2.1", + "cmake-js": "^8.0.0", "jsonc": "^2.0.0", "minimist": "^1.2.8", "node-addon-api": "^6.0.0", From f25513f5fa2fa63036446a8997a888a516f194d8 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Wed, 28 Jan 2026 14:54:01 -0800 Subject: [PATCH 13/23] Add API GetTensorElementTypeAndShapeDataReference (#27175) ### Description Adds C/C++ API named `GetTensorElementTypeAndShapeDataReference` that returns an OrtValue tensor's shape and type without allocating a new buffer for the shape data. ### Motivation and Context This new API function can be used instead of `OrtApi::GetTypeInfo()` or `OrtApi::GetTensorTypeAndShape` to decrease the number of heap allocations and thus improve inference latency for plugin EPs kernels that frequently retrieve tensor shapes during inference. (e.g., WebGPU plugin EP) --- .../core/session/onnxruntime_c_api.h | 25 ++++++ .../core/session/onnxruntime_cxx_api.h | 13 +++ .../core/session/onnxruntime_cxx_inline.h | 7 ++ .../core/framework/tensor_type_and_shape.cc | 58 ++++++++++++ onnxruntime/core/session/onnxruntime_c_api.cc | 3 +- onnxruntime/core/session/ort_apis.h | 5 ++ onnxruntime/test/shared_lib/test_inference.cc | 88 +++++++++++++++++++ 7 files changed, 198 insertions(+), 1 deletion(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 3f058a38d9bfb..dd2736c4a7598 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -7196,6 +7196,31 @@ struct OrtApi { */ ORT_API_T(void, RunOptionsSetSyncStream, _Inout_ OrtRunOptions* options, _In_ OrtSyncStream* sync_stream); + /** \brief Get the element data type and shape for an OrtValue that represents a Tensor (scalar, dense, or sparse). + * + * \note This function is an alternative to ::GetTensorTypeAndShape() that does not allocate a new array for + * the shape data. The OrtValue instance's internal shape data is returned directly. + * + * \note Returns an error if the underlying OrtValue is not a Tensor. + * + * \param[in] value The OrtValue instance. + * \param[out] elem_type Output parameter set to the tensor element data type. + * \param[out] shape_data Output parameter set to the OrtValue instance's internal shape data array. + * For a scalar, `shape_data` is NULL and `shape_data_count` is 0. + * Must not be released as it is owned by the OrtValue instance. This pointer becomes invalid + * when the OrtValue is released or if the underlying shape data is updated or reallocated. + * \param[out] shape_data_count Output parameter set to the number of elements in `shape_data`. + * `shape_data_count` is 0 for a scalar. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(GetTensorElementTypeAndShapeDataReference, _In_ const OrtValue* value, + _Out_ ONNXTensorElementDataType* elem_type, + _Outptr_result_maybenull_ const int64_t** shape_data, + _Out_ size_t* shape_data_count); + /** \brief Enable profiling for this run * * \param[in] options diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 3e7ddf0075adb..2c1d52894e7f3 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2233,6 +2233,19 @@ struct ConstValueImpl : Base { const R* GetSparseTensorValues() const; #endif + + /// + /// Returns the tensor's element type and a reference to the tensor's internal shape data. The shape data is owned + /// by the Ort::Value and becomes invalid when the Ort::Value is destroyed or if the underlying shape data is + /// updated or reallocated. + /// + /// For a scalar, shape.shape is nullptr and shape.shape_len is 0. + /// + /// Wraps OrtApi::GetTensorElementTypeAndShapeDataReference. + /// + /// Output parameter set to the element's data type. + /// Output parameter set to the OrtValue instance's shape data and number of elements. + void GetTensorElementTypeAndShapeDataReference(ONNXTensorElementDataType& elem_type, Shape& shape) const; }; template diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index e416e470f8144..745128fe6c7b4 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -2387,6 +2387,13 @@ inline const R* ConstValueImpl::GetSparseTensorValues() const { #endif +template +void ConstValueImpl::GetTensorElementTypeAndShapeDataReference(ONNXTensorElementDataType& elem_type, + Shape& shape) const { + ThrowOnError(GetApi().GetTensorElementTypeAndShapeDataReference(this->p_, &elem_type, &shape.shape, + &shape.shape_len)); +} + template void ValueImpl::FillStringTensor(const char* const* s, size_t s_len) { ThrowOnError(GetApi().FillStringTensor(this->p_, s, s_len)); diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc index 0bac24a2c3aa0..16817ba1707bd 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.cc +++ b/onnxruntime/core/framework/tensor_type_and_shape.cc @@ -310,6 +310,64 @@ std::unique_ptr OrtTensorTypeAndShapeInfo::GetTensorS return GetTensorShapeAndTypeHelper(type, shape, dim_params); } +ORT_API_STATUS_IMPL(OrtApis::GetTensorElementTypeAndShapeDataReference, _In_ const OrtValue* value, + _Out_ ONNXTensorElementDataType* elem_type, + _Outptr_result_maybenull_ const int64_t** shape_data, + _Out_ size_t* shape_data_count) { + API_IMPL_BEGIN + if (!value->IsAllocated() || (!value->IsTensor() && !value->IsSparseTensor())) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Input parameter `value` must contain a constructed tensor or sparse tensor"); + } + + if (elem_type == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Output parameter `elem_type` must not be NULL"); + } + + if (shape_data == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Output parameter `shape_data` must not be NULL"); + } + + if (shape_data_count == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Output parameter `shape_data_count` must not be NULL"); + } + + gsl::span shape_span; + onnxruntime::MLDataType ml_data_type = nullptr; + ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + + if (value->IsTensor()) { + const Tensor& tensor = value->Get(); + ml_data_type = tensor.DataType(); + shape_span = tensor.Shape().GetDims(); + } else { +#if !defined(DISABLE_SPARSE_TENSORS) + const SparseTensor& tensor = value->Get(); + ml_data_type = tensor.DataType(); + shape_span = tensor.DenseShape().GetDims(); +#else + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SparseTensor is not supported in this build."); +#endif + } + + if (ml_data_type != nullptr) { + type = MLDataTypeToOnnxRuntimeTensorElementDataType(ml_data_type); + } + + if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) { + return OrtApis::CreateStatus(ORT_FAIL, "Tensor does not have a valid or supported tensor element data type"); + } + + *elem_type = type; + *shape_data = shape_span.empty() ? nullptr : shape_span.data(); + *shape_data_count = shape_span.size(); + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, _In_ const OrtValue* v, _Outptr_ OrtTensorTypeAndShapeInfo** out) { API_IMPL_BEGIN diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 5d8aeb521be08..7a027c8eafb81 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -4802,6 +4802,7 @@ static constexpr OrtApi ort_api_1_to_25 = { &OrtApis::EpAssignedNode_GetDomain, &OrtApis::EpAssignedNode_GetOperatorType, &OrtApis::RunOptionsSetSyncStream, + &OrtApis::GetTensorElementTypeAndShapeDataReference, // End of Version 24 - DO NOT MODIFY ABOVE (see above text for more information) &OrtApis::RunOptionsEnableProfiling, @@ -4842,7 +4843,7 @@ static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Siz static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 317, "Size of version 22 API cannot change"); static_assert(offsetof(OrtApi, CreateExternalInitializerInfo) / sizeof(void*) == 389, "Size of version 23 API cannot change"); -static_assert(offsetof(OrtApi, RunOptionsSetSyncStream) / sizeof(void*) == 413, "Size of version 24 API cannot change"); +static_assert(offsetof(OrtApi, GetTensorElementTypeAndShapeDataReference) / sizeof(void*) == 414, "Size of version 24 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: static_assert(std::string_view(ORT_VERSION) == "1.25.0", diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index fad7e8e9c31bb..3d990909cfb41 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -810,4 +810,9 @@ ORT_API_STATUS_IMPL(EpAssignedSubgraph_GetNodes, _In_ const OrtEpAssignedSubgrap ORT_API_STATUS_IMPL(EpAssignedNode_GetName, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out); ORT_API_STATUS_IMPL(EpAssignedNode_GetDomain, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out); ORT_API_STATUS_IMPL(EpAssignedNode_GetOperatorType, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out); + +ORT_API_STATUS_IMPL(GetTensorElementTypeAndShapeDataReference, _In_ const OrtValue* value, + _Out_ ONNXTensorElementDataType* elem_type, + _Outptr_result_maybenull_ const int64_t** shape_data, + _Out_ size_t* shape_data_count); } // namespace OrtApis diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index a96a2c48b4ca6..4e991716dd108 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -480,6 +480,94 @@ TEST(CApiTest, dim_param) { ASSERT_EQ(strcmp(dim_param, ""), 0); } +// Tests calling OrtApi::GetTensorElementTypeAndShapeDataReference for a dense OrtValue tensor. +TEST(CApiTest, Value_GetTensorElementTypeAndShapeDataReference_DenseTensor) { + Ort::MemoryInfo info_cpu = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemTypeDefault); + + const std::array x_shape = {3, 2}; + std::array x_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + Ort::Value x_value = Ort::Value::CreateTensor(info_cpu, x_values.data(), x_values.size(), + x_shape.data(), x_shape.size()); + Ort::TensorTypeAndShapeInfo type_shape_info = x_value.GetTensorTypeAndShapeInfo(); + + ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + Ort::Value::Shape shape{}; + x_value.GetTensorElementTypeAndShapeDataReference(elem_type, shape); + + ASSERT_EQ(elem_type, type_shape_info.GetElementType()); + + std::vector expected_shape = type_shape_info.GetShape(); + gsl::span actual_shape(shape.shape, shape.shape_len); + ASSERT_EQ(actual_shape, gsl::span(expected_shape)); +} + +// Tests calling OrtApi::GetTensorElementTypeAndShapeDataReference for a scalar OrtValue tensor. +TEST(CApiTest, Value_GetTensorElementTypeAndShapeDataReference_Scalar) { + Ort::MemoryInfo info_cpu = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemTypeDefault); + + std::vector x_shape = {}; // Scalar (no shape) + std::array x_values = {1.0f}; + Ort::Value x_value = Ort::Value::CreateTensor(info_cpu, x_values.data(), x_values.size(), + x_shape.data(), x_shape.size()); + Ort::TensorTypeAndShapeInfo type_shape_info = x_value.GetTensorTypeAndShapeInfo(); + + ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + Ort::Value::Shape shape{}; + x_value.GetTensorElementTypeAndShapeDataReference(elem_type, shape); + + ASSERT_EQ(elem_type, type_shape_info.GetElementType()); + + std::vector expected_shape = type_shape_info.GetShape(); + gsl::span actual_shape(shape.shape, shape.shape_len); + ASSERT_EQ(actual_shape, gsl::span(expected_shape)); + ASSERT_EQ(shape.shape, nullptr); + ASSERT_EQ(shape.shape_len, 0); +} + +#if !defined(DISABLE_SPARSE_TENSORS) +// Tests calling OrtApi::GetTensorElementTypeAndShapeDataReference for a sparse OrtValue tensor. +TEST(CApiTest, Value_GetTensorElementTypeAndShapeDataReference_SparseTensor) { + std::vector common_shape{9, 9}; + std::vector A_values{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, + 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, + 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, + 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, + 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, + 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, + 50.0, 51.0, 52.0, 53.0}; + + // 2 - D index + std::vector indices_shape{gsl::narrow(A_values.size()), 2}; + std::vector A_indices{0, 1, 0, 2, 0, 6, 0, 7, 0, 8, 1, 0, 1, + 1, 1, 2, 1, 6, 1, 7, 1, 8, 2, 0, 2, 1, + 2, 2, 2, 6, 2, 7, 2, 8, 3, 3, 3, 4, 3, + 5, 3, 6, 3, 7, 3, 8, 4, 3, 4, 4, 4, 5, + 4, 6, 4, 7, 4, 8, 5, 3, 5, 4, 5, 5, 5, + 6, 5, 7, 5, 8, 6, 0, 6, 1, 6, 2, 6, 3, + 6, 4, 6, 5, 7, 0, 7, 1, 7, 2, 7, 3, 7, + 4, 7, 5, 8, 0, 8, 1, 8, 2, 8, 3, 8, 4, + 8, 5}; + + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + Ort::Value::Shape ort_dense_shape{common_shape.data(), common_shape.size()}; + Ort::Value::Shape ort_values_shape{&indices_shape[0], 1U}; + auto value_sparse = Ort::Value::CreateSparseTensor(info, A_values.data(), ort_dense_shape, ort_values_shape); + value_sparse.UseCooIndices(A_indices.data(), A_indices.size()); + + Ort::TensorTypeAndShapeInfo type_shape_info = value_sparse.GetTensorTypeAndShapeInfo(); + + ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + Ort::Value::Shape shape{}; + value_sparse.GetTensorElementTypeAndShapeDataReference(elem_type, shape); + + ASSERT_EQ(elem_type, type_shape_info.GetElementType()); + + std::vector expected_shape = type_shape_info.GetShape(); + gsl::span actual_shape(shape.shape, shape.shape_len); + ASSERT_EQ(actual_shape, gsl::span(expected_shape)); +} +#endif // !defined(DISABLE_SPARSE_TENSORS) + static std::pair LoadAndGetInputShapePresent(const ORTCHAR_T* const model_url) { Ort::Session session(*ort_env, model_url, Ort::SessionOptions{}); const auto input_num = session.GetInputCount(); From 7d017ba7ab86df77046b187b05ac562a3228b64d Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 28 Jan 2026 15:59:31 -0800 Subject: [PATCH 14/23] Fix: Replace pkg_resources with importlib.metadata in machine_info.py (#27157) Replaces the deprecated pkg_resources library with importlib.metadata to fix ModuleNotFoundError. --- .../python/tools/transformers/machine_info.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/onnxruntime/python/tools/transformers/machine_info.py b/onnxruntime/python/tools/transformers/machine_info.py index 55f71278dd458..f5c7a03fae91c 100644 --- a/onnxruntime/python/tools/transformers/machine_info.py +++ b/onnxruntime/python/tools/transformers/machine_info.py @@ -6,6 +6,7 @@ # It is used to dump machine information for Notebooks import argparse +import importlib.metadata import json import logging import platform @@ -122,10 +123,7 @@ def get_gpu_info_by_nvml(self) -> dict: return result def get_related_packages(self) -> list[str]: - import pkg_resources # noqa: PLC0415 - - installed_packages = pkg_resources.working_set - related_packages = [ + related_packages = { "onnxruntime-gpu", "onnxruntime", "onnx", @@ -137,8 +135,12 @@ def get_related_packages(self) -> list[str]: "flatbuffers", "numpy", "onnxconverter-common", - ] - related_packages_list = {i.key: i.version for i in installed_packages if i.key in related_packages} + } + related_packages_list = {} + for dist in importlib.metadata.distributions(): + if dist.metadata["Name"].lower() in related_packages: + related_packages_list[dist.metadata["Name"].lower()] = dist.version + return related_packages_list def get_onnxruntime_info(self) -> dict: From a70ce6274daed4ffc24c95235fb54ecc4ea9ca36 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 28 Jan 2026 16:17:07 -0800 Subject: [PATCH 15/23] Bump lodash from 4.17.21 to 4.17.23 in /js/web (#27105) Bumps [lodash](https://github.com/lodash/lodash) from 4.17.21 to 4.17.23.
Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=lodash&package-manager=npm_and_yarn&previous-version=4.17.21&new-version=4.17.23)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- js/web/package-lock.json | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/js/web/package-lock.json b/js/web/package-lock.json index 324b533a0d436..6c000515dbc74 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -635,7 +635,6 @@ "resolved": "https://registry.npmjs.org/chai/-/chai-4.3.7.tgz", "integrity": "sha512-HLnAzZ2iupm25PlN0xFreAlBA5zaBSv3og0DdeGA4Ar6h6rJ3A0rolRUKJhSF2V10GZKDgWF/VmAEsNWjCRB+A==", "dev": true, - "peer": true, "dependencies": { "assertion-error": "^1.1.0", "check-error": "^1.0.2", @@ -2082,7 +2081,6 @@ "resolved": "https://registry.npmjs.org/karma/-/karma-6.4.1.tgz", "integrity": "sha512-Cj57NKOskK7wtFWSlMvZf459iX+kpYIPXmkNUzP2WAFcA7nhr/ALn5R7sw3w+1udFDcpMx/tuB8d5amgm3ijaA==", "dev": true, - "peer": true, "dependencies": { "@colors/colors": "1.5.0", "body-parser": "^1.19.0", @@ -2296,9 +2294,9 @@ } }, "node_modules/lodash": { - "version": "4.17.21", - "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz", - "integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==", + "version": "4.17.23", + "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.23.tgz", + "integrity": "sha512-LgVTMpQtIopCi79SJeDiP0TfWi5CNEc/L/aRdTh3yIvmZXTnheWpKjSZhnvMl8iXbC1tFg9gdHHDMLoV7CnG+w==", "dev": true }, "node_modules/lodash.memoize": { @@ -4203,7 +4201,6 @@ "resolved": "https://registry.npmjs.org/chai/-/chai-4.3.7.tgz", "integrity": "sha512-HLnAzZ2iupm25PlN0xFreAlBA5zaBSv3og0DdeGA4Ar6h6rJ3A0rolRUKJhSF2V10GZKDgWF/VmAEsNWjCRB+A==", "dev": true, - "peer": true, "requires": { "assertion-error": "^1.1.0", "check-error": "^1.0.2", @@ -5332,7 +5329,6 @@ "resolved": "https://registry.npmjs.org/karma/-/karma-6.4.1.tgz", "integrity": "sha512-Cj57NKOskK7wtFWSlMvZf459iX+kpYIPXmkNUzP2WAFcA7nhr/ALn5R7sw3w+1udFDcpMx/tuB8d5amgm3ijaA==", "dev": true, - "peer": true, "requires": { "@colors/colors": "1.5.0", "body-parser": "^1.19.0", @@ -5513,9 +5509,9 @@ } }, "lodash": { - "version": "4.17.21", - "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz", - "integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==", + "version": "4.17.23", + "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.23.tgz", + "integrity": "sha512-LgVTMpQtIopCi79SJeDiP0TfWi5CNEc/L/aRdTh3yIvmZXTnheWpKjSZhnvMl8iXbC1tFg9gdHHDMLoV7CnG+w==", "dev": true }, "lodash.memoize": { From c5d5802fd01941131ca01496b24d7b3302ca3e48 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 29 Jan 2026 00:52:45 +0000 Subject: [PATCH 16/23] Bump lodash from 4.17.21 to 4.17.23 in /onnxruntime/test/wasm (#27106) Bumps [lodash](https://github.com/lodash/lodash) from 4.17.21 to 4.17.23.
Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=lodash&package-manager=npm_and_yarn&previous-version=4.17.21&new-version=4.17.23)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- onnxruntime/test/wasm/package-lock.json | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxruntime/test/wasm/package-lock.json b/onnxruntime/test/wasm/package-lock.json index ecdfa07f4b447..9eb1a76e16756 100644 --- a/onnxruntime/test/wasm/package-lock.json +++ b/onnxruntime/test/wasm/package-lock.json @@ -962,9 +962,9 @@ } }, "node_modules/lodash": { - "version": "4.17.21", - "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz", - "integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==", + "version": "4.17.23", + "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.23.tgz", + "integrity": "sha512-LgVTMpQtIopCi79SJeDiP0TfWi5CNEc/L/aRdTh3yIvmZXTnheWpKjSZhnvMl8iXbC1tFg9gdHHDMLoV7CnG+w==", "dev": true }, "node_modules/log4js": { @@ -2495,9 +2495,9 @@ } }, "lodash": { - "version": "4.17.21", - "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz", - "integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==", + "version": "4.17.23", + "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.23.tgz", + "integrity": "sha512-LgVTMpQtIopCi79SJeDiP0TfWi5CNEc/L/aRdTh3yIvmZXTnheWpKjSZhnvMl8iXbC1tFg9gdHHDMLoV7CnG+w==", "dev": true }, "log4js": { From e39799cc577719e078209691f90007affcd5c783 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Wed, 28 Jan 2026 19:26:51 -0800 Subject: [PATCH 17/23] Fix API doc comment for OrtApi::RunOptionsEnableProfiling (#27195) ### Description Fixes C++ documentation generation by replacing `<` and `>` with `[` and `]`. Angle brackets are mistaken as html tags. Successful run: https://github.com/microsoft/onnxruntime/actions/runs/21456738258 ### Motivation and Context Allow C++ document generation to succeed. --- include/onnxruntime/core/session/onnxruntime_c_api.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index dd2736c4a7598..221f3673f2027 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -7225,7 +7225,7 @@ struct OrtApi { * * \param[in] options * \param[in] profile_file_prefix The prefix for the profile file. The actual filename will be: - * _.json + * [profile_file_prefix]_[timestamp].json * * \snippet{doc} snippets.dox OrtStatus Return Value * From eafef0da68453c602c6fc95fbc1f80c6f621991e Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 28 Jan 2026 23:27:43 -0800 Subject: [PATCH 18/23] [MLAS] Fix Flaky LuT GEMM Tests by Replacing Gather with Shuffle (#27174) ## Problem Description The `MatMulNBitsLutGemm` test suite, specifically `Float32_2Bits_Symmetric_256x256_BlkLen64`, was observing intermittent failures (flakiness). The failure manifested as numerical mismatches exceeding the tolerance, suggesting non-deterministic behavior in the kernel execution. ## Root Cause Analysis The issue was traced to the usage of `_mm256_i32gather_ps` in sqnbitgemm_lut_kernel_avx2.cpp While the gather indices were technically calculating addresses within the bounds of the allocated buffer, gather instructions on certain AVX2 hardware implementations can exhibit non-deterministic behavior or subtle performance/prefetching artifacts when operating on specific stride patterns (in this case, gathering with a stride of 4 floats). ## Solution This PR replaces the `_mm256_i32gather_ps` instruction with a sequence of **contiguous loads (`_mm256_loadu_ps`) followed by deterministic shuffles**. ### How it works: 1. **Contiguous Load**: We load 4 contiguous vectors of 8 floats elements using `_mm256_loadu_ps`. This is always memory-safe and deterministic. 2. **Deterministic Shuffle**: We apply a verified sequence of `unpack` and `permutevar8x32` instructions to rearrange these 32 linearly loaded elements into the exact same stride-4 layout that the gather instruction produced. ### Benefits: * **Stability**: Eliminates the hardware-dependent non-determinism of gather. * **Safety**: Usage of `loadu` guarantees we only touch memory within the explicit range of the 32 elements we intend to load. * **Correctness**: The shuffle logic was verified against the reference gather behavior using a C++ reproduction script to ensure bit-exact layout equivalence. ### Performance Micro-benchmark on MatMulNBitsLutGemm (256x256, BlkLen=64). Original (Gather): ~55.55 us Fixed (Load+Shuffle): ~57.79 us Delta: +2.24 us (~4% slower) The slight performance regression is expected because replacing a single hardware gather instruction with a sequence of loadu, unpack, and permute instructions adds instruction count overhead. However, this is a necessary tradeoff to ensure deterministic behavior and memory safety across all AVX2 implementations. ## Verification * **Tests**: All 9 tests in `MatMulNBitsLutGemm` passed successfully (including the previously flaky `BlkLen64` case). --- .../mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp | 54 ++++++++++++++----- .../test/contrib_ops/matmul_2bits_test.cc | 6 ++- 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp index b54f051ca1504..a89993d4515b8 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp @@ -187,21 +187,53 @@ get_bias_scale() return 3; } +static inline void +MlasAvx2LoaduDeinterleave32Ps(const float* src, __m256& v0, __m256& v1, __m256& v2, __m256& v3) +{ + // Process 32 activations contiguously using loadu + shuffle. + // This allows us to mix neighbors (src[4i], src[4i+1], src[4i+2], src[4i+3]) across lanes, + // which matches the T-MAC weight packing. + // We use loadu + shuffle instead of gather to avoid potential issues with gather + // on some hardware and ensure deterministic behavior. + __m256 vec_b0 = _mm256_loadu_ps(src + 0); + __m256 vec_b1 = _mm256_loadu_ps(src + 8); + __m256 vec_b2 = _mm256_loadu_ps(src + 16); + __m256 vec_b3 = _mm256_loadu_ps(src + 24); + + __m256 t0 = _mm256_unpacklo_ps(vec_b0, vec_b1); + __m256 t1 = _mm256_unpackhi_ps(vec_b0, vec_b1); + __m256 t2 = _mm256_unpacklo_ps(vec_b2, vec_b3); + __m256 t3 = _mm256_unpackhi_ps(vec_b2, vec_b3); + + __m256 u0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t2))); + __m256 u1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t2))); + __m256 u2 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t1), _mm256_castps_pd(t3))); + __m256 u3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t1), _mm256_castps_pd(t3))); + + const __m256i perm_idx = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7); + v0 = _mm256_permutevar8x32_ps(u0, perm_idx); + v1 = _mm256_permutevar8x32_ps(u1, perm_idx); + v2 = _mm256_permutevar8x32_ps(u2, perm_idx); + v3 = _mm256_permutevar8x32_ps(u3, perm_idx); +} + void partial_max_g4_int8_k8(float* lut_scales, const float* b) { - // TODO(vraspar): add support for arm neon - const __m256i vec_bi = _mm256_set_epi32(112, 96, 80, 64, 48, 32, 16, 0); - __m256 vec_b0 = _mm256_i32gather_ps(b + 0, vec_bi, 1); - __m256 vec_b1 = _mm256_i32gather_ps(b + 1, vec_bi, 1); - __m256 vec_b2 = _mm256_i32gather_ps(b + 2, vec_bi, 1); - __m256 vec_b3 = _mm256_i32gather_ps(b + 3, vec_bi, 1); + __m256 vec_b0, vec_b1, vec_b2, vec_b3; + MlasAvx2LoaduDeinterleave32Ps(b, vec_b0, vec_b1, vec_b2, vec_b3); + const __m256 vec_sign = _mm256_set1_ps(-0.0f); __m256 vec_babs0 = _mm256_andnot_ps(vec_sign, vec_b0); __m256 vec_babs1 = _mm256_andnot_ps(vec_sign, vec_b1); __m256 vec_babs2 = _mm256_andnot_ps(vec_sign, vec_b2); __m256 vec_babs3 = _mm256_andnot_ps(vec_sign, vec_b3); + + // The upper bound for the LUT values (mixtures of 4 activations) is the sum + // of their absolute values. __m256 abssum = _mm256_add_ps(_mm256_add_ps(vec_babs0, vec_babs1), _mm256_add_ps(vec_babs2, vec_babs3)); + + // Reduce max across lanes to find the global maximum sum in this chunk. __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(abssum, 1), _mm256_castps256_ps128(abssum)); max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4)); @@ -222,16 +254,14 @@ lut_ctor_g4_int8_impl( ) { __m256 vec_lut[16]; - float biases = 0.0; - const __m256i vec_bi = _mm256_set_epi32(112, 96, 80, 64, 48, 32, 16, 0); + float biases = 0.0f; float scales = *lut_scales; float t_scales = scales ? 1.0f / scales : 0.0f; for (int k = 0; k < act_k / 32; ++k) { - __m256 vec_b0 = _mm256_i32gather_ps(b + k * 32 + 0, vec_bi, 1); - __m256 vec_b1 = _mm256_i32gather_ps(b + k * 32 + 1, vec_bi, 1); - __m256 vec_b2 = _mm256_i32gather_ps(b + k * 32 + 2, vec_bi, 1); - __m256 vec_b3 = _mm256_i32gather_ps(b + k * 32 + 3, vec_bi, 1); + const float* b_chunk = b + k * 32; + __m256 vec_b0, vec_b1, vec_b2, vec_b3; + MlasAvx2LoaduDeinterleave32Ps(b_chunk, vec_b0, vec_b1, vec_b2, vec_b3); PRAGMA_UNROLL for (int g = 1; g < 16; g += 2) { diff --git a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc index 853458312cd1f..3d5e3e5f360b4 100644 --- a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc @@ -371,8 +371,10 @@ TEST(MatMulNBitsLutGemm, Float32_2Bits_Symmetric_256x256) { TestMatMul2BitsLutGemm(1, 256, 256, 32, false); } -// TODO: Re-enable once LUT GEMM asymmetric quantization accuracy issue is resolved -TEST(MatMulNBitsLutGemm, DISABLED_Float32_2Bits_Asymmetric_256x256) { +// This test was previously disabled due to accuracy issues related to non-deterministic +// gather operations. It is now re-enabled after replacing gather with deterministic +// load+shuffle operations to improve determinism and stability. +TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_256x256) { TestMatMul2BitsLutGemm(1, 256, 256, 32, true); } From 82d0bc863e630aa7e9a1e7f20a667cf700f5ba71 Mon Sep 17 00:00:00 2001 From: BODAPATIMAHESH <148746454+BODAPATIMAHESH@users.noreply.github.com> Date: Fri, 30 Jan 2026 00:34:29 +0530 Subject: [PATCH 19/23] POWER : Fix build failure due to unsupported cpuinfo on ppc64le (#27120) Description Conditionally disable linking of cpuinfo for onnxruntime_runtime_path_test_shared_library on targets, where cpuinfo is not supported. Motivation and Context Recent changes enabling onnxruntime_autoep_test and related shared library tests on non-Windows platforms exposed a transitive dependency issue. cpuinfo was being linked unconditionally on Linux, leading to linker failures on ppc64le (cannot find -lcpuinfo). Solution Add CPUINFO_SUPPORTED guards to exclude cpuinfo from the link list while preserving existing behavior. --- cmake/onnxruntime_unittests.cmake | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 2c88eaceda8b5..988b98fb29d48 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1520,8 +1520,13 @@ endif() onnxruntime_common ${CMAKE_DL_LIBS}) set_target_properties(onnxruntime_runtime_path_test_shared_library PROPERTIES AIX_SHARED_LIBRARY_ARCHIVE OFF) else() - target_link_libraries(onnxruntime_runtime_path_test_shared_library PRIVATE - onnxruntime_common cpuinfo ${CMAKE_DL_LIBS}) + if (CPUINFO_SUPPORTED) + target_link_libraries(onnxruntime_runtime_path_test_shared_library PRIVATE + onnxruntime_common cpuinfo ${CMAKE_DL_LIBS}) + else() + target_link_libraries(onnxruntime_runtime_path_test_shared_library PRIVATE + onnxruntime_common ${CMAKE_DL_LIBS}) + endif() endif() target_include_directories(onnxruntime_runtime_path_test_shared_library PRIVATE ${ONNXRUNTIME_ROOT}) From 00a5d1a7cd6d98d57c7701bec53355be3fb3a422 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 30 Jan 2026 13:31:20 -0800 Subject: [PATCH 20/23] [MLAS] Fix rotary interleaved NEON kernel (#26390) The logic of interleaved NEON kernel is not correct from code review: 1. **Test Code Logic:** The test code `test_rope.h` allocates the `sin` and `cos` tables based on the `interleaved` flag: ```c++ size_t table_len = interleaved ? rotary_emb_dim / 2 : rotary_emb_dim; std::vector sin_data(table_len); std::vector cos_data(table_len); ``` For the `interleaved = true` case, the test creates `sin` and `cos` tables of length `rotary_emb_dim / 2`. 2. **AVX2 (fp32) Kernel Logic (`interleaved = true`):** This kernel loads the `sin`/`cos` data using an index of `i / 2`: ```c++ float32x8_t sin_val = _mm256_loadu_ps(sin_data + i / 2); float32x8_t cos_val = _mm256_loadu_ps(cos_data + i / 2); ``` This logic expects a `sin`/`cos` table of length `rotary_emb_dim / 2`. **Conclusion: The AVX2 (fp32) kernel is consistent with the test code.** 3. **NEON (fp16) Kernel Logic (`interleaved = true`):** This kernel loads the `sin`/`cos` data using an index of `i`: ```c++ // Enters loop with sin_val = MlasLoadFloat16x8(sin + i); //... // Inside loop, for next iteration: sin_val = MlasLoadFloat16x8(sin + i + 16); ``` This logic expects a `sin`/`cos` table of length `rotary_emb_dim`. **Conclusion: The NEON (fp16) kernel is NOT consistent with the test code.** ### Regression Test ``` cmake --build build/Linux/Release --config Release --target onnxruntime_mlas_test && ./build/Linux/Release/onnxruntime_mlas_test --gtest_filter=NeonFp16RoPE* ``` Before applying the fix, the test failed: ``` [ FAILED ] NeonFp16RoPE.ShortExecute (13 ms) onnxruntime/onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp:66: Failure Value of: CloseEnough(output_impl[i].ToFloat(), output_ref[i].ToFloat()) Actual: false Expected: true Expected bits: 19491 (16.546875) Actual bits: 56596 (-325) @[16], rotary_emb_dim=24, interleaved=true ``` After applying the fix, test passed. ### Summary The `RopeKernel_Avx2_fp32_Impl` kernel correctly aligns with the test code (and the fallback implementation) by expecting a `sin`/`cos` table of length `rotary_emb_dim / 2`. The `RopeKernel_Fp16_Impl` (NEON) kernel incorrectly expects a table of length `rotary_emb_dim`. When run against the provided test, the NEON kernel will read past the end of the `sin_data` and `cos_data` vectors. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../lib/rotary_embedding_kernel_neon_fp16.cpp | 36 +++--- .../mlas/unittest/test_rope_neon_fp16.cpp | 104 ++++++++++++++++++ 2 files changed, 122 insertions(+), 18 deletions(-) create mode 100644 onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp index 3a93723fc3b52..e611009733fbf 100644 --- a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp +++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp @@ -150,8 +150,8 @@ RopeKernel_Fp16_Impl( if (i + 15 < dim) { float16x8_t x0 = MlasLoadFloat16x8(input + i); float16x8_t x1 = MlasLoadFloat16x8(input + i + 8); - float16x8_t sin_val = MlasLoadFloat16x8(sin + i); - float16x8_t cos_val = MlasLoadFloat16x8(cos + i); + float16x8_t sin_val = MlasLoadFloat16x8(sin + i / 2); + float16x8_t cos_val = MlasLoadFloat16x8(cos + i / 2); for (; i + 31 < dim; i += 16) { float16x8_t real = vuzp1q_f16(x0, x1); float16x8_t imag = vuzp2q_f16(x0, x1); @@ -163,8 +163,8 @@ RopeKernel_Fp16_Impl( MlasStoreFloat16x8(output + i + 8, y1); x0 = MlasLoadFloat16x8(input + i + 16); x1 = MlasLoadFloat16x8(input + i + 24); - sin_val = MlasLoadFloat16x8(sin + i + 16); - cos_val = MlasLoadFloat16x8(cos + i + 16); + sin_val = MlasLoadFloat16x8(sin + (i + 16) / 2); + cos_val = MlasLoadFloat16x8(cos + (i + 16) / 2); } float16x8_t real = vuzp1q_f16(x0, x1); float16x8_t imag = vuzp2q_f16(x0, x1); @@ -181,8 +181,8 @@ RopeKernel_Fp16_Impl( float16x4_t x1 = MlasLoadFloat16x4(input + i + 4); float16x4_t real = vuzp1_f16(x0, x1); float16x4_t imag = vuzp2_f16(x0, x1); - float16x4_t sin_val = MlasLoadFloat16x4(sin + i); - float16x4_t cos_val = MlasLoadFloat16x4(cos + i); + float16x4_t sin_val = MlasLoadFloat16x4(sin + i / 2); + float16x4_t cos_val = MlasLoadFloat16x4(cos + i / 2); float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); float16x4_t y0 = vzip1_f16(real_out, imag_out); @@ -201,12 +201,12 @@ RopeKernel_Fp16_Impl( imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag); real = MlasLoadLaneFloat16x4<2>(input + i + 4, real); imag = MlasLoadLaneFloat16x4<2>(input + i + 5, imag); - sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); - sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); - sin_val = MlasLoadLaneFloat16x4<2>(sin + i + 2, sin_val); - cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); - cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); - cos_val = MlasLoadLaneFloat16x4<2>(cos + i + 2, cos_val); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i / 2, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i / 2 + 1, sin_val); + sin_val = MlasLoadLaneFloat16x4<2>(sin + i / 2 + 2, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i / 2, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i / 2 + 1, cos_val); + cos_val = MlasLoadLaneFloat16x4<2>(cos + i / 2 + 2, cos_val); float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); MlasStoreLaneFloat16x4<0>(output + i, real_out); @@ -224,10 +224,10 @@ RopeKernel_Fp16_Impl( imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag); real = MlasLoadLaneFloat16x4<1>(input + i + 2, real); imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag); - sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); - sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); - cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); - cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i / 2, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i / 2 + 1, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i / 2, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i / 2 + 1, cos_val); float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); MlasStoreLaneFloat16x4<0>(output + i, real_out); @@ -241,8 +241,8 @@ RopeKernel_Fp16_Impl( float16x4_t cos_val = MlasZeroFloat16x4(); real = MlasLoadLaneFloat16x4<0>(input + i, real); imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag); - sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); - cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i / 2, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i / 2, cos_val); float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); MlasStoreLaneFloat16x4<0>(output + i, real_out); diff --git a/onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp b/onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp new file mode 100644 index 0000000000000..3ff4fee69eac9 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp @@ -0,0 +1,104 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_rope_neon_fp16.cpp + +Abstract: + + Tests for MLAS fp16 RoPE on NEON. + +--*/ + +#include +#include + +#include "core/mlas/inc/mlas.h" + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + +#include "test_util.h" +#include "core/mlas/lib/mlasi.h" +#include "core/mlas/lib/rotary_embedding.h" +#include "core/mlas/lib/rotary_embedding_kernel_neon.h" + +class MlasNeonFp16RoPETest : public MlasTestBase { + private: + const float Pi = 2 * std::acos(0.0f); + + void Test(size_t rotary_emb_dim, bool interleaved) { + // Per kernel logic (both fallback and optimized), the sin/cos tables + // are always half the rotary embedding dimension. + const size_t table_len = rotary_emb_dim / 2; + + std::vector input(rotary_emb_dim); + std::vector sin_data(table_len); + std::vector cos_data(table_len); + std::vector output_ref(rotary_emb_dim); + std::vector output_impl(rotary_emb_dim); + + // Initialize input data + for (size_t i = 0; i < rotary_emb_dim; ++i) { + input[i] = MLAS_FP16(static_cast(i + 1)); + } + + // Initialize sin/cos tables + for (size_t i = 0; i < table_len; ++i) { + float theta = static_cast(i) / 1000.0f * Pi; + sin_data[i] = MLAS_FP16(std::sin(theta)); + cos_data[i] = MLAS_FP16(std::cos(theta)); + } + + // Call fallback implementation + MlasRotaryEmbedOneRow_FallBack(input.data(), sin_data.data(), cos_data.data(), rotary_emb_dim, interleaved, output_ref.data()); + + // Call dispatched implementation (which should pick up the NEON kernel) + MlasRotaryEmbedOneRow(input.data(), sin_data.data(), cos_data.data(), rotary_emb_dim, interleaved, output_impl.data()); + + // Compare results + for (size_t i = 0; i < rotary_emb_dim; i++) { + ASSERT_TRUE(CloseEnough(output_impl[i].ToFloat(), output_ref[i].ToFloat())) + << "Expected bits: " << output_ref[i].val << " (" << output_ref[i].ToFloat() << ")" + << " Actual bits: " << output_impl[i].val << " (" << output_impl[i].ToFloat() << ")" + << " @[" << i << "], " + << "rotary_emb_dim=" << rotary_emb_dim << ", interleaved=" << interleaved; + } + } + + public: + static const char* GetTestSuiteName() { + return "NeonFp16RoPE"; + } + + void ExecuteShort(void) override { + // Test dimensions that cover main loops and various remainders + Test(6, false); + Test(6, true); + Test(16, false); + Test(16, true); + Test(24, false); + Test(24, true); + Test(32, false); + Test(32, true); + Test(42, false); + Test(42, true); + Test(64, false); + Test(64, true); + Test(70, false); + Test(70, true); + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + return count; +}); + +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) From 67b8bc16c491a2c0db6a0565e0f1d8c4cfc3b8c2 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sun, 1 Feb 2026 20:31:02 -0800 Subject: [PATCH 21/23] [web] use shorter memory info name for WebGPU buffer and WebNN tensor (#27207) ### Description This PR renames the following existing names for MemoryInfo: - `WebGPU_Buffer` -> `WebGPU_Buf` - `WebNN_Tensor` -> `WebNN_Ten` ### Motivation and Context the `OrtMemoryInfo` uses a `std::string` to store the name. modern C++ compilers uses "small string optimization" (SSO) to avoid an extra memory allocation if the string is small enough. While different compiler may have different implementation, the following test program is used to test what exact limit is for a certain compiler: ```c++ #include #include int main() { std::string webgpu0 = "WebGPU_Buf"; std::string webgpu1 = "WebGPU_Buff"; std::string webgpu2 = "WebGPU_Buffe"; std::string webgpu3 = "WebGPU_Buffer"; printf("=========== %s\n string address: %p\n data address : %p\n\n", webgpu0.c_str(), (void*)&webgpu0, (void*)webgpu0.data()); printf("=========== %s\n string address: %p\n data address : %p\n\n", webgpu1.c_str(), (void*)&webgpu1, (void*)webgpu1.data()); printf("=========== %s\n string address: %p\n data address : %p\n\n", webgpu2.c_str(), (void*)&webgpu2, (void*)webgpu2.data()); printf("=========== %s\n string address: %p\n data address : %p\n\n", webgpu3.c_str(), (void*)&webgpu3, (void*)webgpu3.data()); return 0; } ``` While using emscripten (targetting wasm32), the runtime result is like this: ``` =========== WebGPU_Buf string address: 0x10db0 data address : 0x10db0 =========== WebGPU_Buff string address: 0x10da4 data address : 0x10dc8 =========== WebGPU_Buffe string address: 0x10d98 data address : 0x10de0 =========== WebGPU_Buffer string address: 0x10d8c data address : 0x10df8 ``` Which shows that the string need to be no more than 10 bytes (exclude the '\0' at end) to enable SSO. --- include/onnxruntime/core/framework/allocator.h | 4 ++-- js/node/src/inference_session_wrap.cc | 2 +- js/node/src/tensor_helper.cc | 2 +- .../webgpu/quantization/gather_block_quantized.cc | 2 +- onnxruntime/wasm/api.cc | 8 ++++---- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index 983be1f9efd5c..383562bc5a405 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -85,8 +85,8 @@ constexpr const char* OpenVINO_GPU = "OpenVINO_GPU"; constexpr const char* OpenVINO_RT = "OpenVINO_RT"; constexpr const char* OpenVINO_RT_NPU = "OpenVINO_RT_NPU"; constexpr const char* QNN_HTP_SHARED = "QnnHtpShared"; -constexpr const char* WEBGPU_BUFFER = "WebGPU_Buffer"; -constexpr const char* WEBNN_TENSOR = "WebNN_Tensor"; +constexpr const char* WEBGPU_BUFFER = "WebGPU_Buf"; // limited to 10 chars to ensure std::string SSO for web +constexpr const char* WEBNN_TENSOR = "WebNN_Ten"; // limited to 10 chars to ensure std::string SSO for web constexpr size_t kAllocAlignment = 256; constexpr const size_t kAlloc4KAlignment = 4096; diff --git a/js/node/src/inference_session_wrap.cc b/js/node/src/inference_session_wrap.cc index 14f19ef352cd1..e82ab0c3c7498 100644 --- a/js/node/src/inference_session_wrap.cc +++ b/js/node/src/inference_session_wrap.cc @@ -181,7 +181,7 @@ Napi::Value InferenceSessionWrap::Run(const Napi::CallbackInfo& info) { size_t inputIndex = 0; size_t outputIndex = 0; Ort::MemoryInfo cpuMemoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); - Ort::MemoryInfo gpuBufferMemoryInfo{"WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault}; + Ort::MemoryInfo gpuBufferMemoryInfo{"WebGPU_Buf", OrtDeviceAllocator, 0, OrtMemTypeDefault}; try { for (auto& name : inputNames_) { diff --git a/js/node/src/tensor_helper.cc b/js/node/src/tensor_helper.cc index 0630386cfc645..f6b9f3132ec31 100644 --- a/js/node/src/tensor_helper.cc +++ b/js/node/src/tensor_helper.cc @@ -251,7 +251,7 @@ Napi::Value OrtValueToNapiValue(Napi::Env env, Ort::Value&& value) { // location auto memoryInfo = value.GetTensorMemoryInfo(); bool isGpuBuffer = memoryInfo.GetDeviceType() == OrtMemoryInfoDeviceType_GPU && - memoryInfo.GetAllocatorName() == "WebGPU_Buffer"; + memoryInfo.GetAllocatorName() == "WebGPU_Buf"; // size auto size = tensorTypeAndShapeInfo.GetElementCount(); diff --git a/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc b/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc index a1f684311ec82..62bfbcbbf9f5a 100755 --- a/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc @@ -149,7 +149,7 @@ Status GatherBlockQuantized::ComputeInternal(ComputeContext& context) const { TensorShape data_representation_4bit_shape{x->Shape()}; MLDataType new_dtype = (x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) ? DataTypeImpl::GetType() : DataTypeImpl::GetType(); auto memory_info = OrtMemoryInfo{ - "WebGPU_Buffer", + WEBGPU_BUFFER, OrtDeviceAllocator, OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0}}; diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 147eab7116d94..0494c7471f2ac 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -393,10 +393,10 @@ OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* OrtMemoryInfo* memory_info = nullptr; switch (data_location) { case DATA_LOCATION_GPU_BUFFER: - RETURN_NULLPTR_IF_ERROR(CreateMemoryInfo, "WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); + RETURN_NULLPTR_IF_ERROR(CreateMemoryInfo, "WebGPU_Buf", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); break; case DATA_LOCATION_ML_TENSOR: - RETURN_NULLPTR_IF_ERROR(CreateMemoryInfo, "WebNN_Tensor", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); + RETURN_NULLPTR_IF_ERROR(CreateMemoryInfo, "WebNN_Ten", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); break; default: RETURN_NULLPTR_IF_ERROR(CreateCpuMemoryInfo, OrtDeviceAllocator, OrtMemTypeDefault, &memory_info); @@ -563,9 +563,9 @@ int EMSCRIPTEN_KEEPALIVE OrtBindOutput(OrtIoBinding* io_binding, if (output_location != DATA_LOCATION_GPU_BUFFER && output_location != DATA_LOCATION_ML_TENSOR) { RETURN_ERROR_CODE_IF_ERROR(CreateCpuMemoryInfo, OrtDeviceAllocator, OrtMemTypeDefault, &memory_info); } else if (output_location == DATA_LOCATION_ML_TENSOR) { - RETURN_ERROR_CODE_IF_ERROR(CreateMemoryInfo, "WebNN_Tensor", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); + RETURN_ERROR_CODE_IF_ERROR(CreateMemoryInfo, "WebNN_Ten", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); } else { - RETURN_ERROR_CODE_IF_ERROR(CreateMemoryInfo, "WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); + RETURN_ERROR_CODE_IF_ERROR(CreateMemoryInfo, "WebGPU_Buf", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); } REGISTER_AUTO_RELEASE_HANDLE(MemoryInfo, memory_info); return CHECK_STATUS(BindOutputToDevice, io_binding, name, memory_info); From fa933c9c8e29ed8df78256a5fd8ac3a20201e9f5 Mon Sep 17 00:00:00 2001 From: milpuz01 Date: Wed, 21 Jan 2026 13:19:09 +0000 Subject: [PATCH 22/23] mlas/arm64: add NEON conv asm kernels and tune NCHWC kernel selection Signed-off-by: Milos Puzovic --- cmake/onnxruntime_mlas.cmake | 2 ++ onnxruntime/core/mlas/lib/aarch64/SconvDepthwiseKernelNeon.S | 1 + 2 files changed, 3 insertions(+) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 59abea26e4f60..80034f909a155 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -466,6 +466,8 @@ else() ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSdot.S ${MLAS_SRC_DIR}/aarch64/SgemmKernelNeon.S ${MLAS_SRC_DIR}/aarch64/SgemvKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/SconvDepthwiseKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/SconvPointwiseKernelNeon.S ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelNeon.S ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdot.S ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdotLd64.S diff --git a/onnxruntime/core/mlas/lib/aarch64/SconvDepthwiseKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SconvDepthwiseKernelNeon.S index 6521c50eb40a2..121d0f09d3110 100644 --- a/onnxruntime/core/mlas/lib/aarch64/SconvDepthwiseKernelNeon.S +++ b/onnxruntime/core/mlas/lib/aarch64/SconvDepthwiseKernelNeon.S @@ -18,6 +18,7 @@ Abstract: * When an output position touches padding, only the affected 4-wide lanes are checked individually and loaded; others are zeroed. This mirrors the behavior of the C++ helper LoadInputVectorWithBounds. + mirrors the behaviour of the C++ helper LoadInputVectorWithBounds. * Keep the multiply/accumulate operations tightly scheduled to hide the load latency. From 5cba128f0f6e497bf5654f858ea808826dab6281 Mon Sep 17 00:00:00 2001 From: Milos Puzovic Date: Wed, 4 Feb 2026 09:38:13 +0000 Subject: [PATCH 23/23] Address the comments from reviewers, fix failing tests and reduce stack spill Signed-off-by: Milos Puzovic --- .../core/mlas/lib/aarch64/SconvKernelNeon.S | 88 ++++++-- .../lib/aarch64/SconvPointwiseKernelNeon.S | 201 +++++++++--------- onnxruntime/core/mlas/lib/platform.cpp | 10 +- onnxruntime/core/mlas/lib/snchwc.cpp | 10 +- 4 files changed, 187 insertions(+), 122 deletions(-) diff --git a/onnxruntime/core/mlas/lib/aarch64/SconvKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SconvKernelNeon.S index 643f537834663..fd0a57b7c314a 100644 --- a/onnxruntime/core/mlas/lib/aarch64/SconvKernelNeon.S +++ b/onnxruntime/core/mlas/lib/aarch64/SconvKernelNeon.S @@ -19,12 +19,20 @@ Abstract: // integer parameters follow the AArch64 calling convention and are in // x0-x7. Remaining parameters are spilled by the C++ caller. - .equ .LFrame_SavedRegs, (5*16) + .equ .LFrame_SavedRegs, (5*16 + 4*32 + 16 + 32) .equ .LFrame_x19_x20, 0 .equ .LFrame_x21_x22, 16 .equ .LFrame_x23_x24, 32 .equ .LFrame_x25_x26, 48 .equ .LFrame_x27_x28, 64 + .equ .LFrame_q8_q9, 80 + .equ .LFrame_q10_q11, 112 + .equ .LFrame_q12_q13, 144 + .equ .LFrame_q14_q15, 176 + .equ .LFrame_InputBaseSaved, 208 + .equ .LFrame_FilterBaseSaved, 216 + .equ .LFrame_OutputBaseSaved, 224 + .equ .LFrame_LrSaved, 232 .equ .KO_OutputStride, (0 + .LFrame_SavedRegs) .equ .KO_KernelHeight, (8 + .LFrame_SavedRegs) @@ -422,7 +430,6 @@ Abstract: str q13,[x24,#16] str q14,[x24,#32] str q15,[x24,#48] - add x27,x27,x8 add x27,x27,x8 // set2 base add x24,x27,#64 tbz w6,#0,4f @@ -488,10 +495,10 @@ Abstract: add x19,x1,x7 .endif .if \N >= 3 - add x20,x19,x7 + add x24,x19,x7 .endif .if \N >= 4 - add x26,x20,x7 + add x26,x24,x7 .endif mov x21,x0 ldr x9,[sp,#.KO_KernelHeight] @@ -525,7 +532,7 @@ Abstract: fmla v7.4s,v29.4s,v24.4s .endif .if \N >= 3 - ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x20],#64 + ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x24],#64 fmla v8.4s,v26.4s,v24.4s fmla v9.4s,v27.4s,v24.4s fmla v10.4s,v28.4s,v24.4s @@ -558,10 +565,10 @@ Abstract: add x19,x1,x7 .endif .if \N >= 3 - add x20,x19,x7 + add x24,x19,x7 .endif .if \N >= 4 - add x26,x20,x7 + add x26,x24,x7 .endif mov x21,x0 ldr x9,[sp,#.KO_KernelHeight] @@ -584,7 +591,7 @@ Abstract: fmla v7.4s,v29.4s,v24.4s .endif .if \N >= 3 - ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x20],#64 + ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x24],#64 fmla v8.4s,v26.4s,v24.4s fmla v9.4s,v27.4s,v24.4s fmla v10.4s,v28.4s,v24.4s @@ -617,10 +624,13 @@ Abstract: add x19,x1,x7 .endif .if \N >= 3 - add x20,x19,x7 + add x24,x19,x7 .endif ldr x11,[sp,#.KO_KernelWidth] ldr x9,[sp,#.KO_KernelHeight] + ldr x12,[sp,#.KO_DilatedInputWidth] + mul x13,x11,x4 // kernel_width * dilation_width + sub x12,x12,x13 // input stride between kernel rows 1: mov x10,x11 2: prfm pldl1keep,[x28,#192] ld1r {v24.4s},[x25] @@ -646,7 +656,7 @@ Abstract: fmla v15.4s,v29.4s,v25.4s .endif .if \N >= 3 - ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x20],#64 + ld1 {v26.4s,v27.4s,v28.4s,v29.4s},[x24],#64 fmla v16.4s,v26.4s,v24.4s fmla v17.4s,v27.4s,v24.4s fmla v18.4s,v28.4s,v24.4s @@ -660,7 +670,6 @@ Abstract: add x26,x26,x4 subs x10,x10,#1 b.ne 2b - ldr x12,[sp,#.KO_DilatedInputWidth] add x25,x25,x12 add x26,x26,x12 subs x9,x9,#1 @@ -680,13 +689,52 @@ Abstract: stp x23,x24,[sp,#.LFrame_x23_x24] stp x25,x26,[sp,#.LFrame_x25_x26] stp x27,x28,[sp,#.LFrame_x27_x28] + stp q8,q9,[sp,#.LFrame_q8_q9] + stp q10,q11,[sp,#.LFrame_q10_q11] + stp q12,q13,[sp,#.LFrame_q12_q13] + stp q14,q15,[sp,#.LFrame_q14_q15] + str x0,[sp,#.LFrame_InputBaseSaved] + str x1,[sp,#.LFrame_FilterBaseSaved] + str x2,[sp,#.LFrame_OutputBaseSaved] + str x30,[sp,#.LFrame_LrSaved] + + cmp x5,#4 + b.ne .LRunKernelSingle + + // Split the 4-filter case into two 2-filter passes to enable the + // two-output inner loop without inflating register pressure. + mov x5,#2 + ldr x1,[sp,#.LFrame_FilterBaseSaved] + ldr x2,[sp,#.LFrame_OutputBaseSaved] + ldr x17,[sp,#.KO_Bias] + bl .LKernelBody + + ldr x1,[sp,#.LFrame_FilterBaseSaved] + add x1,x1,x7,lsl #1 + ldr x2,[sp,#.LFrame_OutputBaseSaved] + ldr x8,[sp,#.KO_OutputStride] + add x2,x2,x8,lsl #1 + ldr x17,[sp,#.KO_Bias] + cbz x17,1f + add x17,x17,#128 +1: + mov x5,#2 + bl .LKernelBody + + b .LDoneEntry + +.LRunKernelSingle: + ldr x17,[sp,#.KO_Bias] + bl .LKernelBody + b .LDoneEntry + +.LKernelBody: // frequently used parameters ldr x8,[sp,#.KO_OutputStride] ldr x14,[sp,#.KO_OutputCountLeftPad] ldr x15,[sp,#.KO_OutputCount] ldr x16,[sp,#.KO_OutputCountRightPad] - ldr x17,[sp,#.KO_Bias] ldr w6,[sp,#.KO_Flags] // kernel flags eor v31.16b,v31.16b,v31.16b @@ -695,6 +743,7 @@ Abstract: cbz x14,.LInterior mov x20,#0 .LLeftLoop: + ldr x0,[sp,#.LFrame_InputBaseSaved] madd x21,x20,x3,x0 // input pointer for this output lsl x22,x20,#6 // (index*64) add x27,x2,x22 @@ -725,6 +774,7 @@ Abstract: .LInterior: cbz x15,.LRight + ldr x0,[sp,#.LFrame_InputBaseSaved] mul x24,x14,x3 // input byte offset after left pad add x21,x0,x24 // base input for interior lsl x23,x14,#6 // output offset (bytes) @@ -795,9 +845,11 @@ Abstract: // right padded region ------------------------------------------------- .LRight: - cbz x16,.LDone + cbz x16,.LKernelReturn + ldr x0,[sp,#.LFrame_InputBaseSaved] mov x20,#0 .LRightLoop: + ldr x0,[sp,#.LFrame_InputBaseSaved] add x24,x20,x14 add x24,x24,x15 madd x21,x24,x3,x0 @@ -825,11 +877,19 @@ Abstract: cmp x20,x16 blo .LRightLoop -.LDone: +.LKernelReturn: + ret + +.LDoneEntry: + ldp q14,q15,[sp,#.LFrame_q14_q15] + ldp q12,q13,[sp,#.LFrame_q12_q13] + ldp q10,q11,[sp,#.LFrame_q10_q11] + ldp q8,q9,[sp,#.LFrame_q8_q9] ldp x27,x28,[sp,#.LFrame_x27_x28] ldp x25,x26,[sp,#.LFrame_x25_x26] ldp x23,x24,[sp,#.LFrame_x23_x24] ldp x21,x22,[sp,#.LFrame_x21_x22] + ldr x30,[sp,#.LFrame_LrSaved] ldp x19,x20,[sp],#.LFrame_SavedRegs ret diff --git a/onnxruntime/core/mlas/lib/aarch64/SconvPointwiseKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SconvPointwiseKernelNeon.S index 8caae4fc080ac..7b56157094a2a 100644 --- a/onnxruntime/core/mlas/lib/aarch64/SconvPointwiseKernelNeon.S +++ b/onnxruntime/core/mlas/lib/aarch64/SconvPointwiseKernelNeon.S @@ -42,12 +42,12 @@ Abstract: // Compute four outputs for a single input channel block. The accumulators // for the four outputs are held in v16-v31. .macro CPK4_FmlaStep - ldp q0,q1,[x24],#32 - ldp q2,q3,[x24],#32 - ld1r {v4.4s},[x25],#4 - ld1r {v5.4s},[x26],#4 - ld1r {v6.4s},[x27],#4 - ld1r {v7.4s},[x28],#4 + ldp q0,q1,[x0],#32 + ldp q2,q3,[x0],#32 + ld1r {v4.4s},[x1],#4 + ld1r {v5.4s},[x2],#4 + ld1r {v6.4s},[x3],#4 + ld1r {v7.4s},[x4],#4 fmla v16.4s,v0.4s,v4.4s fmla v17.4s,v1.4s,v4.4s fmla v18.4s,v2.4s,v4.4s @@ -70,8 +70,8 @@ Abstract: // output are loaded to v0-v3 and each lane is multiplied with a block of // filter coefficients. .macro CPK1_FmlaWithLane lane, AReg - ldp q4,q5,[x24],#32 - ldp q6,q7,[x24],#32 + ldp q4,q5,[x0],#32 + ldp q6,q7,[x0],#32 fmla v16.4s,v4.4s,\AReg\().s[\lane] fmla v17.4s,v5.4s,\AReg\().s[\lane] fmla v18.4s,v6.4s,\AReg\().s[\lane] @@ -80,16 +80,16 @@ Abstract: // Compute a single output position. Results are returned in v16-v19. .macro CPK_ComputeOneOutput - mov x22,#0 + mov x5,#0 eor v16.16b,v16.16b,v16.16b eor v17.16b,v17.16b,v17.16b eor v18.16b,v18.16b,v18.16b eor v19.16b,v19.16b,v19.16b .Lpw_ic_loop1: - madd x23,x22,x6,x15 - ldp q0,q1,[x23] - ldp q2,q3,[x23,#32] - add x24,x17,x22,lsl #10 + madd x1,x5,x7,x15 + ldp q0,q1,[x1] + ldp q2,q3,[x1,#32] + add x0,x17,x5,lsl #10 CPK1_FmlaWithLane 0, v0 CPK1_FmlaWithLane 1, v0 CPK1_FmlaWithLane 2, v0 @@ -106,8 +106,8 @@ Abstract: CPK1_FmlaWithLane 1, v3 CPK1_FmlaWithLane 2, v3 CPK1_FmlaWithLane 3, v3 - add x22,x22,#1 - cmp x22,x4 + add x5,x5,#1 + cmp x5,x9 blt .Lpw_ic_loop1 .endm @@ -123,31 +123,40 @@ Abstract: ldr x10,[sp,#.LPW_Bias] ldr w11,[sp,#.LPW_Flags] - // Preserve callee saved registers that are used by this routine. - stp x22,x23,[sp,#-16]! - stp x24,x25,[sp,#-16]! - stp x26,x27,[sp,#-16]! - stp x28,x19,[sp,#-16]! + // Spill base arguments so caller-saved registers can be reused freely. + sub sp,sp,#96 + stp x0,x1,[sp,#0] + stp x2,x3,[sp,#16] + stp x4,x5,[sp,#32] + stp x6,x7,[sp,#48] + str x10,[sp,#64] // bias base + str x9,[sp,#72] // output count - mov x14,#0 // current filter set + mov x12,#0 // current filter set cbz x5,.Lpw_exit // nothing to do .Lpw_filter_loop: // Compute the base pointers for this filter block. - madd x16,x14,x8,x2 // output pointer for this filter - madd x17,x14,x7,x1 // filter pointer for this filter - add x18,x10,x14,lsl #6 // bias pointer (if used) - - mov x15,x0 // input base for this iteration - mov x20,x16 // running output pointer - mov x12,x9 // output count this iteration - - lsr x13,x12,#2 // number of groups of four outputs - cbz x13,.Lpw_process_remainder + ldr x15,[sp,#0] // input base + ldr x16,[sp,#16] // output base + madd x16,x12,x8,x16 // output pointer for this filter + ldr x17,[sp,#8] // filter base + ldr x0,[sp,#56] // filter set stride + madd x17,x12,x0,x17 // filter pointer for this filter + ldr x10,[sp,#64] // bias base + add x10,x10,x12,lsl #6 // bias pointer (if used) + ldr x6,[sp,#24] // input row stride + ldr x7,[sp,#48] // input channel stride + ldr x13,[sp,#72] // output count + lsr x14,x13,#2 // number of groups of four outputs + and x13,x13,#3 // remaining outputs + ldr x9,[sp,#32] // input channel blocks + cbz x14,.Lpw_process_remainder // ------------------------------------------------------------------ // Main loop processing 4 outputs at a time. // ------------------------------------------------------------------ + .p2align 4 .Lpw_groups: // Clear accumulators for 4 outputs (16 vectors total). eor v16.16b,v16.16b,v16.16b @@ -167,14 +176,13 @@ Abstract: eor v30.16b,v30.16b,v30.16b eor v31.16b,v31.16b,v31.16b - mov x22,#0 // current input channel block + mov x5,#0 // current input channel block .Lpw_ic_loop4: - madd x23,x22,x6,x15 // input for this block - mov x25,x23 // four rows starting positions - add x26,x23,x3 - add x27,x26,x3 - add x28,x27,x3 - add x24,x17,x22,lsl #10 // filter for this block + madd x1,x5,x7,x15 // input for this block + add x2,x1,x6 // four rows starting positions + add x3,x2,x6 + add x4,x3,x6 + add x0,x17,x5,lsl #10 // filter for this block // The block size is 16 so unroll 16 steps. CPK4_FmlaStep @@ -194,8 +202,8 @@ Abstract: CPK4_FmlaStep CPK4_FmlaStep - add x22,x22,#1 - cmp x22,x4 + add x5,x5,#1 + cmp x5,x9 blt .Lpw_ic_loop4 // ----------------------------------------------------------------- @@ -209,12 +217,12 @@ Abstract: // Accumulation path. Load bias once as it is re-used for all four // stores when present. tbz w11,#1,1f - ldp q4,q5,[x18] - ldp q6,q7,[x18,#32] + ldp q4,q5,[x10] + ldp q6,q7,[x10,#32] 1: // ---- output 0 ---- - ldp q0,q1,[x20] - ldp q2,q3,[x20,#32] + ldp q0,q1,[x16] + ldp q2,q3,[x16,#32] tbz w11,#1,2f fadd v0.4s,v0.4s,v4.4s fadd v1.4s,v1.4s,v5.4s @@ -232,13 +240,13 @@ Abstract: fmax v18.4s,v18.4s,v0.4s fmax v19.4s,v19.4s,v0.4s 3: - stp q16,q17,[x20] - stp q18,q19,[x20,#32] + stp q16,q17,[x16] + stp q18,q19,[x16,#32] // ---- output 1 ---- - add x22,x20,#.LPW_BlockBytes - ldp q0,q1,[x22] - ldp q2,q3,[x22,#32] + add x0,x16,#.LPW_BlockBytes + ldp q0,q1,[x0] + ldp q2,q3,[x0,#32] tbz w11,#1,4f fadd v0.4s,v0.4s,v4.4s fadd v1.4s,v1.4s,v5.4s @@ -256,13 +264,13 @@ Abstract: fmax v22.4s,v22.4s,v0.4s fmax v23.4s,v23.4s,v0.4s 5: - stp q20,q21,[x22] - stp q22,q23,[x22,#32] + stp q20,q21,[x0] + stp q22,q23,[x0,#32] // ---- output 2 ---- - add x22,x22,#.LPW_BlockBytes - ldp q0,q1,[x22] - ldp q2,q3,[x22,#32] + add x0,x0,#.LPW_BlockBytes + ldp q0,q1,[x0] + ldp q2,q3,[x0,#32] tbz w11,#1,6f fadd v0.4s,v0.4s,v4.4s fadd v1.4s,v1.4s,v5.4s @@ -280,13 +288,13 @@ Abstract: fmax v26.4s,v26.4s,v0.4s fmax v27.4s,v27.4s,v0.4s 7: - stp q24,q25,[x22] - stp q26,q27,[x22,#32] + stp q24,q25,[x0] + stp q26,q27,[x0,#32] // ---- output 3 ---- - add x22,x22,#.LPW_BlockBytes - ldp q0,q1,[x22] - ldp q2,q3,[x22,#32] + add x0,x0,#.LPW_BlockBytes + ldp q0,q1,[x0] + ldp q2,q3,[x0,#32] tbz w11,#1,8f fadd v0.4s,v0.4s,v4.4s fadd v1.4s,v1.4s,v5.4s @@ -304,15 +312,15 @@ Abstract: fmax v30.4s,v30.4s,v0.4s fmax v31.4s,v31.4s,v0.4s 9: - stp q28,q29,[x22] - stp q30,q31,[x22,#32] + stp q28,q29,[x0] + stp q30,q31,[x0,#32] b .Lpw_advance_group // Non-accumulating path: add bias directly to the results if requested .Lpw_store_nacc: tbz w11,#1,10f - ldp q4,q5,[x18] - ldp q6,q7,[x18,#32] + ldp q4,q5,[x10] + ldp q6,q7,[x10,#32] fadd v16.4s,v16.4s,v4.4s fadd v17.4s,v17.4s,v5.4s fadd v18.4s,v18.4s,v6.4s @@ -349,40 +357,39 @@ Abstract: fmax v30.4s,v30.4s,v0.4s fmax v31.4s,v31.4s,v0.4s 11: - stp q16,q17,[x20] - stp q18,q19,[x20,#32] - add x22,x20,#.LPW_BlockBytes - stp q20,q21,[x22] - stp q22,q23,[x22,#32] - add x22,x22,#.LPW_BlockBytes - stp q24,q25,[x22] - stp q26,q27,[x22,#32] - add x22,x22,#.LPW_BlockBytes - stp q28,q29,[x22] - stp q30,q31,[x22,#32] + stp q16,q17,[x16] + stp q18,q19,[x16,#32] + add x0,x16,#.LPW_BlockBytes + stp q20,q21,[x0] + stp q22,q23,[x0,#32] + add x0,x0,#.LPW_BlockBytes + stp q24,q25,[x0] + stp q26,q27,[x0,#32] + add x0,x0,#.LPW_BlockBytes + stp q28,q29,[x0] + stp q30,q31,[x0,#32] .Lpw_advance_group: - add x15,x15,x3,lsl #2 - add x20,x20,#(.LPW_BlockBytes*4) - subs x13,x13,#1 + add x15,x15,x6,lsl #2 + add x16,x16,#(.LPW_BlockBytes*4) + subs x14,x14,#1 b.ne .Lpw_groups // ------------------------------------------------------------------ // Handle the leftover (0..3) output positions. // ------------------------------------------------------------------ .Lpw_process_remainder: - ands x12,x12,#3 - beq .Lpw_after_filter + cbz x13,.Lpw_after_filter .Lpw_left_loop: CPK_ComputeOneOutput // Accumulate? tbz w11,#0,.Lpw_left_noacc - ldp q0,q1,[x20] - ldp q2,q3,[x20,#32] + ldp q0,q1,[x16] + ldp q2,q3,[x16,#32] tbz w11,#1,12f - ldp q4,q5,[x18] - ldp q6,q7,[x18,#32] + ldp q4,q5,[x10] + ldp q6,q7,[x10,#32] fadd v0.4s,v0.4s,v4.4s fadd v1.4s,v1.4s,v5.4s fadd v2.4s,v2.4s,v6.4s @@ -399,14 +406,14 @@ Abstract: fmax v18.4s,v18.4s,v0.4s fmax v19.4s,v19.4s,v0.4s 13: - stp q16,q17,[x20] - stp q18,q19,[x20,#32] + stp q16,q17,[x16] + stp q18,q19,[x16,#32] b 14f .Lpw_left_noacc: tbz w11,#1,15f - ldp q4,q5,[x18] - ldp q6,q7,[x18,#32] + ldp q4,q5,[x10] + ldp q6,q7,[x10,#32] fadd v16.4s,v16.4s,v4.4s fadd v17.4s,v17.4s,v5.4s fadd v18.4s,v18.4s,v6.4s @@ -419,23 +426,21 @@ Abstract: fmax v18.4s,v18.4s,v0.4s fmax v19.4s,v19.4s,v0.4s 16: - stp q16,q17,[x20] - stp q18,q19,[x20,#32] + stp q16,q17,[x16] + stp q18,q19,[x16,#32] 14: - add x15,x15,x3 - add x20,x20,#.LPW_BlockBytes - subs x12,x12,#1 + add x15,x15,x6 + add x16,x16,#.LPW_BlockBytes + subs x13,x13,#1 b.ne .Lpw_left_loop .Lpw_after_filter: - add x14,x14,#1 - cmp x14,x5 + add x12,x12,#1 + ldr x0,[sp,#40] // output channel blocks + cmp x12,x0 blt .Lpw_filter_loop .Lpw_exit: - ldp x28,x19,[sp],#16 - ldp x26,x27,[sp],#16 - ldp x24,x25,[sp],#16 - ldp x22,x23,[sp],#16 + add sp,sp,#96 ret .end diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 88fb59ac47093..c3548b79dcb66 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -571,15 +571,15 @@ Return Value: this->EltwiseDispatch = &MlasEltwiseDispatchNeon; #if defined(MLAS_USE_ARM_NEON_NCHWC) - // Prefer the hand written micro-kernel for the NCHW convolution path. It + // Prefer the hand written micro-kernel for the NCHW convolution path. It // offers a tighter schedule and a specialised two-output inner loop that - // reduces pressure on the memory system compared + // reduces pressure on the memory system compared to the generic kernel. this->ConvNchwFloatKernel = MlasConvNchwFloatKernelNeonAsm; // Prefer the hand written AArch64 micro-kernel for pointwise convolution // as it computes multiple output positions at once and significantly - // reduces memory traffic. The AArch64 assembly kernel is picked up by - // heuristics in platform.cpp to avoid regressions on small convolutions. - // So here we set the default to the intrinsics version + // reduces memory traffic. The AArch64 assembly kernel is selected by + // heuristics in snchwc.cpp to avoid regressions on small convolutions, so + // we set the default to the intrinsics version here. this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelNeon; this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelNeon; this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelNeon; diff --git a/onnxruntime/core/mlas/lib/snchwc.cpp b/onnxruntime/core/mlas/lib/snchwc.cpp index 0c4d95400a8d1..ef59b22c8a2b4 100644 --- a/onnxruntime/core/mlas/lib/snchwc.cpp +++ b/onnxruntime/core/mlas/lib/snchwc.cpp @@ -926,9 +926,8 @@ struct MLAS_NCHWC_CONV_POINTWISE_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM const float* filter = Filter; float* output = Output + BlockSize * ph * OutputWidth; - size_t InputChannelBatch; - - for (size_t ic = 0; ic < InputChannels; ic += InputChannelBatch) { + size_t InputChannelBatch = 0; + for (size_t ic = 0; ic < InputChannels; ) { constexpr size_t MaximumInputChannelBatch = 128; @@ -959,8 +958,9 @@ struct MLAS_NCHWC_CONV_POINTWISE_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM DoActivation(output, FilterCount, BlockSize * OutputThisIteration); } - input += MaximumInputChannelBatch * InputSize; - filter += BlockSize * MaximumInputChannelBatch; + input += InputChannelBatch * InputSize; + filter += BlockSize * InputChannelBatch; + ic += InputChannelBatch; } //