diff --git a/.github/workflows/mac.yml b/.github/workflows/mac.yml index 3d94d30947c76..6efa8a5592337 100644 --- a/.github/workflows/mac.yml +++ b/.github/workflows/mac.yml @@ -58,6 +58,90 @@ jobs: --use_xnnpack \ --use_binskim_compliant_compile_flags + Vcpkg: + runs-on: macos-13 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ${{ env.python_version }} + + - name: "Run vcpkg(x64-osx)" + uses: lukka/run-vcpkg@v11 + with: + vcpkgDirectory: "${{ runner.temp }}/vcpkg" + vcpkgGitCommitId: "1de2026f28ead93ff1773e6e680387643e914ea1" # 2024.07.12 + runVcpkgInstall: true + vcpkgJsonGlob: "cmake/vcpkg.json" + vcpkgConfigurationJsonGlob: "cmake/vcpkg-configuration.json" + env: + VCPKG_INSTALLED_DIR: "${{ github.workspace }}/.build" + VCPKG_DEFAULT_TRIPLET: "x64-osx" + # VCPKG_BINARY_SOURCES: "default" # https://learn.microsoft.com/en-us/vcpkg/reference/binarycaching + + - name: "Run compile_schema.py" + run: | + # Runner's host triplet should be x64-osx or arm64-osx + export FLATC_DIR="${{ github.workspace }}/.build/${{ runner.arch }}-osx/tools/flatbuffers" + export PATH="$FLATC_DIR:$PATH" + flatc --version + python onnxruntime/core/flatbuffers/schema/compile_schema.py --flatc "$(which flatc)" + + - name: "Detect protoc" + id: protoc-detect + run: | + export PROTOC_DIR="${{ github.workspace }}/.build/${{ runner.arch }}-osx/tools/protobuf" + export PATH="$PROTOC_DIR:$PATH" + protoc --version + echo "protoc_path=$(which protoc)" >> "$GITHUB_OUTPUT" + + - name: "Run build.py(x64-osx)" + run: | + python ./tools/ci_build/build.py \ + --build_dir "build/x64-osx" \ + --skip_submodule_sync \ + --skip_tests \ + --compile_no_warning_as_error \ + --parallel \ + --path_to_protoc_exe "${{ steps.protoc-detect.outputs.protoc_path }}" \ + --osx_arch x86_64 \ + --use_vcpkg \ + --cmake_extra_defines "CMAKE_TOOLCHAIN_FILE:FILEPATH=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake" \ + --cmake_extra_defines "VCPKG_TARGET_TRIPLET=x64-osx" \ + --cmake_extra_defines "VCPKG_INSTALLED_DIR:PATH=${{ github.workspace }}/.build" \ + --cmake_extra_defines "VCPKG_INSTALL_OPTIONS=--x-feature=tests" + shell: bash + + - name: "Run vcpkg(arm64-osx)" + uses: lukka/run-vcpkg@v11 + with: + vcpkgDirectory: "${{ runner.temp }}/vcpkg" + vcpkgGitCommitId: "1de2026f28ead93ff1773e6e680387643e914ea1" # 2024.07.12 + runVcpkgInstall: true + vcpkgJsonGlob: "cmake/vcpkg.json" + vcpkgConfigurationJsonGlob: "cmake/vcpkg-configuration.json" + env: + VCPKG_INSTALLED_DIR: "${{ github.workspace }}/.build" + VCPKG_DEFAULT_TRIPLET: "arm64-osx" + # VCPKG_BINARY_SOURCES: "default" # https://learn.microsoft.com/en-us/vcpkg/reference/binarycaching + + - name: "Run build.py(arm64-osx)" + run: | + python ./tools/ci_build/build.py \ + --build_dir "build/arm64-osx" \ + --skip_submodule_sync \ + --skip_tests \ + --compile_no_warning_as_error \ + --parallel \ + --path_to_protoc_exe "${{ steps.protoc-detect.outputs.protoc_path }}" \ + --osx_arch arm64 \ + --use_vcpkg \ + --cmake_extra_defines "CMAKE_TOOLCHAIN_FILE:FILEPATH=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake" \ + --cmake_extra_defines "VCPKG_TARGET_TRIPLET=arm64-osx" \ + --cmake_extra_defines "VCPKG_INSTALLED_DIR:PATH=${{ github.workspace }}/.build" \ + --cmake_extra_defines "VCPKG_INSTALL_OPTIONS=--x-feature=tests" + shell: bash + Objective-C-StaticAnalysis: runs-on: macos-14 diff --git a/.github/workflows/publish-c-apidocs.yml b/.github/workflows/publish-c-apidocs.yml index b097cdbd9a55c..6c4dc43847d0b 100644 --- a/.github/workflows/publish-c-apidocs.yml +++ b/.github/workflows/publish-c-apidocs.yml @@ -49,4 +49,4 @@ jobs: with: name: onnxruntime-c-apidocs path: _site - retention-days: 60 + retention-days: 30 diff --git a/.github/workflows/publish-csharp-apidocs.yml b/.github/workflows/publish-csharp-apidocs.yml index 02aa053365790..862a7a70e33a2 100644 --- a/.github/workflows/publish-csharp-apidocs.yml +++ b/.github/workflows/publish-csharp-apidocs.yml @@ -59,4 +59,4 @@ jobs: with: name: onnxruntime-csharp-apidocs path: _site - retention-days: 60 + retention-days: 30 diff --git a/.github/workflows/publish-java-apidocs.yml b/.github/workflows/publish-java-apidocs.yml index 3e553049a186e..9e42dca708a17 100644 --- a/.github/workflows/publish-java-apidocs.yml +++ b/.github/workflows/publish-java-apidocs.yml @@ -47,4 +47,4 @@ jobs: with: name: onnxruntime-java-apidocs path: _site - retention-days: 60 + retention-days: 30 diff --git a/.github/workflows/publish-js-apidocs.yml b/.github/workflows/publish-js-apidocs.yml index db021106a6554..cec4a52d39c93 100644 --- a/.github/workflows/publish-js-apidocs.yml +++ b/.github/workflows/publish-js-apidocs.yml @@ -47,4 +47,4 @@ jobs: with: name: onnxruntime-node-apidocs path: _site - retention-days: 60 + retention-days: 30 diff --git a/.github/workflows/publish-objectivec-apidocs.yml b/.github/workflows/publish-objectivec-apidocs.yml index ebacd38f1f882..a8b81c8d5cf84 100644 --- a/.github/workflows/publish-objectivec-apidocs.yml +++ b/.github/workflows/publish-objectivec-apidocs.yml @@ -48,4 +48,4 @@ jobs: with: name: onnxruntime-objectivec-apidocs path: ./_site - retention-days: 60 + retention-days: 30 diff --git a/.github/workflows/publish-python-apidocs.yml b/.github/workflows/publish-python-apidocs.yml index e98d22450c5b0..8b2f72d80bacf 100644 --- a/.github/workflows/publish-python-apidocs.yml +++ b/.github/workflows/publish-python-apidocs.yml @@ -53,4 +53,4 @@ jobs: with: name: onnxruntime-python-apidocs path: _site - retention-days: 60 + retention-days: 30 diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index b77e48942ec44..d276877b7ad47 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -42,3 +42,90 @@ jobs: # The build machine doesn't have a GPU. So the value of CMAKE_CUDA_ARCHITECTURES doesn't matter. - name: Build code run: python tools\ci_build\build.py --windows_sdk_version 10.0.22621.0 --enable_training --build_java --config Debug --build_dir D:\b --skip_submodule_sync --build_csharp --update --build --parallel --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_pybind --use_cuda --cuda_home=${{ github.workspace }}\cuda_sdk\v12.2 --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75 + + Vcpkg: + runs-on: "windows-latest" + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.11.x' + architecture: 'x64' + + - name: "Run vcpkg(x64-windows)" + uses: lukka/run-vcpkg@v11 + with: + vcpkgDirectory: "C:/vcpkg" # use VCPKG_INSTALLATION_ROOT of the image + doNotUpdateVcpkg: true + runVcpkgInstall: true + vcpkgJsonGlob: "cmake/vcpkg.json" + vcpkgConfigurationJsonGlob: "cmake/vcpkg-configuration.json" + env: + VCPKG_INSTALLED_DIR: "${{ github.workspace }}/.build" + VCPKG_DEFAULT_TRIPLET: "x64-windows" + # VCPKG_BINARY_SOURCES: "default" # https://learn.microsoft.com/en-us/vcpkg/reference/binarycaching + + - name: "Run compile_schema.py" + run: | + # Runner's host triplet should be x64-windows or arm64-windows + $FLATC_DIR="${{ github.workspace }}/.build/${{ runner.arch }}-windows/tools/flatbuffers" + $env:PATH="$FLATC_DIR;$env:PATH" + flatc --version + $FLATC_PATH = Join-Path "$FLATC_DIR" "flatc.exe" + python onnxruntime/core/flatbuffers/schema/compile_schema.py --flatc "$FLATC_PATH" + shell: pwsh + + - name: "Detect protoc" + id: protoc-detect + run: | + $PROTOC_DIR="${{ github.workspace }}/.build/${{ runner.arch }}-windows/tools/protobuf" + $env:PATH="$PROTOC_DIR;$env:PATH" + protoc --version + $PROTOC_PATH = Join-Path "$PROTOC_DIR" "protoc.exe" + "protoc_path=$PROTOC_PATH" >> $env:GITHUB_OUTPUT + shell: pwsh + + - name: "Run build.py(x64-windows)" + run: | + python tools\ci_build\build.py ` + --build_dir "cmake_build/x64-windows" ` + --skip_submodule_sync ` + --skip_tests ` + --compile_no_warning_as_error ` + --parallel ` + --path_to_protoc_exe "${{ steps.protoc-detect.outputs.protoc_path }}" ` + --use_vcpkg ` + --cmake_extra_defines "CMAKE_TOOLCHAIN_FILE:FILEPATH=C:/vcpkg/scripts/buildsystems/vcpkg.cmake" ` + --cmake_extra_defines "VCPKG_TARGET_TRIPLET=x64-windows" ` + --cmake_extra_defines "VCPKG_INSTALLED_DIR:PATH=${{ github.workspace }}/.build" ` + --cmake_extra_defines "VCPKG_INSTALL_OPTIONS=--x-feature=tests" + shell: pwsh + + - name: "Run vcpkg(arm64-windows)" + uses: lukka/run-vcpkg@v11 + with: + vcpkgDirectory: "C:/vcpkg" # use VCPKG_INSTALLATION_ROOT of the image + doNotUpdateVcpkg: true + runVcpkgInstall: true + vcpkgJsonGlob: "cmake/vcpkg.json" + vcpkgConfigurationJsonGlob: "cmake/vcpkg-configuration.json" + env: + VCPKG_INSTALLED_DIR: "${{ github.workspace }}/.build" + VCPKG_DEFAULT_TRIPLET: "arm64-windows" + # VCPKG_BINARY_SOURCES: "default" # https://learn.microsoft.com/en-us/vcpkg/reference/binarycaching + + - name: "Run build.py(arm64-windows)" + run: | + python tools\ci_build\build.py ` + --build_dir "cmake_build/arm64-windows" --arm64 ` + --skip_submodule_sync ` + --skip_tests ` + --compile_no_warning_as_error ` + --parallel ` + --path_to_protoc_exe "${{ steps.protoc-detect.outputs.protoc_path }}" ` + --use_vcpkg ` + --cmake_extra_defines "CMAKE_TOOLCHAIN_FILE:FILEPATH=C:/vcpkg/scripts/buildsystems/vcpkg.cmake" ` + --cmake_extra_defines "VCPKG_TARGET_TRIPLET=arm64-windows" ` + --cmake_extra_defines "VCPKG_INSTALLED_DIR:PATH=${{ github.workspace }}/.build" ` + --cmake_extra_defines "VCPKG_INSTALL_OPTIONS=--x-feature=tests" + shell: pwsh diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index f8589598c7571..654099958b21b 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -146,7 +146,7 @@ "component": { "type": "git", "git": { - "commitHash": "0da379fc4808f9601faef392352018c741c0f297", + "commitHash": "309b75c9e56e0a674bf78d59872ce131f814dfb6", "repositoryUrl": "https://github.com/google/XNNPACK.git" }, "comments": "googlexnnpack" diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 87d17c9788ddc..208bcb67ca731 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1,4 +1,5 @@ # Copyright (c) Microsoft Corporation. All rights reserved. +# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates # Licensed under the MIT License. # Minimum CMake required @@ -38,6 +39,7 @@ include(CheckLanguage) include(CMakeDependentOption) include(FetchContent) include(CheckFunctionExists) +include(GNUInstallDirs) # onnxruntime_providers_* require CMAKE_INSTALL_* variables # TODO: update this once all system adapt c++20 if(UNIX) @@ -69,6 +71,7 @@ if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_ endif() # Options +option(onnxruntime_USE_VCPKG "Build with the vcpkg package manager" OFF) option(onnxruntime_RUN_ONNX_TESTS "Enable ONNX Compatibility Testing" OFF) option(onnxruntime_GENERATE_TEST_REPORTS "Enable test report generation" OFF) option(onnxruntime_ENABLE_STATIC_ANALYSIS "Enable static analysis" OFF) @@ -130,11 +133,6 @@ option(onnxruntime_USE_DML "Build with DirectML support" OFF) option(onnxruntime_USE_MIGRAPHX "Build with AMDMIGraphX support" OFF) option(onnxruntime_USE_WINML "Build with WinML support" OFF) option(onnxruntime_USE_ACL "Build with ACL support" OFF) -option(onnxruntime_USE_ACL_1902 "Build with ACL version 1902 support" OFF) -option(onnxruntime_USE_ACL_1905 "Build with ACL version 1905 support" OFF) -option(onnxruntime_USE_ACL_1908 "Build with ACL version 1908 support" OFF) -option(onnxruntime_USE_ACL_2002 "Build with ACL version 2002 support" OFF) -option(onnxruntime_USE_ACL_2308 "Build with ACL version 2308 support" OFF) option(onnxruntime_USE_ARMNN "Build with ArmNN support" OFF) option(onnxruntime_ARMNN_RELU_USE_CPU "Use the CPU implementation for the Relu operator for the ArmNN EP" ON) option(onnxruntime_ARMNN_BN_USE_CPU "Use the CPU implementation for the Batch Normalization operator for the ArmNN EP" ON) @@ -595,6 +593,7 @@ get_filename_component(ORTTRAINING_ROOT "${ORTTRAINING_ROOT}" ABSOLUTE) get_filename_component(REPO_ROOT "${REPO_ROOT}" ABSOLUTE) set(ONNXRUNTIME_INCLUDE_DIR ${REPO_ROOT}/include/onnxruntime) +list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/external) include(external/onnxruntime_external_deps.cmake) set(ORT_WARNING_FLAGS) @@ -1204,25 +1203,8 @@ function(onnxruntime_add_include_to_target dst_target) endfunction() # ACL -if (onnxruntime_USE_ACL OR onnxruntime_USE_ACL_1902 OR onnxruntime_USE_ACL_1905 OR onnxruntime_USE_ACL_1908 OR onnxruntime_USE_ACL_2002 OR onnxruntime_USE_ACL_2308) +if (onnxruntime_USE_ACL) set(onnxruntime_USE_ACL ON) - if (onnxruntime_USE_ACL_1902) - add_definitions(-DACL_1902=1) - else() - if (onnxruntime_USE_ACL_1908) - add_definitions(-DACL_1908=1) - else() - if (onnxruntime_USE_ACL_2002) - add_definitions(-DACL_2002=1) - else() - if (onnxruntime_USE_ACL_2308) - add_definitions(-DACL_2308=1) - else() - add_definitions(-DACL_1905=1) - endif() - endif() - endif() - endif() if (NOT ${onnxruntime_ACL_LIBS} STREQUAL "") add_library(arm_compute SHARED IMPORTED) @@ -1230,18 +1212,13 @@ if (onnxruntime_USE_ACL OR onnxruntime_USE_ACL_1902 OR onnxruntime_USE_ACL_1905 IMPORTED_NO_SONAME 1 IMPORTED_LOCATION "${onnxruntime_ACL_LIBS}/libarm_compute.so") - add_library(arm_compute_core SHARED IMPORTED) - set_target_properties(arm_compute_core PROPERTIES - IMPORTED_NO_SONAME 1 - IMPORTED_LOCATION "${onnxruntime_ACL_LIBS}/libarm_compute_core.so") - add_library(arm_compute_graph SHARED IMPORTED) set_target_properties(arm_compute_graph PROPERTIES IMPORTED_NO_SONAME 1 IMPORTED_LOCATION "${onnxruntime_ACL_LIBS}/libarm_compute_graph.so") endif() - list(APPEND onnxruntime_EXTERNAL_LIBRARIES arm_compute arm_compute_core arm_compute_graph) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES arm_compute arm_compute_graph) endif() @@ -1260,11 +1237,6 @@ if (onnxruntime_USE_ARMNN) IMPORTED_NO_SONAME 1 IMPORTED_LOCATION "${onnxruntime_ACL_LIBS}/libarm_compute.so") - add_library(arm_compute_core SHARED IMPORTED) - set_target_properties(arm_compute_core PROPERTIES - IMPORTED_NO_SONAME 1 - IMPORTED_LOCATION "${onnxruntime_ACL_LIBS}/libarm_compute_core.so") - add_library(arm_compute_graph SHARED IMPORTED) set_target_properties(arm_compute_graph PROPERTIES IMPORTED_NO_SONAME 1 @@ -1278,7 +1250,7 @@ if (onnxruntime_USE_ARMNN) IMPORTED_LOCATION "${onnxruntime_ARMNN_LIBS}/libarmnn.so") endif() - list(APPEND onnxruntime_EXTERNAL_LIBRARIES armnn arm_compute arm_compute_core arm_compute_graph) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES armnn arm_compute arm_compute_graph) endif() if (onnxruntime_USE_DNNL) diff --git a/cmake/CMakePresets.json b/cmake/CMakePresets.json new file mode 100644 index 0000000000000..1b7aa11975a3e --- /dev/null +++ b/cmake/CMakePresets.json @@ -0,0 +1,192 @@ +{ + "version": 6, + "cmakeMinimumRequired": { + "major": 3, + "minor": 25, + "patch": 0 + }, + "configurePresets": [ + { + "name": "vcpkg-manifest", + "hidden": true, + "toolchainFile": "$env{VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake", + "cacheVariables": { + "VCPKG_INSTALLED_DIR": "${sourceParentDir}/.build" + }, + "environment": { + "VCPKG_FEATURE_FLGAS": "manifests,versions" + } + }, + { + "name": "msvc-static-runtime", + "hidden": true, + "cacheVariables": { + "ONNX_USE_MSVC_STATIC_RUNTIME": true, + "protobuf_MSVC_STATIC_RUNTIME": true, + "gtest_force_shared_crt": false, + "CMAKE_MSVC_RUNTIME_LIBRARY": "MultiThreaded$<$:Debug>" + } + }, + { + "name": "unit-test", + "hidden": true, + "cacheVariables": { + "onnxruntime_RUN_ONNX_TESTS": true, + "onnxruntime_BUILD_BENCHMARKS": true, + "onnxruntime_BUILD_UNIT_TESTS": true, + "onnxruntime_GENERATE_TEST_REPORTS": true + } + }, + { + "name": "x64-windows", + "inherits": [ + "msvc-static-runtime", + "unit-test" + ], + "generator": "Visual Studio 17 2022", + "architecture": "x64", + "binaryDir": "${sourceParentDir}/cmake_build/x64-windows", + "installDir": "${sourceParentDir}/cmake_build/out", + "cacheVariables": { + "onnxruntime_USE_XNNPACK": true, + "onnxruntime_USE_DML": true, + "onnxruntime_BUILD_SHARED_LIB": true, + "CMAKE_CONFIGURATION_TYPES": "Debug;Release" + }, + "vendor": { + "microsoft.com/VisualStudioSettings/CMake/1.0": { + "intelliSenseMode": "windows-msvc-x64" + } + }, + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Windows" + } + }, + { + "name": "x64-windows-vcpkg", + "inherits": [ + "unit-test", + "vcpkg-manifest" + ], + "generator": "Visual Studio 17 2022", + "architecture": "x64", + "binaryDir": "${sourceParentDir}/cmake_build/x64-windows", + "installDir": "${sourceParentDir}/cmake_build/out", + "cacheVariables": { + "onnxruntime_USE_VCPKG": true, + "onnxruntime_USE_XNNPACK": false, + "onnxruntime_USE_DML": false, + "onnxruntime_BUILD_SHARED_LIB": true, + "CMAKE_MSVC_RUNTIME_LIBRARY": "MultiThreaded$<$:Debug>DLL", + "CMAKE_CONFIGURATION_TYPES": "Debug;Release", + "VCPKG_INSTALL_OPTIONS": "--x-feature=tests", + "VCPKG_TARGET_TRIPLET": "x64-windows" + } + }, + { + "name": "x64-osx", + "inherits": [ + "unit-test" + ], + "generator": "Xcode", + "binaryDir": "${sourceParentDir}/cmake_build/x64-osx", + "installDir": "${sourceParentDir}/cmake_build/out", + "cacheVariables": { + "CMAKE_OSX_ARCHITECTURES": "x86_64", + "onnxruntime_BUILD_SHARED_LIB": true, + "onnxruntime_USE_XNNPACK": false, + "onnxruntime_USE_COREML": true, + "onnxruntime_BUILD_OBJC": true, + "onnxruntime_BUILD_APPLE_FRAMEWORK": true, + "CMAKE_CONFIGURATION_TYPES": "Debug;Release" + }, + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Darwin" + } + }, + { + "name": "x64-osx-vcpkg", + "inherits": [ + "x64-osx", + "vcpkg-manifest" + ], + "cacheVariables": { + "onnxruntime_USE_VCPKG": true, + "onnxruntime_USE_XNNPACK": false, + "onnxruntime_USE_COREML": false, + "onnxruntime_BUILD_OBJC": false, + "onnxruntime_BUILD_APPLE_FRAMEWORK": false, + "VCPKG_INSTALL_OPTIONS": "--x-feature=tests", + "VCPKG_TARGET_TRIPLET": "x64-osx" + } + } + ], + "buildPresets": [ + { + "name": "x64-windows-debug", + "configurePreset": "x64-windows", + "configuration": "Debug" + }, + { + "name": "x64-windows-vcpkg-debug", + "configurePreset": "x64-windows-vcpkg", + "configuration": "Debug" + }, + { + "name": "x64-osx-debug", + "configurePreset": "x64-osx", + "configuration": "Debug" + }, + { + "name": "x64-osx-vcpkg-debug", + "configurePreset": "x64-osx-vcpkg", + "configuration": "Debug" + } + ], + "testPresets": [ + { + "name": "x64-windows-debug", + "configurePreset": "x64-windows", + "configuration": "Debug", + "output": { + "verbosity": "default", + "outputJUnitFile": "TEST-x64-windows-debug.xml", + "outputLogFile": "TEST-x64-windows-debug.log", + "outputOnFailure": true + }, + "execution": { + "noTestsAction": "error", + "stopOnFailure": false + } + }, + { + "name": "x64-windows-vcpkg-debug", + "inherits": "x64-windows-debug", + "configurePreset": "x64-windows-vcpkg" + }, + { + "name": "x64-osx-debug", + "configurePreset": "x64-osx", + "configuration": "Debug", + "output": { + "verbosity": "default", + "outputJUnitFile": "TEST-x64-osx-debug.xml", + "outputLogFile": "TEST-x64-osx-debug.log", + "outputOnFailure": true + }, + "execution": { + "noTestsAction": "error", + "stopOnFailure": false + } + }, + { + "name": "x64-osx-vcpkg-debug", + "inherits": "x64-osx-debug", + "configurePreset": "x64-osx-vcpkg" + } + ] +} diff --git a/cmake/deps.txt b/cmake/deps.txt index 597c051b5f477..342184bda2f0e 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -29,7 +29,8 @@ fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34 google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.8.5.zip;cd47d3d272faf353600c8cc2fdec2b52d6f69177 google_nsync;https://github.com/google/nsync/archive/refs/tags/1.26.0.zip;5e7c00ef6bf5b787386fc040067903ec774e2752 googletest;https://github.com/google/googletest/archive/refs/tags/v1.15.0.zip;9d2d0af8d77ac726ea55d44a8fa727ec98311349 -googlexnnpack;https://github.com/google/XNNPACK/archive/0da379fc4808f9601faef392352018c741c0f297.zip;663883491e380b628e0a5b162b5f2658032fae73 +#xnnpack 2024.09.04 +googlexnnpack;https://github.com/google/XNNPACK/archive/309b75c9e56e0a674bf78d59872ce131f814dfb6.zip;39FA5259EAEACE0547284B63D5CEDC4F05553F5A json;https://github.com/nlohmann/json/archive/refs/tags/v3.10.5.zip;f257f8dc27c5b8c085dc887b40cddd18ae1f725c microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 @@ -60,3 +61,4 @@ composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/arch directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.5.2.zip;11071a47594b20f00af09aad83e0d5203ccf6029 dawn;https://github.com/google/dawn/archive/511eb80847afe6bded34ec491a38d5d78ba2d604.zip;c493f5aca5586f6634e25d0121c85df71189fb99 +kleidiai;https://gitlab.arm.com/kleidi/kleidiai/-/archive/v0.2.0/kleidiai-v0.2.0.zip;B1E3173992FD91F20DB904AB77D6E901778C2681 diff --git a/cmake/external/abseil-cpp.cmake b/cmake/external/abseil-cpp.cmake index 3223724693a49..dda7c5ff19ba4 100644 --- a/cmake/external/abseil-cpp.cmake +++ b/cmake/external/abseil-cpp.cmake @@ -27,7 +27,7 @@ FetchContent_Declare( URL ${DEP_URL_abseil_cpp} URL_HASH SHA1=${DEP_SHA1_abseil_cpp} PATCH_COMMAND ${ABSL_PATCH_COMMAND} - FIND_PACKAGE_ARGS 20240116 NAMES absl + FIND_PACKAGE_ARGS NAMES absl ) onnxruntime_fetchcontent_makeavailable(abseil_cpp) @@ -142,4 +142,4 @@ absl::throw_delegate absl::memory absl::charset absl::endian -absl::config) \ No newline at end of file +absl::config) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 4e52707474052..43f18abbe9522 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -1,4 +1,4 @@ -message("Loading Dependencies URLs ...") +message(STATUS "Loading Dependencies URLs ...") include(external/helper_functions.cmake) @@ -27,7 +27,9 @@ foreach(ONNXRUNTIME_DEP IN LISTS ONNXRUNTIME_DEPS_LIST) endif() endforeach() -message("Loading Dependencies ...") +message(STATUS "Loading Dependencies ...") +include(FetchContent) + # ABSL should be included before protobuf because protobuf may use absl include(external/abseil-cpp.cmake) @@ -39,6 +41,7 @@ FetchContent_Declare( URL_HASH SHA1=${DEP_SHA1_re2} FIND_PACKAGE_ARGS NAMES re2 ) +onnxruntime_fetchcontent_makeavailable(re2) if (onnxruntime_BUILD_UNIT_TESTS) # WebAssembly threading support in Node.js is still an experimental feature and @@ -65,6 +68,7 @@ if (onnxruntime_BUILD_UNIT_TESTS) URL_HASH SHA1=${DEP_SHA1_googletest} FIND_PACKAGE_ARGS 1.14.0...<2.0.0 NAMES GTest ) + FetchContent_MakeAvailable(googletest) endif() if (onnxruntime_BUILD_BENCHMARKS) @@ -77,50 +81,41 @@ if (onnxruntime_BUILD_BENCHMARKS) google_benchmark URL ${DEP_URL_google_benchmark} URL_HASH SHA1=${DEP_SHA1_google_benchmark} + FIND_PACKAGE_ARGS NAMES benchmark ) + onnxruntime_fetchcontent_makeavailable(google_benchmark) endif() if (NOT WIN32) - FetchContent_Declare( + FetchContent_Declare( google_nsync URL ${DEP_URL_google_nsync} URL_HASH SHA1=${DEP_SHA1_google_nsync} - FIND_PACKAGE_ARGS NAMES nsync - ) -endif() -list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/external) - -FetchContent_Declare( - mimalloc - URL ${DEP_URL_mimalloc} - URL_HASH SHA1=${DEP_SHA1_mimalloc} -) - + FIND_PACKAGE_ARGS NAMES nsync unofficial-nsync + ) + #nsync tests failed on Mac Build + set(NSYNC_ENABLE_TESTS OFF CACHE BOOL "" FORCE) + onnxruntime_fetchcontent_makeavailable(google_nsync) -# Flatbuffers -# We do not need to build flatc for iOS or Android Cross Compile -if (CMAKE_SYSTEM_NAME STREQUAL "iOS" OR CMAKE_SYSTEM_NAME STREQUAL "Android" OR CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - set(FLATBUFFERS_BUILD_FLATC OFF CACHE BOOL "FLATBUFFERS_BUILD_FLATC" FORCE) -endif() -set(FLATBUFFERS_BUILD_TESTS OFF CACHE BOOL "FLATBUFFERS_BUILD_TESTS" FORCE) -set(FLATBUFFERS_INSTALL OFF CACHE BOOL "FLATBUFFERS_INSTALL" FORCE) -set(FLATBUFFERS_BUILD_FLATHASH OFF CACHE BOOL "FLATBUFFERS_BUILD_FLATHASH" FORCE) -set(FLATBUFFERS_BUILD_FLATLIB ON CACHE BOOL "FLATBUFFERS_BUILD_FLATLIB" FORCE) -if(Patch_FOUND) - set(ONNXRUNTIME_FLATBUFFERS_PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/flatbuffers/flatbuffers.patch) -else() - set(ONNXRUNTIME_FLATBUFFERS_PATCH_COMMAND "") + if (google_nsync_SOURCE_DIR) + add_library(nsync::nsync_cpp ALIAS nsync_cpp) + target_include_directories(nsync_cpp PUBLIC ${google_nsync_SOURCE_DIR}/public) + endif() + if(TARGET unofficial::nsync::nsync_cpp AND NOT TARGET nsync::nsync_cpp) + message(STATUS "Aliasing unofficial::nsync::nsync_cpp to nsync::nsync_cpp") + add_library(nsync::nsync_cpp ALIAS unofficial::nsync::nsync_cpp) + endif() endif() -#flatbuffers 1.11.0 does not have flatbuffers::IsOutRange, therefore we require 1.12.0+ -FetchContent_Declare( - flatbuffers - URL ${DEP_URL_flatbuffers} - URL_HASH SHA1=${DEP_SHA1_flatbuffers} - PATCH_COMMAND ${ONNXRUNTIME_FLATBUFFERS_PATCH_COMMAND} - FIND_PACKAGE_ARGS 23.5.9 NAMES Flatbuffers -) - +if(onnxruntime_USE_MIMALLOC) + FetchContent_Declare( + mimalloc + URL ${DEP_URL_mimalloc} + URL_HASH SHA1=${DEP_SHA1_mimalloc} + FIND_PACKAGE_ARGS NAMES mimalloc + ) + FetchContent_MakeAvailable(mimalloc) +endif() #Protobuf depends on utf8_range FetchContent_Declare( @@ -133,6 +128,10 @@ FetchContent_Declare( set(utf8_range_ENABLE_TESTS OFF CACHE BOOL "Build test suite" FORCE) set(utf8_range_ENABLE_INSTALL OFF CACHE BOOL "Configure installation" FORCE) +# The next line will generate an error message "fatal: not a git repository", but it is ok. It is from flatbuffers +onnxruntime_fetchcontent_makeavailable(utf8_range) +# protobuf's cmake/utf8_range.cmake has the following line +include_directories(${utf8_range_SOURCE_DIR}) # Download a protoc binary from Internet if needed if(NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) @@ -146,12 +145,12 @@ if(NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_mac_universal} URL_HASH SHA1=${DEP_SHA1_protoc_mac_universal}) FetchContent_Populate(protoc_binary) if(protoc_binary_SOURCE_DIR) - message("Use prebuilt protoc") + message(STATUS "Use prebuilt protoc") set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc) set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) endif() elseif (CMAKE_CROSSCOMPILING) - message("CMAKE_HOST_SYSTEM_NAME: ${CMAKE_HOST_SYSTEM_NAME}") + message(STATUS "CMAKE_HOST_SYSTEM_NAME: ${CMAKE_HOST_SYSTEM_NAME}") if(CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") if(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "AMD64") FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_win64} URL_HASH SHA1=${DEP_SHA1_protoc_win64}) @@ -162,7 +161,7 @@ if(NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) endif() if(protoc_binary_SOURCE_DIR) - message("Use prebuilt protoc") + message(STATUS "Use prebuilt protoc") set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc.exe) set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) endif() @@ -179,7 +178,7 @@ if(NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) endif() if(protoc_binary_SOURCE_DIR) - message("Use prebuilt protoc") + message(STATUS "Use prebuilt protoc") set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc) set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) endif() @@ -217,7 +216,7 @@ FetchContent_Declare( URL ${DEP_URL_protobuf} URL_HASH SHA1=${DEP_SHA1_protobuf} PATCH_COMMAND ${ONNXRUNTIME_PROTOBUF_PATCH_COMMAND} - FIND_PACKAGE_ARGS 3.21.12 NAMES Protobuf + FIND_PACKAGE_ARGS NAMES Protobuf protobuf ) set(protobuf_BUILD_TESTS OFF CACHE BOOL "Build protobuf tests" FORCE) @@ -239,6 +238,51 @@ endif() include(protobuf_function) #protobuf end +onnxruntime_fetchcontent_makeavailable(Protobuf) +if(Protobuf_FOUND) + message(STATUS "Protobuf version: ${Protobuf_VERSION}") +else() + # Adjust warning flags + if (TARGET libprotoc) + if (NOT MSVC) + target_compile_options(libprotoc PRIVATE "-w") + endif() + endif() + if (TARGET protoc) + add_executable(protobuf::protoc ALIAS protoc) + if (UNIX AND onnxruntime_ENABLE_LTO) + #https://github.com/protocolbuffers/protobuf/issues/5923 + target_link_options(protoc PRIVATE "-Wl,--no-as-needed") + endif() + if (NOT MSVC) + target_compile_options(protoc PRIVATE "-w") + endif() + get_target_property(PROTOC_OSX_ARCH protoc OSX_ARCHITECTURES) + if (PROTOC_OSX_ARCH) + if (${CMAKE_HOST_SYSTEM_PROCESSOR} IN_LIST PROTOC_OSX_ARCH) + message(STATUS "protoc can run") + else() + list(APPEND PROTOC_OSX_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR}) + set_target_properties(protoc PROPERTIES OSX_ARCHITECTURES "${CMAKE_HOST_SYSTEM_PROCESSOR}") + set_target_properties(libprotoc PROPERTIES OSX_ARCHITECTURES "${PROTOC_OSX_ARCH}") + set_target_properties(libprotobuf PROPERTIES OSX_ARCHITECTURES "${PROTOC_OSX_ARCH}") + endif() + endif() + endif() + if (TARGET libprotobuf AND NOT MSVC) + target_compile_options(libprotobuf PRIVATE "-w") + endif() + if (TARGET libprotobuf-lite AND NOT MSVC) + target_compile_options(libprotobuf-lite PRIVATE "-w") + endif() +endif() +if (onnxruntime_USE_FULL_PROTOBUF) + set(PROTOBUF_LIB protobuf::libprotobuf) +else() + set(PROTOBUF_LIB protobuf::libprotobuf-lite) +endif() + +# date set(ENABLE_DATE_TESTING OFF CACHE BOOL "" FORCE) set(USE_SYSTEM_TZ_DB ON CACHE BOOL "" FORCE) @@ -254,7 +298,16 @@ FetchContent_Declare( mp11 URL ${DEP_URL_mp11} URL_HASH SHA1=${DEP_SHA1_mp11} + FIND_PACKAGE_ARGS NAMES Boost ) +onnxruntime_fetchcontent_makeavailable(mp11) +if(NOT TARGET Boost::mp11) + if(onnxruntime_USE_VCPKG) + find_package(Boost REQUIRED) + endif() + message(STATUS "Aliasing Boost::headers to Boost::mp11") + add_library(Boost::mp11 ALIAS Boost::headers) +endif() set(JSON_BuildTests OFF CACHE INTERNAL "") set(JSON_Install OFF CACHE INTERNAL "") @@ -265,6 +318,7 @@ FetchContent_Declare( URL_HASH SHA1=${DEP_SHA1_json} FIND_PACKAGE_ARGS 3.10 NAMES nlohmann_json ) +onnxruntime_fetchcontent_makeavailable(nlohmann_json) #TODO: include clog first if (onnxruntime_ENABLE_CPUINFO) @@ -301,20 +355,6 @@ else() set(CPUINFO_SUPPORTED FALSE) endif() -# xnnpack depends on clog -# Android build should use the system's log library instead of clog -if ((CPUINFO_SUPPORTED OR onnxruntime_USE_XNNPACK) AND NOT ANDROID) - set(CLOG_BUILD_TESTS OFF CACHE BOOL "" FORCE) - FetchContent_Declare( - pytorch_clog - URL ${DEP_URL_pytorch_cpuinfo} - URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo} - SOURCE_SUBDIR deps/clog - ) - set(ONNXRUNTIME_CLOG_PROJ pytorch_clog) - set(ONNXRUNTIME_CLOG_TARGET_NAME clog) -endif() - if (CPUINFO_SUPPORTED) if (CMAKE_SYSTEM_NAME STREQUAL "iOS") set(IOS ON CACHE INTERNAL "") @@ -333,7 +373,7 @@ if (CPUINFO_SUPPORTED) set(CPUINFO_BUILD_MOCK_TESTS OFF CACHE INTERNAL "") set(CPUINFO_BUILD_BENCHMARKS OFF CACHE INTERNAL "") if(onnxruntime_target_platform STREQUAL "ARM64EC") - message("Applying a patch for Windows ARM64EC in cpuinfo") + message(STATUS "Applying a patch for Windows ARM64EC in cpuinfo") FetchContent_Declare( pytorch_cpuinfo URL ${DEP_URL_pytorch_cpuinfo} @@ -350,20 +390,33 @@ if (CPUINFO_SUPPORTED) ) endif() set(ONNXRUNTIME_CPUINFO_PROJ pytorch_cpuinfo) + onnxruntime_fetchcontent_makeavailable(${ONNXRUNTIME_CPUINFO_PROJ}) + if(TARGET cpuinfo::cpuinfo AND NOT TARGET cpuinfo) + message(STATUS "Aliasing cpuinfo::cpuinfo to cpuinfo") + add_library(cpuinfo ALIAS cpuinfo::cpuinfo) + endif() endif() - -if (onnxruntime_BUILD_BENCHMARKS) - onnxruntime_fetchcontent_makeavailable(google_benchmark) -endif() - -if (NOT WIN32) - #nsync tests failed on Mac Build - set(NSYNC_ENABLE_TESTS OFF CACHE BOOL "" FORCE) - onnxruntime_fetchcontent_makeavailable(google_nsync) - if (google_nsync_SOURCE_DIR) - add_library(nsync::nsync_cpp ALIAS nsync_cpp) - target_include_directories(nsync_cpp PUBLIC ${google_nsync_SOURCE_DIR}/public) +# xnnpack depends on clog +# Android build should use the system's log library instead of clog +if ((CPUINFO_SUPPORTED OR onnxruntime_USE_XNNPACK) AND NOT ANDROID) + set(CLOG_BUILD_TESTS OFF CACHE BOOL "" FORCE) + FetchContent_Declare( + pytorch_clog + URL ${DEP_URL_pytorch_cpuinfo} + URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo} + SOURCE_SUBDIR deps/clog + FIND_PACKAGE_ARGS NAMES cpuinfo + ) + set(ONNXRUNTIME_CLOG_PROJ pytorch_clog) + onnxruntime_fetchcontent_makeavailable(${ONNXRUNTIME_CLOG_PROJ}) + set(ONNXRUNTIME_CLOG_TARGET_NAME clog) + # if cpuinfo is from find_package, use it with imported name + if(TARGET cpuinfo::clog) + set(ONNXRUNTIME_CLOG_TARGET_NAME cpuinfo::clog) + elseif(onnxruntime_USE_VCPKG) + # however, later cpuinfo versions may not contain clog. use cpuinfo + set(ONNXRUNTIME_CLOG_TARGET_NAME cpuinfo::cpuinfo) endif() endif() @@ -383,21 +436,51 @@ else() FIND_PACKAGE_ARGS 4.0 NAMES Microsoft.GSL ) endif() +set(GSL_TARGET "Microsoft.GSL::GSL") +set(GSL_INCLUDE_DIR "$") +onnxruntime_fetchcontent_makeavailable(GSL) +find_path(safeint_SOURCE_DIR NAMES "SafeInt.hpp") +if(NOT safeint_SOURCE_DIR) + unset(safeint_SOURCE_DIR) + FetchContent_Declare( + safeint + URL ${DEP_URL_safeint} + URL_HASH SHA1=${DEP_SHA1_safeint} + ) + + # use fetch content rather than makeavailable because safeint only includes unconditional test targets + FetchContent_Populate(safeint) +endif() +add_library(safeint_interface INTERFACE) +target_include_directories(safeint_interface INTERFACE ${safeint_SOURCE_DIR}) + + +# Flatbuffers +# We do not need to build flatc for iOS or Android Cross Compile +if (CMAKE_SYSTEM_NAME STREQUAL "iOS" OR CMAKE_SYSTEM_NAME STREQUAL "Android" OR CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + set(FLATBUFFERS_BUILD_FLATC OFF CACHE BOOL "FLATBUFFERS_BUILD_FLATC" FORCE) +endif() +set(FLATBUFFERS_BUILD_TESTS OFF CACHE BOOL "FLATBUFFERS_BUILD_TESTS" FORCE) +set(FLATBUFFERS_INSTALL OFF CACHE BOOL "FLATBUFFERS_INSTALL" FORCE) +set(FLATBUFFERS_BUILD_FLATHASH OFF CACHE BOOL "FLATBUFFERS_BUILD_FLATHASH" FORCE) +set(FLATBUFFERS_BUILD_FLATLIB ON CACHE BOOL "FLATBUFFERS_BUILD_FLATLIB" FORCE) +if(Patch_FOUND) + set(ONNXRUNTIME_FLATBUFFERS_PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/flatbuffers/flatbuffers.patch) +else() + set(ONNXRUNTIME_FLATBUFFERS_PATCH_COMMAND "") +endif() + +#flatbuffers 1.11.0 does not have flatbuffers::IsOutRange, therefore we require 1.12.0+ FetchContent_Declare( - safeint - URL ${DEP_URL_safeint} - URL_HASH SHA1=${DEP_SHA1_safeint} + flatbuffers + URL ${DEP_URL_flatbuffers} + URL_HASH SHA1=${DEP_SHA1_flatbuffers} + PATCH_COMMAND ${ONNXRUNTIME_FLATBUFFERS_PATCH_COMMAND} + FIND_PACKAGE_ARGS 23.5.9 NAMES Flatbuffers flatbuffers ) -# use fetch content rather than makeavailable because safeint only includes unconditional test targets -FetchContent_Populate(safeint) -# The next line will generate an error message "fatal: not a git repository", but it is ok. It is from flatbuffers -onnxruntime_fetchcontent_makeavailable(utf8_range) -# protobuf's cmake/utf8_range.cmake has the following line -include_directories(${utf8_range_SOURCE_DIR}) - -onnxruntime_fetchcontent_makeavailable(Protobuf nlohmann_json mp11 re2 GSL flatbuffers ${ONNXRUNTIME_CPUINFO_PROJ} ${ONNXRUNTIME_CLOG_PROJ}) +onnxruntime_fetchcontent_makeavailable(flatbuffers) if(NOT flatbuffers_FOUND) if(NOT TARGET flatbuffers::flatbuffers) add_library(flatbuffers::flatbuffers ALIAS flatbuffers) @@ -424,54 +507,6 @@ namespace std { using ::getenv; } endif() endif() -if (onnxruntime_BUILD_UNIT_TESTS) - onnxruntime_fetchcontent_makeavailable(googletest) -endif() - -if(Protobuf_FOUND) - message("Protobuf version: ${Protobuf_VERSION}") -else() - # Adjust warning flags - if (TARGET libprotoc) - if (NOT MSVC) - target_compile_options(libprotoc PRIVATE "-w") - endif() - endif() - if (TARGET protoc) - add_executable(protobuf::protoc ALIAS protoc) - if (UNIX AND onnxruntime_ENABLE_LTO) - #https://github.com/protocolbuffers/protobuf/issues/5923 - target_link_options(protoc PRIVATE "-Wl,--no-as-needed") - endif() - if (NOT MSVC) - target_compile_options(protoc PRIVATE "-w") - endif() - get_target_property(PROTOC_OSX_ARCH protoc OSX_ARCHITECTURES) - if (PROTOC_OSX_ARCH) - if (${CMAKE_HOST_SYSTEM_PROCESSOR} IN_LIST PROTOC_OSX_ARCH) - message("protoc can run") - else() - list(APPEND PROTOC_OSX_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR}) - set_target_properties(protoc PROPERTIES OSX_ARCHITECTURES "${CMAKE_HOST_SYSTEM_PROCESSOR}") - set_target_properties(libprotoc PROPERTIES OSX_ARCHITECTURES "${PROTOC_OSX_ARCH}") - set_target_properties(libprotobuf PROPERTIES OSX_ARCHITECTURES "${PROTOC_OSX_ARCH}") - endif() - endif() - endif() - if (TARGET libprotobuf AND NOT MSVC) - target_compile_options(libprotobuf PRIVATE "-w") - endif() - if (TARGET libprotobuf-lite AND NOT MSVC) - target_compile_options(libprotobuf-lite PRIVATE "-w") - endif() -endif() -if (onnxruntime_USE_FULL_PROTOBUF) - set(PROTOBUF_LIB protobuf::libprotobuf) -else() - set(PROTOBUF_LIB protobuf::libprotobuf-lite) -endif() - - # ONNX if (NOT onnxruntime_USE_FULL_PROTOBUF) set(ONNX_USE_LITE_PROTO ON CACHE BOOL "" FORCE) @@ -490,27 +525,36 @@ FetchContent_Declare( URL ${DEP_URL_onnx} URL_HASH SHA1=${DEP_SHA1_onnx} PATCH_COMMAND ${ONNXRUNTIME_ONNX_PATCH_COMMAND} + FIND_PACKAGE_ARGS NAMES ONNX onnx ) - - - - - - -include(eigen) -include(wil) - if (NOT onnxruntime_MINIMAL_BUILD) - onnxruntime_fetchcontent_makeavailable(onnx) + onnxruntime_fetchcontent_makeavailable(onnx) else() include(onnx_minimal) endif() -set(GSL_TARGET "Microsoft.GSL::GSL") -set(GSL_INCLUDE_DIR "$") +if(TARGET ONNX::onnx AND NOT TARGET onnx) + message(STATUS "Aliasing ONNX::onnx to onnx") + add_library(onnx ALIAS ONNX::onnx) +endif() +if(TARGET ONNX::onnx_proto AND NOT TARGET onnx_proto) + message(STATUS "Aliasing ONNX::onnx_proto to onnx_proto") + add_library(onnx_proto ALIAS ONNX::onnx_proto) +endif() -add_library(safeint_interface INTERFACE) -target_include_directories(safeint_interface INTERFACE ${safeint_SOURCE_DIR}) +find_package(Eigen3 CONFIG) +if(Eigen3_FOUND) + get_target_property(eigen_INCLUDE_DIRS Eigen3::Eigen INTERFACE_INCLUDE_DIRECTORIES) +else() + include(eigen) # FetchContent +endif() + +if(onnxruntime_USE_VCPKG) + find_package(wil CONFIG REQUIRED) + set(WIL_TARGET "WIL::WIL") +else() + include(wil) # FetchContent +endif() # XNNPACK EP if (onnxruntime_USE_XNNPACK) @@ -539,9 +583,11 @@ set(onnxruntime_EXTERNAL_LIBRARIES ${onnxruntime_EXTERNAL_LIBRARIES_XNNPACK} ${W # The other libs do not have the problem. All the sources are already there. We can compile them in any order. set(onnxruntime_EXTERNAL_DEPENDENCIES onnx_proto flatbuffers::flatbuffers) -target_compile_definitions(onnx PUBLIC $ PRIVATE "__ONNX_DISABLE_STATIC_REGISTRATION") -if (NOT onnxruntime_USE_FULL_PROTOBUF) - target_compile_definitions(onnx PUBLIC "__ONNX_NO_DOC_STRINGS") +if(NOT (onnx_FOUND OR ONNX_FOUND)) # building ONNX from source + target_compile_definitions(onnx PUBLIC $ PRIVATE "__ONNX_DISABLE_STATIC_REGISTRATION") + if (NOT onnxruntime_USE_FULL_PROTOBUF) + target_compile_definitions(onnx PUBLIC "__ONNX_NO_DOC_STRINGS") + endif() endif() if (onnxruntime_RUN_ONNX_TESTS) @@ -550,11 +596,12 @@ endif() if(onnxruntime_ENABLE_ATEN) - message("Aten fallback is enabled.") + message(STATUS "Aten fallback is enabled.") FetchContent_Declare( dlpack URL ${DEP_URL_dlpack} URL_HASH SHA1=${DEP_SHA1_dlpack} + FIND_PACKAGE_ARGS NAMES dlpack ) # We can't use onnxruntime_fetchcontent_makeavailable since some part of the the dlpack code is Linux only. # For example, dlpackcpp.h uses posix_memalign. @@ -568,6 +615,7 @@ if(onnxruntime_ENABLE_TRAINING OR (onnxruntime_ENABLE_TRAINING_APIS AND onnxrunt cxxopts URL ${DEP_URL_cxxopts} URL_HASH SHA1=${DEP_SHA1_cxxopts} + FIND_PACKAGE_ARGS NAMES cxxopts ) set(CXXOPTS_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE) set(CXXOPTS_BUILD_TESTS OFF CACHE BOOL "" FORCE) @@ -585,7 +633,7 @@ if (onnxruntime_USE_COREML) FetchContent_Populate(coremltools) endif() -message("Finished fetching external dependencies") +message(STATUS "Finished fetching external dependencies") set(onnxruntime_LINK_DIRS ) diff --git a/cmake/external/xnnpack.cmake b/cmake/external/xnnpack.cmake index 41f02ce6f22bc..9519e4e6a7796 100644 --- a/cmake/external/xnnpack.cmake +++ b/cmake/external/xnnpack.cmake @@ -5,6 +5,8 @@ set(FP16_BUILD_TESTS OFF CACHE INTERNAL "") set(FP16_BUILD_BENCHMARKS OFF CACHE INTERNAL "") set(PTHREADPOOL_BUILD_TESTS OFF CACHE INTERNAL "") set(PTHREADPOOL_BUILD_BENCHMARKS OFF CACHE INTERNAL "") +set(KLEIDIAI_BUILD_TESTS OFF CACHE INTERNAL "") +set(KLEIDIAI_BUILD_BENCHMARK OFF CACHE INTERNAL "") if(CMAKE_SYSTEM_PROCESSOR MATCHES "^riscv64.*") set(XNNPACK_USE_SYSTEM_LIBS OFF) @@ -30,6 +32,60 @@ set(FXDIV_SOURCE_DIR ${fxdiv_SOURCE_DIR}) FetchContent_Declare(pthreadpool URL ${DEP_URL_pthreadpool} URL_HASH SHA1=${DEP_SHA1_pthreadpool}) onnxruntime_fetchcontent_makeavailable(pthreadpool) +# --- Determine target processor +# Why ORT_TARGET_PROCESSOR is only for XNNPACK +# So far, only Onnxruntime + XNNPack only allow one target processor. +# And we support Mac universal package, so, +# CMAKE_OSX_ARCHITECTURES_COUNT greater than 1 is allowed in other places. +IF(CMAKE_OSX_ARCHITECTURES) + LIST(LENGTH CMAKE_OSX_ARCHITECTURES CMAKE_OSX_ARCHITECTURES_COUNT) + IF(CMAKE_OSX_ARCHITECTURES_COUNT GREATER 1) + MESSAGE(STATUS "Building ONNX Runtime with XNNPACK and multiple OSX architectures is not supported. Got:(${CMAKE_OSX_ARCHITECTURES}). " + "Please specify a single architecture in CMAKE_OSX_ARCHITECTURES and re-configure. ") + ENDIF() + IF(NOT CMAKE_OSX_ARCHITECTURES MATCHES "^(x86_64|arm64|arm64e|arm64_32)$") + MESSAGE(FATAL_ERROR "Unrecognized CMAKE_OSX_ARCHITECTURES value \"${CMAKE_OSX_ARCHITECTURES}\"") + ENDIF() + SET(ORT_TARGET_PROCESSOR "${CMAKE_OSX_ARCHITECTURES}") + ADD_COMPILE_OPTIONS("-Wno-shorten-64-to-32") +ELSEIF(CMAKE_GENERATOR MATCHES "^Visual Studio " AND CMAKE_GENERATOR_PLATFORM) + IF(CMAKE_GENERATOR_PLATFORM MATCHES "^Win32") + SET(ORT_TARGET_PROCESSOR "x86") + ELSEIF(CMAKE_GENERATOR_PLATFORM MATCHES "^x64") + SET(ORT_TARGET_PROCESSOR "x86_64") + ELSEIF(CMAKE_GENERATOR_PLATFORM MATCHES "^ARM64") + SET(ORT_TARGET_PROCESSOR "arm64") + ELSEIF(CMAKE_GENERATOR_PLATFORM MATCHES "^ARM64EC") + SET(ORT_TARGET_PROCESSOR "arm64") + ELSE() + MESSAGE(FATAL_ERROR "Unsupported Visual Studio architecture \"${CMAKE_GENERATOR_PLATFORM}\"") + ENDIF() +ELSEIF(CMAKE_SYSTEM_PROCESSOR MATCHES "^i[3-7]86$") + SET(ORT_TARGET_PROCESSOR "x86") +ELSEIF(CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64") + SET(ORT_TARGET_PROCESSOR "x86_64") +ELSEIF(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv[5-8]") + SET(ORT_TARGET_PROCESSOR "arm") +ELSEIF(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") + SET(ORT_TARGET_PROCESSOR "arm64") +ELSEIF(CMAKE_SYSTEM_PROCESSOR STREQUAL "ppc64le") + SET(ORT_TARGET_PROCESSOR "ppc64") +ELSEIF(NOT ORT_TARGET_PROCESSOR MATCHES "^(x86(_64)?|arm64|riscv(32|64|128)|Hexagon|ppc64)$") + SET(ORT_TARGET_PROCESSOR "${CMAKE_SYSTEM_PROCESSOR}") +ELSE() + MESSAGE(FATAL_ERROR "Unrecognized CMAKE_SYSTEM_PROCESSOR value \"${CMAKE_SYSTEM_PROCESSOR}\"") +ENDIF() +MESSAGE(STATUS "Building for ORT_TARGET_PROCESSOR: ${ORT_TARGET_PROCESSOR}") + +# KleidiAI is only used in Arm64 platform and not supported by MSVC, the details can be seen in +# https://github.com/google/XNNPACK/blob/3b3f7b8a6668f6ab3b6ce33b9f1d1fce971549d1/CMakeLists.txt#L206C82-L206C117 +if(ORT_TARGET_PROCESSOR MATCHES "^arm64.*" AND NOT CMAKE_C_COMPILER_ID STREQUAL "MSVC") + FetchContent_Declare(kleidiai URL ${DEP_URL_kleidiai} URL_HASH SHA1=${DEP_SHA1_kleidiai}) + onnxruntime_fetchcontent_makeavailable(kleidiai) + set(KLEIDIAI_SOURCE_DIR ${kleidiai_SOURCE_DIR}) +endif() + + FetchContent_Declare(googlexnnpack URL ${DEP_URL_googlexnnpack} URL_HASH SHA1=${DEP_SHA1_googlexnnpack} PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/xnnpack/AddEmscriptenAndIosSupport.patch ) @@ -37,8 +93,10 @@ onnxruntime_fetchcontent_makeavailable(googlexnnpack) set(XNNPACK_DIR ${googlexnnpack_SOURCE_DIR}) set(XNNPACK_INCLUDE_DIR ${XNNPACK_DIR}/include) -set(onnxruntime_EXTERNAL_LIBRARIES_XNNPACK XNNPACK pthreadpool) - +set(onnxruntime_EXTERNAL_LIBRARIES_XNNPACK XNNPACK microkernels-prod pthreadpool) +if(ORT_TARGET_PROCESSOR MATCHES "^arm64.*" AND NOT CMAKE_C_COMPILER_ID STREQUAL "MSVC") + list(APPEND onnxruntime_EXTERNAL_LIBRARIES_XNNPACK kleidiai) +endif() # the XNNPACK CMake setup doesn't include the WASM kernels so we have to manually set those up if(CMAKE_SYSTEM_NAME STREQUAL "Emscripten") diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 927b4ac84b037..7e992fb33077c 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -332,6 +332,9 @@ if(onnxruntime_BUILD_APPLE_FRAMEWORK) # If it's an onnxruntime library, extract .o files from the original cmake build path to a separate directory for # each library to avoid any clashes with filenames (e.g. utils.o) foreach(_LIB ${onnxruntime_INTERNAL_LIBRARIES} ) + if(NOT TARGET ${_LIB}) # if we didn't build from source. it may not a target + continue() + endif() GET_TARGET_PROPERTY(_LIB_TYPE ${_LIB} TYPE) if(_LIB_TYPE STREQUAL "STATIC_LIBRARY") set(CUR_STATIC_LIB_OBJ_DIR ${STATIC_LIB_TEMP_DIR}/$) @@ -362,6 +365,9 @@ if(onnxruntime_BUILD_APPLE_FRAMEWORK) # for external libraries we create a symlink to the .a file foreach(_LIB ${onnxruntime_EXTERNAL_LIBRARIES}) + if(NOT TARGET ${_LIB}) # if we didn't build from source. it may not a target + continue() + endif() GET_TARGET_PROPERTY(_LIB_TYPE ${_LIB} TYPE) if(_LIB_TYPE STREQUAL "STATIC_LIBRARY") add_custom_command(TARGET onnxruntime POST_BUILD diff --git a/cmake/onnxruntime_fuzz_test.cmake b/cmake/onnxruntime_fuzz_test.cmake index 26d41e98687d4..eea411d938176 100644 --- a/cmake/onnxruntime_fuzz_test.cmake +++ b/cmake/onnxruntime_fuzz_test.cmake @@ -4,23 +4,24 @@ # Check that the options are properly set for # the fuzzing project if (onnxruntime_FUZZ_ENABLED) - message(STATUS "Building dependency protobuf-mutator and libfuzzer") - - # set the options used to control the protobuf-mutator build - set(PROTOBUF_LIBRARIES ${PROTOBUF_LIB}) - set(LIB_PROTO_MUTATOR_TESTING OFF) - - # include the protobuf-mutator CMakeLists.txt rather than the projects CMakeLists.txt to avoid target clashes - # with google test - add_subdirectory("external/libprotobuf-mutator/src") - - # add the appropriate include directory and compilation flags - # needed by the protobuf-mutator target and the libfuzzer - set(PROTOBUF_MUT_INCLUDE_DIRS "external/libprotobuf-mutator") - onnxruntime_add_include_to_target(protobuf-mutator ${PROTOBUF_LIB}) - onnxruntime_add_include_to_target(protobuf-mutator-libfuzzer ${PROTOBUF_LIB}) - target_include_directories(protobuf-mutator PRIVATE ${INCLUDE_DIRECTORIES} ${PROTOBUF_MUT_INCLUDE_DIRS}) - target_include_directories(protobuf-mutator-libfuzzer PRIVATE ${INCLUDE_DIRECTORIES} ${PROTOBUF_MUT_INCLUDE_DIRS}) + message(STATUS "Building dependency protobuf-mutator and libfuzzer") + + # set the options used to control the protobuf-mutator build + set(PROTOBUF_LIBRARIES ${PROTOBUF_LIB}) + set(LIB_PROTO_MUTATOR_TESTING OFF) + + # include the protobuf-mutator CMakeLists.txt rather than the projects CMakeLists.txt to avoid target clashes + # with google test + add_subdirectory("external/libprotobuf-mutator/src") + + # add the appropriate include directory and compilation flags + # needed by the protobuf-mutator target and the libfuzzer + set(PROTOBUF_MUT_INCLUDE_DIRS "external/libprotobuf-mutator") + onnxruntime_add_include_to_target(protobuf-mutator ${PROTOBUF_LIB}) + onnxruntime_add_include_to_target(protobuf-mutator-libfuzzer ${PROTOBUF_LIB}) + target_include_directories(protobuf-mutator PRIVATE ${INCLUDE_DIRECTORIES} ${PROTOBUF_MUT_INCLUDE_DIRS}) + target_include_directories(protobuf-mutator-libfuzzer PRIVATE ${INCLUDE_DIRECTORIES} ${PROTOBUF_MUT_INCLUDE_DIRS}) + if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") # MSVC-specific compiler options target_compile_options(protobuf-mutator PRIVATE "/wd4244" "/wd4245" "/wd4267" "/wd4100" "/wd4456") @@ -44,42 +45,96 @@ if (onnxruntime_FUZZ_ENABLED) ) endif() - # add Fuzzing Engine Build Configuration - message(STATUS "Building Fuzzing engine") + # add Fuzzing Engine Build Configuration + message(STATUS "Building Fuzzing engine") + + # set Fuzz root directory + set(SEC_FUZZ_ROOT ${TEST_SRC_DIR}/fuzzing) + + # Security fuzzing engine src file reference + set(SEC_FUZ_SRC "${SEC_FUZZ_ROOT}/src/BetaDistribution.cpp" + "${SEC_FUZZ_ROOT}/src/OnnxPrediction.cpp" + "${SEC_FUZZ_ROOT}/src/testlog.cpp" + "${SEC_FUZZ_ROOT}/src/test.cpp") + + # compile the executables + onnxruntime_add_executable(onnxruntime_security_fuzz ${SEC_FUZ_SRC}) + + # compile with c++17 + target_compile_features(onnxruntime_security_fuzz PUBLIC cxx_std_17) - # set Fuzz root directory - set(SEC_FUZZ_ROOT ${TEST_SRC_DIR}/fuzzing) + # Security fuzzing engine header file reference + onnxruntime_add_include_to_target(onnxruntime_security_fuzz onnx onnxruntime) - # Security fuzzing engine src file reference - set(SEC_FUZ_SRC "${SEC_FUZZ_ROOT}/src/BetaDistribution.cpp" - "${SEC_FUZZ_ROOT}/src/OnnxPrediction.cpp" - "${SEC_FUZZ_ROOT}/src/testlog.cpp" - "${SEC_FUZZ_ROOT}/src/test.cpp") + # Assign all include to one variable + set(SEC_FUZ_INC "${SEC_FUZZ_ROOT}/include") + set(INCLUDE_FILES ${SEC_FUZ_INC} "$") - # compile the executables - onnxruntime_add_executable(onnxruntime_security_fuzz ${SEC_FUZ_SRC}) + # add all these include directory to the Fuzzing engine + target_include_directories(onnxruntime_security_fuzz PRIVATE ${INCLUDE_FILES}) - # compile with c++17 - target_compile_features(onnxruntime_security_fuzz PUBLIC cxx_std_17) + # add link libraries to the project + target_link_libraries(onnxruntime_security_fuzz onnx_proto onnxruntime protobuf-mutator ${PROTOBUF_LIB}) - # Security fuzzing engine header file reference - onnxruntime_add_include_to_target(onnxruntime_security_fuzz onnx onnxruntime) + # add the dependencies + add_dependencies(onnxruntime_security_fuzz onnx_proto onnxruntime protobuf-mutator ${PROTOBUF_LIB}) - # Assign all include to one variable - set(SEC_FUZ_INC "${SEC_FUZZ_ROOT}/include") - set(INCLUDE_FILES ${SEC_FUZ_INC} "$") + # copy the shared libraries (DLLs on Windows, SOs on Linux) to the execution directory + add_custom_command(TARGET onnxruntime_security_fuzz POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ $ + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ $) - # add all these include directory to the Fuzzing engine - target_include_directories(onnxruntime_security_fuzz PRIVATE ${INCLUDE_FILES}) + if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + # Add a second fuzzer that uses libFuzzer in fuzzer/libfuzzer + message(STATUS "Building libProtoBufFuzzer-based fuzzer") - # add link libraries the project - target_link_libraries(onnxruntime_security_fuzz onnx_proto onnxruntime protobuf-mutator ${PROTOBUF_LIB}) + # Set source files for the libFuzzer + set(LIBFUZZER_SRC "${SEC_FUZZ_ROOT}/src/OnnxPrediction.cpp" + "${SEC_FUZZ_ROOT}/src/testlog.cpp" + "${SEC_FUZZ_ROOT}/ort_libfuzzer/OrtProtoLibfuzzer.cpp") - # add the dependencies - add_dependencies(onnxruntime_security_fuzz onnx_proto onnxruntime protobuf-mutator ${PROTOBUF_LIB}) + # Compile the libFuzzer-based fuzzer + onnxruntime_add_executable(onnxruntime_proto_libfuzzer ${LIBFUZZER_SRC}) + # Security fuzzing engine header file reference + onnxruntime_add_include_to_target(onnxruntime_proto_libfuzzer onnx onnxruntime) + # Set include directories for libFuzzer + target_include_directories(onnxruntime_proto_libfuzzer PRIVATE ${INCLUDE_FILES}) - # copy the dlls to the execution directory - add_custom_command(TARGET onnxruntime_security_fuzz POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy_if_different $ $ - COMMAND ${CMAKE_COMMAND} -E copy_if_different $ $) + # Add link libraries for libFuzzer + target_link_libraries(onnxruntime_proto_libfuzzer onnx_proto onnxruntime protobuf-mutator protobuf-mutator-libfuzzer -fsanitize=fuzzer,address ${PROTOBUF_LIB}) + + # Add the dependencies for libFuzzer + add_dependencies(onnxruntime_proto_libfuzzer onnx_proto onnxruntime protobuf-mutator protobuf-mutator-libfuzzer ${PROTOBUF_LIB}) + + # Copy shared libraries for libFuzzer + add_custom_command(TARGET onnxruntime_proto_libfuzzer POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ $ + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ $) + # Add a second fuzzer that uses libFuzzer in fuzzer/libfuzzer + message(STATUS "Building libBufFuzzer-based fuzzer") + + # Set source files for the libFuzzer + set(LIBFUZZER_SRC "${SEC_FUZZ_ROOT}/src/OnnxPrediction.cpp" + "${SEC_FUZZ_ROOT}/src/testlog.cpp" + "${SEC_FUZZ_ROOT}/ort_libfuzzer/OrtLibfuzzer.cpp") + + # Compile the libFuzzer-based fuzzer + onnxruntime_add_executable(onnxruntime_libfuzzer_fuzz ${LIBFUZZER_SRC}) + # Security fuzzing engine header file reference + onnxruntime_add_include_to_target(onnxruntime_libfuzzer_fuzz onnx onnxruntime) + # Set include directories for libFuzzer + target_compile_definitions(onnxruntime_libfuzzer_fuzz PRIVATE GOOGLE_PROTOBUF_NO_LOGGING=1) + target_include_directories(onnxruntime_libfuzzer_fuzz PRIVATE ${INCLUDE_FILES}) + + # Add link libraries for libFuzzer + target_link_libraries(onnxruntime_libfuzzer_fuzz onnx_proto onnxruntime protobuf-mutator protobuf-mutator-libfuzzer -fsanitize=fuzzer,address ${PROTOBUF_LIB}) + + # Add the dependencies for libFuzzer + add_dependencies(onnxruntime_libfuzzer_fuzz onnx_proto onnxruntime protobuf-mutator protobuf-mutator-libfuzzer ${PROTOBUF_LIB}) + + # Copy shared libraries for libFuzzer + add_custom_command(TARGET onnxruntime_libfuzzer_fuzz POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ $ + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ $) + endif() endif() diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index cf23416943c1f..e35c83ba45952 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -578,12 +578,12 @@ else() message(STATUS "CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}") -if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "10") +if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "11") message(STATUS "Using -mavx2 -mfma -mavxvnni flags") - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mavxvnni") + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c -mavxvnni") else() message(STATUS "Using -mavx2 -mfma flags") - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma") + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c") endif() set(mlas_platform_srcs_avx512f ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S diff --git a/cmake/onnxruntime_providers_openvino.cmake b/cmake/onnxruntime_providers_openvino.cmake index e559583fae8f5..2eb3611bae902 100644 --- a/cmake/onnxruntime_providers_openvino.cmake +++ b/cmake/onnxruntime_providers_openvino.cmake @@ -21,6 +21,10 @@ message(FATAL_ERROR "OpenVINO 2024.0 and newer are supported. Please, use latest OpenVINO release") endif() + if(OpenVINO_VERSION VERSION_GREATER_EQUAL 2024.4) + add_definitions(-DUSE_OVEP_NPU_MEMORY=1) + endif() + if (WIN32) unset(CMAKE_MAP_IMPORTED_CONFIG_RELWITHDEBINFO) endif() diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 58dd08f15f4e2..4b880c4437dfd 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -877,6 +877,7 @@ AddTest( DEPENDS ${all_dependencies} TEST_ARGS ${test_all_args} ) +target_include_directories(onnxruntime_test_all PRIVATE ${ONNXRUNTIME_ROOT}/core/flatbuffers/schema) # ort.fbs.h if (MSVC) # The warning means the type of two integral values around a binary operator is narrow than their result. @@ -982,6 +983,9 @@ target_compile_definitions(onnx_test_data_proto PRIVATE "-DONNX_API=") onnxruntime_add_include_to_target(onnx_test_data_proto onnx_proto) target_include_directories(onnx_test_data_proto PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) set_target_properties(onnx_test_data_proto PROPERTIES FOLDER "ONNXRuntimeTest") +if(NOT DEFINED onnx_SOURCE_DIR) + find_path(onnx_SOURCE_DIR NAMES "onnx/onnx-ml.proto3" "onnx/onnx-ml.proto" REQUIRED) +endif() onnxruntime_protobuf_generate(APPEND_PATH IMPORT_DIRS ${onnx_SOURCE_DIR} TARGET onnx_test_data_proto) # diff --git a/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch b/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch index 736fffb1e384c..3abf2d3afec42 100644 --- a/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch +++ b/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch @@ -1,8 +1,8 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index dba9b4687..a4345898d 100755 +index 1ff85b538..c3ef2183f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -122,7 +122,7 @@ ENDIF() +@@ -253,7 +253,7 @@ ENDIF() # ---[ Build flags IF(NOT CMAKE_SYSTEM_NAME) MESSAGE(FATAL_ERROR "CMAKE_SYSTEM_NAME not defined") @@ -11,29 +11,27 @@ index dba9b4687..a4345898d 100755 MESSAGE(FATAL_ERROR "Unrecognized CMAKE_SYSTEM_NAME value \"${CMAKE_SYSTEM_NAME}\"") ENDIF() IF(CMAKE_SYSTEM_NAME MATCHES "Windows") -@@ -534,7 +534,12 @@ IF(XNNPACK_BUILD_LIBRARY) - TARGET_LINK_LIBRARIES(operator-utils PRIVATE logging) - TARGET_LINK_LIBRARIES(post-operation PRIVATE logging) - TARGET_LINK_LIBRARIES(subgraph PRIVATE allocator logging memory mutex operators operator-run) -- TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection jit logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing post-operation microkernels-prod subgraph) +@@ -763,7 +763,12 @@ IF(XNNPACK_BUILD_LIBRARY) + TARGET_LINK_LIBRARIES(operator-run PRIVATE xnnpack-base logging) + TARGET_LINK_LIBRARIES(operator-utils PRIVATE xnnpack-base logging) + TARGET_LINK_LIBRARIES(subgraph PRIVATE xnnpack-base allocator logging memory mutex operators operator-run) +- TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing microkernels-prod subgraph) + IF(CMAKE_SYSTEM_NAME STREQUAL "Emscripten") -+ # omit microkernels-prod as the list is manually created by ORT in cmake/external/xnnpack.cmake -+ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection jit logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing post-operation subgraph) ++ # omit microkernels-prod as the list is manually created by ORT in cmake/external/xnnpack.cmake ++ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing subgraph) + ELSE() -+ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection jit logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing post-operation microkernels-prod subgraph) -+ ENDIF() ++ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing microkernels-prod subgraph) ++ ENDIF() + TARGET_LINK_LIBRARIES(XNNPACK PUBLIC xnnpack-base) SET_TARGET_PROPERTIES(XNNPACK PROPERTIES C_EXTENSIONS YES) ENDIF() - IF(NOT MSVC) -@@ -543,8 +548,9 @@ ENDIF() +@@ -772,7 +777,8 @@ IF(NOT MSVC) + ENDIF() IF(XNNPACK_TARGET_PROCESSOR STREQUAL "arm") SET_PROPERTY(SOURCE ${ALL_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -marm ") - SET_PROPERTY(SOURCE ${PROD_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -marm ") - SET_PROPERTY(SOURCE ${ALL_ARMSIMD32_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv6 -mfpu=vfp -munaligned-access ") -- SET_PROPERTY(SOURCE ${PROD_ARMSIMD32_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv6 -mfpu=vfp -munaligned-access ") + # set this to armv7-a to workaround build issue. we don't target armv6 so it shouldn't matter -+ SET_PROPERTY(SOURCE ${ALL_ARMSIMD32_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=vfp -munaligned-access ") -+ SET_PROPERTY(SOURCE ${PROD_ARMSIMD32_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=vfp -munaligned-access ") ++ SET_PROPERTY(SOURCE ${ALL_ARMSIMD32_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=vfp -munaligned-access ") SET_PROPERTY(SOURCE ${ALL_NEON_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=neon ") - SET_PROPERTY(SOURCE ${PROD_NEON_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=neon ") SET_PROPERTY(SOURCE ${ALL_NEONFP16_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=neon-fp16 ") + # GCC requires -mfp16-format=ieee to define __fp16 type, but Clang doesn't support this option at all. diff --git a/cmake/vcpkg-configuration.json b/cmake/vcpkg-configuration.json new file mode 100644 index 0000000000000..f3525977c7bb9 --- /dev/null +++ b/cmake/vcpkg-configuration.json @@ -0,0 +1,8 @@ +{ + "default-registry": { + "kind": "git", + "repository": "https://github.com/Microsoft/vcpkg", + "baseline": "3508985146f1b1d248c67ead13f8f54be5b4f5da" + }, + "registries": [] +} diff --git a/cmake/vcpkg.json b/cmake/vcpkg.json new file mode 100644 index 0000000000000..159b8654c1cb1 --- /dev/null +++ b/cmake/vcpkg.json @@ -0,0 +1,78 @@ +{ + "$schema": "https://raw.githubusercontent.com/microsoft/vcpkg-tool/main/docs/vcpkg.schema.json", + "name": "onnxruntime", + "version-date": "2024-09-10", + "description": "ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator", + "homepage": "https://onnxruntime.ai/", + "license": "MIT", + "dependencies": [ + "abseil", + { + "name": "boost-config", + "version>=": "1.82.0" + }, + { + "name": "boost-mp11", + "version>=": "1.82.0" + }, + "cpuinfo", + "cxxopts", + "date", + "dlpack", + { + "name": "flatbuffers", + "host": true, + "version>=": "23.5.26" + }, + { + "name": "flatbuffers", + "version>=": "23.5.26" + }, + "ms-gsl", + "nlohmann-json", + { + "name": "nsync", + "platform": "!windows" + }, + { + "name": "nsync", + "platform": "!windows", + "version>=": "1.26.0" + }, + "optional-lite", + { + "name": "protobuf", + "version>=": "3.21.12" + }, + { + "name": "protobuf", + "host": true, + "version>=": "3.21.12" + }, + "re2", + "safeint", + "utf8-range", + { + "name": "vcpkg-cmake", + "host": true + }, + { + "name": "vcpkg-cmake-config", + "host": true + }, + "wil", + { + "name": "zlib", + "platform": "windows" + } + ], + "features": { + "tests": { + "description": "Build ONNXRuntime unit tests", + "dependencies": [ + "benchmark", + "gtest" + ] + } + } +} diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index aadf4ebe2f488..09a7e47fc9913 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2521,6 +2521,8 @@ This version of the operator has been available since version 1 of the 'com.micr Only supports causal and local attention. Supports rotary position embedding for CPU and CUDA. Supports packed input for CPU and CUDA. + Supports continuous decoding for batch_size == 1 for CPU and CUDA. + #### Version @@ -2561,9 +2563,9 @@ This version of the operator has been available since version 1 of the 'com.micr
past_value (optional) : T
past state value with support for format BNSH. When past_value uses same tensor as present_value(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
seqlens_k : M
-
1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.
+
1D Tensor of shape (batch_size). Equivalent to (total_sequence_lengths - 1).
total_sequence_length : M
-
Scalar tensor of total sequence length (past + new).
+
Scalar tensor equivalent to the maximum total sequence length (past + new) of the batch. Used for checking inputs and determining prompt vs token generation case.
cos_cache (optional) : T
2D tensor with shape (max_sequence_length, head_size / 2).
sin_cache (optional) : T
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index d57394b3e7b97..121240e6e18f9 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -488,7 +488,7 @@ Do not modify directly.* |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| |MatMulInteger16|*in* A:**T1**
*in* B:**T2**
*out* Y:**T3**|1+|**T1** = tensor(int16)
**T2** = tensor(int16)
**T3** = tensor(int32)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)| -|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(float), tensor(uint8)
**T4** = tensor(int32)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)
**T3** = tensor(float), tensor(float16), tensor(uint8)
**T4** = tensor(int32)| |MaxpoolWithMask|*in* X:**T**
*in* M:**tensor(int32)**
*out* Y:**T**|1+|**T** = tensor(float)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float)| |MurmurHash3|*in* X:**T1**
*out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(uint32)| diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index 097873c5e3653..abab118efd04f 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -50,6 +50,8 @@ constexpr const char* HIP = "Hip"; constexpr const char* HIP_PINNED = "HipPinned"; constexpr const char* OpenVINO_CPU = "OpenVINO_CPU"; constexpr const char* OpenVINO_GPU = "OpenVINO_GPU"; +constexpr const char* OpenVINO_RT = "OpenVINO_RT"; +constexpr const char* OpenVINO_RT_NPU = "OpenVINO_RT_NPU"; constexpr const char* WEBGPU_BUFFER = "WebGPU_Buffer"; constexpr size_t kAllocAlignment = 256; diff --git a/include/onnxruntime/core/providers/acl/acl_provider_factory.h b/include/onnxruntime/core/providers/acl/acl_provider_factory.h index 0dc0ec27ff345..8875a83a39f54 100644 --- a/include/onnxruntime/core/providers/acl/acl_provider_factory.h +++ b/include/onnxruntime/core/providers/acl/acl_provider_factory.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "onnxruntime_c_api.h" @@ -10,7 +11,8 @@ extern "C" { /** * \param use_arena zero: false. non-zero: true. */ -ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_ACL, _In_ OrtSessionOptions* options, int use_arena) +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_ACL, _In_ OrtSessionOptions* options, + bool enable_fast_math) ORT_ALL_ARGS_NONNULL; #ifdef __cplusplus diff --git a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java index 26137e88478b5..8382ef06e26e5 100644 --- a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java +++ b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2024 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -7,6 +7,7 @@ import ai.onnxruntime.OrtSession.SessionOptions; import ai.onnxruntime.OrtTrainingSession.OrtCheckpointState; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.EnumSet; import java.util.Objects; import java.util.logging.Logger; @@ -236,6 +237,52 @@ OrtSession createSession(String modelPath, OrtAllocator allocator, SessionOption return new OrtSession(this, modelPath, allocator, options); } + /** + * Create a session using the specified {@link SessionOptions}, model and the default memory + * allocator. + * + * @param modelBuffer Byte buffer representing an ONNX model. Must be a direct byte buffer. + * @param options The session options. + * @return An {@link OrtSession} with the specified model. + * @throws OrtException If the model failed to parse, wasn't compatible or caused an error. + */ + public OrtSession createSession(ByteBuffer modelBuffer, SessionOptions options) + throws OrtException { + return createSession(modelBuffer, defaultAllocator, options); + } + + /** + * Create a session using the default {@link SessionOptions}, model and the default memory + * allocator. + * + * @param modelBuffer Byte buffer representing an ONNX model. Must be a direct byte buffer. + * @return An {@link OrtSession} with the specified model. + * @throws OrtException If the model failed to parse, wasn't compatible or caused an error. + */ + public OrtSession createSession(ByteBuffer modelBuffer) throws OrtException { + return createSession(modelBuffer, new OrtSession.SessionOptions()); + } + + /** + * Create a session using the specified {@link SessionOptions} and model buffer. + * + * @param modelBuffer Byte buffer representing an ONNX model. Must be a direct byte buffer. + * @param allocator The memory allocator to use. + * @param options The session options. + * @return An {@link OrtSession} with the specified model. + * @throws OrtException If the model failed to parse, wasn't compatible or caused an error. + */ + OrtSession createSession(ByteBuffer modelBuffer, OrtAllocator allocator, SessionOptions options) + throws OrtException { + Objects.requireNonNull(modelBuffer, "model array must not be null"); + if (modelBuffer.remaining() == 0) { + throw new OrtException("Invalid model buffer, no elements remaining."); + } else if (!modelBuffer.isDirect()) { + throw new OrtException("ByteBuffer is not direct."); + } + return new OrtSession(this, modelBuffer, allocator, options); + } + /** * Create a session using the specified {@link SessionOptions}, model and the default memory * allocator. diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index 8ab4a1cb26bb1..6d146d5857d3c 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -1,5 +1,6 @@ /* * Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved. + * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates * Licensed under the MIT License. */ package ai.onnxruntime; @@ -10,6 +11,7 @@ import ai.onnxruntime.providers.OrtFlags; import ai.onnxruntime.providers.OrtTensorRTProviderOptions; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -93,6 +95,31 @@ public class OrtSession implements AutoCloseable { allocator); } + /** + * Creates a session reading the model from the supplied byte buffer. + * + *

Must be a direct byte buffer. + * + * @param env The environment. + * @param modelBuffer The model protobuf as a byte buffer. + * @param allocator The allocator to use. + * @param options Session configuration options. + * @throws OrtException If the model was corrupted or some other error occurred in native code. + */ + OrtSession( + OrtEnvironment env, ByteBuffer modelBuffer, OrtAllocator allocator, SessionOptions options) + throws OrtException { + this( + createSession( + OnnxRuntime.ortApiHandle, + env.getNativeHandle(), + modelBuffer, + modelBuffer.position(), + modelBuffer.remaining(), + options.getNativeHandle()), + allocator); + } + /** * Private constructor to build the Java object wrapped around a native session. * @@ -513,6 +540,15 @@ private static native long createSession( private static native long createSession( long apiHandle, long envHandle, byte[] modelArray, long optsHandle) throws OrtException; + private static native long createSession( + long apiHandle, + long envHandle, + ByteBuffer modelBuffer, + int bufferPos, + int bufferSize, + long optsHandle) + throws OrtException; + private native long getNumInputs(long apiHandle, long nativeHandle) throws OrtException; private native String[] getInputNames(long apiHandle, long nativeHandle, long allocatorHandle) @@ -906,6 +942,20 @@ public void setSymbolicDimensionValue(String dimensionName, long dimensionValue) OnnxRuntime.ortApiHandle, nativeHandle, dimensionName, dimensionValue); } + /** + * Set whether to use deterministic compute. + * + *

Default is false. If set to true, this will enable deterministic compute for GPU kernels + * where possible. Note that this most likely will have a performance cost. + * + * @param value Should the compute be deterministic? + * @throws OrtException If there was an error in native code. + */ + public void setDeterministicCompute(boolean value) throws OrtException { + checkClosed(); + setDeterministicCompute(OnnxRuntime.ortApiHandle, nativeHandle, value); + } + /** * Disables the per session thread pools. Must be used in conjunction with an environment * containing global thread pools. @@ -1181,12 +1231,12 @@ public void addDirectML(int deviceId) throws OrtException { /** * Adds the ARM Compute Library as an execution backend. * - * @param useArena If true use the arena memory allocator. + * @param enableFastMath Enable fast math mode in ACL. * @throws OrtException If there was an error in native code. */ - public void addACL(boolean useArena) throws OrtException { + public void addACL(boolean enableFastMath) throws OrtException { checkClosed(); - addACL(OnnxRuntime.ortApiHandle, nativeHandle, useArena ? 1 : 0); + addACL(OnnxRuntime.ortApiHandle, nativeHandle, enableFastMath); } /** @@ -1291,6 +1341,9 @@ private native void registerCustomOpsUsingFunction( private native void closeOptions(long apiHandle, long nativeHandle); + private native void setDeterministicCompute( + long apiHandle, long nativeHandle, boolean isDeterministic) throws OrtException; + private native void addFreeDimensionOverrideByName( long apiHandle, long nativeHandle, String dimensionName, long dimensionValue) throws OrtException; @@ -1354,7 +1407,8 @@ private native void addTvm(long apiHandle, long nativeHandle, String settings) private native void addDirectML(long apiHandle, long nativeHandle, int deviceId) throws OrtException; - private native void addACL(long apiHandle, long nativeHandle, int useArena) throws OrtException; + private native void addACL(long apiHandle, long nativeHandle, boolean enableFastMath) + throws OrtException; private native void addArmNN(long apiHandle, long nativeHandle, int useArena) throws OrtException; diff --git a/java/src/main/java/ai/onnxruntime/OrtUtil.java b/java/src/main/java/ai/onnxruntime/OrtUtil.java index 5b2e9b2efac4c..4f3dee3c00b91 100644 --- a/java/src/main/java/ai/onnxruntime/OrtUtil.java +++ b/java/src/main/java/ai/onnxruntime/OrtUtil.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved. * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. */ @@ -483,9 +483,12 @@ static BufferTuple prepareBuffer(Buffer data, OnnxJavaType type) { if (type == OnnxJavaType.STRING || type == OnnxJavaType.UNKNOWN) { throw new IllegalStateException("Cannot create a " + type + " tensor from a buffer"); } + // This buffer could be a ByteBuffer which is being used to carry data of another type, if so, + // it's type.size should be 1 to compute the correct buffer size and offset. + int elementSize = data instanceof ByteBuffer ? 1 : type.size; int bufferPos; - long bufferSizeLong = data.remaining() * (long) type.size; - if (bufferSizeLong > (Integer.MAX_VALUE - (8 * type.size))) { + long bufferSizeLong = data.remaining() * (long) elementSize; + if (bufferSizeLong > (Integer.MAX_VALUE - (8L * elementSize))) { // The maximum direct byte buffer size is a little below Integer.MAX_VALUE depending // on the JVM, so we check for something 8 elements below the maximum size which // should be allocatable (assuming there is enough memory) on all 64-bit JVMs. @@ -496,11 +499,11 @@ static BufferTuple prepareBuffer(Buffer data, OnnxJavaType type) { + type); } // Now we know we're in range - int bufferSize = data.remaining() * type.size; + int bufferSize = data.remaining() * elementSize; Buffer tmp; if (data.isDirect()) { tmp = data; - bufferPos = data.position() * type.size; + bufferPos = data.position() * elementSize; } else { // Copy the data to a new direct buffer, then restore the state of the input. int origPosition = data.position(); diff --git a/java/src/main/native/ai_onnxruntime_OrtSession.c b/java/src/main/native/ai_onnxruntime_OrtSession.c index f4d5ab080cd31..ee8cdee659296 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2020, 2022 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2024 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ #include @@ -48,6 +48,29 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_la return (jlong)session; } +/* + * Class: ai_onnxruntime_OrtSession + * Method: createSession + * Signature: (JJLjava/nio/ByteBuffer;IIJ)J + */ +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_nio_ByteBuffer_2IIJ(JNIEnv* jniEnv, jclass jclazz, jlong apiHandle, jlong envHandle, jobject buffer, jint bufferPos, jint bufferSize, jlong optsHandle) { + (void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtEnv* env = (OrtEnv*)envHandle; + OrtSessionOptions* opts = (OrtSessionOptions*)optsHandle; + OrtSession* session = NULL; + + // Extract the buffer + char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, buffer); + // Increment by bufferPos bytes + bufferArr = bufferArr + bufferPos; + + // Create the session + checkOrtStatus(jniEnv, api, api->CreateSessionFromArray(env, bufferArr, bufferSize, opts, &session)); + + return (jlong)session; +} + /* * Class: ai_onnxruntime_OrtSession * Method: createSession diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c index 337f4c1921c6e..ff6b7fa703e6e 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c @@ -1,5 +1,6 @@ /* * Copyright (c) 2019, 2023 Oracle and/or its affiliates. All rights reserved. + * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates * Licensed under the MIT License. */ #include @@ -258,6 +259,19 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setSes checkOrtStatus(jniEnv,api,api->SetSessionLogVerbosityLevel(options,logLevel)); } +/* + * Class: ai_onnxruntime_OrtSession_SessionOptions + * Method: setDeterministicCompute + * Signature: (JJZ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setDeterministicCompute + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jboolean isDeterministic) { + (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle; + checkOrtStatus(jniEnv,api,api->SetDeterministicCompute(options, isDeterministic)); +} + /* * Class: ai_onnxruntime_OrtSession_SessionOptions * Method: registerCustomOpLibrary @@ -644,12 +658,13 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addDir * Signature: (JJI)V */ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addACL - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint useArena) { + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jboolean enableFastMath) { (void)jobj; #ifdef USE_ACL - checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,OrtSessionOptionsAppendExecutionProvider_ACL((OrtSessionOptions*) handle,useArena)); + checkOrtStatus(jniEnv,(const OrtApi*)apiHandle, + OrtSessionOptionsAppendExecutionProvider_ACL((OrtSessionOptions*) handle, enableFastMath)); #else - (void)apiHandle;(void)handle;(void)useArena; // Parameters used when ACL is defined. + (void)apiHandle;(void)handle;(void)enableFastMath; // Parameters used when ACL is defined. throwOrtException(jniEnv,convertErrorCode(ORT_INVALID_ARGUMENT),"This binary was not compiled with ACL support."); #endif } diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 3340a2e5e9f3a..11141a3a65a3e 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -20,10 +20,14 @@ import ai.onnxruntime.OrtSession.SessionOptions.OptLevel; import java.io.File; import java.io.IOException; +import java.io.RandomAccessFile; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; import java.nio.LongBuffer; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.channels.FileChannel.MapMode; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; @@ -338,6 +342,33 @@ public void partialInputsTest() throws OrtException { } } + @Test + public void createSessionFromByteBuffer() throws IOException, OrtException { + Path modelPath = TestHelpers.getResourcePath("/squeezenet.onnx"); + try (RandomAccessFile file = new RandomAccessFile(modelPath.toFile(), "r"); + FileChannel channel = file.getChannel()) { + MappedByteBuffer modelBuffer = channel.map(MapMode.READ_ONLY, 0, channel.size()); + try (OrtSession.SessionOptions options = new SessionOptions(); + OrtSession session = env.createSession(modelBuffer, options)) { + assertNotNull(session); + assertEquals(1, session.getNumInputs()); // 1 input node + Map inputInfoList = session.getInputInfo(); + assertNotNull(inputInfoList); + assertEquals(1, inputInfoList.size()); + NodeInfo input = inputInfoList.get("data_0"); + assertEquals("data_0", input.getName()); // input node name + assertTrue(input.getInfo() instanceof TensorInfo); + TensorInfo inputInfo = (TensorInfo) input.getInfo(); + assertEquals(OnnxJavaType.FLOAT, inputInfo.type); + int[] expectedInputDimensions = new int[] {1, 3, 224, 224}; + assertEquals(expectedInputDimensions.length, inputInfo.shape.length); + for (int i = 0; i < expectedInputDimensions.length; i++) { + assertEquals(expectedInputDimensions[i], inputInfo.shape[i]); + } + } + } + } + @Test public void createSessionFromByteArray() throws IOException, OrtException { Path modelPath = TestHelpers.getResourcePath("/squeezenet.onnx"); @@ -1232,6 +1263,7 @@ public void testExtraSessionOptions() throws OrtException, IOException { options.setLoggerId("monkeys"); options.setSessionLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL); options.setSessionLogVerbosityLevel(5); + options.setDeterministicCompute(true); Map configEntries = options.getConfigEntries(); assertTrue(configEntries.isEmpty()); options.addConfigEntry("key", "value"); diff --git a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java index c060cf73ecf14..ea210d96c1507 100644 --- a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java +++ b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -218,6 +218,30 @@ public void testUint8Creation() throws OrtException { } } + @Test + public void testByteBufferCreation() throws OrtException { + OrtEnvironment env = OrtEnvironment.getEnvironment(); + ByteBuffer byteBuf = ByteBuffer.allocateDirect(Float.BYTES * 5).order(ByteOrder.nativeOrder()); + FloatBuffer floatBuf = byteBuf.asFloatBuffer(); + floatBuf.put(1.0f); + floatBuf.put(2.0f); + floatBuf.put(3.0f); + floatBuf.put(4.0f); + floatBuf.put(5.0f); + floatBuf.position(1); + float[] expected = new float[floatBuf.remaining()]; + floatBuf.get(expected); + floatBuf.position(1); + byteBuf.position(4); + try (OnnxTensor t = + OnnxTensor.createTensor( + env, byteBuf, new long[] {floatBuf.remaining()}, OnnxJavaType.FLOAT)) { + Assertions.assertNotNull(t); + float[] actual = (float[]) t.getValue(); + Assertions.assertArrayEquals(expected, actual); + } + } + @Test public void testEmptyTensor() throws OrtException { OrtEnvironment env = OrtEnvironment.getEnvironment(); diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts index 4e0ef821dde57..342f5e3a467eb 100644 --- a/js/common/lib/tensor-impl.ts +++ b/js/common/lib/tensor-impl.ts @@ -51,13 +51,16 @@ export class Tensor implements TensorInterface { */ constructor( type: TensorType, - data: TensorDataType | readonly string[] | readonly number[] | readonly boolean[], + data: TensorDataType | Uint8ClampedArray | readonly string[] | readonly number[] | readonly boolean[], dims?: readonly number[], ); /** * Construct a new CPU tensor object from the given data and dims. Type is inferred from data. */ - constructor(data: TensorDataType | readonly string[] | readonly boolean[], dims?: readonly number[]); + constructor( + data: TensorDataType | Uint8ClampedArray | readonly string[] | readonly boolean[], + dims?: readonly number[], + ); /** * Construct a new tensor object from the pinned CPU data with the given type and dims. * @@ -90,12 +93,13 @@ export class Tensor implements TensorInterface { arg0: | TensorType | TensorDataType + | Uint8ClampedArray | readonly string[] | readonly boolean[] | CpuPinnedConstructorParameters | TextureConstructorParameters | GpuBufferConstructorParameters, - arg1?: TensorDataType | readonly number[] | readonly string[] | readonly boolean[], + arg1?: TensorDataType | Uint8ClampedArray | readonly number[] | readonly string[] | readonly boolean[], arg2?: readonly number[], ) { // perform one-time check for BigInt/Float16Array support @@ -216,6 +220,12 @@ export class Tensor implements TensorInterface { } } else if (arg1 instanceof typedArrayConstructor) { data = arg1; + } else if (arg1 instanceof Uint8ClampedArray) { + if (arg0 === 'uint8') { + data = Uint8Array.from(arg1); + } else { + throw new TypeError(`A Uint8ClampedArray tensor's data must be type of uint8`); + } } else { throw new TypeError(`A ${type} tensor's data must be type of ${typedArrayConstructor}`); } @@ -243,6 +253,9 @@ export class Tensor implements TensorInterface { } else { throw new TypeError(`Invalid element type of data array: ${firstElementType}.`); } + } else if (arg0 instanceof Uint8ClampedArray) { + type = 'uint8'; + data = Uint8Array.from(arg0); } else { // get tensor type from TypedArray const mappedType = NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.get( diff --git a/js/common/lib/tensor.ts b/js/common/lib/tensor.ts index 70396bbe1e9a3..8a1197994393b 100644 --- a/js/common/lib/tensor.ts +++ b/js/common/lib/tensor.ts @@ -192,6 +192,15 @@ export interface TensorConstructor extends TensorFactory { dims?: readonly number[], ): TypedTensor<'bool'>; + /** + * Construct a new uint8 tensor object from a Uint8ClampedArray, data and dims. + * + * @param type - Specify the element type. + * @param data - Specify the CPU tensor data. + * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. + */ + new (type: 'uint8', data: Uint8ClampedArray, dims?: readonly number[]): TypedTensor<'uint8'>; + /** * Construct a new 64-bit integer typed tensor object from the given type, data and dims. * @@ -245,6 +254,14 @@ export interface TensorConstructor extends TensorFactory { */ new (data: Uint8Array, dims?: readonly number[]): TypedTensor<'uint8'>; + /** + * Construct a new uint8 tensor object from the given data and dims. + * + * @param data - Specify the CPU tensor data. + * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. + */ + new (data: Uint8ClampedArray, dims?: readonly number[]): TypedTensor<'uint8'>; + /** * Construct a new uint16 tensor object from the given data and dims. * diff --git a/js/common/test/type-tests/tensor/create-new-uint8.ts b/js/common/test/type-tests/tensor/create-new-uint8.ts new file mode 100644 index 0000000000000..46438f97ca2e7 --- /dev/null +++ b/js/common/test/type-tests/tensor/create-new-uint8.ts @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import * as ort from 'onnxruntime-common'; + +// construct from Uint8Array +// +// {type-tests}|pass +new ort.Tensor(new Uint8Array(1)); + +// construct from Uint8ClampedArray +// +// {type-tests}|pass +new ort.Tensor(new Uint8ClampedArray(1)); + +// construct from type (bool), data (Uint8ClampedArray) and shape (number array) +// +// {type-tests}|fail|1|2769 +new ort.Tensor('bool', new Uint8ClampedArray([255, 256]), [2]); diff --git a/js/common/test/unit-tests/tensor/constructor-type.ts b/js/common/test/unit-tests/tensor/constructor-type.ts index def711684d7f5..02390800e8611 100644 --- a/js/common/test/unit-tests/tensor/constructor-type.ts +++ b/js/common/test/unit-tests/tensor/constructor-type.ts @@ -82,6 +82,14 @@ describe('Tensor Constructor Tests - check types', () => { assert.equal(tensor.type, 'bool', "tensor.type should be 'bool'"); }); + it('[uint8] new Tensor(uint8ClampedArray, dims): uint8 tensor can be constructed from Uint8ClampedArray', () => { + const uint8ClampedArray = new Uint8ClampedArray(2); + uint8ClampedArray[0] = 0; + uint8ClampedArray[1] = 256; // clamped + const tensor = new Tensor('uint8', uint8ClampedArray, [2]); + assert.equal(tensor.type, 'uint8', "tensor.type should be 'uint8'"); + }); + it("[bool] new Tensor('bool', uint8Array, dims): tensor can be constructed from Uint8Array", () => { const tensor = new Tensor('bool', new Uint8Array([1, 0, 1, 0]), [2, 2]); assert.equal(tensor.type, 'bool', "tensor.type should be 'bool'"); diff --git a/js/react_native/e2e/.detoxrc.js b/js/react_native/e2e/.detoxrc.js index e24833a1d09c9..0792c3d528585 100644 --- a/js/react_native/e2e/.detoxrc.js +++ b/js/react_native/e2e/.detoxrc.js @@ -38,7 +38,8 @@ module.exports = { simulator: { type: 'ios.simulator', device: { - type: 'iPhone 13', + type: 'iPhone 14', + os: 'iOS 16.4', }, }, attached: { diff --git a/js/react_native/e2e/ios/OnnxruntimeModuleExample.xcodeproj/project.pbxproj b/js/react_native/e2e/ios/OnnxruntimeModuleExample.xcodeproj/project.pbxproj index f2e4a791eae57..b7c38547d67e1 100644 --- a/js/react_native/e2e/ios/OnnxruntimeModuleExample.xcodeproj/project.pbxproj +++ b/js/react_native/e2e/ios/OnnxruntimeModuleExample.xcodeproj/project.pbxproj @@ -3,7 +3,7 @@ archiveVersion = 1; classes = { }; - objectVersion = 46; + objectVersion = 54; objects = { /* Begin PBXBuildFile section */ @@ -322,6 +322,7 @@ DEVELOPMENT_TEAM = ""; ENABLE_BITCODE = NO; INFOPLIST_FILE = OnnxruntimeModuleExample/Info.plist; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; OTHER_LDFLAGS = ( "$(inherited)", @@ -344,6 +345,7 @@ CURRENT_PROJECT_VERSION = 1; DEVELOPMENT_TEAM = ""; INFOPLIST_FILE = OnnxruntimeModuleExample/Info.plist; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; OTHER_LDFLAGS = ( "$(inherited)", @@ -405,7 +407,7 @@ GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; "HEADER_SEARCH_PATHS[arch=*]" = ""; - IPHONEOS_DEPLOYMENT_TARGET = 14.4; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LD_RUNPATH_SEARCH_PATHS = "/usr/lib/swift $(inherited)"; LIBRARY_SEARCH_PATHS = ( "\"$(TOOLCHAIN_DIR)/usr/lib/swift/$(PLATFORM_NAME)\"", @@ -458,7 +460,7 @@ GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 14.4; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LD_RUNPATH_SEARCH_PATHS = "/usr/lib/swift $(inherited)"; LIBRARY_SEARCH_PATHS = ( "\"$(TOOLCHAIN_DIR)/usr/lib/swift/$(PLATFORM_NAME)\"", @@ -486,7 +488,7 @@ DEBUG_INFORMATION_FORMAT = dwarf; GCC_C_LANGUAGE_STANDARD = gnu11; GENERATE_INFOPLIST_FILE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 15.2; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; MARKETING_VERSION = 1.0; MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; MTL_FAST_MATH = YES; @@ -514,7 +516,7 @@ DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; GCC_C_LANGUAGE_STANDARD = gnu11; GENERATE_INFOPLIST_FILE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 15.2; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; MARKETING_VERSION = 1.0; MTL_FAST_MATH = YES; PRODUCT_BUNDLE_IDENTIFIER = ai.onnxruntime.reactnative.OnnxruntimeModuleExampleUITests; diff --git a/js/react_native/e2e/ios/Podfile b/js/react_native/e2e/ios/Podfile index d31a6f50221fb..0272c23092838 100644 --- a/js/react_native/e2e/ios/Podfile +++ b/js/react_native/e2e/ios/Podfile @@ -3,6 +3,15 @@ require_relative '../node_modules/@react-native-community/cli-platform-ios/nativ platform :ios, '13.0' +pre_install do |installer| + # Custom pre-install script or commands + puts "Running pre-install script..." + + # Recommended fix for https://github.com/facebook/react-native/issues/32483 + # from https://github.com/facebook/react-native/issues/32483#issuecomment-966784501 + system("sed -i '' 's/typedef uint8_t clockid_t;//' \"./Pods/RCT-Folly/folly/portability/Time.h\"") +end + target 'OnnxruntimeModuleExample' do config = use_native_modules! @@ -19,3 +28,12 @@ target 'OnnxruntimeModuleExample' do inherit! :search_paths end +post_install do |installer| + installer.generated_projects.each do |project| + project.targets.each do |target| + target.build_configurations.each do |config| + config.build_settings['IPHONEOS_DEPLOYMENT_TARGET'] = '13.0' + end + end + end +end \ No newline at end of file diff --git a/js/react_native/ios/OnnxruntimeModule.xcodeproj/project.pbxproj b/js/react_native/ios/OnnxruntimeModule.xcodeproj/project.pbxproj index 2a093b2b89c95..f11bc73687098 100644 --- a/js/react_native/ios/OnnxruntimeModule.xcodeproj/project.pbxproj +++ b/js/react_native/ios/OnnxruntimeModule.xcodeproj/project.pbxproj @@ -3,7 +3,7 @@ archiveVersion = 1; classes = { }; - objectVersion = 46; + objectVersion = 54; objects = { /* Begin PBXBuildFile section */ @@ -189,7 +189,8 @@ 58B511D31A9E6C8500147676 /* Project object */ = { isa = PBXProject; attributes = { - LastUpgradeCheck = 0920; + BuildIndependentTargetsInParallel = YES; + LastUpgradeCheck = 1540; ORGANIZATIONNAME = Facebook; TargetAttributes = { 58B511DA1A9E6C8500147676 = { @@ -420,7 +421,7 @@ GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; "HEADER_SEARCH_PATHS[arch=*]" = ""; - IPHONEOS_DEPLOYMENT_TARGET = 9.0; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LIBRARY_SEARCH_PATHS = ""; MTL_ENABLE_DEBUG_INFO = YES; ONLY_ACTIVE_ARCH = YES; @@ -465,7 +466,7 @@ GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; "HEADER_SEARCH_PATHS[arch=*]" = ""; - IPHONEOS_DEPLOYMENT_TARGET = 9.0; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LIBRARY_SEARCH_PATHS = ""; MTL_ENABLE_DEBUG_INFO = NO; SDKROOT = iphoneos; @@ -484,6 +485,7 @@ "$(SRCROOT)/../../react-native/React/**", ); "HEADER_SEARCH_PATHS[arch=*]" = "\"$(PODS_ROOT)/onnxruntime/onnxruntime.framework/Headers\""; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LIBRARY_SEARCH_PATHS = ( "$(inherited)", "$(PROJECT_DIR)", @@ -508,6 +510,7 @@ "$(SRCROOT)/../../react-native/React/**", ); "HEADER_SEARCH_PATHS[arch=*]" = "\"$(PODS_ROOT)/onnxruntime/onnxruntime.framework/Headers\""; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LIBRARY_SEARCH_PATHS = ( "$(inherited)", "$(PROJECT_DIR)", @@ -582,7 +585,7 @@ "$(PROJECT_DIR)", ); INFOPLIST_FILE = OnnxruntimeModuleTest/Info.plist; - IPHONEOS_DEPLOYMENT_TARGET = 14.4; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks @loader_path/Frameworks"; LIBRARY_SEARCH_PATHS = "$(inherited)"; MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; @@ -656,7 +659,7 @@ "$(PROJECT_DIR)", ); INFOPLIST_FILE = OnnxruntimeModuleTest/Info.plist; - IPHONEOS_DEPLOYMENT_TARGET = 14.4; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks @loader_path/Frameworks"; LIBRARY_SEARCH_PATHS = "$(inherited)"; MTL_FAST_MATH = YES; diff --git a/js/react_native/ios/Podfile b/js/react_native/ios/Podfile index 53c83948672a3..ad8a5ce3b2d5f 100644 --- a/js/react_native/ios/Podfile +++ b/js/react_native/ios/Podfile @@ -3,6 +3,15 @@ require_relative '../node_modules/@react-native-community/cli-platform-ios/nativ platform :ios, '13.0' +pre_install do |installer| + # Custom pre-install script or commands + puts "Running pre-install script..." + + # Recommended fix for https://github.com/facebook/react-native/issues/32483 + # from https://github.com/facebook/react-native/issues/32483#issuecomment-966784501 + system("sed -i '' 's/typedef uint8_t clockid_t;//' \"./Pods/RCT-Folly/folly/portability/Time.h\"") +end + def shared config = use_native_modules! @@ -29,3 +38,13 @@ end target 'OnnxruntimeModuleTest' do shared end + +post_install do |installer| + installer.generated_projects.each do |project| + project.targets.each do |target| + target.build_configurations.each do |config| + config.build_settings['IPHONEOS_DEPLOYMENT_TARGET'] = '13.0' + end + end + end +end \ No newline at end of file diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index 48b06b780dfc7..6fd4f9af20432 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -25,6 +25,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | ConvTranspose | ai.onnx(7-10, 11+) | convTranspose2d | ✓ | ✓ | Only supports 3-D or 4-D input and 'W' (weight). WebNN CPU backend only supports default dilations and group | | Cos | ai.onnx(7+) | cos | ✓ | ✓ | | | Div | ai.onnx(7-12, 13, 14+) | div | ✓ | ✓ | | +| DequantizeLinear | ai.onnx(10-12, 13-18, 19-20, 21-22, 23+) | dequantizeLinear | ✗ | ✓ | | | Dropout | ai.onnx(7-9, 10-11, 12, 13-21, 22+) | identity | ✓ | ✓ | Only supports test mode | | Elu | ai.onnx(7+) | elu | ✓ | ✓ | WebNN CPU backend only supports 'alpha' value is 1.0 | | Equal | ai.onnx(7-10, 11-12, 13-18, 19+) | equal | ✓ | ✓ | | @@ -41,6 +42,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | GlobalLpPool| ai.onnx(7+) | l2Pool2d | ✗ | ✓ | Only supports 4-D input, 'p' value is 2 | | Greater | ai.onnx(7-8, 9-12, 13+) | greater | ✓ | ✓ | | | GreaterOrEqual | ai.onnx(12-15, 16+) | greaterOrEqual | ✓ | ✓ | | +| GRU | ai.onnx(7-13, 14-21, 22+) | gru | ✓ | ✓ | Only supports 'layout' == 0. 'clip' is not supported. The activation functions in 'activations' must be one of 'Relu', 'Tanh', 'Sigmoid'. Forward and backward activations must be the same if bidirectional. 'sequence_lens' if present should be constant with values equal to the first dimension length of input 'X' | | HardSigmoid | ai.onnx(7+) | hardSigmoid | ✓ | ✓ | | | HardSwish | ai.onnx(14+) | hardSwish | ✓ | ✓ | | | Identity | ai.onnx(7-13, 14-15, 16-18, 19-20, 21+) | identity | ✓ | ✓ | | @@ -61,6 +63,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | Pad | ai.onnx(7-10, 11-12, 13-17, 18, 19-20, 21+) | pad | ✓ | ✓ | modes == 'wrap' is not supported | | Pow | ai.onnx(7-11, 12, 13-14, 15+) | pow | ✓ | ✓ | | | PRelu | ai.onnx(7-8, 9-15, 16+) | prelu | ✓ | ✓ | WebNN CPU backend restricts the last dimension of input and slope to be same (Chromium issue: https://issues.chromium.org/issues/335517470) | +| QuantizeLinear | ai.onnx(10-12, 13-18, 19-20, 21-22, 23+) | quantizeLinear | ✗ | ✓ | | | Reciprocal | ai.onnx(7-12, 13+) | reciprocal | ✓ | ✓ | | | ReduceL1 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL1 | ✓ | ✓ | Input 'axes' if present should be a constant | | ReduceL2 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL2 | ✓ | ✓ | Input 'axes' if present should be a constant | diff --git a/js/web/lib/backend-wasm-inference.ts b/js/web/lib/backend-wasm-inference.ts deleted file mode 100644 index 7dfe7ee05a1d3..0000000000000 --- a/js/web/lib/backend-wasm-inference.ts +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import { OnnxruntimeWebAssemblyBackend } from './backend-wasm'; -export const wasmBackend = new OnnxruntimeWebAssemblyBackend(); diff --git a/js/web/lib/backend-wasm-training.ts b/js/web/lib/backend-wasm-training.ts deleted file mode 100644 index 7332b3f97eba0..0000000000000 --- a/js/web/lib/backend-wasm-training.ts +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import { InferenceSession, TrainingSessionHandler } from 'onnxruntime-common'; - -import { OnnxruntimeWebAssemblyBackend } from './backend-wasm'; -import { OnnxruntimeWebAssemblyTrainingSessionHandler } from './wasm/session-handler-training'; - -class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { - async createTrainingSessionHandler( - checkpointStateUriOrBuffer: string | Uint8Array, - trainModelUriOrBuffer: string | Uint8Array, - evalModelUriOrBuffer: string | Uint8Array, - optimizerModelUriOrBuffer: string | Uint8Array, - options: InferenceSession.SessionOptions, - ): Promise { - const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler(); - await handler.createTrainingSession( - checkpointStateUriOrBuffer, - trainModelUriOrBuffer, - evalModelUriOrBuffer, - optimizerModelUriOrBuffer, - options, - ); - return Promise.resolve(handler); - } -} - -export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend(); diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index 7bef538b26063..766937dc4c4cf 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -99,3 +99,5 @@ export class OnnxruntimeWebAssemblyBackend implements Backend { return Promise.resolve(handler); } } + +export const wasmBackend = new OnnxruntimeWebAssemblyBackend(); diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index 321394466b365..776c0d026bc97 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -20,9 +20,7 @@ if (!BUILD_DEFS.DISABLE_WEBGL) { } if (!BUILD_DEFS.DISABLE_WASM) { - const wasmBackend = BUILD_DEFS.DISABLE_TRAINING - ? require('./backend-wasm-inference').wasmBackend - : require('./backend-wasm-training').wasmBackend; + const wasmBackend = require('./backend-wasm').wasmBackend; if (!BUILD_DEFS.DISABLE_JSEP) { registerBackend('webgpu', wasmBackend, 5); registerBackend('webnn', wasmBackend, 5); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index ece2e1b7c7dcd..236f1b09a6c93 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -44,10 +44,8 @@ const calculateOutputShapeAndPads = ( ) => { const spatialRank = inputShape.length - 2; const updateOutputShape = outputShape.length === 0; - if (outputPadding.length === 0) { - for (let i = 0; i < spatialRank; ++i) { - outputPadding.push(0); - } + if (outputPadding.length < spatialRank) { + outputPadding.push(...Array(spatialRank - outputPadding.length).fill(0)); } const batchSize = inputShape[0]; const outChannels = kernelShape[isChannelLast ? 3 : 1] * group; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index fe37163a0cd08..de9f7bc8885ab 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -103,7 +103,10 @@ const validateInputs = (inputs: readonly TensorView[], attributes: ConvAttribute const getAdjustedConvAttributes = (attributes: T, inputs: readonly TensorView[]): T => { const kernelShape = attributes.kernelShape.slice(); - // if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims + // if kernelShape is not well specified in the attributes, infer it from the weight tensor dims + if (kernelShape.length < inputs[1].dims.length - 2) { + kernelShape.push(...Array(inputs[1].dims.length - 2 - kernelShape.length).fill(0)); + } for (let i = 2; i < inputs[1].dims.length; ++i) { if (kernelShape[i - 2] === 0) { kernelShape[i - 2] = inputs[1].dims[i]; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index ee877f8f0c3f2..1fd99d085e0ed 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -26,14 +26,12 @@ const getOutputShape = (inputShape: readonly number[], perm: number[]): readonly ShapeUtil.sortBasedOnPerm(inputShape, getAdjustedPerm(inputShape.length, perm)); const permFunctionBody = (perm: number[], rank: number, input: IndicesHelper, output: IndicesHelper): string => { - const reverseFunc = []; - reverseFunc.push(`fn perm(i: ${output.type.indices}) -> ${input.type.indices} { - var a: ${input.type.indices};`); + let reverseFunc = `fn perm(i: ${output.type.indices}) -> ${input.type.indices} { + var a: ${input.type.indices};`; for (let i = 0; i < rank; ++i) { - reverseFunc.push(input.indicesSet('a', perm[i], `i[${i}]`)); + reverseFunc += input.indicesSet('a', perm[i], `i[${i}]`); } - reverseFunc.push('return a;}'); - return reverseFunc.join('\n'); + return (reverseFunc += 'return a;}'); }; const squeezeShape = (shape: readonly number[], adjustedPerm: number[]): { newShape: number[]; newPerm: number[] } => { diff --git a/js/web/lib/wasm/session-handler-training.ts b/js/web/lib/wasm/session-handler-training.ts deleted file mode 100644 index 8bbfb9cf06668..0000000000000 --- a/js/web/lib/wasm/session-handler-training.ts +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import { InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler } from 'onnxruntime-common'; - -import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages'; -import { decodeTensorMetadata, encodeTensorMetadata } from './session-handler-inference'; -import { copyFromExternalBuffer } from './wasm-core-impl'; -import { - createCheckpointHandle, - createTrainingSessionHandle, - getContiguousParameters, - getModelInputOutputNames, - getParametersSize, - lazyResetGrad, - loadParametersBuffer, - releaseTrainingSessionAndCheckpoint, - runEvalStep, - runOptimizerStep, - runTrainStep, -} from './wasm-training-core-impl'; - -export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { - private sessionId: number; - private checkpointId: number; - - inputNames: string[]; - outputNames: string[]; - - evalInputNames: string[] = []; - evalOutputNames: string[] = []; - - async uriOrBufferToHeap(uriOrBuffer: string | Uint8Array): Promise { - let buffer: Uint8Array; - if (typeof uriOrBuffer === 'string') { - const response = await fetch(uriOrBuffer); - const arrayBuffer = await response.arrayBuffer(); - buffer = new Uint8Array(arrayBuffer); - } else { - buffer = uriOrBuffer; - } - return copyFromExternalBuffer(buffer); - } - - async createTrainingSession( - checkpointStateUriOrBuffer: string | Uint8Array, - trainModelUriOrBuffer: string | Uint8Array, - evalModelUriOrBuffer: string | Uint8Array, - optimizerModelUriOrBuffer: string | Uint8Array, - options: InferenceSession.SessionOptions, - ) { - const checkpointData: SerializableInternalBuffer = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer); - const trainModelData: SerializableInternalBuffer = await this.uriOrBufferToHeap(trainModelUriOrBuffer); - // 0 is supposed to be the nullptr - let evalModelData: SerializableInternalBuffer = [0, 0]; - let optimizerModelData: SerializableInternalBuffer = [0, 0]; - - if (evalModelUriOrBuffer !== '') { - evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer); - } - if (optimizerModelUriOrBuffer !== '') { - optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer); - } - - this.checkpointId = createCheckpointHandle(checkpointData); - this.sessionId = createTrainingSessionHandle( - this.checkpointId, - trainModelData, - evalModelData, - optimizerModelData, - options, - ); - [this.inputNames, this.outputNames] = getModelInputOutputNames(this.sessionId, false); - if (evalModelUriOrBuffer !== '') { - [this.evalInputNames, this.evalOutputNames] = getModelInputOutputNames(this.sessionId, true); - } - } - - /** - * Helper method that converts a feeds or fetches datatype to two arrays, one of values and one that stores the - * corresponding name as a number referring to the index in the list of names provided. - * - * @param feeds meant to match either SessionHandler.FeedsType or SessionHandler.FetchesType - * @param names either inputNames or outputNames - * @returns a tuple of a list of values and a list of indices. - */ - convertMapIntoValuesArrayAndIndicesArray( - feeds: { [name: string]: T }, - names: string[], - mapFunc: (val: T, index: number) => U, - ): [T[], number[], U[]] { - const values: T[] = []; - const indices: number[] = []; - Object.entries(feeds).forEach((kvp) => { - const name = kvp[0]; - const tensor = kvp[1]; - const index = names.indexOf(name); - if (index === -1) { - throw new Error(`invalid input '${name}`); - } - values.push(tensor); - indices.push(index); - }); - - const uList = values.map(mapFunc); - return [values, indices, uList]; - } - - /** - * Helper method that converts the TensorMetadata that the wasm-core functions return to the - * SessionHandler.ReturnType. Any outputs in the provided outputArray that are falsy will be populated with the - * corresponding result. - * - * @param results used to populate the resultMap if there is no value for that outputName already - * @param outputArray used to populate the resultMap. If null or undefined, use the corresponding result from results - * @param outputIndices specifies which outputName the corresponding value for outputArray refers to. - * @returns a map of output names and OnnxValues. - */ - convertTensorMetadataToReturnType( - results: TensorMetadata[], - outputArray: Array, - outputIndices: number[], - ): SessionHandler.ReturnType { - const resultMap: SessionHandler.ReturnType = {}; - for (let i = 0; i < results.length; i++) { - resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]); - } - return resultMap; - } - - async lazyResetGrad(): Promise { - await lazyResetGrad(this.sessionId); - } - - async runTrainStep( - feeds: SessionHandler.FeedsType, - fetches: SessionHandler.FetchesType, - options: InferenceSession.RunOptions, - ): Promise { - const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray( - feeds, - this.inputNames, - (t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`), - ); - - const [outputArray, outputIndices, outputs] = this.convertMapIntoValuesArrayAndIndicesArray< - Tensor | null, - TensorMetadata | null - >(fetches, this.outputNames, (t, i): TensorMetadata | null => - t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null, - ); - - const results = await runTrainStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); - return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices); - } - - async runOptimizerStep(options: InferenceSession.RunOptions): Promise { - await runOptimizerStep(this.sessionId, options); - } - - async runEvalStep( - feeds: SessionHandler.FeedsType, - fetches: SessionHandler.FetchesType, - options: InferenceSession.RunOptions, - ): Promise { - const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray( - feeds, - this.evalInputNames, - (t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.evalInputNames[inputIndices[i]]}"`), - ); - - const [outputArray, outputIndices, outputs] = this.convertMapIntoValuesArrayAndIndicesArray< - Tensor | null, - TensorMetadata | null - >(fetches, this.evalOutputNames, (t, i): TensorMetadata | null => - t ? encodeTensorMetadata(t, () => `output "${this.evalOutputNames[outputIndices[i]]}"`) : null, - ); - - const results = await runEvalStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); - return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices); - } - - async getParametersSize(trainableOnly: boolean): Promise { - return getParametersSize(this.sessionId, trainableOnly); - } - - async loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise { - await loadParametersBuffer(this.sessionId, array, trainableOnly); - } - async getContiguousParameters(trainableOnly: boolean): Promise { - const tensorResult = await getContiguousParameters(this.sessionId, trainableOnly); - return decodeTensorMetadata(tensorResult); - } - - async dispose(): Promise { - return releaseTrainingSessionAndCheckpoint(this.checkpointId, this.sessionId); - } -} diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 6c4e28df62f23..ed001cfa90f59 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -41,8 +41,8 @@ import { loadFile } from './wasm-utils-load-file'; * Refer to web/lib/index.ts for the backend registration. * * 2. WebAssembly artifact initialization. - * This happens when any registered wasm backend is used for the first time (ie. `ort.InferenceSession.create()` or - * `ort.TrainingSession.create()` is called). In this step, onnxruntime-web does the followings: + * This happens when any registered wasm backend is used for the first time (ie. `ort.InferenceSession.create()` is + * called). In this step, onnxruntime-web does the followings: * - create a proxy worker and make sure the proxy worker is ready to receive messages, if proxy is enabled. * - perform feature detection, locate correct WebAssembly artifact path and call the Emscripten generated * JavaScript code to initialize the WebAssembly runtime. @@ -57,9 +57,8 @@ import { loadFile } from './wasm-utils-load-file'; * - logging level (ort.env.logLevel) and thread number (ort.env.wasm.numThreads) are set in this step. * * 4. Session initialization. - * This happens when `ort.InferenceSession.create()` or `ort.TrainingSession.create()` is called. Unlike the first 3 - * steps (they only called once), this step will be done for each session. In this step, onnxruntime-web does the - * followings: + * This happens when `ort.InferenceSession.create()` is called. Unlike the first 3 steps (they only called once), + * this step will be done for each session. In this step, onnxruntime-web does the followings: * If the parameter is a URL: * - download the model data from the URL. * - copy the model data to the WASM heap. (proxy: 'copy-from') diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts deleted file mode 100644 index 22cd6ec30732c..0000000000000 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ /dev/null @@ -1,631 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import { InferenceSession, Tensor } from 'onnxruntime-common'; - -import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages'; -import { setRunOptions } from './run-options'; -import { setSessionOptions } from './session-options'; -import { - dataLocationStringToEnum, - tensorDataTypeEnumToString, - tensorDataTypeStringToEnum, - tensorTypeToTypedArrayConstructor, -} from './wasm-common'; -import { prepareInputOutputTensor } from './wasm-core-impl'; -import { getInstance } from './wasm-factory'; -import { checkLastError } from './wasm-utils'; - -const NO_TRAIN_FUNCS_MSG = - "Built without training API's enabled. Use the onnxruntime-web/training import for training " + - 'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' + - 'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.'; - -/** - * Runs the checkLastError function which will throw an error, if the provided error code matches the specified - * pattern for an error code. - * @param errCode number to evaluated for if it's an error - * @param message message to pass into checkLastError - * @param checkNeqZero when true, treats not equal to zero as an error. - * When false, treats equal to zero as an error. - */ -const ifErrCodeCheckLastError = (errCode: number, message: string, checkNeqZero = true) => { - if (checkNeqZero && errCode !== 0) { - checkLastError(message); - } else if (!checkNeqZero && errCode === 0) { - checkLastError(message); - } -}; - -export const createCheckpointHandle = (checkpointData: SerializableInternalBuffer): number => { - const wasm = getInstance(); - - const [checkpointDataOffset, checkpointDataLength] = checkpointData; - let checkpointHandle = 0; - - try { - if (wasm._OrtTrainingLoadCheckpoint) { - checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointDataOffset, checkpointDataLength); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - - ifErrCodeCheckLastError(checkpointHandle, 'Error occurred when trying to create a CheckpointState', false); - return checkpointHandle; - } catch (e) { - if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) { - wasm._OrtTrainingReleaseCheckpoint(checkpointHandle); - } - throw e; - } finally { - // free buffer from wasm heap - wasm._OrtFree(checkpointData[0]); - } -}; - -const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolean): [number, number] => { - const wasm = getInstance(); - const stack = wasm.stackSave(); - try { - const dataOffset = wasm.stackAlloc(8); - if (wasm._OrtTrainingGetModelInputOutputCount) { - const errorCode = wasm._OrtTrainingGetModelInputOutputCount( - trainingSessionId, - dataOffset, - dataOffset + 4, - isEvalModel, - ); - ifErrCodeCheckLastError(errorCode, "Can't get session input/output count."); - return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } finally { - wasm.stackRestore(stack); - } -}; - -const getModelInputOutputNamesLoop = ( - trainingSessionId: number, - count: number, - isInput: boolean, - isEvalModel: boolean, -): string[] => { - const names = []; - const wasm = getInstance(); - - for (let i = 0; i < count; i++) { - if (wasm._OrtTrainingGetModelInputOutputName) { - const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel); - ifErrCodeCheckLastError(name, `Can't get input or output name -- is input: ${isInput}, index ${i}`, false); - - names.push(wasm.UTF8ToString(name)); - wasm._free(name); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } - return names; -}; - -export const getModelInputOutputNames = (trainingSessionId: number, isEvalModel: boolean): [string[], string[]] => { - let inputNames: string[] = []; - let outputNames: string[] = []; - - const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, isEvalModel); - - inputNames = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, isEvalModel); - outputNames = getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, isEvalModel); - - return [inputNames, outputNames]; -}; - -export const createTrainingSessionHandle = ( - checkpointHandle: number, - trainModelData: SerializableInternalBuffer, - evalModelData: SerializableInternalBuffer, - optimizerModelData: SerializableInternalBuffer, - options: InferenceSession.SessionOptions, -): number => { - const wasm = getInstance(); - - let trainingSessionHandle = 0; - let sessionOptionsHandle = 0; - let allocs: number[] = []; - - try { - [sessionOptionsHandle, allocs] = setSessionOptions(options); - if (wasm._OrtTrainingCreateSession) { - trainingSessionHandle = wasm._OrtTrainingCreateSession( - sessionOptionsHandle, - checkpointHandle, - trainModelData[0], - trainModelData[1], - evalModelData[0], - evalModelData[1], - optimizerModelData[0], - optimizerModelData[1], - ); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - - ifErrCodeCheckLastError(trainingSessionHandle, 'Error occurred when trying to create a TrainingSession', false); - return trainingSessionHandle; - } catch (e) { - if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) { - wasm._OrtTrainingReleaseSession(trainingSessionHandle); - } - throw e; - } finally { - wasm._free(trainModelData[0]); - wasm._free(evalModelData[0]); - wasm._free(optimizerModelData[0]); - - if (sessionOptionsHandle !== 0) { - wasm._OrtReleaseSessionOptions(sessionOptionsHandle); - } - allocs.forEach((alloc) => wasm._free(alloc)); - } -}; - -/** - * Prepares input and output tensors by creating the tensors in the WASM side then creates a list of the handles of the - * WASM tensors. - * - * @param trainingSessionId - * @param indices for each tensor, the index of the input or output name that the tensor corresponds with - * @param tensors list of TensorMetaData - * @param tensorHandles should pass in an empty list of numbers; modified in-place by this method & stores the resulting - * handles of the allocated tensors on the heap - * @param inputOutputAllocs modified in-place by this method - * @param indexAdd constant to add to the index that is passed to prepareInputOutputTensor - */ -const createAndAllocateTensors = ( - trainingSessionId: number, - indices: number[], - tensors: Array, - tensorHandles: number[], - inputOutputAllocs: number[], - indexAdd: number, -) => { - const count = indices.length; - - // creates the tensors - for (let i = 0; i < count; i++) { - prepareInputOutputTensor(tensors[i], tensorHandles, inputOutputAllocs, trainingSessionId, indexAdd + indices[i]); - } - - // moves to heap - const wasm = getInstance(); - const valuesOffset = wasm.stackAlloc(count * 4); - let valuesIndex = valuesOffset / 4; - for (let i = 0; i < count; i++) { - wasm.HEAPU32[valuesIndex++] = tensorHandles[i]; - } - - return valuesOffset; -}; - -/** - * Retrieves the information from the output tensor handles, copies to an array, and frees the WASM information - * associated with the tensor handle. - * - * @param outputValuesOffset - * @param outputCount - * @returns list of TensorMetadata retrieved from the output handles. - */ -const moveOutputToTensorMetadataArr = ( - outputValuesOffset: number, - outputCount: number, - outputTensorHandles: number[], - outputTensors: Array, -) => { - const wasm = getInstance(); - const output: TensorMetadata[] = []; - - for (let i = 0; i < outputCount; i++) { - const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; - if (tensor === outputTensorHandles[i]) { - // output tensor is pre-allocated. no need to copy data. - output.push(outputTensors[i]!); - continue; - } - - const beforeGetTensorDataStack = wasm.stackSave(); - // stack allocate 4 pointer value - const tensorDataOffset = wasm.stackAlloc(4 * 4); - - let type: Tensor.Type | undefined, - dataOffset = 0; - try { - const errorCode = wasm._OrtGetTensorData( - tensor, - tensorDataOffset, - tensorDataOffset + 4, - tensorDataOffset + 8, - tensorDataOffset + 12, - ); - ifErrCodeCheckLastError(errorCode, `Can't access output tensor data on index ${i}.`); - - let tensorDataIndex = tensorDataOffset / 4; - const dataType = wasm.HEAPU32[tensorDataIndex++]; - dataOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsLength = wasm.HEAPU32[tensorDataIndex++]; - const dims = []; - for (let i = 0; i < dimsLength; i++) { - dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); - } - wasm._OrtFree(dimsOffset); - - const size = dims.reduce((a, b) => a * b, 1); - type = tensorDataTypeEnumToString(dataType); - - if (type === 'string') { - const stringData: string[] = []; - let dataIndex = dataOffset / 4; - for (let i = 0; i < size; i++) { - const offset = wasm.HEAPU32[dataIndex++]; - const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; - stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); - } - output.push([type, dims, stringData, 'cpu']); - } else { - const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); - const data = new typedArrayConstructor(size); - new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set( - wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength), - ); - output.push([type, dims, data, 'cpu']); - } - } finally { - wasm.stackRestore(beforeGetTensorDataStack); - if (type === 'string' && dataOffset) { - wasm._free(dataOffset); - } - wasm._OrtReleaseTensor(tensor); - } - } - - return output; -}; - -export const lazyResetGrad = async (trainingSessionId: number): Promise => { - const wasm = getInstance(); - - if (wasm._OrtTrainingLazyResetGrad) { - const errorCode = wasm._OrtTrainingLazyResetGrad(trainingSessionId); - ifErrCodeCheckLastError(errorCode, "Can't call lazyResetGrad."); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } -}; - -export const runTrainStep = async ( - trainingSessionId: number, - inputIndices: number[], - inputTensors: TensorMetadata[], - outputIndices: number[], - outputTensors: Array, - options: InferenceSession.RunOptions, -): Promise => { - const wasm = getInstance(); - - const inputCount = inputIndices.length; - const outputCount = outputIndices.length; - - let runOptionsHandle = 0; - let runOptionsAllocs: number[] = []; - - const inputTensorHandles: number[] = []; - const outputTensorHandles: number[] = []; - const inputOutputAllocs: number[] = []; - - const beforeRunStack = wasm.stackSave(); - - try { - // prepare parameters by moving them to heap - [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); - - // handle inputs -- you don't want anything added to the index - const inputValuesOffset = createAndAllocateTensors( - trainingSessionId, - inputIndices, - inputTensors, - inputTensorHandles, - inputOutputAllocs, - 0, - ); - // handle outputs - // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor - const outputValuesOffset = createAndAllocateTensors( - trainingSessionId, - outputIndices, - outputTensors, - outputTensorHandles, - inputOutputAllocs, - inputCount, - ); - - if (wasm._OrtTrainingRunTrainStep) { - const errorCode = wasm._OrtTrainingRunTrainStep( - trainingSessionId, - inputValuesOffset, - inputCount, - outputValuesOffset, - outputCount, - runOptionsHandle, - ); - ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingRunTrainStep in the WebAssembly layer'); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - - return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors); - } finally { - wasm.stackRestore(beforeRunStack); - - inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); - outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); - inputOutputAllocs.forEach((p) => wasm._free(p)); - - if (runOptionsHandle !== 0) { - wasm._OrtReleaseRunOptions(runOptionsHandle); - } - runOptionsAllocs.forEach((p) => wasm._free(p)); - } -}; - -export const runOptimizerStep = async ( - trainingSessionId: number, - options: InferenceSession.RunOptions, -): Promise => { - const wasm = getInstance(); - - let runOptionsHandle = 0; - let runOptionsAllocs: number[] = []; - - try { - [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); - - if (wasm._OrtTrainingOptimizerStep) { - const errCode = wasm._OrtTrainingOptimizerStep(trainingSessionId, runOptionsHandle); - ifErrCodeCheckLastError(errCode, 'Failed to call OrtTrainingOptimizerStep in the WebAssembly layer'); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } finally { - if (runOptionsHandle !== 0) { - wasm._OrtReleaseRunOptions(runOptionsHandle); - } - runOptionsAllocs.forEach((p) => wasm._free(p)); - } -}; - -export const runEvalStep = async ( - trainingSessionId: number, - inputIndices: number[], - inputTensors: TensorMetadata[], - outputIndices: number[], - outputTensors: Array, - options: InferenceSession.RunOptions, -): Promise => { - const wasm = getInstance(); - - const inputCount = inputIndices.length; - const outputCount = outputIndices.length; - - let runOptionsHandle = 0; - let runOptionsAllocs: number[] = []; - - const inputTensorHandles: number[] = []; - const outputTensorHandles: number[] = []; - const inputOutputAllocs: number[] = []; - - const beforeRunStack = wasm.stackSave(); - - try { - // prepare parameters by moving them to heap - [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); - - // handle inputs -- you don't want anything added to the index - const inputValuesOffset = createAndAllocateTensors( - trainingSessionId, - inputIndices, - inputTensors, - inputTensorHandles, - inputOutputAllocs, - 0, - ); - // handle outputs - // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor - const outputValuesOffset = createAndAllocateTensors( - trainingSessionId, - outputIndices, - outputTensors, - outputTensorHandles, - inputOutputAllocs, - inputCount, - ); - - if (wasm._OrtTrainingEvalStep) { - const errorCode = wasm._OrtTrainingEvalStep( - trainingSessionId, - inputValuesOffset, - inputCount, - outputValuesOffset, - outputCount, - runOptionsHandle, - ); - - ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingEvalStep in the WebAssembly layer'); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - - return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors); - } finally { - wasm.stackRestore(beforeRunStack); - - inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); - outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); - inputOutputAllocs.forEach((p) => wasm._free(p)); - - if (runOptionsHandle !== 0) { - wasm._OrtReleaseRunOptions(runOptionsHandle); - } - runOptionsAllocs.forEach((p) => wasm._free(p)); - } -}; - -export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): number => { - const wasm = getInstance(); - const stack = wasm.stackSave(); - - try { - const sizeOffset = wasm.stackAlloc(4); - if (wasm._OrtTrainingGetParametersSize) { - const errorCode = wasm._OrtTrainingGetParametersSize(trainingSessionId, sizeOffset, trainableOnly); - ifErrCodeCheckLastError(errorCode, "Can't get parameters size"); - - return wasm.HEAP32[sizeOffset / 4]; - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } finally { - wasm.stackRestore(stack); - } -}; - -export const getContiguousParameters = async ( - trainingSessionId: number, - trainableOnly: boolean, -): Promise => { - const wasm = getInstance(); - const stack = wasm.stackSave(); - - const tensorTypeAsString = 'float32'; - const locationAsString = 'cpu'; - - const parametersSize = getParametersSize(trainingSessionId, trainableOnly); - let tensor = 0; - - // allocates a buffer of the correct size on the WASM heap - const paramsByteLength = 4 * parametersSize; - const paramsOffset = wasm._malloc(paramsByteLength); - - // handles the dimensions-related createTensor parameters - const dims = [parametersSize]; - - const dimsOffset = wasm.stackAlloc(4); - const dimsIndex = dimsOffset / 4; - wasm.HEAP32[dimsIndex] = parametersSize; - - try { - // wraps allocated array in a tensor - tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum(tensorTypeAsString), - paramsOffset, - paramsByteLength, - dimsOffset, - dims.length, - dataLocationStringToEnum(locationAsString), - ); - ifErrCodeCheckLastError( - tensor, - `Can't create tensor for getContiguousParameters. session=${trainingSessionId}.`, - false, - ); - - if (wasm._OrtTrainingCopyParametersToBuffer) { - const errCode = wasm._OrtTrainingCopyParametersToBuffer(trainingSessionId, tensor, parametersSize, trainableOnly); - ifErrCodeCheckLastError(errCode, "Can't get contiguous parameters."); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - - // copies from WASM memory to a JavaScript typed array, which is then put into a TensorMetadata object - const typedArrayConstructor = tensorTypeToTypedArrayConstructor(tensorTypeAsString); - const data = new typedArrayConstructor(parametersSize); - const output: TensorMetadata[] = []; - new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set( - wasm.HEAPU8.subarray(paramsOffset, paramsOffset + paramsByteLength), - ); - output.push([tensorTypeAsString, dims, data, locationAsString]); - if (output.length !== 1) { - throw new Error(`something unexpected happened in the getContiguousParameters function. Expected output length of - one, got ${output.length}`); - } else { - return output[0]; - } - } finally { - if (tensor !== 0) { - wasm._OrtReleaseTensor(tensor); - } - wasm._free(paramsOffset); - wasm._free(dimsOffset); - wasm.stackRestore(stack); - } -}; - -export const loadParametersBuffer = async ( - trainingSessionId: number, - buffer: Uint8Array, - trainableOnly: boolean, -): Promise => { - const wasm = getInstance(); - const stack = wasm.stackSave(); - - const tensorTypeAsString = 'float32'; - const locationAsString = 'cpu'; - - // allocates & copies JavaScript buffer to WASM heap - const bufferByteLength = buffer.length; - const bufferCount = bufferByteLength / 4; - const bufferOffset = wasm._malloc(bufferByteLength); - wasm.HEAPU8.set(buffer, bufferOffset); - - // allocates and handles moving dimensions information to WASM memory - const dimsOffset = wasm.stackAlloc(4); - wasm.HEAP32[dimsOffset / 4] = bufferCount; - const dimsLength = 1; - let tensor = 0; - - try { - tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum(tensorTypeAsString), - bufferOffset, - bufferByteLength, - dimsOffset, - dimsLength, - dataLocationStringToEnum(locationAsString), - ); - ifErrCodeCheckLastError(tensor, `Can't create tensor for input/output. session=${trainingSessionId}`, false); - - if (wasm._OrtTrainingCopyParametersFromBuffer) { - const errCode = wasm._OrtTrainingCopyParametersFromBuffer(trainingSessionId, tensor, bufferCount, trainableOnly); - ifErrCodeCheckLastError(errCode, "Can't copy buffer to parameters."); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } finally { - if (tensor !== 0) { - wasm._OrtReleaseTensor(tensor); - } - wasm.stackRestore(stack); - wasm._free(bufferOffset); - wasm._free(dimsOffset); - } -}; - -export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number): void => { - const wasm = getInstance(); - - if (wasm._OrtTrainingReleaseSession) { - wasm._OrtTrainingReleaseSession(sessionId); - } - if (wasm._OrtTrainingReleaseCheckpoint) { - wasm._OrtTrainingReleaseCheckpoint(checkpointId); - } -}; diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 70b6cceab0eef..828cd3cfd94fa 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -213,84 +213,10 @@ export interface OrtInferenceAPIs { _OrtEndProfiling(sessionHandle: number): number; } -export interface OrtTrainingAPIs { - _OrtTrainingLoadCheckpoint(dataOffset: number, dataLength: number): number; - - _OrtTrainingReleaseCheckpoint(checkpointHandle: number): void; - - _OrtTrainingCreateSession( - sessionOptionsHandle: number, - checkpointHandle: number, - trainOffset: number, - trainLength: number, - evalOffset: number, - evalLength: number, - optimizerOffset: number, - optimizerLength: number, - ): number; - - _OrtTrainingLazyResetGrad(trainingHandle: number): number; - - _OrtTrainingRunTrainStep( - trainingHandle: number, - inputsOffset: number, - inputCount: number, - outputsOffset: number, - outputCount: number, - runOptionsHandle: number, - ): number; - - _OrtTrainingOptimizerStep(trainingHandle: number, runOptionsHandle: number): number; - - _OrtTrainingEvalStep( - trainingHandle: number, - inputsOffset: number, - inputCount: number, - outputsOffset: number, - outputCount: number, - runOptionsHandle: number, - ): number; - - _OrtTrainingGetParametersSize(trainingHandle: number, paramSizeT: number, trainableOnly: boolean): number; - - _OrtTrainingCopyParametersToBuffer( - trainingHandle: number, - parametersBuffer: number, - parameterCount: number, - trainableOnly: boolean, - ): number; - - _OrtTrainingCopyParametersFromBuffer( - trainingHandle: number, - parametersBuffer: number, - parameterCount: number, - trainableOnly: boolean, - ): number; - - _OrtTrainingGetModelInputOutputCount( - trainingHandle: number, - inputCount: number, - outputCount: number, - isEvalModel: boolean, - ): number; - _OrtTrainingGetModelInputOutputName( - trainingHandle: number, - index: number, - isInput: boolean, - isEvalModel: boolean, - ): number; - - _OrtTrainingReleaseSession(trainingHandle: number): void; -} - /** * The interface of the WebAssembly module for ONNX Runtime, compiled from C++ source code by Emscripten. */ -export interface OrtWasmModule - extends EmscriptenModule, - OrtInferenceAPIs, - Partial, - Partial { +export interface OrtWasmModule extends EmscriptenModule, OrtInferenceAPIs, Partial { // #region emscripten functions stackSave(): number; stackRestore(stack: number): void; diff --git a/js/web/lib/wasm/wasm-utils-import.ts b/js/web/lib/wasm/wasm-utils-import.ts index 008b9b41b1592..bd9e0ce083ef0 100644 --- a/js/web/lib/wasm/wasm-utils-import.ts +++ b/js/web/lib/wasm/wasm-utils-import.ts @@ -135,11 +135,9 @@ const embeddedWasmModule: EmscriptenModuleFactory | undefined = BUILD_DEFS.IS_ESM && BUILD_DEFS.DISABLE_DYNAMIC_IMPORT ? // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires require( - !BUILD_DEFS.DISABLE_TRAINING - ? '../../dist/ort-training-wasm-simd-threaded.mjs' - : !BUILD_DEFS.DISABLE_JSEP - ? '../../dist/ort-wasm-simd-threaded.jsep.mjs' - : '../../dist/ort-wasm-simd-threaded.mjs', + !BUILD_DEFS.DISABLE_JSEP + ? '../../dist/ort-wasm-simd-threaded.jsep.mjs' + : '../../dist/ort-wasm-simd-threaded.mjs', ).default : undefined; @@ -163,11 +161,9 @@ export const importWasmModule = async ( if (BUILD_DEFS.DISABLE_DYNAMIC_IMPORT) { return [undefined, embeddedWasmModule!]; } else { - const wasmModuleFilename = !BUILD_DEFS.DISABLE_TRAINING - ? 'ort-training-wasm-simd-threaded.mjs' - : !BUILD_DEFS.DISABLE_JSEP - ? 'ort-wasm-simd-threaded.jsep.mjs' - : 'ort-wasm-simd-threaded.mjs'; + const wasmModuleFilename = !BUILD_DEFS.DISABLE_JSEP + ? 'ort-wasm-simd-threaded.jsep.mjs' + : 'ort-wasm-simd-threaded.mjs'; const wasmModuleUrl = urlOverride ?? normalizeUrl(wasmModuleFilename, prefixOverride); // need to preload if all of the following conditions are met: // 1. not in Node.js. diff --git a/js/web/package-lock.json b/js/web/package-lock.json index d37cf6bd90887..9db48f74a94a4 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -456,9 +456,9 @@ } }, "node_modules/body-parser": { - "version": "1.20.2", - "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.2.tgz", - "integrity": "sha512-ml9pReCu3M61kGlqoTm2umSXTlRTuGTx0bfYj+uIUKKYycG5NtSbeetV3faSU6R7ajOPw0g/J1PvK4qNy7s5bA==", + "version": "1.20.3", + "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.3.tgz", + "integrity": "sha512-7rAxByjUMqQ3/bHJy7D6OGXvx/MMc4IqBn/X0fcM1QUcAItpZrBEYhWGem+tzXH90c+G01ypMcYJBO9Y30203g==", "dev": true, "dependencies": { "bytes": "3.1.2", @@ -469,7 +469,7 @@ "http-errors": "2.0.0", "iconv-lite": "0.4.24", "on-finished": "2.4.1", - "qs": "6.11.0", + "qs": "6.13.0", "raw-body": "2.5.2", "type-is": "~1.6.18", "unpipe": "1.0.0" @@ -603,13 +603,19 @@ } }, "node_modules/call-bind": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.2.tgz", - "integrity": "sha512-7O+FbCihrB5WGbFYesctwmTKae6rOiIzmz1icreWJ+0aA7LJfuqhEso2T9ncpcFtzMQtzXf2QGGueWJGTYsqrA==", + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.7.tgz", + "integrity": "sha512-GHTSNSYICQ7scH7sZ+M2rFopRoLh8t2bLSW6BbgrtLsahOIB5iyAVJf9GjWK3cYTDaMj4XdBpM1cA6pIS0Kv2w==", "dev": true, "dependencies": { - "function-bind": "^1.1.1", - "get-intrinsic": "^1.0.2" + "es-define-property": "^1.0.0", + "es-errors": "^1.3.0", + "function-bind": "^1.1.2", + "get-intrinsic": "^1.2.4", + "set-function-length": "^1.2.1" + }, + "engines": { + "node": ">= 0.4" }, "funding": { "url": "https://github.com/sponsors/ljharb" @@ -959,6 +965,23 @@ "node": ">=10" } }, + "node_modules/define-data-property": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/define-data-property/-/define-data-property-1.1.4.tgz", + "integrity": "sha512-rBMvIzlpA8v6E+SJZoo++HAYqsLrkg7MSfIinMPFhmkorw7X+dOXVJQs+QT69zGkzMyfDnIMN2Wid1+NbL3T+A==", + "dev": true, + "dependencies": { + "es-define-property": "^1.0.0", + "es-errors": "^1.3.0", + "gopd": "^1.0.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, "node_modules/define-properties": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/define-properties/-/define-properties-1.2.0.tgz", @@ -1137,6 +1160,27 @@ "node": ">=6" } }, + "node_modules/es-define-property": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.0.tgz", + "integrity": "sha512-jxayLKShrEqqzJ0eumQbVhTYQM27CfT1T35+gCgDFoL82JLsXqTJ76zv6A0YLOgEnLUMvLzsDsGIrl8NFpT2gQ==", + "dev": true, + "dependencies": { + "get-intrinsic": "^1.2.4" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-errors": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz", + "integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==", + "dev": true, + "engines": { + "node": ">= 0.4" + } + }, "node_modules/es6-error": { "version": "4.1.1", "resolved": "https://registry.npmjs.org/es6-error/-/es6-error-4.1.1.tgz", @@ -1447,14 +1491,19 @@ } }, "node_modules/get-intrinsic": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.0.tgz", - "integrity": "sha512-L049y6nFOuom5wGyRc3/gdTLO94dySVKRACj1RmJZBQXlbTMhtNIgkWkUHq+jYmZvKf14EW1EoJnnjbmoHij0Q==", + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.4.tgz", + "integrity": "sha512-5uYhsJH8VJBTv7oslg4BznJYhDoRI6waYCxMmCdnTrcCrHA/fCFKoTFz2JKKE0HdDFUF7/oQuhzumXJK7paBRQ==", "dev": true, "dependencies": { - "function-bind": "^1.1.1", - "has": "^1.0.3", - "has-symbols": "^1.0.3" + "es-errors": "^1.3.0", + "function-bind": "^1.1.2", + "has-proto": "^1.0.1", + "has-symbols": "^1.0.3", + "hasown": "^2.0.0" + }, + "engines": { + "node": ">= 0.4" }, "funding": { "url": "https://github.com/sponsors/ljharb" @@ -1598,6 +1647,18 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/gopd": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.0.1.tgz", + "integrity": "sha512-d65bNlIadxvpb/A2abVdlqKqV563juRnZ1Wtk6s1sIR8uNsXR70xqIzVqxVf1eTqDunwT2MkczEeaezCKTZhwA==", + "dev": true, + "dependencies": { + "get-intrinsic": "^1.1.3" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, "node_modules/got": { "version": "11.8.6", "resolved": "https://registry.npmjs.org/got/-/got-11.8.6.tgz", @@ -1640,18 +1701,6 @@ "resolved": "https://registry.npmjs.org/guid-typescript/-/guid-typescript-1.0.9.tgz", "integrity": "sha512-Y8T4vYhEfwJOTbouREvG+3XDsjr8E3kIr7uf+JZ0BYloFsttiHU0WfvANVsR7TxNUJa/WpCnw/Ino/p+DeBhBQ==" }, - "node_modules/has": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/has/-/has-1.0.3.tgz", - "integrity": "sha512-f2dvO0VU6Oej7RkWJGrehjbzMAjFp5/VKPp5tTpWIV4JHHZK1/BxbFRtf/siA2SWTe09caDmVtYYzWEIbBS4zw==", - "dev": true, - "dependencies": { - "function-bind": "^1.1.1" - }, - "engines": { - "node": ">= 0.4.0" - } - }, "node_modules/has-flag": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-3.0.0.tgz", @@ -1662,13 +1711,24 @@ } }, "node_modules/has-property-descriptors": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/has-property-descriptors/-/has-property-descriptors-1.0.0.tgz", - "integrity": "sha512-62DVLZGoiEBDHQyqG4w9xCuZ7eJEwNmJRWw2VY84Oedb7WFcA27fiEVe8oUQx9hAUJ4ekurquucTGwsyO1XGdQ==", + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/has-property-descriptors/-/has-property-descriptors-1.0.2.tgz", + "integrity": "sha512-55JNKuIW+vq4Ke1BjOTjM2YctQIvCT7GFzHwmfZPGo5wnrgkid0YQtnAleFSqumZm4az3n2BS+erby5ipJdgrg==", "dev": true, - "optional": true, "dependencies": { - "get-intrinsic": "^1.1.1" + "es-define-property": "^1.0.0" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/has-proto": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/has-proto/-/has-proto-1.0.3.tgz", + "integrity": "sha512-SJ1amZAJUiZS+PhsVLf5tGydlaVB8EdFpaSO4gmiUKUOxk8qzn5AIy4ZeJUmh22znIdk/uMAUT2pl3FxzVUH+Q==", + "dev": true, + "engines": { + "node": ">= 0.4" }, "funding": { "url": "https://github.com/sponsors/ljharb" @@ -1686,6 +1746,18 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/hasown": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz", + "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==", + "dev": true, + "dependencies": { + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, "node_modules/http-cache-semantics": { "version": "4.1.1", "resolved": "https://registry.npmjs.org/http-cache-semantics/-/http-cache-semantics-4.1.1.tgz", @@ -2502,10 +2574,13 @@ } }, "node_modules/object-inspect": { - "version": "1.12.3", - "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.12.3.tgz", - "integrity": "sha512-geUvdk7c+eizMNUDkRpW1wJwgfOiOeHbxBR/hLXK1aT6zmVSO0jsQcs7fj6MGw89jC/cjGfLcNOrtMYtGqm81g==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.2.tgz", + "integrity": "sha512-IRZSRuzJiynemAXPYtPe5BoI/RESNYR7TYm50MC5Mqbd3Jmw5y790sErYw3V6SryFJD64b74qQQs9wn5Bg/k3g==", "dev": true, + "engines": { + "node": ">= 0.4" + }, "funding": { "url": "https://github.com/sponsors/ljharb" } @@ -2717,12 +2792,12 @@ } }, "node_modules/qs": { - "version": "6.11.0", - "resolved": "https://registry.npmjs.org/qs/-/qs-6.11.0.tgz", - "integrity": "sha512-MvjoMCJwEarSbUYk5O+nmoSzSutSsTwF85zcHPQ9OrlFoZOYIjaqBAJIqIXjptyD5vThxGq52Xu/MaJzRkIk4Q==", + "version": "6.13.0", + "resolved": "https://registry.npmjs.org/qs/-/qs-6.13.0.tgz", + "integrity": "sha512-+38qI9SOr8tfZ4QmJNplMUxqjbe7LKvvZgWdExBOmd+egZTtjLB67Gu0HRX3u/XOq7UU2Nx6nsjvS16Z9uwfpg==", "dev": true, "dependencies": { - "side-channel": "^1.0.4" + "side-channel": "^1.0.6" }, "engines": { "node": ">=0.6" @@ -2967,6 +3042,23 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/set-function-length": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/set-function-length/-/set-function-length-1.2.2.tgz", + "integrity": "sha512-pgRc4hJ4/sNjWCSS9AmnS40x3bNMDTknHgL5UaMBTMyJnU90EgWh1Rz+MC9eFu4BuN/UwZjKQuY/1v3rM7HMfg==", + "dev": true, + "dependencies": { + "define-data-property": "^1.1.4", + "es-errors": "^1.3.0", + "function-bind": "^1.1.2", + "get-intrinsic": "^1.2.4", + "gopd": "^1.0.1", + "has-property-descriptors": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + } + }, "node_modules/setprototypeof": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.2.0.tgz", @@ -2995,14 +3087,18 @@ } }, "node_modules/side-channel": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.0.4.tgz", - "integrity": "sha512-q5XPytqFEIKHkGdiMIrY10mvLRvnQh42/+GoBlFW3b2LXLE2xxJpZFdm94we0BaoV3RwJyGqg5wS7epxTv0Zvw==", + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.0.6.tgz", + "integrity": "sha512-fDW/EZ6Q9RiO8eFG8Hj+7u/oW+XrPTIChwCOM2+th2A6OblDtYYIpve9m+KvI9Z4C9qSEXlaGR6bTEYHReuglA==", "dev": true, "dependencies": { - "call-bind": "^1.0.0", - "get-intrinsic": "^1.0.2", - "object-inspect": "^1.9.0" + "call-bind": "^1.0.7", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.4", + "object-inspect": "^1.13.1" + }, + "engines": { + "node": ">= 0.4" }, "funding": { "url": "https://github.com/sponsors/ljharb" @@ -3876,9 +3972,9 @@ "dev": true }, "body-parser": { - "version": "1.20.2", - "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.2.tgz", - "integrity": "sha512-ml9pReCu3M61kGlqoTm2umSXTlRTuGTx0bfYj+uIUKKYycG5NtSbeetV3faSU6R7ajOPw0g/J1PvK4qNy7s5bA==", + "version": "1.20.3", + "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.3.tgz", + "integrity": "sha512-7rAxByjUMqQ3/bHJy7D6OGXvx/MMc4IqBn/X0fcM1QUcAItpZrBEYhWGem+tzXH90c+G01ypMcYJBO9Y30203g==", "dev": true, "requires": { "bytes": "3.1.2", @@ -3889,7 +3985,7 @@ "http-errors": "2.0.0", "iconv-lite": "0.4.24", "on-finished": "2.4.1", - "qs": "6.11.0", + "qs": "6.13.0", "raw-body": "2.5.2", "type-is": "~1.6.18", "unpipe": "1.0.0" @@ -4005,13 +4101,16 @@ } }, "call-bind": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.2.tgz", - "integrity": "sha512-7O+FbCihrB5WGbFYesctwmTKae6rOiIzmz1icreWJ+0aA7LJfuqhEso2T9ncpcFtzMQtzXf2QGGueWJGTYsqrA==", + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.7.tgz", + "integrity": "sha512-GHTSNSYICQ7scH7sZ+M2rFopRoLh8t2bLSW6BbgrtLsahOIB5iyAVJf9GjWK3cYTDaMj4XdBpM1cA6pIS0Kv2w==", "dev": true, "requires": { - "function-bind": "^1.1.1", - "get-intrinsic": "^1.0.2" + "es-define-property": "^1.0.0", + "es-errors": "^1.3.0", + "function-bind": "^1.1.2", + "get-intrinsic": "^1.2.4", + "set-function-length": "^1.2.1" } }, "chai": { @@ -4286,6 +4385,17 @@ "integrity": "sha512-4tvttepXG1VaYGrRibk5EwJd1t4udunSOVMdLSAL6mId1ix438oPwPZMALY41FCijukO1L0twNcGsdzS7dHgDg==", "dev": true }, + "define-data-property": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/define-data-property/-/define-data-property-1.1.4.tgz", + "integrity": "sha512-rBMvIzlpA8v6E+SJZoo++HAYqsLrkg7MSfIinMPFhmkorw7X+dOXVJQs+QT69zGkzMyfDnIMN2Wid1+NbL3T+A==", + "dev": true, + "requires": { + "es-define-property": "^1.0.0", + "es-errors": "^1.3.0", + "gopd": "^1.0.1" + } + }, "define-properties": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/define-properties/-/define-properties-1.2.0.tgz", @@ -4429,6 +4539,21 @@ "integrity": "sha512-+h1lkLKhZMTYjog1VEpJNG7NZJWcuc2DDk/qsqSTRRCOXiLjeQ1d1/udrUGhqMxUgAlwKNZ0cf2uqan5GLuS2A==", "dev": true }, + "es-define-property": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.0.tgz", + "integrity": "sha512-jxayLKShrEqqzJ0eumQbVhTYQM27CfT1T35+gCgDFoL82JLsXqTJ76zv6A0YLOgEnLUMvLzsDsGIrl8NFpT2gQ==", + "dev": true, + "requires": { + "get-intrinsic": "^1.2.4" + } + }, + "es-errors": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz", + "integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==", + "dev": true + }, "es6-error": { "version": "4.1.1", "resolved": "https://registry.npmjs.org/es6-error/-/es6-error-4.1.1.tgz", @@ -4678,14 +4803,16 @@ "dev": true }, "get-intrinsic": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.0.tgz", - "integrity": "sha512-L049y6nFOuom5wGyRc3/gdTLO94dySVKRACj1RmJZBQXlbTMhtNIgkWkUHq+jYmZvKf14EW1EoJnnjbmoHij0Q==", + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.4.tgz", + "integrity": "sha512-5uYhsJH8VJBTv7oslg4BznJYhDoRI6waYCxMmCdnTrcCrHA/fCFKoTFz2JKKE0HdDFUF7/oQuhzumXJK7paBRQ==", "dev": true, "requires": { - "function-bind": "^1.1.1", - "has": "^1.0.3", - "has-symbols": "^1.0.3" + "es-errors": "^1.3.0", + "function-bind": "^1.1.2", + "has-proto": "^1.0.1", + "has-symbols": "^1.0.3", + "hasown": "^2.0.0" } }, "get-stream": { @@ -4791,6 +4918,15 @@ "slash": "^4.0.0" } }, + "gopd": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.0.1.tgz", + "integrity": "sha512-d65bNlIadxvpb/A2abVdlqKqV563juRnZ1Wtk6s1sIR8uNsXR70xqIzVqxVf1eTqDunwT2MkczEeaezCKTZhwA==", + "dev": true, + "requires": { + "get-intrinsic": "^1.1.3" + } + }, "got": { "version": "11.8.6", "resolved": "https://registry.npmjs.org/got/-/got-11.8.6.tgz", @@ -4827,15 +4963,6 @@ "resolved": "https://registry.npmjs.org/guid-typescript/-/guid-typescript-1.0.9.tgz", "integrity": "sha512-Y8T4vYhEfwJOTbouREvG+3XDsjr8E3kIr7uf+JZ0BYloFsttiHU0WfvANVsR7TxNUJa/WpCnw/Ino/p+DeBhBQ==" }, - "has": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/has/-/has-1.0.3.tgz", - "integrity": "sha512-f2dvO0VU6Oej7RkWJGrehjbzMAjFp5/VKPp5tTpWIV4JHHZK1/BxbFRtf/siA2SWTe09caDmVtYYzWEIbBS4zw==", - "dev": true, - "requires": { - "function-bind": "^1.1.1" - } - }, "has-flag": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-3.0.0.tgz", @@ -4843,21 +4970,35 @@ "dev": true }, "has-property-descriptors": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/has-property-descriptors/-/has-property-descriptors-1.0.0.tgz", - "integrity": "sha512-62DVLZGoiEBDHQyqG4w9xCuZ7eJEwNmJRWw2VY84Oedb7WFcA27fiEVe8oUQx9hAUJ4ekurquucTGwsyO1XGdQ==", + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/has-property-descriptors/-/has-property-descriptors-1.0.2.tgz", + "integrity": "sha512-55JNKuIW+vq4Ke1BjOTjM2YctQIvCT7GFzHwmfZPGo5wnrgkid0YQtnAleFSqumZm4az3n2BS+erby5ipJdgrg==", "dev": true, - "optional": true, "requires": { - "get-intrinsic": "^1.1.1" + "es-define-property": "^1.0.0" } }, + "has-proto": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/has-proto/-/has-proto-1.0.3.tgz", + "integrity": "sha512-SJ1amZAJUiZS+PhsVLf5tGydlaVB8EdFpaSO4gmiUKUOxk8qzn5AIy4ZeJUmh22znIdk/uMAUT2pl3FxzVUH+Q==", + "dev": true + }, "has-symbols": { "version": "1.0.3", "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.0.3.tgz", "integrity": "sha512-l3LCuF6MgDNwTDKkdYGEihYjt5pRPbEg46rtlmnSPlUbgmB8LOIrKJbYYFBSbnPaJexMKtiPO8hmeRjRz2Td+A==", "dev": true }, + "hasown": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz", + "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==", + "dev": true, + "requires": { + "function-bind": "^1.1.2" + } + }, "http-cache-semantics": { "version": "4.1.1", "resolved": "https://registry.npmjs.org/http-cache-semantics/-/http-cache-semantics-4.1.1.tgz", @@ -5505,9 +5646,9 @@ "dev": true }, "object-inspect": { - "version": "1.12.3", - "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.12.3.tgz", - "integrity": "sha512-geUvdk7c+eizMNUDkRpW1wJwgfOiOeHbxBR/hLXK1aT6zmVSO0jsQcs7fj6MGw89jC/cjGfLcNOrtMYtGqm81g==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.2.tgz", + "integrity": "sha512-IRZSRuzJiynemAXPYtPe5BoI/RESNYR7TYm50MC5Mqbd3Jmw5y790sErYw3V6SryFJD64b74qQQs9wn5Bg/k3g==", "dev": true }, "object-keys": { @@ -5666,12 +5807,12 @@ "dev": true }, "qs": { - "version": "6.11.0", - "resolved": "https://registry.npmjs.org/qs/-/qs-6.11.0.tgz", - "integrity": "sha512-MvjoMCJwEarSbUYk5O+nmoSzSutSsTwF85zcHPQ9OrlFoZOYIjaqBAJIqIXjptyD5vThxGq52Xu/MaJzRkIk4Q==", + "version": "6.13.0", + "resolved": "https://registry.npmjs.org/qs/-/qs-6.13.0.tgz", + "integrity": "sha512-+38qI9SOr8tfZ4QmJNplMUxqjbe7LKvvZgWdExBOmd+egZTtjLB67Gu0HRX3u/XOq7UU2Nx6nsjvS16Z9uwfpg==", "dev": true, "requires": { - "side-channel": "^1.0.4" + "side-channel": "^1.0.6" } }, "queue-microtask": { @@ -5832,6 +5973,20 @@ } } }, + "set-function-length": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/set-function-length/-/set-function-length-1.2.2.tgz", + "integrity": "sha512-pgRc4hJ4/sNjWCSS9AmnS40x3bNMDTknHgL5UaMBTMyJnU90EgWh1Rz+MC9eFu4BuN/UwZjKQuY/1v3rM7HMfg==", + "dev": true, + "requires": { + "define-data-property": "^1.1.4", + "es-errors": "^1.3.0", + "function-bind": "^1.1.2", + "get-intrinsic": "^1.2.4", + "gopd": "^1.0.1", + "has-property-descriptors": "^1.0.2" + } + }, "setprototypeof": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.2.0.tgz", @@ -5854,14 +6009,15 @@ "dev": true }, "side-channel": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.0.4.tgz", - "integrity": "sha512-q5XPytqFEIKHkGdiMIrY10mvLRvnQh42/+GoBlFW3b2LXLE2xxJpZFdm94we0BaoV3RwJyGqg5wS7epxTv0Zvw==", + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.0.6.tgz", + "integrity": "sha512-fDW/EZ6Q9RiO8eFG8Hj+7u/oW+XrPTIChwCOM2+th2A6OblDtYYIpve9m+KvI9Z4C9qSEXlaGR6bTEYHReuglA==", "dev": true, "requires": { - "call-bind": "^1.0.0", - "get-intrinsic": "^1.0.2", - "object-inspect": "^1.9.0" + "call-bind": "^1.0.7", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.4", + "object-inspect": "^1.13.1" } }, "signal-exit": { diff --git a/js/web/package.json b/js/web/package.json index 94dd047915b05..d770499adada4 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -23,7 +23,6 @@ "build:doc": "node ./script/generate-webgl-operator-md && node ./script/generate-webgpu-operator-md", "pull:wasm": "node ./script/pull-prebuilt-wasm-artifacts", "test:e2e": "node ./test/e2e/run", - "test:training:e2e": "node ./test/training/e2e/run", "prebuild": "tsc -p . --noEmit && tsc -p lib/wasm/proxy-worker --noEmit", "build": "node ./script/build", "test": "tsc --build ../scripts && node ../scripts/prepare-onnx-node-tests && node ./script/test-runner-cli", @@ -101,12 +100,6 @@ "import": "./dist/ort.webgpu.bundle.min.mjs", "require": "./dist/ort.webgpu.min.js", "types": "./types.d.ts" - }, - "./training": { - "node": null, - "import": "./dist/ort.training.wasm.min.mjs", - "require": "./dist/ort.training.wasm.min.js", - "types": "./types.d.ts" } }, "types": "./types.d.ts", diff --git a/js/web/script/build.ts b/js/web/script/build.ts index 6d1b3bdb65068..408f9e00a5cbd 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -56,7 +56,6 @@ const DEFAULT_DEFINE = { 'BUILD_DEFS.DISABLE_JSEP': 'false', 'BUILD_DEFS.DISABLE_WASM': 'false', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'false', - 'BUILD_DEFS.DISABLE_TRAINING': 'true', 'BUILD_DEFS.DISABLE_DYNAMIC_IMPORT': 'false', 'BUILD_DEFS.IS_ESM': 'false', @@ -253,7 +252,7 @@ async function buildBundle(options: esbuild.BuildOptions) { * * The distribution code is split into multiple files: * - [output-name][.min].[m]js - * - ort[-training]-wasm-simd-threaded[.jsep].mjs + * - ort-wasm-simd-threaded[.jsep].mjs */ async function buildOrt({ isProduction = false, @@ -630,16 +629,6 @@ async function main() { 'BUILD_DEFS.DISABLE_WASM_PROXY': 'true', }, }); - // ort.training.wasm[.min].[m]js - await addAllWebBuildTasks({ - outputName: 'ort.training.wasm', - define: { - ...DEFAULT_DEFINE, - 'BUILD_DEFS.DISABLE_TRAINING': 'false', - 'BUILD_DEFS.DISABLE_JSEP': 'true', - 'BUILD_DEFS.DISABLE_WEBGL': 'true', - }, - }); } if (BUNDLE_MODE === 'dev' || BUNDLE_MODE === 'perf') { diff --git a/js/web/script/pull-prebuilt-wasm-artifacts.ts b/js/web/script/pull-prebuilt-wasm-artifacts.ts index b1b2fa26b2351..5b8b0d27c88db 100644 --- a/js/web/script/pull-prebuilt-wasm-artifacts.ts +++ b/js/web/script/pull-prebuilt-wasm-artifacts.ts @@ -149,11 +149,9 @@ downloadJson( void jszip.loadAsync(buffer).then((zip) => { extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.wasm', folderName); extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.jsep.wasm', folderName); - extractFile(zip, WASM_FOLDER, 'ort-training-wasm-simd-threaded.wasm', folderName); extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.mjs', folderName); extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.jsep.mjs', folderName); - extractFile(zip, WASM_FOLDER, 'ort-training-wasm-simd-threaded.mjs', folderName); }); }); }, diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 7f0c1cc3e420c..5c1e2e27a6eff 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1812,9 +1812,9 @@ // // "test_gridsample_zeros_padding", // // "test_gridsample", // // "test_gru_batchwise", - // // "test_gru_defaults", - // // "test_gru_seq_length", - // // "test_gru_with_initial_bias", + "test_gru_defaults", + "test_gru_seq_length", + "test_gru_with_initial_bias", // // "test_hammingwindow_expanded", // // "test_hammingwindow_symmetric_expanded", // // "test_hammingwindow_symmetric", diff --git a/js/web/test/training/e2e/browser-test-wasm.js b/js/web/test/training/e2e/browser-test-wasm.js deleted file mode 100644 index 05750ed149303..0000000000000 --- a/js/web/test/training/e2e/browser-test-wasm.js +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -'use strict'; - -describe('Browser E2E testing for training package', function () { - it('Check that training package encompasses inference', async function () { - ort.env.wasm.numThreads = 1; - await testInferenceFunction(ort, { executionProviders: ['wasm'] }); - }); - - it('Check training functionality, all options', async function () { - ort.env.wasm.numThreads = 1; - await testTrainingFunctionAll(ort, { executionProviders: ['wasm'] }); - }); - - it('Check training functionality, minimum options', async function () { - ort.env.wasm.numThreads = 1; - await testTrainingFunctionMin(ort, { executionProviders: ['wasm'] }); - }); -}); diff --git a/js/web/test/training/e2e/common.js b/js/web/test/training/e2e/common.js deleted file mode 100644 index 0574ae85aabd1..0000000000000 --- a/js/web/test/training/e2e/common.js +++ /dev/null @@ -1,248 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -'use strict'; - -const DATA_FOLDER = 'data/'; -const TRAININGDATA_TRAIN_MODEL = DATA_FOLDER + 'training_model.onnx'; -const TRAININGDATA_OPTIMIZER_MODEL = DATA_FOLDER + 'adamw.onnx'; -const TRAININGDATA_EVAL_MODEL = DATA_FOLDER + 'eval_model.onnx'; -const TRAININGDATA_CKPT = DATA_FOLDER + 'checkpoint.ckpt'; - -const trainingSessionAllOptions = { - checkpointState: TRAININGDATA_CKPT, - trainModel: TRAININGDATA_TRAIN_MODEL, - evalModel: TRAININGDATA_EVAL_MODEL, - optimizerModel: TRAININGDATA_OPTIMIZER_MODEL, -}; - -const trainingSessionMinOptions = { - checkpointState: TRAININGDATA_CKPT, - trainModel: TRAININGDATA_TRAIN_MODEL, -}; - -// ASSERT METHODS - -function assert(cond) { - if (!cond) throw new Error(); -} - -function assertStrictEquals(actual, expected) { - if (actual !== expected) { - let strRep = actual; - if (typeof actual === 'object') { - strRep = JSON.stringify(actual); - } - throw new Error(`expected: ${expected}; got: ${strRep}`); - } -} - -function assertTwoListsUnequal(list1, list2) { - if (list1.length !== list2.length) { - return; - } - for (let i = 0; i < list1.length; i++) { - if (list1[i] !== list2[i]) { - return; - } - } - throw new Error(`expected ${list1} and ${list2} to be unequal; got two equal lists`); -} - -// HELPER METHODS FOR TESTS - -function generateGaussianRandom(mean = 0, scale = 1) { - const u = 1 - Math.random(); - const v = Math.random(); - const z = Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v); - return z * scale + mean; -} - -function generateGaussianFloatArray(length) { - const array = new Float32Array(length); - - for (let i = 0; i < length; i++) { - array[i] = generateGaussianRandom(); - } - - return array; -} - -/** - * creates the TrainingSession and verifies that the input and output names of the training model loaded into the - * training session are correct. - * @param {} ort - * @param {*} createOptions - * @param {*} options - * @returns - */ -async function createTrainingSessionAndCheckTrainingModel(ort, createOptions, options) { - const trainingSession = await ort.TrainingSession.create(createOptions, options); - - assertStrictEquals(trainingSession.trainingInputNames[0], 'input-0'); - assertStrictEquals(trainingSession.trainingInputNames[1], 'labels'); - assertStrictEquals(trainingSession.trainingInputNames.length, 2); - assertStrictEquals(trainingSession.trainingOutputNames[0], 'onnx::loss::21273'); - assertStrictEquals(trainingSession.trainingOutputNames.length, 1); - return trainingSession; -} - -/** - * verifies that the eval input and output names associated with the eval model loaded into the given training session - * are correct. - */ -function checkEvalModel(trainingSession) { - assertStrictEquals(trainingSession.evalInputNames[0], 'input-0'); - assertStrictEquals(trainingSession.evalInputNames[1], 'labels'); - assertStrictEquals(trainingSession.evalInputNames.length, 2); - assertStrictEquals(trainingSession.evalOutputNames[0], 'onnx::loss::21273'); - assertStrictEquals(trainingSession.evalOutputNames.length, 1); -} - -/** - * Checks that accessing trainingSession.evalInputNames or trainingSession.evalOutputNames will throw an error if - * accessed - * @param {} trainingSession - */ -function checkNoEvalModel(trainingSession) { - try { - assertStrictEquals(trainingSession.evalInputNames, 'should have thrown an error upon accessing'); - } catch (error) { - assertStrictEquals(error.message, 'This training session has no evalModel loaded.'); - } - try { - assertStrictEquals(trainingSession.evalOutputNames, 'should have thrown an error upon accessing'); - } catch (error) { - assertStrictEquals(error.message, 'This training session has no evalModel loaded.'); - } -} - -/** - * runs the train step with the given inputs and checks that the tensor returned is of type float32 and has a length - * of 1 for the loss. - * @param {} trainingSession - * @param {*} feeds - * @returns - */ -var runTrainStepAndCheck = async function (trainingSession, feeds) { - const results = await trainingSession.runTrainStep(feeds); - assertStrictEquals(Object.keys(results).length, 1); - assertStrictEquals(results['onnx::loss::21273'].data.length, 1); - assertStrictEquals(results['onnx::loss::21273'].type, 'float32'); - return results; -}; - -var loadParametersBufferAndCheck = async function (trainingSession, paramsLength, constant, paramsBefore) { - // make a float32 array that is filled with the constant - const newParams = new Float32Array(paramsLength); - for (let i = 0; i < paramsLength; i++) { - newParams[i] = constant; - } - - const newParamsUint8 = new Uint8Array(newParams.buffer, newParams.byteOffset, newParams.byteLength); - - await trainingSession.loadParametersBuffer(newParamsUint8); - const paramsAfterLoad = await trainingSession.getContiguousParameters(); - - // check that the parameters have changed - assertTwoListsUnequal(paramsAfterLoad.data, paramsBefore.data); - assertStrictEquals(paramsAfterLoad.dims[0], paramsLength); - - // check that the parameters have changed to what they should be - for (let i = 0; i < paramsLength; i++) { - // round to the same number of digits (4 decimal places) - assertStrictEquals(paramsAfterLoad.data[i].toFixed(4), constant.toFixed(4)); - } - - return paramsAfterLoad; -}; - -// TESTS - -var testInferenceFunction = async function (ort, options) { - const session = await ort.InferenceSession.create('data/model.onnx', options || {}); - - const dataA = Float32Array.from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); - const dataB = Float32Array.from([10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]); - - const fetches = await session.run({ - a: new ort.Tensor('float32', dataA, [3, 4]), - b: new ort.Tensor('float32', dataB, [4, 3]), - }); - - const c = fetches.c; - - assert(c instanceof ort.Tensor); - assert(c.dims.length === 2 && c.dims[0] === 3 && c.dims[1] === 3); - assert(c.data[0] === 700); - assert(c.data[1] === 800); - assert(c.data[2] === 900); - assert(c.data[3] === 1580); - assert(c.data[4] === 1840); - assert(c.data[5] === 2100); - assert(c.data[6] === 2460); - assert(c.data[7] === 2880); - assert(c.data[8] === 3300); -}; - -var testTrainingFunctionMin = async function (ort, options) { - const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionMinOptions, options); - checkNoEvalModel(trainingSession); - const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]); - const labels = new ort.Tensor('int32', [2, 1], [2]); - const feeds = { 'input-0': input0, labels: labels }; - - // check getParametersSize - const paramsSize = await trainingSession.getParametersSize(); - assertStrictEquals(paramsSize, 397510); - - // check getContiguousParameters - const originalParams = await trainingSession.getContiguousParameters(); - assertStrictEquals(originalParams.dims.length, 1); - assertStrictEquals(originalParams.dims[0], 397510); - assertStrictEquals(originalParams.data[0], -0.025190064683556557); - assertStrictEquals(originalParams.data[2000], -0.034044936299324036); - - await runTrainStepAndCheck(trainingSession, feeds); - - await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, originalParams); -}; - -var testTrainingFunctionAll = async function (ort, options) { - const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionAllOptions, options); - checkEvalModel(trainingSession); - - const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]); - const labels = new ort.Tensor('int32', [2, 1], [2]); - let feeds = { 'input-0': input0, labels: labels }; - - // check getParametersSize - const paramsSize = await trainingSession.getParametersSize(); - assertStrictEquals(paramsSize, 397510); - - // check getContiguousParameters - const originalParams = await trainingSession.getContiguousParameters(); - assertStrictEquals(originalParams.dims.length, 1); - assertStrictEquals(originalParams.dims[0], 397510); - assertStrictEquals(originalParams.data[0], -0.025190064683556557); - assertStrictEquals(originalParams.data[2000], -0.034044936299324036); - - const results = await runTrainStepAndCheck(trainingSession, feeds); - - await trainingSession.runOptimizerStep(feeds); - feeds = { 'input-0': input0, labels: labels }; - // check getContiguousParameters after optimizerStep -- that the parameters have been updated - const optimizedParams = await trainingSession.getContiguousParameters(); - assertTwoListsUnequal(originalParams.data, optimizedParams.data); - - const results2 = await runTrainStepAndCheck(trainingSession, feeds); - - // check that loss decreased after optimizer step and training again - assert(results2['onnx::loss::21273'].data < results['onnx::loss::21273'].data); - - await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, optimizedParams); -}; - -if (typeof module === 'object') { - module.exports = [testInferenceFunction, testTrainingFunctionMin, testTrainingFunctionAll, testTest]; -} diff --git a/js/web/test/training/e2e/data/model.onnx b/js/web/test/training/e2e/data/model.onnx deleted file mode 100644 index 088124bd48624..0000000000000 --- a/js/web/test/training/e2e/data/model.onnx +++ /dev/null @@ -1,16 +0,0 @@ - backend-test:b - -a -bc"MatMultest_matmul_2dZ -a -  - -Z -b -  - -b -c -  - -B \ No newline at end of file diff --git a/js/web/test/training/e2e/karma.conf.js b/js/web/test/training/e2e/karma.conf.js deleted file mode 100644 index 74662b67676f7..0000000000000 --- a/js/web/test/training/e2e/karma.conf.js +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -'use strict'; - -const args = require('minimist')(process.argv.slice(2)); -const SELF_HOST = !!args['self-host']; -const ORT_MAIN = args['ort-main']; -const TEST_MAIN = args['test-main']; -if (typeof TEST_MAIN !== 'string') { - throw new Error('flag --test-main= is required'); -} -const USER_DATA = args['user-data']; -if (typeof USER_DATA !== 'string') { - throw new Error('flag --user-data= is required'); -} - -module.exports = function (config) { - const distPrefix = SELF_HOST ? './node_modules/onnxruntime-web/dist/' : 'http://localhost:8081/dist/'; - config.set({ - frameworks: ['mocha'], - files: [ - { pattern: distPrefix + ORT_MAIN }, - { pattern: './common.js' }, - { pattern: TEST_MAIN }, - { pattern: './node_modules/onnxruntime-web/dist/*.*', included: false, nocache: true }, - { pattern: './data/*', included: false }, - ], - plugins: [require('@chiragrupani/karma-chromium-edge-launcher'), ...config.plugins], - proxies: { - '/model.onnx': '/base/model.onnx', - '/data/': '/base/data/', - }, - client: { captureConsole: true, mocha: { expose: ['body'], timeout: 60000 } }, - reporters: ['mocha'], - captureTimeout: 120000, - reportSlowerThan: 100, - browserDisconnectTimeout: 600000, - browserNoActivityTimeout: 300000, - browserDisconnectTolerance: 0, - browserSocketTimeout: 60000, - hostname: 'localhost', - browsers: [], - customLaunchers: { - Chrome_default: { base: 'ChromeHeadless', chromeDataDir: USER_DATA }, - Chrome_no_threads: { - base: 'ChromeHeadless', - chromeDataDir: USER_DATA, - // TODO: no-thread flags - }, - Edge_default: { base: 'Edge', edgeDataDir: USER_DATA }, - }, - }); -}; diff --git a/js/web/test/training/e2e/package.json b/js/web/test/training/e2e/package.json deleted file mode 100644 index 5f11a27de6dfc..0000000000000 --- a/js/web/test/training/e2e/package.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "devDependencies": { - "@chiragrupani/karma-chromium-edge-launcher": "^2.2.2", - "fs-extra": "^11.1.0", - "globby": "^13.1.3", - "karma": "^6.4.1", - "karma-chrome-launcher": "^3.1.1", - "karma-mocha": "^2.0.1", - "karma-mocha-reporter": "^2.2.5", - "light-server": "^2.9.1", - "minimist": "^1.2.7", - "mocha": "^10.2.0" - } -} diff --git a/js/web/test/training/e2e/run.js b/js/web/test/training/e2e/run.js deleted file mode 100644 index d12bcc7aa66ed..0000000000000 --- a/js/web/test/training/e2e/run.js +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -'use strict'; - -const path = require('path'); -const fs = require('fs-extra'); -const { spawn } = require('child_process'); -const startServer = require('./simple-http-server'); -const minimist = require('minimist'); - -// copy whole folder to out-side of /js/ because we need to test in a folder that no `package.json` file -// exists in its parent folder. -// here we use /build/js/e2e-training/ for the test - -const TEST_E2E_SRC_FOLDER = __dirname; -const JS_ROOT_FOLDER = path.resolve(__dirname, '../../../..'); -const TEST_E2E_RUN_FOLDER = path.resolve(JS_ROOT_FOLDER, '../build/js/e2e-training'); -const NPM_CACHE_FOLDER = path.resolve(TEST_E2E_RUN_FOLDER, '../npm_cache'); -const CHROME_USER_DATA_FOLDER = path.resolve(TEST_E2E_RUN_FOLDER, '../user_data'); -fs.emptyDirSync(TEST_E2E_RUN_FOLDER); -fs.emptyDirSync(NPM_CACHE_FOLDER); -fs.emptyDirSync(CHROME_USER_DATA_FOLDER); -fs.copySync(TEST_E2E_SRC_FOLDER, TEST_E2E_RUN_FOLDER); - -// training data to copy -const ORT_ROOT_FOLDER = path.resolve(JS_ROOT_FOLDER, '..'); -const TRAINING_DATA_FOLDER = path.resolve(ORT_ROOT_FOLDER, 'onnxruntime/test/testdata/training_api'); -const TRAININGDATA_DEST = path.resolve(TEST_E2E_RUN_FOLDER, 'data'); - -// always use a new folder as user-data-dir -let nextUserDataDirId = 0; -function getNextUserDataDir() { - const dir = path.resolve(CHROME_USER_DATA_FOLDER, nextUserDataDirId.toString()); - nextUserDataDirId++; - fs.emptyDirSync(dir); - return dir; -} - -// commandline arguments -const BROWSER = minimist(process.argv.slice(2)).browser || 'Chrome_default'; - -async function main() { - // find packed package - const { globbySync } = await import('globby'); - - const ORT_COMMON_FOLDER = path.resolve(JS_ROOT_FOLDER, 'common'); - const ORT_COMMON_PACKED_FILEPATH_CANDIDATES = globbySync('onnxruntime-common-*.tgz', { cwd: ORT_COMMON_FOLDER }); - - const PACKAGES_TO_INSTALL = []; - - if (ORT_COMMON_PACKED_FILEPATH_CANDIDATES.length === 1) { - PACKAGES_TO_INSTALL.push(path.resolve(ORT_COMMON_FOLDER, ORT_COMMON_PACKED_FILEPATH_CANDIDATES[0])); - } else if (ORT_COMMON_PACKED_FILEPATH_CANDIDATES.length > 1) { - throw new Error('multiple packages found for onnxruntime-common.'); - } - - const ORT_WEB_FOLDER = path.resolve(JS_ROOT_FOLDER, 'web'); - const ORT_WEB_PACKED_FILEPATH_CANDIDATES = globbySync('onnxruntime-web-*.tgz', { cwd: ORT_WEB_FOLDER }); - if (ORT_WEB_PACKED_FILEPATH_CANDIDATES.length !== 1) { - throw new Error('cannot find exactly single package for onnxruntime-web.'); - } - PACKAGES_TO_INSTALL.push(path.resolve(ORT_WEB_FOLDER, ORT_WEB_PACKED_FILEPATH_CANDIDATES[0])); - - // we start here: - - // install dev dependencies - await runInShell(`npm install`); - - // npm install with "--cache" to install packed packages with an empty cache folder - await runInShell(`npm install --cache "${NPM_CACHE_FOLDER}" ${PACKAGES_TO_INSTALL.map((i) => `"${i}"`).join(' ')}`); - - // prepare training data - prepareTrainingDataByCopying(); - - console.log('==============================================================='); - console.log('Running self-hosted tests'); - console.log('==============================================================='); - // test cases with self-host (ort hosted in same origin) - await testAllBrowserCases({ hostInKarma: true }); - - console.log('==============================================================='); - console.log('Running not self-hosted tests'); - console.log('==============================================================='); - // test cases without self-host (ort hosted in cross origin) - const server = startServer(path.join(TEST_E2E_RUN_FOLDER, 'node_modules', 'onnxruntime-web'), 8081); - try { - await testAllBrowserCases({ hostInKarma: false }); - } finally { - // close the server after all tests - await server.close(); - } -} - -async function testAllBrowserCases({ hostInKarma }) { - await runKarma({ hostInKarma, main: './browser-test-wasm.js' }); -} - -async function runKarma({ hostInKarma, main, browser = BROWSER, ortMain = 'ort.training.wasm.min.js' }) { - console.log('==============================================================='); - console.log(`Running karma with the following binary: ${ortMain}`); - console.log('==============================================================='); - const selfHostFlag = hostInKarma ? '--self-host' : ''; - await runInShell( - `npx karma start --single-run --browsers ${browser} ${selfHostFlag} --ort-main=${ - ortMain - } --test-main=${main} --user-data=${getNextUserDataDir()}`, - ); -} - -async function runInShell(cmd) { - console.log('==============================================================='); - console.log(' Running command in shell:'); - console.log(' > ' + cmd); - console.log('==============================================================='); - let complete = false; - const childProcess = spawn(cmd, { shell: true, stdio: 'inherit', cwd: TEST_E2E_RUN_FOLDER }); - childProcess.on('close', function (code) { - if (code !== 0) { - process.exit(code); - } else { - complete = true; - } - }); - while (!complete) { - await delay(100); - } -} - -async function delay(ms) { - return new Promise(function (resolve) { - setTimeout(function () { - resolve(); - }, ms); - }); -} - -function prepareTrainingDataByCopying() { - fs.copySync(TRAINING_DATA_FOLDER, TRAININGDATA_DEST); - console.log(`Copied ${TRAINING_DATA_FOLDER} to ${TRAININGDATA_DEST}`); -} - -main(); diff --git a/js/web/test/training/e2e/simple-http-server.js b/js/web/test/training/e2e/simple-http-server.js deleted file mode 100644 index ef9cced681cc8..0000000000000 --- a/js/web/test/training/e2e/simple-http-server.js +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -'use strict'; - -// this is a simple HTTP server that enables CORS. -// following code is based on https://developer.mozilla.org/en-US/docs/Learn/Server-side/Node_server_without_framework - -const http = require('http'); -const fs = require('fs'); -const path = require('path'); - -const getRequestData = (url, dir) => { - const pathname = new URL(url, 'http://localhost').pathname; - - let filepath; - let mimeType; - if (pathname.startsWith('/test-wasm-path-override/') || pathname.startsWith('/dist/')) { - filepath = path.resolve(dir, pathname.substring(1)); - } else { - return null; - } - - if (filepath.endsWith('.wasm')) { - mimeType = 'application/wasm'; - } else if (filepath.endsWith('.js') || filepath.endsWith('.mjs')) { - mimeType = 'text/javascript'; - } else { - return null; - } - - return [filepath, mimeType]; -}; - -module.exports = function (dir, port) { - const server = http - .createServer(function (request, response) { - const url = request.url.replace(/\n|\r/g, ''); - console.log(`request ${url}`); - - const requestData = getRequestData(url, dir); - if (!request || !requestData) { - response.writeHead(404); - response.end('404'); - } else { - const [filePath, contentType] = requestData; - fs.readFile(path.resolve(dir, filePath), function (error, content) { - if (error) { - if (error.code == 'ENOENT') { - response.writeHead(404); - response.end('404'); - } else { - response.writeHead(500); - response.end('500'); - } - } else { - response.setHeader('access-control-allow-origin', '*'); - response.writeHead(200, { 'Content-Type': contentType }); - response.end(content, 'utf-8'); - } - }); - } - }) - .listen(port); - console.log(`Server running at http://localhost:${port}/`); - return server; -}; diff --git a/js/web/types.d.ts b/js/web/types.d.ts index 735b6a89a2a86..b82248c0c83b8 100644 --- a/js/web/types.d.ts +++ b/js/web/types.d.ts @@ -20,7 +20,3 @@ declare module 'onnxruntime-web/webgl' { declare module 'onnxruntime-web/webgpu' { export * from 'onnxruntime-web'; } - -declare module 'onnxruntime-web/training' { - export * from 'onnxruntime-web'; -} diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 45acb90ba68b0..e0fa581c8071d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -114,7 +114,8 @@ struct GroupQueryAttentionParameters { int local_window_size; bool kv_share_buffer; bool is_packed_qkv; - bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor + bool is_subsequent_prompt; // indicates whether we have past context and seqlen > 1 + bool is_first_prompt; // indicates whether this is first decoding step bool do_rotary; bool rotary_interleaved; bool use_smooth_softmax; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index e6c948acb0d6c..4d435f71cc195 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -236,19 +236,16 @@ T* ConcatStateChunkGQA(const T* past, size_t past_buff_chunk_length, size_t past_chunk_length, size_t new_chunk_length, - bool is_prompt, bool past_present_share_buffer, std::ptrdiff_t i) { T* start = present + i * present_buff_chunk_length; T* p = start; - if (!is_prompt) { - if (!past_present_share_buffer) { - const T* src_past = past + i * past_buff_chunk_length; - memcpy(p, src_past, past_chunk_length * sizeof(T)); - } - p += past_chunk_length; + if (!past_present_share_buffer && past_chunk_length > 0) { + const T* src_past = past + i * past_buff_chunk_length; + memcpy(p, src_past, past_chunk_length * sizeof(T)); } + p += past_chunk_length; memcpy(p, chunk, new_chunk_length * sizeof(T)); return start; diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 2bf0aa0915c2d..bfec9aef56727 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -59,6 +59,7 @@ class GQAAttentionBase { GroupQueryAttentionParameters& parameters, // attention parameters AllocatorPtr allocator, // allocator for temporary tensors OpKernelContext* context) const { + const bool is_prompt = parameters.is_first_prompt; const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int head_size = parameters.head_size; @@ -88,14 +89,14 @@ class GQAAttentionBase { const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, - present_key_data, past_present_share_buffer, packed_qkv, tp); + present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v, seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, - tp); + is_prompt, tp); return Status::OK(); } @@ -105,35 +106,35 @@ class GQAAttentionBase { // attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T) // attention_probs(B, N, S, T) = Softmax(attention_probs) template - void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT - const T* Q, // Q data. Its size is BxNxSxH - const T* K, // k data. Its size is BxNxLxH - const int32_t* seqlens_k, // past sequence lengths tensor - int batch_size, // batch size of self-attention - int sequence_length, // sequence length of self-attention (S) - int past_buffer_sequence_length, // sequence length of past state - int present_buffer_sequence_length, // sequence length of present state - int head_size, // head size of self-attention - const T* past_key, // past key only - T* present_key, // present key only - bool past_present_share_buffer, // whether present key and value share the same buffer - bool packed_qkv, // whether Q, K, V are packed - ThreadPool* tp) const { // thread pool - const bool is_prompt = sequence_length != 1; + void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT + const T* Q, // Q data. Its size is BxNxSxH + const T* K, // k data. Its size is BxNxLxH + const int32_t* seqlens_k, // total - 1 sequence lengths tensor + const size_t batch_size, // batch size of self-attention + const size_t sequence_length, // sequence length of self-attention (S) + const size_t past_buffer_sequence_length, // sequence length of past state + const size_t present_buffer_sequence_length, // sequence length of present state + const size_t head_size, // head size of self-attention + const T* past_key, // past key only + T* present_key, // present key only + const bool past_present_share_buffer, // whether present key and value share the same buffer + const bool packed_qkv, // whether Q, K, V are packed + const bool is_prompt, // whether it is prompt + ThreadPool* tp) const { // thread pool const ptrdiff_t packed_batch_stride = packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : SafeInt(0); - const int kv_num_heads_factor = num_heads_ / kv_num_heads_; - const size_t q_input_chunk_length = static_cast(sequence_length) * head_size; // S x H - const size_t kv_input_chunk_length = static_cast(sequence_length) * head_size; // L x H - const size_t past_buff_chunk_length = static_cast(past_buffer_sequence_length) * head_size; // L x H - const size_t present_buff_chunk_length = static_cast(present_buffer_sequence_length) * head_size; // T x H + const size_t kv_num_heads_factor = num_heads_ / kv_num_heads_; + const size_t q_input_chunk_length = sequence_length * head_size; // S x H + const size_t kv_input_chunk_length = sequence_length * head_size; // L x H + const size_t past_buff_chunk_length = past_buffer_sequence_length * head_size; // L x H + const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H if (!past_present_share_buffer) { memset(present_key, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); } - const int loop_len = batch_size * num_heads_; + const size_t loop_len = batch_size * num_heads_; const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; TensorOpCost unit_cost; @@ -156,12 +157,11 @@ class GQAAttentionBase { ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t i = begin; i != end; ++i) { - const int batch_index = static_cast(i) / num_heads_; - const int head_index = static_cast(i) % num_heads_; - const int past_seqlen = - sequence_length == 1 ? static_cast(seqlens_k[batch_index]) : past_buffer_sequence_length; - const size_t past_chunk_length = static_cast(past_seqlen) * head_size; - const int total_seqlen = seqlens_k[batch_index] + 1; + const size_t batch_index = i / num_heads_; + const size_t head_index = i % num_heads_; + const size_t total_seqlen = static_cast(seqlens_k[batch_index]) + 1; + const size_t past_seqlen = is_prompt ? 0 : total_seqlen - sequence_length; // Assume no padding sequence length + const size_t past_chunk_length = past_seqlen * head_size; const ptrdiff_t output_offset = SafeInt(i) * sequence_length * present_buffer_sequence_length; T* output = attention_probs + output_offset; @@ -174,7 +174,7 @@ class GQAAttentionBase { } if (nullptr != present_key) { k = ConcatStateChunkGQA(past_key, k, present_key, present_buff_chunk_length, past_buff_chunk_length, - past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer, + past_chunk_length, kv_input_chunk_length, past_present_share_buffer, i / kv_num_heads_factor); } @@ -189,16 +189,17 @@ class GQAAttentionBase { } else { q = Q + q_input_chunk_length * i; } + math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q, - head_size, k, head_size, 0.0f /*bata*/, output, present_buffer_sequence_length, - nullptr); + static_cast(head_size), k, static_cast(head_size), 0.0f /*bata*/, output, + static_cast(present_buffer_sequence_length), nullptr); // compute Softmax T* output_softmax = output; - for (int seq = 0; seq < sequence_length; seq++) { - int seq_causal_length = sequence_length == 1 ? total_seqlen : seq + 1; - if (local_window_size_ > 0 && seq_causal_length > local_window_size_ + 1) { - for (int total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) { + for (size_t seq = 0; seq < sequence_length; seq++) { + size_t seq_causal_length = past_seqlen + seq + 1; + if (local_window_size_ > 0 && seq_causal_length > static_cast(local_window_size_) + 1) { + for (size_t total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) { output_softmax[total_seq_id] = 0.f; } if (softcap_ > 0.f) { @@ -214,17 +215,17 @@ class GQAAttentionBase { } } else { if (softcap_ > 0.f) { - ComputeAttentionSoftcapInplace(output_softmax, seq_causal_length, softcap_); + ComputeAttentionSoftcapInplace(output_softmax, static_cast(seq_causal_length), softcap_); } if (use_smooth_softmax_) { - ComputeSmoothSoftmaxInplace(output_softmax, 1, seq_causal_length, nullptr); + ComputeSmoothSoftmaxInplace(output_softmax, 1, static_cast(seq_causal_length), nullptr); } else { - ComputeAttentionSoftmaxInplace(output_softmax, 1, seq_causal_length, nullptr); + ComputeAttentionSoftmaxInplace(output_softmax, 1, static_cast(seq_causal_length), nullptr); } } // set causal [seq_causal_length, total_seqlen) to 0.f - for (int total_seq_id = seq_causal_length; total_seq_id < total_seqlen; total_seq_id++) { + for (size_t total_seq_id = seq_causal_length; total_seq_id < total_seqlen; total_seq_id++) { output_softmax[total_seq_id] = 0.f; } @@ -235,34 +236,36 @@ class GQAAttentionBase { } template - void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH - const T* attention_probs, // Attention probs with size BxNxSxT - const T* V, // V value with size BxN_kvxSxH - const int32_t* seqlens_k, // past sequence lengths tensor - int batch_size, // batch size - int sequence_length, // sequence length - int past_buffer_sequence_length, // sequence length in past state - int present_buffer_sequence_length, // sequence length in past state - int head_size, // head size of Q, K, V - int hidden_size, // hidden size of Output - const T* past_value, // past value only - T* present_value, // present value only - bool past_present_share_buffer, // whether present key and value share the same buffer - bool packed_qkv, // whether Q, K, V are packed + void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH + const T* attention_probs, // Attention probs with size BxNxSxT + const T* V, // V value with size BxN_kvxSxH + const int32_t* seqlens_k, // total - 1 sequence lengths tensor + const size_t batch_size, // batch size + const size_t sequence_length, // sequence length + const size_t past_buffer_sequence_length, // sequence length in past state + const size_t present_buffer_sequence_length, // sequence length in past state + const size_t head_size, // head size of Q, K, V + const size_t hidden_size, // hidden size of Output + const T* past_value, // past value only + T* present_value, // present value only + const bool past_present_share_buffer, // whether present key and value share the same buffer + const bool packed_qkv, // whether Q, K, V are packed + const bool is_prompt, // whether it is prompt ThreadPool* tp) const { - const bool is_prompt = sequence_length != 1; const ptrdiff_t packed_batch_stride = packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : SafeInt(0); - const int kv_num_heads_factor = num_heads_ / kv_num_heads_; - const int kv_input_chunk_length = sequence_length * head_size; // L x H - const size_t past_buff_chunk_length = static_cast(past_buffer_sequence_length) * head_size; // L x H - const size_t present_buff_chunk_length = static_cast(present_buffer_sequence_length) * head_size; // T x H + const size_t kv_num_heads_factor = num_heads_ / kv_num_heads_; + const size_t kv_input_chunk_length = sequence_length * head_size; // L x H + const size_t past_buff_chunk_length = past_buffer_sequence_length * head_size; // L x H + const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H if (!past_present_share_buffer) { memset(present_value, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); } + const size_t loop_len = batch_size * num_heads_; + // The cost of Gemm TensorOpCost unit_cost; unit_cost.compute_cycles = @@ -282,37 +285,35 @@ class GQAAttentionBase { unit_cost.bytes_loaded += bytes_to_copy_trans_all; unit_cost.bytes_stored += bytes_to_copy_trans_all; - ThreadPool::TryParallelFor( - tp, SafeInt(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { - for (std::ptrdiff_t i = begin; i != end; ++i) { - const int batch_index = static_cast(i / num_heads_); - const int head_index = static_cast(i % num_heads_); - const int past_seqlen = - sequence_length == 1 ? static_cast(seqlens_k[batch_index]) : past_buffer_sequence_length; - const size_t past_chunk_length = static_cast(past_seqlen) * head_size; - const int total_seqlen = seqlens_k[batch_index] + 1; - - const T* v; - if (packed_qkv) { - v = V + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor); - } else { - v = V + kv_input_chunk_length * (i / kv_num_heads_factor); - } - if (nullptr != present_value) { - v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length, - past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer, - i / kv_num_heads_factor); - } + ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t i = begin; i != end; ++i) { + const size_t batch_index = i / num_heads_; + const size_t head_index = i % num_heads_; + const size_t total_seqlen = static_cast(seqlens_k[batch_index]) + 1; + const size_t past_seqlen = is_prompt ? 0 : total_seqlen - sequence_length; // Assume no padding sequence length + const size_t past_chunk_length = past_seqlen * head_size; + + const T* v; + if (packed_qkv) { + v = V + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor); + } else { + v = V + kv_input_chunk_length * (i / kv_num_heads_factor); + } + if (nullptr != present_value) { + v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length, + past_chunk_length, kv_input_chunk_length, past_present_share_buffer, + i / kv_num_heads_factor); + } - T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size; - ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * present_buffer_sequence_length * i; + T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size; + ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * present_buffer_sequence_length * i; - math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, - 1.f, /*alpha*/ - attention_probs + attention_probs_offset, present_buffer_sequence_length, v, - head_size, 0.0f /*beta*/, output_current, hidden_size, nullptr); - } - }); + math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, 1.f, /*alpha*/ + attention_probs + attention_probs_offset, + static_cast(present_buffer_sequence_length), v, static_cast(head_size), + 0.0f /*beta*/, output_current, static_cast(hidden_size), nullptr); + } + }); } }; diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 87675255f5ba4..2a38e4a1ac636 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -45,7 +45,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { const Tensor* past_key = context->Input(3); const Tensor* past_value = context->Input(4); const Tensor* seqlens_k = context->Input(5); - const Tensor* total_seqlen = context->Input(6); + const Tensor* total_seqlen_tensor = context->Input(6); const Tensor* cos_cache = context->Input(7); const Tensor* sin_cache = context->Input(8); @@ -61,7 +61,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { num_heads_, kv_num_heads_, seqlens_k, - total_seqlen, + total_seqlen_tensor, scale_, softcap_)); @@ -103,6 +103,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { } if (do_rotary_) { + // Initialize rotary parameters rotary_embedding_helper::RotaryParameters rotary_params = {}; rotary_params.batch_size = batch_size; rotary_params.sequence_length = sequence_length; @@ -114,17 +115,29 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { rotary_params.seq_stride = head_size; rotary_params.head_stride = sequence_length * rotary_params.seq_stride; rotary_params.batch_stride = (packed_qkv ? (num_heads_ + 2 * kv_num_heads_) : num_heads_) * rotary_params.head_stride; - rotary_params.position_ids_format = sequence_length == 1 ? 1 : 0; + rotary_params.position_ids_format = !parameters.is_first_prompt ? 1 : 0; rotary_params.transposed = true; auto* tp = context->GetOperatorThreadPool(); - std::vector pos_ids(sequence_length == 1 ? batch_size : 1); - if (sequence_length == 1) { + // Generate position ids + const int pos_ids_size = parameters.is_first_prompt ? 1 : batch_size * sequence_length; + std::vector pos_ids(pos_ids_size); + if (parameters.is_first_prompt) { + pos_ids[0] = static_cast(0); + } else { + // Note: As of now, interactive decoding supports only batch size 1 and token generation supports only sequence length 1. for (int b = 0; b < batch_size; b++) { - pos_ids[b] = static_cast(seqlens_k->Data()[b]); + const int total_seqlen = seqlens_k->Data()[b] + 1; + const int past_seqlen = total_seqlen - sequence_length; + for (int s = 0; s < sequence_length; s++) { + if (past_seqlen + s < total_seqlen) { + pos_ids[b * sequence_length + s] = static_cast(past_seqlen) + s; + } else { + pos_ids[b * sequence_length + s] = static_cast(1); + } + } } - } else { - pos_ids[0] = static_cast(0); } + // Initialize separate buffers for rotary embeddings const T* q_input; const T* k_input; T* q_rotary; @@ -149,6 +162,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { Q = RotaryQ; K = RotaryK; } + // Run rotary embedding for Q and K ORT_RETURN_IF_ERROR(RunRotaryEmbedding(tp, rotary_params, q_input, pos_ids.data(), cos_cache->Data(), sin_cache->Data(), q_rotary, rotary_interleaved_)); @@ -161,6 +175,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(RunRotaryEmbedding(tp, rotary_params, k_input, pos_ids.data(), cos_cache->Data(), sin_cache->Data(), k_rotary, rotary_interleaved_)); + // Pack V into rotary QKV buffer if (packed_qkv) { const T* v_input = k_input + kv_num_heads_ * sequence_length * head_size; T* v_rotary = k_rotary + kv_num_heads_ * sequence_length * head_size; diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index 3342052260ff9..0bdee151d2173 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -168,14 +168,13 @@ Status CheckInputs(const Tensor* query, "Input 'past_key' and 'past_value' shall be both present or both absent."); } - // Check seqlens_k tensor (holding past seqlen for token gen) - const auto& seqlens_dim = seqlens_k->Shape().GetDims(); - if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) { + const auto& seqlens_k_dim = seqlens_k->Shape().GetDims(); + if (seqlens_k_dim.size() != 1 && seqlens_k_dim[0] != batch_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "seqlens_k must be shape (batch_size)."); } - // Set present sequence length and kv_share_buffer from input total_seqlen tensor + // Set present sequence length from input total_seqlen tensor if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "total_sequence_length tensor must be of one element."); @@ -195,11 +194,11 @@ Status CheckInputs(const Tensor* query, } if (cos_dims[0] < total_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache dimension 0 should be not be less than total_sequence_length."); + "cos_cache dimension 0 shall not be less than total_sequence_length."); } if (sin_dims[0] < total_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sin_cache dimension 0 should be not be less than total_sequence_length."); + "sin_cache dimension 0 shall not be less than total_sequence_length."); } if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -219,7 +218,26 @@ Status CheckInputs(const Tensor* query, "Input 'cos_cache' and 'sin_cache' shall be both present or both absent."); } - bool is_prompt = sequence_length != 1; + bool is_subsequent_prompt = false; + if (sequence_length > 1 && sequence_length != total_sequence_length) { + if (batch_size != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "batch_size must be 1 when sequence_length > 1 and past context is given."); + } + is_subsequent_prompt = true; + } + + bool is_first_prompt; + if (is_subsequent_prompt) { + is_first_prompt = false; // irrelevant for interactive decoding + } else { + // If not interactive, sequence_length is 1 for token gen and arbitrarily large for prompt + is_first_prompt = (sequence_length == total_sequence_length); + if (!is_first_prompt && sequence_length != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "sequence_length shall be 1 when it is not prompt."); + } + } if (parameters != nullptr) { GroupQueryAttentionParameters* output_parameters = reinterpret_cast(parameters); @@ -227,6 +245,7 @@ Status CheckInputs(const Tensor* query, output_parameters->sequence_length = sequence_length; // sequence length of Q output_parameters->seqlen_past_kv_cache = past_sequence_length; // max sequence length of past kv tensors output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors + output_parameters->total_sequence_length = total_sequence_length; // total sequence length output_parameters->hidden_size = q_hidden_size; output_parameters->num_heads = num_heads; output_parameters->head_size = head_size; @@ -235,7 +254,8 @@ Status CheckInputs(const Tensor* query, output_parameters->rotary_dim = rotary_dim; output_parameters->is_packed_qkv = is_packed_qkv; output_parameters->is_unidirectional = true; - output_parameters->is_prompt = is_prompt; + output_parameters->is_subsequent_prompt = is_subsequent_prompt; + output_parameters->is_first_prompt = is_first_prompt; output_parameters->scale = scale; output_parameters->softcap = softcap; output_parameters->qkv_format = qkv_format; diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index bf43aca73ef3a..ccb779721d006 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -146,8 +146,15 @@ class MatMulNBits final : public OpKernel { bool all_constant_{false}; #endif // defined(ORT_NEURAL_SPEED) + + template + Status ComputeTyped(OpKernelContext* ctx) const; }; +bool IsATypeFloat16(const Tensor& tensor) { + return tensor.GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; +} + Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { @@ -211,10 +218,10 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat #else // defined(ORT_NEURAL_SPEED) ORT_UNUSED_PARAMETER(prepacked_weights); const auto compute_type = static_cast(accuracy_level_); + if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { + return Status::OK(); + } if (input_idx == InputIndex::B) { - if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { - return Status::OK(); - } packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type); if (packed_b_size_ == 0) { return Status::OK(); @@ -226,8 +233,15 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat } else if (compute_type == CompInt8) { #ifdef MLAS_TARGET_AMD64_IX86 if (input_idx == InputIndex::scales && packed_b_ != nullptr) { - auto sptr = tensor.Data(); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), sptr, has_zp_input_, nullptr, nullptr); + if (IsATypeFloat16(tensor)) { + auto sptr = tensor.Data(); + std::vector scales_v(static_cast(tensor.Shape().Size())); + MlasConvertHalfToFloatBuffer(sptr, &scales_v[0], scales_v.size()); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), &scales_v[0], has_zp_input_, nullptr, nullptr); + } else { + auto sptr = tensor.Data(); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), sptr, has_zp_input_, nullptr, nullptr); + } is_packed = false; } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { auto zptr = tensor.Data(); @@ -274,9 +288,20 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep } Status MatMulNBits::Compute(OpKernelContext* ctx) const { + const Tensor* a = ctx->Input(InputIndex::A); + + if (IsATypeFloat16(*a)) { + return ComputeTyped(ctx); + } else { + return ComputeTyped(ctx); + } +} + +template +Status MatMulNBits::ComputeTyped(OpKernelContext* ctx) const { concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); const Tensor* a = ctx->Input(InputIndex::A); - const auto* a_data = a->Data(); + const auto* a_data = a->Data(); TensorShape b_shape({static_cast(N_), static_cast(K_)}); MatMulComputeHelper helper; @@ -289,7 +314,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { return Status::OK(); } - auto* y_data = y->MutableData(); + auto* y_data = y->MutableData(); const size_t batch_count = helper.OutputOffsets().size(); const size_t M = static_cast(helper.M()); @@ -297,9 +322,12 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const size_t K = static_cast(helper.K()); const size_t lda = helper.Lda(false); - const bool has_single_b_matrix = std::all_of(helper.RightOffsets().begin(), - helper.RightOffsets().end(), - [](size_t offset) { return offset == 0; }); + // clang-format off + const bool has_single_b_matrix = std::all_of( + helper.RightOffsets().begin(), + helper.RightOffsets().end(), + [](size_t offset) { return offset == 0; }); + // clang-format on #if defined(ORT_NEURAL_SPEED) @@ -336,9 +364,9 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const Tensor* zero_points = ctx->Input(InputIndex::zero_points); const Tensor* bias = ctx->Input(InputIndex::bias); - const auto* scales_data = scales->Data(); + const auto* scales_data = scales->Data(); const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); - const auto* bias_data = bias == nullptr ? nullptr : bias->Data(); + const auto* bias_data = bias == nullptr ? nullptr : bias->Data(); IAllocatorUniquePtr workspace{}; const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize( @@ -349,26 +377,64 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); } - InlinedVector data(batch_count); - for (size_t i = 0; i < batch_count; ++i) { - data[i].A = a_data + helper.LeftOffsets()[i]; - data[i].lda = lda; + if constexpr (std::is_same::value) { + InlinedVector data(batch_count); + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); + + auto tmp_a_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(a->Shape().Size())); + MlasConvertHalfToFloatBuffer(a_data, tmp_a_data_ptr.get(), static_cast(a->Shape().Size())); + + auto tmp_scales_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(scales->Shape().Size())); + MlasConvertHalfToFloatBuffer(scales_data, tmp_scales_data_ptr.get(), static_cast(scales->Shape().Size())); + + std::vector bias_data_v; + if (bias_data != nullptr) { + bias_data_v.resize((const unsigned int)(bias->Shape().Size())); + MlasConvertHalfToFloatBuffer(bias_data, &bias_data_v[0], bias_data_v.size()); + } + std::vector C_v((const unsigned int)(y->Shape().Size())); + for (size_t i = 0; i < batch_count; ++i) { + data[i].A = tmp_a_data_ptr.get() + helper.LeftOffsets()[i]; + data[i].lda = lda; #ifdef MLAS_TARGET_AMD64_IX86 - if (compute_type == CompInt8) { - data[i].QuantBDataWorkspace = packed_b_.get(); + if (compute_type == CompInt8) { + data[i].QuantBDataWorkspace = packed_b_.get(); + } +#endif + data[i].PackedQuantBData = static_cast(packed_b_.get()); + data[i].QuantBScale = tmp_scales_data_ptr.get(); + data[i].QuantBZeroPoint = zero_points_data; + data[i].Bias = bias_data != nullptr ? &bias_data_v[0] : nullptr; + data[i].C = &C_v[0] + helper.OutputOffsets()[i]; + data[i].ldc = N; } + MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), + thread_pool); + MlasConvertFloatToHalfBuffer(&C_v[0], y_data, C_v.size()); + return Status::OK(); + } else { + InlinedVector data(batch_count); + for (size_t i = 0; i < batch_count; ++i) { + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; +#ifdef MLAS_TARGET_AMD64_IX86 + if (compute_type == CompInt8) { + data[i].QuantBDataWorkspace = packed_b_.get(); + } #endif - data[i].PackedQuantBData = static_cast(packed_b_.get()); - data[i].QuantBScale = scales_data; - data[i].QuantBZeroPoint = zero_points_data; - data[i].Bias = bias_data; - data[i].C = y_data + helper.OutputOffsets()[i]; - data[i].ldc = N; + data[i].PackedQuantBData = static_cast(packed_b_.get()); + data[i].QuantBScale = scales_data; + data[i].QuantBZeroPoint = zero_points_data; + data[i].Bias = bias_data; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + } + MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), + thread_pool); + return Status::OK(); } - MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), - thread_pool); - - return Status::OK(); } } @@ -380,7 +446,17 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const Tensor* zero_points = ctx->Input(InputIndex::zero_points); const Tensor* reorder_idx = ctx->Input(InputIndex::g_idx); - const auto* scales_data = scales->Data(); + const auto* scales_data = scales->Data(); + const float* scales_data_; + std::vector scales_data_v; + if constexpr (std::is_same::value) { + scales_data_v.resize((const unsigned int)scales->Shape().Size()); + MlasConvertHalfToFloatBuffer(scales_data, &scales_data_v[0], scales_data_v.size()); + scales_data_ = &scales_data_v[0]; + } else { + scales_data_ = scales_data; + } + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); @@ -391,12 +467,12 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { AllocatorPtr allocator; ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); - if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { + if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { // dequantize b, only 4b quantization is supported for now MlasDequantizeBlockwise( tmp_b_data_ptr.get(), // dequantized output b_data, // quantized input - scales_data, // quantization scales + scales_data_, // quantization scales static_cast(zero_points_data), // quantization zero points static_cast(block_size_), // quantization block size column_wise_quant_, // columnwise quantization or row-wise @@ -406,12 +482,12 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { } else { ORT_ENFORCE(column_wise_quant_, "Row-wise quantization is not supported for now"); // !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!! - if ((zero_points && zero_points->IsDataType())) { - DequantizeBlockwise( + if ((zero_points && zero_points->IsDataType())) { + DequantizeBlockwise( tmp_b_data_ptr.get(), // dequantized output b_data, // quantized input - scales_data, // quantization scales - static_cast(zero_points_data), // quantization zero points + scales_data_, // quantization scales + static_cast(zero_points_data), // quantization zero points reorder_idx_data, static_cast(block_size_), // quantization block size column_wise_quant_, // columnwise quantization or row-wise @@ -422,7 +498,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { DequantizeBlockwise( tmp_b_data_ptr.get(), // dequantized output b_data, // quantized input - scales_data, // quantization scales + scales_data_, // quantization scales static_cast(zero_points_data), // quantization zero points reorder_idx_data, static_cast(block_size_), // quantization block size @@ -436,40 +512,80 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { auto tm_b_data_ptr_trans = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); MlasTranspose(tmp_b_data_ptr.get(), tm_b_data_ptr_trans.get(), N_, K_); #endif + if constexpr (std::is_same::value) { + std::vector data(batch_count); - std::vector data(batch_count); - for (size_t i = 0; i < batch_count; i++) { - data[i].BIsPacked = false; - data[i].A = a_data + helper.LeftOffsets()[i]; - data[i].lda = lda; - data[i].B = tmp_b_data_ptr.get() + helper.RightOffsets()[i]; - data[i].ldb = ldb; - data[i].C = y_data + helper.OutputOffsets()[i]; - data[i].ldc = N; - data[i].alpha = 1.f; - data[i].beta = 0.0f; - } + auto tmp_a_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(a->Shape().Size())); + MlasConvertHalfToFloatBuffer(a_data, tmp_a_data_ptr.get(), static_cast(a->Shape().Size())); - // if there is a bias input, copy bias values into C and set beta to 1.0f - if (const Tensor* bias = ctx->Input(InputIndex::bias); - bias != nullptr) { - gsl::span bias_span = bias->DataAsSpan(); - for (size_t i = 0; i < batch_count; ++i) { - float* C_row = data[i].C; - const size_t ldc = data[i].ldc; - for (size_t m = 0; m < M; ++m) { - memcpy(C_row, bias_span.data(), bias_span.size_bytes()); - C_row += ldc; + auto tmp_c_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(y->Shape().Size())); + for (size_t i = 0; i < batch_count; i++) { + data[i].BIsPacked = false; + data[i].A = tmp_a_data_ptr.get() + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = tmp_b_data_ptr.get() + helper.RightOffsets()[i]; + data[i].ldb = ldb; + data[i].C = tmp_c_ptr.get() + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].alpha = 1.f; + data[i].beta = 0.0f; + } + + // if there is a bias input, copy bias values into C and set beta to 1.0f + if (const Tensor* bias = ctx->Input(InputIndex::bias); + bias != nullptr) { + auto tmp_bias_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(bias->Shape().Size())); + MlasConvertHalfToFloatBuffer(bias->Data(), tmp_bias_data_ptr.get(), static_cast(bias->Shape().Size())); + for (size_t i = 0; i < batch_count; ++i) { + float* C_row = data[i].C; + const size_t ldc = data[i].ldc; + for (size_t m = 0; m < M; ++m) { + std::copy(tmp_bias_data_ptr.get(), tmp_bias_data_ptr.get() + bias->Shape().Size(), C_row); + C_row += ldc; + } + data[i].beta = 1.0f; } + } - data[i].beta = 1.0f; + MlasGemmBatch(CblasNoTrans, CblasTrans, + M, N, K, data.data(), batch_count, thread_pool); + MlasConvertFloatToHalfBuffer(tmp_c_ptr.get(), y_data, static_cast(y->Shape().Size())); + return Status::OK(); + } else { + std::vector data(batch_count); + for (size_t i = 0; i < batch_count; i++) { + data[i].BIsPacked = false; + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = tmp_b_data_ptr.get() + helper.RightOffsets()[i]; + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].alpha = 1.f; + data[i].beta = 0.0f; } - } - MlasGemmBatch(CblasNoTrans, CblasTrans, - M, N, K, data.data(), batch_count, thread_pool); + // if there is a bias input, copy bias values into C and set beta to 1.0f + if (const Tensor* bias = ctx->Input(InputIndex::bias); + bias != nullptr) { + gsl::span bias_span = bias->DataAsSpan(); + for (size_t i = 0; i < batch_count; ++i) { + float* C_row = data[i].C; + const size_t ldc = data[i].ldc; + for (size_t m = 0; m < M; ++m) { + memcpy(C_row, bias_span.data(), bias_span.size_bytes()); + C_row += ldc; + } - return Status::OK(); + data[i].beta = 1.0f; + } + } + + MlasGemmBatch(CblasNoTrans, CblasTrans, + M, N, K, data.data(), batch_count, thread_pool); + + return Status::OK(); + } } ONNX_OPERATOR_KERNEL_EX( @@ -478,9 +594,9 @@ ONNX_OPERATOR_KERNEL_EX( 1, kCpuExecutionProvider, KernelDefBuilder() - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) .TypeConstraint("T2", DataTypeImpl::GetTensorType()) - .TypeConstraint("T3", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T3", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) .TypeConstraint("T4", DataTypeImpl::GetTensorType()), MatMulNBits); diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc index b28f3758f89b5..6a19a741c3028 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc @@ -54,12 +54,12 @@ void Dequantize4BitsKernelReOrder( T scale = *(scale_data + n_idx * scales_shape_x + rid); float zp_f = 8; if (zero_points) { - if constexpr (std::is_same_v) { - zp_f = *(zero_points + n_idx * scales_shape_x + rid); - } else { + if constexpr (std::is_same_v) { uint8_t zp = 8; zp = zero_points[n_idx * zero_point_shape_x + rid / 2]; zp = (rid & 0x01) ? (zp >> 4) : (zp & 0x0f); + } else { + zp_f = *(zero_points + static_cast(n_idx) * static_cast(scales_shape_x) + static_cast(rid)); } } @@ -112,5 +112,10 @@ template void DequantizeBlockwise( const float* zero_points, const int32_t* reorder_idx, int32_t block_size, bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); +template void DequantizeBlockwise( + float* output, const uint8_t* quant_data, const float* scales_data, + const MLFloat16* zero_points, const int32_t* reorder_idx, int32_t block_size, + bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); + } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h index cf66bd8407126..37172074e5d86 100644 --- a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h @@ -184,7 +184,7 @@ class SparseAttentionBase { // Concatenate past_k + k -> present_k // TODO: avoid copying mutiple times for a group. k = ConcatStateChunkGQA(past_key, k, present_key, present_buff_chunk_length, past_buff_chunk_length, - past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer, + is_prompt ? 0 : past_chunk_length, kv_input_chunk_length, past_present_share_buffer, i / kv_num_heads_factor); // Compute Q*K' + AttentionMask @@ -365,7 +365,7 @@ class SparseAttentionBase { // Concatenate past_v + v -> present_v v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length, - past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer, + is_prompt ? 0 : past_chunk_length, kv_input_chunk_length, past_present_share_buffer, i / kv_num_heads_factor); DUMP_CPU_TENSOR("present_value", v, total_seq_len, head_size); diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index a10d2548fa7b8..7f1c3786858c8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -42,7 +42,6 @@ struct RightPaddingBatchHook { auto lse_dim = ceil_div((int32_t)(p.num_queries), kAlignLSE) * kAlignLSE; - // Advance to current batch - in case of different sequence lengths if (p.seqlen_k_ptr) { p.num_keys = p.seqlen_k_ptr[batch_id]; } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index d0ae812bb4fa2..6eff584cec5da 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -5,7 +5,7 @@ #include "core/platform/env_var_utils.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cuda/bert/group_query_attention.h" -#include "contrib_ops/cuda/bert/group_query_attention_helper.h" +#include "contrib_ops/cpu/bert/group_query_attention_helper.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" @@ -95,7 +95,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { kv_num_heads_, seqlens_k, total_seqlen, - is_past_bsnh_, scale_, softcap_, device_prop.maxThreadsPerBlock)); @@ -253,7 +252,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { data.out_accum = reinterpret_cast(out_accum_buffer.get()); } if (seqlens_k_buffer != nullptr) { - data.seqlens_k_total = reinterpret_cast(seqlens_k_buffer.get()); + data.seqlens_k_buff = reinterpret_cast(seqlens_k_buffer.get()); } // Memory Efficient Buffers if (k_buffer != nullptr) { diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h deleted file mode 100644 index e65827e4ccdd5..0000000000000 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h +++ /dev/null @@ -1,298 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/providers/common.h" -#include "contrib_ops/cpu/bert/attention_common.h" - -namespace onnxruntime { -namespace contrib { -namespace group_query_attention_helper { - -Status CheckInputs(const Tensor* query, - const Tensor* key, - const Tensor* value, - const Tensor* past_key, - const Tensor* past_value, - const Tensor* cos_cache, - const Tensor* sin_cache, - void* parameters, - int num_heads, - int kv_num_heads, - const Tensor* seqlens_k, - const Tensor* total_seqlen, - bool is_past_bsnh, - float scale, - float softcap) { - // Note: Here S* is past_cache_sequence_length, S- is past_sequence_length, S+ is sequence_length - // past_key : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr - // past_value : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr - // no packing for q/k/v: - // query (Q) : (B, S, D) or (B, S, (D_q + 2 D_kv)) - // key (K) : (B, S, D_kv) or nullptr - // value (V) : (B, S, D_kv) or nullptr - AttentionQkvFormat qkv_format = Q_K_V_BSNH; - AttentionQkvFormat past_kv_format = is_past_bsnh ? Q_K_V_BSNH : Q_K_V_BNSH; - const bool is_packed_qkv = key == nullptr; - const auto& query_dims = query->Shape().GetDims(); - - if (query_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", - query_dims.size()); - } - - int batch_size = static_cast(query_dims[0]); - int sequence_length = static_cast(query_dims[1]); - int q_hidden_size = static_cast(query_dims[2]); - int head_size = 0; - - if (num_heads % kv_num_heads != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", - num_heads % kv_num_heads); - } - - int kv_hidden_size = 0; - // Check key and value when not packed - if (!is_packed_qkv) { - head_size = static_cast(q_hidden_size) / num_heads; - if (head_size % 8 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "head_size must be a multiple of 8. Got head_size % 8 == ", - head_size % 8); - } - if (value == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); - } - const auto& key_dims = key->Shape().GetDims(); - if (key_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", - key_dims.size()); - } else if (query_dims[0] != key_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 0 (batch size)"); - } else if (query_dims[1] != key_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 1 (sequence length)"); - } - kv_hidden_size = static_cast(key_dims[2]); - const auto& value_dims = value->Shape().GetDims(); - if (value_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", - value_dims.size()); - } else if (query_dims[0] != value_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 0 (batch size)"); - } else if (query_dims[1] != value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 1 (sequence length)"); - } else if (value_dims[2] != kv_hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); - } - } else { - // Check packed qkv - head_size = static_cast(q_hidden_size) / (num_heads + 2 * kv_num_heads); - if (head_size % 8 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "head_size must be a multiple of 8. Got head_size % 8 == ", - head_size % 8); - } - if (value != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); - } - q_hidden_size = head_size * num_heads; - kv_hidden_size = head_size * kv_num_heads; - } - - // Check past-present KV - int32_t past_sequence_length = 0; - if (past_key != nullptr && past_value != nullptr) { - const auto& past_key_dims = past_key->Shape().GetDims(); - const auto& past_value_dims = past_value->Shape().GetDims(); - - if (past_key_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' is expected to have 4 dimensions, got ", - past_key_dims.size()); - } - if (past_value_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' is expected to have 4 dimensions, got ", - past_value_dims.size()); - } - - if (past_key_dims[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 0 should be batch_size, got ", - past_key_dims[0]); - } - if (past_value_dims[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' dimension 0 should be batch_size, got ", - past_value_dims[0]); - } - - // BNSH - if (!is_past_bsnh) { - if (past_key_dims[2] != past_value_dims[2]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "BNSH Input 'past_key' and 'past_value' should have same dimension 2 (max sequence" - "length or past sequence length), got ", - past_key_dims[1]); - } - if (past_key_dims[1] != kv_num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' shall have kv_num_heads"); - } - if (past_value_dims[1] != kv_num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' shall have kv_num_heads"); - } - // We assume all sequence in past kv are right-padded to max or past sequence length - past_sequence_length = static_cast(past_key_dims[2]); - // BSNH - } else { - if (past_key_dims[1] != past_value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "BNSH Input 'past_key' and 'past_value' should have same dimension 1 (max sequence" - "length or past sequence length), got ", - past_key_dims[1]); - } - if (past_key_dims[2] != kv_num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' shall have kv_num_heads"); - } - if (past_value_dims[2] != kv_num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' shall have kv_num_heads"); - } - // We assume all sequence in past kv are right-padded to max or past sequence length - past_sequence_length = static_cast(past_key_dims[1]); - } - - if (past_key_dims[3] != head_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 3 should be same as head_size, got ", - past_key_dims[3]); - } - if (past_value_dims[3] != head_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' dimension 3 should be same as head_size, got ", - past_value_dims[3]); - } - } else if (past_key != nullptr || past_value != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' and 'past_value' shall be both present or both absent."); - } - - // Check seqlens_k tensor (holding past seqlen for token gen) - const auto& seqlens_dim = seqlens_k->Shape().GetDims(); - if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "seqlens_k must be shape (batch_size)."); - } - - // Set present sequence length and kv_share_buffer from input total_seqlen tensor - if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "total_sequence_length tensor must be of one element."); - } - int total_sequence_length = *((*total_seqlen).template Data()); - int present_sequence_length = std::max(total_sequence_length, past_sequence_length); - - int rotary_dim = 0; - if (cos_cache != nullptr && sin_cache != nullptr) { - const auto& cos_dims = cos_cache->Shape().GetDims(); - const auto& sin_dims = sin_cache->Shape().GetDims(); - - if (head_size % 16 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "head_size shall be a multiple of 16. Got head_size % 16 == ", - head_size % 16); - } - if (cos_dims[0] < total_sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache dimension 0 should be not be less than total_sequence_length."); - } - if (sin_dims[0] < total_sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sin_cache dimension 0 should be not be less than total_sequence_length."); - } - if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); - } - if (sin_dims[1] > (head_size / 16) * 8 || sin_dims[1] % 8 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); - } - if (cos_dims[1] != sin_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache and sin_cache dimension 1 must be the same."); - } - rotary_dim = static_cast(cos_dims[1] * 2); - } else if (cos_cache != nullptr || sin_cache != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'cos_cache' and 'sin_cache' shall be both present or both absent."); - } - - bool is_prompt = (sequence_length == total_sequence_length); - if (!is_prompt && sequence_length != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sequence_length shall be 1 when it is not prompt."); - } - - if (parameters != nullptr) { - GroupQueryAttentionParameters* output_parameters = reinterpret_cast(parameters); - output_parameters->batch_size = batch_size; - output_parameters->sequence_length = sequence_length; // sequence length of Q - output_parameters->seqlen_past_kv_cache = past_sequence_length; // max sequence length of past kv tensors - output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors - output_parameters->total_sequence_length = total_sequence_length; // total sequence length - output_parameters->hidden_size = q_hidden_size; - output_parameters->num_heads = num_heads; - output_parameters->head_size = head_size; - output_parameters->kv_hidden_size = kv_hidden_size; - output_parameters->kv_num_heads = kv_num_heads; - output_parameters->rotary_dim = rotary_dim; - output_parameters->is_packed_qkv = is_packed_qkv; - output_parameters->is_prompt = is_prompt; - output_parameters->scale = scale; - output_parameters->softcap = softcap; - output_parameters->qkv_format = qkv_format; - output_parameters->past_kv_format = past_kv_format; - } - - return Status::OK(); -} - -Status CheckInputs(const Tensor* query, - const Tensor* key, - const Tensor* value, - const Tensor* past_key, - const Tensor* past_value, - const Tensor* cos_cache, - const Tensor* sin_cache, - void* parameters, - int num_heads, - int kv_num_heads, - const Tensor* seqlens_k, - const Tensor* total_seqlen, - bool is_past_bsnh, - float scale, - float softcap, - int max_threads_per_block) { - if (max_threads_per_block > 0 && num_heads > max_threads_per_block) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); - } - - return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale, softcap); -} - -} // namespace group_query_attention_helper -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index be94f26ec298f..8bf9848245ec7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -71,6 +71,8 @@ __global__ void ConcatNewToPastKV(const int new_seqlen, const T* new_kv, T* present_kv, const int* seqlens_k, + const bool past_only, + // const int* seqlens_q, const bool is_bsnh) { // refers to past; otherwise bnsh const int h = threadIdx.x; const int n = threadIdx.y; @@ -88,7 +90,9 @@ __global__ void ConcatNewToPastKV(const int new_seqlen, // past_kv: BPNH or BNPH // new_kv: BLNH // present_kv: BTNH or BNTH, where T = P + L - const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b]; + + // prompt, token, and interactive decoding cases + const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b] + 1 - new_seqlen; int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; if (s < past_seqlen) { @@ -96,7 +100,7 @@ __global__ void ConcatNewToPastKV(const int new_seqlen, const int past_head_stride = is_bsnh ? H : past_buffer_seqlen * H; const int in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; present_kv[out_offset] = past_kv[in_offset]; - } else if (s < past_seqlen + new_seqlen) { + } else if (!past_only && s < past_seqlen + new_seqlen) { // Note: new KV always BSNH const int new_batch_stride = new_seqlen * num_heads * H; const int new_row_stride = num_heads * H; @@ -116,6 +120,7 @@ __global__ void ConcatNewToPastKVLarge(const int new_seqlen, const T* new_kv, T* present_kv, const int* seqlens_k, + const bool past_only, const bool is_bsnh) { int i = threadIdx.x + (blockDim.x * blockIdx.x); if (i < H * num_heads) { @@ -132,7 +137,9 @@ __global__ void ConcatNewToPastKVLarge(const int new_seqlen, // past_kv: BPNH or BNPH // new_kv: BLNH // present_kv: BTNH or BNTH, where T = P + L - const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b]; + + // prompt, token, and interactive decoding cases + const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b] + 1 - new_seqlen; int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; if (s < past_seqlen) { @@ -140,7 +147,7 @@ __global__ void ConcatNewToPastKVLarge(const int new_seqlen, const int past_head_stride = is_bsnh ? H : past_buffer_seqlen * H; const int in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; present_kv[out_offset] = past_kv[in_offset]; - } else if (s < past_seqlen + new_seqlen) { + } else if (!past_only && s < past_seqlen + new_seqlen) { const int new_batch_stride = new_seqlen * num_heads * H; const int new_row_stride = num_heads * H; const int new_head_stride = H; @@ -160,13 +167,12 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter const int max_threads_per_block, const bool past_only = false) { const int batch_size = parameters.batch_size; - const int kv_sequence_length = past_only ? 0 : parameters.sequence_length; + const int kv_sequence_length = parameters.sequence_length; const int past_sequence_length = parameters.seqlen_past_kv_cache; const int present_sequence_length = parameters.seqlen_present_kv_cache; const int kv_num_heads = parameters.kv_num_heads; const int head_size = parameters.head_size; - const int* seqlens_k = parameters.is_prompt ? nullptr : reinterpret_cast(data.seqlens_k); - + const int* seqlens_k = parameters.is_first_prompt ? nullptr : reinterpret_cast(data.seqlens_k); AttentionQkvFormat past_kv_format = parameters.past_kv_format; assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); @@ -180,6 +186,7 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter reinterpret_cast(new_key), reinterpret_cast(data.present_key), seqlens_k, + past_only, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); ConcatNewToPastKV<<>>(kv_sequence_length, past_sequence_length, @@ -187,6 +194,7 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter reinterpret_cast(new_value), reinterpret_cast(data.present_value), seqlens_k, + past_only, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); } else { int steps = (H * kv_num_heads + 255) / 256; @@ -200,6 +208,7 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter reinterpret_cast(new_key), reinterpret_cast(data.present_key), seqlens_k, + past_only, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); ConcatNewToPastKVLarge<<>>(kv_sequence_length, past_sequence_length, @@ -209,6 +218,7 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter reinterpret_cast(new_value), reinterpret_cast(data.present_value), seqlens_k, + past_only, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); } return CUDA_CALL(cudaGetLastError()); @@ -219,7 +229,7 @@ template __global__ void ConcatKVInPlace(const int max_seqlen, T* kv_buff, const T* new_kv, - const int* past_seqlens_k, + const int* seqlens_k, const int* total_seqlens_k, const bool is_past_kv_bnsh_format, const bool is_new_kv_bnsh_format) { @@ -234,7 +244,7 @@ __global__ void ConcatKVInPlace(const int max_seqlen, const int past_seq_len = (total_seqlens_k != nullptr) ? (total_seqlens_k[b] - new_seqlen) - : (past_seqlens_k == nullptr ? 0 : past_seqlens_k[b]); + : (seqlens_k == nullptr ? 0 : (seqlens_k[b] + 1 - new_seqlen)); int out_offset = is_past_kv_bnsh_format ? INDEX_4D(kv_num_heads, max_seqlen, H, b, n, s + past_seq_len, h) @@ -253,7 +263,7 @@ __global__ void ConcatKVInPlaceLarge(const int max_seqlen, const int kv_num_heads, T* kv_buff, const T* new_kv, - const int* past_seqlens_k, + const int* seqlens_k, const int* total_seqlens_k, const bool is_past_kv_bnsh_format, const bool is_new_kv_bnsh_format) { // refers to kv buff; otherwise bnsh @@ -264,9 +274,10 @@ __global__ void ConcatKVInPlaceLarge(const int max_seqlen, const int s = blockIdx.y; const int b = blockIdx.z; const int new_seqlen = gridDim.y; + const int past_seq_len = (total_seqlens_k != nullptr) ? (total_seqlens_k[b] - new_seqlen) - : (past_seqlens_k == nullptr ? 0 : past_seqlens_k[b]); + : (seqlens_k == nullptr ? 0 : (seqlens_k[b] + 1 - new_seqlen)); int out_offset = is_past_kv_bnsh_format ? INDEX_4D(kv_num_heads, max_seqlen, H, b, n, s + past_seq_len, h) @@ -286,15 +297,15 @@ Status LaunchConcatKVInPlace(int batch_size, int kv_num_heads, int head_size, int max_sequence_length, - const int* past_seqlens_k, + const int* seqlens_k, const int* total_seqlens_k, int new_seq_len, const T* new_key, const T* new_value, T* present_key, T* present_value, - bool is_past_kv_bnsh_format, - bool is_new_kv_bnsh_format, + const bool is_past_kv_bnsh_format, + const bool is_new_kv_bnsh_format, cudaStream_t stream, const int max_threads_per_block) { static_assert(sizeof(T) == 2); @@ -307,14 +318,14 @@ Status LaunchConcatKVInPlace(int batch_size, ConcatKVInPlace<<>>(max_sequence_length, reinterpret_cast(present_key), reinterpret_cast(new_key), - past_seqlens_k, + seqlens_k, total_seqlens_k, is_past_kv_bnsh_format, is_new_kv_bnsh_format); ConcatKVInPlace<<>>(max_sequence_length, reinterpret_cast(present_value), reinterpret_cast(new_value), - past_seqlens_k, + seqlens_k, total_seqlens_k, is_past_kv_bnsh_format, is_new_kv_bnsh_format); @@ -327,7 +338,7 @@ Status LaunchConcatKVInPlace(int batch_size, kv_num_heads, reinterpret_cast(present_key), reinterpret_cast(new_key), - past_seqlens_k, + seqlens_k, total_seqlens_k, is_past_kv_bnsh_format, is_new_kv_bnsh_format); @@ -336,7 +347,7 @@ Status LaunchConcatKVInPlace(int batch_size, kv_num_heads, reinterpret_cast(present_value), reinterpret_cast(new_value), - past_seqlens_k, + seqlens_k, total_seqlens_k, is_past_kv_bnsh_format, is_new_kv_bnsh_format); @@ -354,7 +365,8 @@ Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, cudaStream_t stream, const int max_threads_per_block) { const int max_sequence_length = parameters.seqlen_present_kv_cache; - const int* past_seqlens_k = parameters.is_prompt ? nullptr : reinterpret_cast(data.seqlens_k); + const int* seqlens_k = (parameters.is_first_prompt && !parameters.is_subsequent_prompt) ? nullptr + : reinterpret_cast(data.seqlens_k); assert(parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); @@ -364,8 +376,8 @@ Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, parameters.kv_num_heads, parameters.head_size, max_sequence_length, - past_seqlens_k, - nullptr, // total_seqlens_k is not available + seqlens_k, + nullptr, // total_seqlens_k would be wrong to use here parameters.sequence_length, reinterpret_cast(new_key), reinterpret_cast(new_value), @@ -495,23 +507,33 @@ __global__ void PastToTotalSeqlen(int32_t* seqlens_k, seqlens_k_buff[threadIdx.x] = seqlens_k[threadIdx.x] + add_seqlen; } -// Convert Past to Total sequence length tensor -Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, - int32_t* seqlens_k_buff, bool is_total, cudaStream_t stream, - const int /*threads_per_block*/) { - if (parameters.is_prompt) { - return Status::OK(); - } - const int batch_size = parameters.batch_size; - const int add_seqlen = is_total ? parameters.sequence_length : 0; - +// Calculate total sequence length from seqlens_k +Status LaunchGetSeqlensTotal(int32_t* seqlens_k, int32_t* seqlens_k_buff, const int batch_size, cudaStream_t stream, + const int /*threads_per_block*/) { const dim3 grid(1, 1, 1); // TODO(aciddelgado): unlikely but could have a bigger batch_size than max_threads const dim3 block(batch_size, 1, 1); + PastToTotalSeqlen<<>>(seqlens_k, seqlens_k_buff, 1); + return CUDA_CALL(cudaGetLastError()); +} - // TODO(aciddelgado): small version - PastToTotalSeqlen<<>>(seqlens_k, seqlens_k_buff, add_seqlen); +// Currently, interactive decoding only works for batch_size 1 +__global__ void GetSeqlensInteractive(const int32_t* seqlens_k, int32_t* seqlens_k_buff, + const int batch_size, const int sequence_length) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid < batch_size) { + seqlens_k_buff[tid] = seqlens_k[tid] + 1 - sequence_length; + } +} +// Calculate past sequence length for each batch entry for flash attention kernel +Status LaunchGetSeqlensInteractive(const int32_t* seqlens_k, int32_t* seqlens_k_buff, + const int batch_size, const int sequence_length, cudaStream_t stream, + const int max_threads_per_block) { + const int threads = std::min(batch_size, max_threads_per_block); + const int blocks = (threads / max_threads_per_block) + 1; + GetSeqlensInteractive<<>>(seqlens_k, seqlens_k_buff, batch_size, + sequence_length); return CUDA_CALL(cudaGetLastError()); } @@ -576,7 +598,22 @@ Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unp return CUDA_CALL(cudaGetLastError()); } -// Kernel to convert seqlens_k to position_ids +__global__ void SeqlensToPosIdsInteractive(const int32_t* seqlens_k, int64_t* position_ids, + const int seqlen, const int batch_size) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + int b = tid / seqlen; + int s = tid % seqlen; + if (b < batch_size) { + const int total_seqlen = seqlens_k[b] + 1; + const int past_seqlen = total_seqlen - seqlen; + if (past_seqlen + s < total_seqlen) { + position_ids[tid] = past_seqlen + s; + } else { + position_ids[tid] = 1; + } + } +} + __global__ void SeqlensToPosIdsPrompt(const int32_t* seqlens_k, int64_t* position_ids, const int seqlen, const int batch_size) { int tid = blockDim.x * blockIdx.x + threadIdx.x; @@ -591,7 +628,6 @@ __global__ void SeqlensToPosIdsPrompt(const int32_t* seqlens_k, int64_t* positio } } -// Kernel to convert seqlens_k to position_ids __global__ void SeqlensToPosIdsToken(const int32_t* seqlens_k, int64_t* position_ids, const int batch_size) { int tid = blockDim.x * blockIdx.x + threadIdx.x; if (tid < batch_size) { @@ -601,12 +637,15 @@ __global__ void SeqlensToPosIdsToken(const int32_t* seqlens_k, int64_t* position // Convert seqlens_k to position_ids Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, const int32_t* seqlens_k, - int64_t* position_ids, cudaStream_t stream, const int max_threads_per_block) { + int64_t* position_ids, cudaStream_t stream, + const int max_threads_per_block) { const int seqlen = parameters.sequence_length; const int batch_size = parameters.batch_size; const int threads = max_threads_per_block; const int blocks = (batch_size * seqlen + threads - 1) / threads; - if (parameters.is_prompt) { + if (parameters.is_subsequent_prompt) { + SeqlensToPosIdsInteractive<<>>(seqlens_k, position_ids, seqlen, batch_size); + } else if (parameters.is_first_prompt) { SeqlensToPosIdsPrompt<<>>(seqlens_k, position_ids, seqlen, batch_size); } else { SeqlensToPosIdsToken<<>>(seqlens_k, position_ids, batch_size); @@ -650,7 +689,12 @@ Status FlashAttention( } void* seqlens_k = reinterpret_cast(data.seqlens_k); - if (parameters.is_prompt) { + if (parameters.is_subsequent_prompt) { + ORT_RETURN_IF_ERROR(LaunchGetSeqlensInteractive(reinterpret_cast(data.seqlens_k), + reinterpret_cast(data.seqlens_k_buff), batch_size, + sequence_length, stream, max_threads_per_block)); + seqlens_k = reinterpret_cast(data.seqlens_k_buff); + } else if (parameters.is_first_prompt) { // set seqlens_k to zeros... flash api uses seqlens_k to indicate where to append key and value // user should use seqlens_k to index into output to get new tokens if (batch_size <= parameters.zeros_count) { @@ -659,10 +703,12 @@ Status FlashAttention( // Launch kernel to create larger seqlen tensor when batch_size > 256 constexpr int thr_per_blk = 256; int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk; - repeat_seqlen<<>>(data.seqlens_k_total, 0, batch_size); - seqlens_k = data.seqlens_k_total; + repeat_seqlen<<>>(data.seqlens_k_buff, 0, batch_size); + seqlens_k = reinterpret_cast(data.seqlens_k_buff); } - } else if (!parameters.kv_share_buffer) { // copy past kv to present kv + } + + if (!parameters.kv_share_buffer || parameters.is_first_prompt) { // copy past kv to present kv ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, nullptr, nullptr, stream, max_threads_per_block, true)); } @@ -682,7 +728,7 @@ Status FlashAttention( reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved, parameters.is_packed_qkv)); - // if (parameters.left_padding && parameters.is_prompt) { + // if (parameters.left_padding && parameters.is_first_prompt) { // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); // } @@ -766,15 +812,16 @@ Status EfficientAttention( key = reinterpret_cast(k_buffer); } - if (parameters.is_prompt) { + if (parameters.is_subsequent_prompt || !parameters.is_first_prompt) { + ORT_RETURN_IF_ERROR(LaunchGetSeqlensTotal(data.seqlens_k, data.seqlens_k_buff, batch_size, stream, 256)); + } else { // Launch kernel to copy seqlen constexpr int thr_per_blk = 256; int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk; - repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, + repeat_seqlen<<>>(data.seqlens_k_buff, parameters.sequence_length, batch_size); - } else { - ORT_RETURN_IF_ERROR(LaunchGetSeqlenBuff(parameters, data.seqlens_k, data.seqlens_k_total, true, stream, 256)); } + int* seqlens_k = data.seqlens_k_buff; if (parameters.kv_share_buffer) { // Share buffer case @@ -815,7 +862,7 @@ Status EfficientAttention( } DUMP_TENSOR_INIT(); - DUMP_TENSOR("seqlens_k", data.seqlens_k_total, batch_size, 1); + DUMP_TENSOR("seqlens_k", seqlens_k, batch_size, 1); MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; @@ -823,14 +870,14 @@ Status EfficientAttention( p.batch_size = batch_size; p.num_heads = num_heads; p.sequence_length = sequence_length; - p.kv_sequence_length = present_sequence_length; // TOTALLY UNNECESSARY IF WE HAVE SEQLENS_K, maybe remove + p.kv_sequence_length = present_sequence_length; // maybe remove p.max_sequence_length = present_sequence_length; p.qk_head_size = head_size; p.v_head_size = head_size; p.causal = true; p.scale = scale; p.softcap = parameters.softcap; - p.seqlen_k_ptr = data.seqlens_k_total; // Note: seqlens_k is total sequence length for efficient + p.seqlen_k_ptr = seqlens_k; // Note: seqlens_k is total sequence length for efficient p.seqstart_q_ptr = nullptr; p.seqstart_k_ptr = nullptr; p.query = query; @@ -912,7 +959,7 @@ template Status LaunchConcatKVInPlace(int batch_size, int kv_num_heads, int head_size, int max_sequence_length, - const int* past_seqlens_k, + const int* seqlens_k, const int* total_seqlens_k, int new_seq_len, const half* new_key, @@ -928,7 +975,7 @@ template Status LaunchConcatKVInPlace(int batch_size, int kv_num_heads, int head_size, int max_sequence_length, - const int* past_seqlens_k, + const int* seqlens_k, const int* total_seqlens_k, int new_seq_len, const BFloat16* new_key, diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index e8dc69188b95f..8593ecede2bab 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -27,7 +27,7 @@ struct GroupQueryAttentionData { T* softmax_lse = nullptr; T* softmax_lse_accum = nullptr; T* out_accum = nullptr; - int* seqlens_k_total = nullptr; + int* seqlens_k_buff = nullptr; // Memory Efficient buffers T* fmha_buffer = nullptr; T* unpacked_qkv_buffer = nullptr; @@ -61,7 +61,7 @@ Status LaunchConcatKVInPlace(int batch_size, int kv_num_heads, int head_size, int max_sequence_length, // max sequence length of present_key or present_value. - const int* past_seqlens_k, // it is not used when total_seqlens_k is available. + const int* seqlens_k, // it is not used when total_seqlens_k is available. const int* total_seqlens_k, // optional, nullptr means it is not available. int new_seq_len, const T* new_key, diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu index af9e87eaf225d..ce6c07fbed2bc 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu @@ -289,7 +289,7 @@ bool TryMatMul4Bits( return false; } dim3 blocks((n + kColsPerThreadBlock - 1) / kColsPerThreadBlock, m); - dim3 threads(kWarpSize, kColsPerThreadBlock); + dim3 threads(GPU_WARP_SIZE_HOST, kColsPerThreadBlock); int blocks_per_K = (k + block_size - 1) / block_size; int shared_mem_size = sizeof(T) * blocks_per_K * kColsPerThreadBlock + (zero_points != nullptr ? (blocks_per_K + 1) / 2 * kColsPerThreadBlock * 2 : 0); diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu index 7a16eb38181aa..e644b7e903138 100644 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -5,7 +5,7 @@ #include "core/providers/rocm/rocm_common.h" #include "core/platform/env_var_utils.h" #include "contrib_ops/rocm/bert/group_query_attention.h" -#include "contrib_ops/rocm/bert/group_query_attention_helper.h" +#include "contrib_ops/cpu/bert/group_query_attention_helper.h" #include "contrib_ops/rocm/bert/rotary_embedding_impl.h" #include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" @@ -115,7 +115,7 @@ Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, const int batch_size = parameters.batch_size; const int threads = max_threads_per_block; const int blocks = (batch_size * seqlen + threads - 1) / threads; - if (parameters.is_prompt) { + if (parameters.is_first_prompt) { SeqlensToPosIdsPrompt<<>>(seqlens_k, position_ids, seqlen, batch_size); } else { SeqlensToPosIdsToken<<>>(seqlens_k, position_ids, batch_size); @@ -325,7 +325,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { // build present kv cache auto* present_key_ptr = reinterpret_cast(present_key->MutableDataRaw()); auto* present_value_ptr = reinterpret_cast(present_value->MutableDataRaw()); - if (parameters.is_prompt) { + if (parameters.is_first_prompt) { // copy prompt kv to present kv ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, key_ptr, kv_shape, key_strides.ForBNSHCoord(), present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); @@ -383,7 +383,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { return ret; } - if (parameters.is_prompt && is_unidirectional_) { + if (parameters.is_first_prompt && is_unidirectional_) { return mask_info::decode("t", sequence_length, kv_sequence_length); } @@ -496,7 +496,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { parameters.head_size, parameters.head_size, // v head size GetCkFmhaDataTypeString(), - !parameters.is_prompt, // true, // is_group_mode + !parameters.is_first_prompt, // true, // is_group_mode true, // is_v_rowmajor ? dim is fastest : seq is fastest mask.type, bias_type, diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index c3e96e450c59b..5e66f2b99fded 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -145,6 +145,10 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA *out = new OrtMemoryInfo( name1, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(id1)), id1, mem_type1); + } else if (strcmp(name1, onnxruntime::OpenVINO_RT_NPU) == 0) { + *out = new OrtMemoryInfo( + name1, type, OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, static_cast(id1)), id1, + mem_type1); } else if (strcmp(name1, onnxruntime::CUDA_PINNED) == 0) { *out = new OrtMemoryInfo( onnxruntime::CUDA_PINNED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast(id1)), diff --git a/onnxruntime/core/framework/kernel_type_str_resolver.cc b/onnxruntime/core/framework/kernel_type_str_resolver.cc index d05e02eb3ab32..3142f94f289b3 100644 --- a/onnxruntime/core/framework/kernel_type_str_resolver.cc +++ b/onnxruntime/core/framework/kernel_type_str_resolver.cc @@ -46,12 +46,8 @@ Status KernelTypeStrResolver::ResolveKernelTypeStr(const Node& node, std::string ORT_RETURN_IF(op_it == op_kernel_type_str_map_.end(), "Failed to find op_id: ", op_id); const auto& type_str_map = op_it->second; -#ifdef DISABLE_ABSEIL // TODO(edgchen1) maybe we can use transparent hash/eq to enable lookup with string_view const auto type_str_it = type_str_map.find(std::string(kernel_type_str)); -#else - const auto type_str_it = type_str_map.find(kernel_type_str); -#endif ORT_RETURN_IF(type_str_it == type_str_map.end(), "Failed to find args for kernel type string '", kernel_type_str, diff --git a/onnxruntime/core/framework/ort_value_name_idx_map.h b/onnxruntime/core/framework/ort_value_name_idx_map.h index 1b5f6bcee9bd0..76e7e369514d4 100644 --- a/onnxruntime/core/framework/ort_value_name_idx_map.h +++ b/onnxruntime/core/framework/ort_value_name_idx_map.h @@ -33,11 +33,7 @@ class OrtValueNameIdxMap { common::Status GetIdx(std::string_view name, int& idx) const { idx = -1; -#ifdef DISABLE_ABSEIL auto it = map_.find(std::string(name)); -#else - auto it = map_.find(name); -#endif if (it == map_.end()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Could not find OrtValue with name '", name, "'"); } diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 5185205f1dde1..c706c6fc5ff5f 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1049,6 +1049,8 @@ Supports different number of heads for q and kv for CPU and CUDA. Only supports causal and local attention. Supports rotary position embedding for CPU and CUDA. Supports packed input for CPU and CUDA. +Supports continuous decoding for batch_size == 1 for CPU and CUDA. + )DOC"; ONNX_MS_OPERATOR_SET_SCHEMA( @@ -1110,12 +1112,12 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Input(5, "seqlens_k", - // For prompt, the value is number of tokens (excluding padding) - 1. - "1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.", + "1D Tensor of shape (batch_size). Equivalent to (total_sequence_lengths - 1).", "M") .Input(6, "total_sequence_length", - "Scalar tensor of total sequence length (past + new).", + "Scalar tensor equivalent to the maximum total sequence length (past + new) of the batch. Used for " + "checking inputs and determining prompt vs token generation case.", "M") .Input(7, "cos_cache", diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 8b3156d77e57c..28ae64c4d5b3e 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -20,6 +20,7 @@ Module Name: #include #include #include +#include // // Define the calling convention for Windows targets. @@ -1025,18 +1026,6 @@ MlasComputeTanh( size_t N ); -// -// Half-precision floating-point routines. -// - -void -MLASCALL -MlasConvertHalfToFloatBuffer( - const unsigned short* Source, - float* Destination, - size_t Count -); - // // Transpose routines. // @@ -1426,7 +1415,27 @@ using MLAS_FP16 = onnxruntime::MLFloat16; constexpr size_t FP16_SIZE = sizeof(uint16_t); -/** +// +// Half-precision floating-point routines. +// + +void +MLASCALL +MlasConvertHalfToFloatBuffer( + const MLAS_FP16* Source, + float* Destination, + size_t Count +); + +void +MLASCALL +MlasConvertFloatToHalfBuffer( +const float* Source, +MLAS_FP16* Destination, +size_t Count +); + + /** * @brief Whether current CPU supports FP16 acceleration. */ bool MLASCALL @@ -1787,6 +1796,7 @@ MlasTranspose( M, N); } + #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED /** * @brief Max Pooling for fp16 NHWC diff --git a/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm b/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm index c7f6342c527bf..800863c77a230 100644 --- a/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm +++ b/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm @@ -54,7 +54,7 @@ HIGH_SELECTOR equ 00110001b LEAF_ENTRY MlasCastF16ToF32KernelAvx, _TEXT - test r8, r8 ; Check if we have any elements to convert + test r8, r8 ; Check if we have any elements to convert jz ExitRoutine cmp r8, 8 jb ConvertMaskedVectors @@ -80,6 +80,8 @@ Convert256Vectors: jz ExitRoutine ; If we are done, exit cmp r8, 16 ; If the vector is big enough, we go again jae Convert256Vectors + cmp r8, 8 ; Check if we have enough elements to convert + jb ConvertMaskedVectors diff --git a/onnxruntime/core/mlas/lib/cast.cpp b/onnxruntime/core/mlas/lib/cast.cpp index 24af4064bbd9b..a6138e29bd796 100644 --- a/onnxruntime/core/mlas/lib/cast.cpp +++ b/onnxruntime/core/mlas/lib/cast.cpp @@ -23,37 +23,35 @@ union fp32_bits { void MLASCALL MlasConvertHalfToFloatBuffer( - const unsigned short* Source, + const MLAS_FP16* Source, float* Destination, size_t Count ) { - if (GetMlasPlatform().CastF16ToF32Kernel == nullptr) { - // If there is no kernel use the reference implementation, adapted from mlas_float16.h. - constexpr fp32_bits magic = {113 << 23}; - constexpr uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift + for (size_t i = 0; i < Count; ++i) { + Destination[i] = Source[i].ToFloat(); + } + } else { + // If the kernel is available, use it to perform the conversion. + GetMlasPlatform().CastF16ToF32Kernel(reinterpret_cast(Source), Destination, Count); + } +} +void +MLASCALL +MlasConvertFloatToHalfBuffer( + const float* Source, + MLAS_FP16* Destination, + size_t Count +) +{ + if (GetMlasPlatform().CastF32ToF16Kernel == nullptr) { for (size_t i = 0; i < Count; ++i) { - fp32_bits o; - o.u = (Source[i] & 0x7fff) << 13; // exponent/mantissa bits - uint32_t exp = shifted_exp & o.u; // just the exponent - o.u += (127 - 15) << 23; // exponent adjust - - // handle exponent special cases - if (exp == shifted_exp) { // Inf/NaN? - o.u += (128 - 16) << 23; // extra exp adjust - } else if (exp == 0) { // Zero/Denormal? - o.u += 1 << 23; // extra exp adjust - o.f -= magic.f; // renormalize - } - - o.u |= (Source[i] & 0x8000) << 16; // sign bit - Destination[i] = o.f; + Destination[i] = MLAS_FP16(Source[i]); } - } else { // If the kernel is available, use it to perform the conversion. - GetMlasPlatform().CastF16ToF32Kernel(Source, Destination, Count); + GetMlasPlatform().CastF32ToF16Kernel(Source, reinterpret_cast(Destination), Count); } } diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 6f5db766b7def..8e8f46b8a102e 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -610,13 +610,19 @@ void size_t N ); -typedef +typedef void(MLASCALL MLAS_CAST_F16_TO_F32_KERNEL)( const unsigned short* Source, float* Destination, size_t Count ); +typedef void(MLASCALL MLAS_CAST_F32_TO_F16_KERNEL)( + const float* Source, + unsigned short* Destination, + size_t Count +); + typedef void (MLASCALL MLAS_QLINEAR_BINARY_OP_S8_KERNEL)( @@ -880,6 +886,8 @@ extern "C" { #if defined(MLAS_TARGET_AMD64) MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelSse; MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelAvx; + MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelAvx2; + MLAS_CAST_F32_TO_F16_KERNEL MlasCastF32ToF16KernelAvx2; #endif } @@ -1165,6 +1173,7 @@ struct MLAS_PLATFORM { const MLAS_SQNBIT_GEMM_DISPATCH* SQNBitGemmDispatch{nullptr}; MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel; + MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel; }; inline diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 4cd7faaa9e6ff..2b4d99800c546 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -245,6 +245,7 @@ Return Value: this->ConvDepthwiseS8S8Kernel = MlasConvDepthwiseKernel; this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernel; this->CastF16ToF32Kernel = nullptr; + this->CastF32ToF16Kernel = nullptr; #if defined(MLAS_TARGET_AMD64_IX86) @@ -387,6 +388,9 @@ Return Value: this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernelAvx2; this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelFma3; this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2; + this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx2; + this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelAvx2; + // // Check if the processor supports Hybrid core architecture. diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index 55d86bb9cc18e..baaa4ba1a3b1f 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -29,6 +29,51 @@ Module Name: #include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h" #include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h" +void +MlasCastF16ToF32KernelAvx2(const unsigned short* src_fp16, float* dst_fp32, size_t size) +{ + size_t i = 0; + + // Process 16 elements at a time using AVX2 + for (; i + 15 < size; i += 16) { + // Load 16 FP16 values into an AVX2 register + __m256i fp16_values = _mm256_loadu_si256(reinterpret_cast(src_fp16 + i)); + + // Convert FP16 values to FP32 + __m256 fp32_values1 = _mm256_cvtph_ps(_mm256_castsi256_si128(fp16_values)); + __m256 fp32_values2 = _mm256_cvtph_ps(_mm256_extracti128_si256(fp16_values, 1)); + + // Store the converted FP32 values into the output vector + _mm256_storeu_ps(dst_fp32 + i, fp32_values1); + _mm256_storeu_ps(dst_fp32 + i + 8, fp32_values2); + } + + // Process any remaining elements + const MLAS_FP16* fp16 = reinterpret_cast(src_fp16); + for (; i < size; ++i) { + dst_fp32[i] = fp16[i].ToFloat(); + } +} + +void +MlasCastF32ToF16KernelAvx2(const float* src_fp32, unsigned short* dst_fp16, size_t size) +{ + size_t i = 0; + + // Process 8 elements at a time using AVX2 + for (; i + 8 <= size; i += 8) { + __m256 fp32_chunk = _mm256_loadu_ps(&src_fp32[i]); + __m128i fp16_chunk = _mm256_cvtps_ph(fp32_chunk, _MM_FROUND_TO_NEAREST_INT); + _mm_storeu_si128(reinterpret_cast<__m128i*>(&dst_fp16[i]), fp16_chunk); + } + + // Process any remaining elements + for (; i < size; ++i) { + MLAS_FP16 fp16(src_fp32[i]); + dst_fp16[i] = fp16.val; + } +} + MLAS_FORCEINLINE __m256 load_float_n_avx2(const float* data, int n) diff --git a/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S b/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S index 1a70061460e50..a4d730fa513ab 100644 --- a/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S +++ b/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S @@ -51,8 +51,6 @@ FUNCTION_ENTRY MlasCastF16ToF32KernelAvx test rdx, rdx // Check if we have any elements to convert jz ExitRoutine - -AVX_NE_CONVERT: cmp rdx, 8 jb ConvertMaskedVectors cmp rdx, 16 @@ -75,6 +73,8 @@ Convert256Vectors: jz ExitRoutine // If we are done, exit cmp rdx, 16 // If the vector is big enough, we go again jae Convert256Vectors + cmp rdx, 8 // Check if we have enough elements to convert + jb ConvertMaskedVectors diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 0530ab771e0be..997d99441d36d 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "core/optimizer/graph_transformer_utils.h" @@ -196,6 +197,8 @@ InlinedVector> GenerateTransformers( session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1"; #ifndef DISABLE_CONTRIB_OPS const InlinedHashSet cpu_ep = {onnxruntime::kCpuExecutionProvider}; + const InlinedHashSet cpu_acl_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kAclExecutionProvider}; #endif const InlinedHashSet dml_ep = {onnxruntime::kDmlExecutionProvider}; AllocatorPtr cpu_allocator = std::make_shared(); @@ -285,6 +288,11 @@ InlinedVector> GenerateTransformers( onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider, onnxruntime::kDmlExecutionProvider}; + const InlinedHashSet cpu_acl_cuda_dml_rocm_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kAclExecutionProvider, + onnxruntime::kCudaExecutionProvider, + onnxruntime::kRocmExecutionProvider, + onnxruntime::kDmlExecutionProvider}; const InlinedHashSet cpu_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider, onnxruntime::kRocmExecutionProvider, onnxruntime::kAclExecutionProvider, @@ -296,8 +304,9 @@ InlinedVector> GenerateTransformers( onnxruntime::kAclExecutionProvider, onnxruntime::kArmNNExecutionProvider, onnxruntime::kJsExecutionProvider}; - const InlinedHashSet cpu_dml_eps = {onnxruntime::kCpuExecutionProvider, - onnxruntime::kDmlExecutionProvider}; + const InlinedHashSet cpu_dml_acl_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kDmlExecutionProvider, + onnxruntime::kAclExecutionProvider}; const int64_t qdq_matmulnbits_accuracy_level = ParseStringWithClassicLocale( session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, @@ -323,26 +332,26 @@ InlinedVector> GenerateTransformers( } transformers.emplace_back(std::make_unique(cpu_ep)); - transformers.emplace_back(std::make_unique(cpu_dml_eps)); - transformers.emplace_back(std::make_unique(cpu_ep)); + transformers.emplace_back(std::make_unique(cpu_dml_acl_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_eps)); transformers.emplace_back(std::make_unique(cpu_rocm_acl_armnn_js_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps, level)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps, level)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps, level)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps, level)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); // GeluApproximation has side effects which may change results. It needs to be manually enabled, // or alternatively the model can be updated offline using a model conversion script @@ -367,7 +376,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); #endif - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(dml_ep)); #ifdef MLAS_TARGET_AMD64_IX86 diff --git a/onnxruntime/core/optimizer/nhwc_transformer.cc b/onnxruntime/core/optimizer/nhwc_transformer.cc index e67557dcf9391..ee79fa620374e 100644 --- a/onnxruntime/core/optimizer/nhwc_transformer.cc +++ b/onnxruntime/core/optimizer/nhwc_transformer.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include @@ -183,7 +184,8 @@ Status NhwcTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, modified = false; for (std::unique_ptr& node : api_graph->Nodes()) { // If the node is not supported in the CPU EP, skip it - if (node->GetExecutionProviderType() != kCpuExecutionProvider) { + const auto ep = node->GetExecutionProviderType(); + if ((ep != kCpuExecutionProvider) && (ep != kAclExecutionProvider)) { continue; } diff --git a/onnxruntime/core/optimizer/pad_fusion.cc b/onnxruntime/core/optimizer/pad_fusion.cc index 3391e20cf0bb7..25afed52403c4 100644 --- a/onnxruntime/core/optimizer/pad_fusion.cc +++ b/onnxruntime/core/optimizer/pad_fusion.cc @@ -31,15 +31,15 @@ bool VerifyNotCastChild(const Node& child_node) { return false; } - // This pass currently assumed that this attribute already exists on the child node - if (child_node.GetAttributes().find("pads") == child_node.GetAttributes().end()) { - return false; - } - return true; } void UpdatePaddingAttribute(Node& child_node, const std::vector& pads_values, const uint32_t pads_size) { + if (child_node.GetAttributes().find("pads") == child_node.GetAttributes().end()) { + std::vector pads(pads_size - 4, 0); + child_node.AddAttribute("pads", pads); + } + auto child_pads = child_node.GetMutableAttributes()["pads"].mutable_ints(); uint32_t child_pads_size = static_cast(child_pads->size()); @@ -162,4 +162,4 @@ Status PadFusion::Apply(Graph& graph, Node& pad_node, RewriteRuleEffect& rule_ef rule_effect = RewriteRuleEffect::kRemovedCurrentNode; return Status::OK(); } -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index adfa680878945..1c506bafd1d14 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include @@ -381,9 +382,9 @@ QDQSelectorActionTransformer::QDQSelectorActionTransformer( CreateSelectorActionRegistry(is_int8_allowed, qdq_matmulnbits_accuracy_level, intra_op_thread_pool, p_buffered_tensors), apply_context, - // this transformer is compatible with CPU, DML and CUDA EP. + // this transformer is compatible with CPU, DML, ACL and CUDA EP. // There is further EP control on the rule level. - {kCpuExecutionProvider, kDmlExecutionProvider, kCudaExecutionProvider}} { + {kCpuExecutionProvider, kDmlExecutionProvider, kAclExecutionProvider, kCudaExecutionProvider}} { } } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index df81367c5bbee..5d689a9d933e8 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -78,7 +78,7 @@ static std::unique_ptr MakeNode1Attr(api::GraphRef& graph, std::st std::string_view input, std::string_view attr_name, const std::vector& attr_val) { std::vector inputs{input}; - std::unique_ptr node = graph.AddNode(op_type, inputs, /*num_outputs*/ 1); + std::unique_ptr node = graph.AddNode(op_type, op_type, inputs, /*num_outputs*/ 1); node->SetAttributeInts(attr_name, attr_val); return node; } @@ -102,7 +102,7 @@ static std::unique_ptr MakeSqueezeOrUnsqueeze(int64_t opset, api:: std::vector inputs{input, axes_initializer}; - return graph.AddNode(op_type, inputs, /*num_outputs*/ 1); + return graph.AddNode(op_type, op_type, inputs, /*num_outputs*/ 1); } ///

@@ -136,7 +136,7 @@ static std::unique_ptr MakeQuantizeOp(api::GraphRef& graph, std::s std::optional block_size, std::optional output_dtype, std::optional saturate) { - std::unique_ptr node = graph.AddNode("QuantizeLinear", inputs, /* num_outputs */ 1, domain); + std::unique_ptr node = graph.AddNode("QuantizeLinear", "QuantizeLinear", inputs, /* num_outputs */ 1, domain); SetAttrIfNotDefault(*node, "axis", axis, 1); @@ -170,7 +170,7 @@ static std::unique_ptr MakeDequantizeOp(api::GraphRef& graph, std: std::vector inputs, std::optional axis, std::optional block_size) { - std::unique_ptr node = graph.AddNode("DequantizeLinear", inputs, /* num_outputs */ 1, domain); + std::unique_ptr node = graph.AddNode("DequantizeLinear", "DequantizeLinear", inputs, /* num_outputs */ 1, domain); SetAttrIfNotDefault(*node, "axis", axis, 1); @@ -1724,7 +1724,7 @@ static bool HandleShape(HandlerArgs& args) { // X -> Shape -> Y, Gather std::vector gather_inputs{"", perm_const}; - auto gather_ptr = args.ctx.graph.AddNode("Gather", gather_inputs, /*num_outputs*/ 1); + auto gather_ptr = args.ctx.graph.AddNode("Gather", "Gather", gather_inputs, /*num_outputs*/ 1); api::NodeRef& gather = *gather_ptr; gather.SetAttributeInt("axis", 0); @@ -1767,7 +1767,7 @@ static void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, con // inputs that would never be quantized. std::string_view gather_indices_const = AddInitializerInt64(graph, /*shape*/ {rank_int}, perm); std::vector gather_inputs{input_name, gather_indices_const}; - auto gather_ptr = graph.AddNode("Gather", gather_inputs, /*num_outputs*/ 1); + auto gather_ptr = graph.AddNode("Gather", "Gather", gather_inputs, /*num_outputs*/ 1); api::NodeRef& gather = *gather_ptr; std::string_view gather_output = gather.Outputs()[0]; graph.CopyValueInfo(input_name, gather_output); @@ -2215,7 +2215,7 @@ static bool HandleTile(HandlerArgs& args) { // Case 2: Repeats is computed. Insert Gather node. std::string_view perm_inv_const = AddInitializerInt64(args.ctx.graph, perm_shape, args.perm_inv); std::vector gather_inputs{repeats_inp, perm_inv_const}; - auto gather_node_ptr = args.ctx.graph.AddNode("Gather", gather_inputs, /*num_outputs*/ 1); + auto gather_node_ptr = args.ctx.graph.AddNode("Gather", "Gather", gather_inputs, /*num_outputs*/ 1); api::NodeRef& gather_node = *gather_node_ptr; std::string_view gather_output = gather_node.Outputs()[0]; args.ctx.graph.CopyValueInfo(repeats_inp, gather_output); @@ -2265,7 +2265,7 @@ static void RemoveCancelingTransposeNodes(HandlerArgs& args) { // Worst-case scenario: Both parent output and 2nd transpose/reshape output cannot be removed (both graph outputs) // despite computing the same value. Use an Identity op instead. std::vector single_empty_input{""}; - auto identity_ptr = args.ctx.graph.AddNode("Identity", single_empty_input, /*num_outputs*/ 1); + auto identity_ptr = args.ctx.graph.AddNode("Identity", "Identity", single_empty_input, /*num_outputs*/ 1); api::NodeRef& identity = *identity_ptr; args.ctx.graph.MoveOutput(args.node, 0, identity, 0); identity.SetInput(0, transpose_input); @@ -2297,7 +2297,7 @@ static bool HandleTransposeImpl(HandlerArgs& args, const std::vector& n // replace Reshape with Transpose to simplify the logic. // use the same input as the 1st Transpose, move the output from the Reshape to the new Transpose node, // and remove the Reshape node. - new_node = args.ctx.graph.AddNode("Transpose", {args.transpose.Inputs()[0]}, 1); + new_node = args.ctx.graph.AddNode("Transpose", "Transpose", {args.transpose.Inputs()[0]}, 1); args.ctx.graph.MoveOutput(args.node, 0, *new_node, 0); args.ctx.graph.RemoveNode(args.node); } else { diff --git a/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h b/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h index 211734f4bacc8..7122aec45e61a 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h +++ b/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h @@ -146,6 +146,9 @@ class ValueInfoRef { /// class NodeRef { public: + /// Node name + virtual std::string_view Name() const = 0; + /// Op computed by the node virtual std::string_view OpType() const = 0; @@ -361,6 +364,7 @@ class GraphRef { /// generated. Outputs of created node have unspecified shapes/dtypes. They will be populated afterwards using /// CopyValueInfo. /// + /// The new node's name /// The new node's op type /// Inputs for the node. "" for missing optional inputs. /// @@ -368,7 +372,7 @@ class GraphRef { /// /// The new node's domain. Empty string signifies default onnx domain. /// The new node - virtual std::unique_ptr AddNode(std::string_view op_type, const std::vector& inputs, + virtual std::unique_ptr AddNode(std::string_view name, std::string_view op_type, const std::vector& inputs, size_t num_outputs, std::string_view domain = /*kOnnxDomain*/ "") = 0; /// diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc index 33408474f92a6..f87df746234fa 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc @@ -80,6 +80,10 @@ class ApiNode final : public api::NodeRef { return node_; } + std::string_view Name() const override { + return node_.Name(); + } + std::string_view OpType() const override { return node_.OpType(); } @@ -134,7 +138,7 @@ class ApiGraph final : public api::GraphRef { std::unique_ptr GetNodeProducingOutput(std::string_view name) const override; void TransposeInitializer(std::string_view name, const std::vector& perm) override; void ReshapeInitializer(std::string_view name, const std::vector& shape) override; - std::unique_ptr AddNode(std::string_view op_type, const std::vector& inputs, + std::unique_ptr AddNode(std::string_view name, std::string_view op_type, const std::vector& inputs, size_t num_outputs = 1, std::string_view domain = "") override; std::unique_ptr CopyNode(const api::NodeRef& source_node, std::string_view op_type, @@ -621,11 +625,12 @@ void ApiGraph::ReshapeInitializer(std::string_view name, const std::vectorSetShape(new_shape); } -static Node& CreateNodeHelper(onnxruntime::Graph& graph, std::string_view op_type, +static Node& CreateNodeHelper(onnxruntime::Graph& graph, std::string_view op_name, std::string_view op_type, const std::vector& inputs, size_t num_outputs, std::string_view domain, int since_version, std::string_view node_ep) { const std::string op_type_str(op_type); - std::string name = graph.GenerateNodeName(op_type_str); + const std::string op_name_str(op_name); + std::string name = graph.GenerateNodeName(op_name_str); std::vector input_args; std::vector output_args; @@ -731,11 +736,11 @@ static int GetSinceVersionForNewOp(std::string_view op_type, std::string_view do return *since_version; } -std::unique_ptr ApiGraph::AddNode(std::string_view op_type, +std::unique_ptr ApiGraph::AddNode(std::string_view name, std::string_view op_type, const std::vector& inputs, size_t num_outputs, std::string_view domain) { int since_version = GetSinceVersionForNewOp(op_type, domain, graph_.DomainToVersionMap()); - Node& node = CreateNodeHelper(graph_, op_type, inputs, num_outputs, + Node& node = CreateNodeHelper(graph_, name, op_type, inputs, num_outputs, domain, since_version, new_node_ep_ != nullptr ? new_node_ep_ : ""); return std::make_unique(node, graph_); @@ -744,7 +749,7 @@ std::unique_ptr ApiGraph::AddNode(std::string_view op_type, std::unique_ptr ApiGraph::CopyNode(const api::NodeRef& source_node, std::string_view op_type, std::string_view domain, std::optional since_version) { const int new_node_since_version = since_version.has_value() ? *since_version : source_node.SinceVersion(); - Node& node = CreateNodeHelper(graph_, op_type, source_node.Inputs(), + Node& node = CreateNodeHelper(graph_, source_node.Name(), op_type, source_node.Inputs(), source_node.Outputs().size(), domain, new_node_since_version, source_node.GetExecutionProviderType()); diff --git a/onnxruntime/core/providers/acl/acl_common.cc b/onnxruntime/core/providers/acl/acl_common.cc index f1ab6682a8259..c8d878a81bd1a 100644 --- a/onnxruntime/core/providers/acl/acl_common.cc +++ b/onnxruntime/core/providers/acl/acl_common.cc @@ -1,5 +1,6 @@ // Copyright(C) 2018 Intel Corporation // Copyright (c) 2019-2020, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License #ifdef _WIN32 @@ -8,14 +9,45 @@ #include "core/providers/acl/acl_common.h" -#include "arm_compute/runtime/PoolManager.h" -#include "arm_compute/runtime/BlobLifetimeManager.h" - -#undef ACL_1902 - namespace onnxruntime { namespace acl { +void PopulateWorkspace(const arm_compute::experimental::MemoryRequirements& reqs, + Workspace& workspace, arm_compute::MemoryGroup& memory_group, + arm_compute::ITensorPack& run_pack, arm_compute::ITensorPack& prep_pack) { + for (const arm_compute::experimental::MemoryInfo& req : reqs) { + if (req.size == 0) { + continue; + } + + arm_compute::Tensor* aux_tensor; + if (req.lifetime == arm_compute::experimental::MemoryLifetime::Temporary) { + workspace.temporary_tensors.emplace_back(std::make_unique()); + aux_tensor = workspace.temporary_tensors.back().get(); + + memory_group.manage(aux_tensor); + } else if (req.lifetime == arm_compute::experimental::MemoryLifetime::Prepare) { + workspace.prepare_tensors.emplace_back(std::make_unique()); + aux_tensor = workspace.prepare_tensors.back().get(); + + prep_pack.add_tensor(req.slot, aux_tensor); + } else { + workspace.persistent_tensors.emplace_back(std::make_unique()); + aux_tensor = workspace.persistent_tensors.back().get(); + + prep_pack.add_tensor(req.slot, aux_tensor); + } + run_pack.add_tensor(req.slot, aux_tensor); + + const auto aux_info = arm_compute::TensorInfo{arm_compute::TensorShape(req.size), 1, arm_compute::DataType::U8}; + aux_tensor->allocator()->init(aux_info, req.alignment); + } + + for (const std::unique_ptr& tensor : workspace.temporary_tensors) { + tensor->allocator()->allocate(); + } +} + arm_compute::TensorShape ACLTensorShape(const TensorShape& tensorShape, unsigned int extDim) { arm_compute::TensorShape shape; unsigned int inDim = tensorShape.NumDimensions(); @@ -36,27 +68,112 @@ arm_compute::TensorShape ACLTensorShape(const TensorShape& tensorShape, unsigned return shape; } +Status GetArgShape(const NodeArg* tensor, TensorShape& outShape) { + const auto& inShape = tensor->Shape(); + TensorShapeVector shapeVec; + + for (int i = 0; i < inShape->dim_size(); i++) { + const auto& dim = inShape->dim(i); + ORT_RETURN_IF_NOT(dim.has_dim_value(), "ACL does not support unknown tensor shapes: ", tensor->Name()); + shapeVec.push_back(dim.dim_value()); + } + + outShape = TensorShape(shapeVec); + return Status::OK(); +} + void ACLPrintTensorShape(const char* s, arm_compute::Tensor& t) { for (unsigned int i = 0; i < t.info()->tensor_shape().num_dimensions(); i++) LOGS_DEFAULT(VERBOSE) << "ACL " << s << " " << t.info()->tensor_shape()[i]; LOGS_DEFAULT(VERBOSE) << std::endl; } -std::shared_ptr ACLCreateMemoryManager() { - auto lifetime_mgr = std::make_shared(); - auto pool_mgr = std::make_shared(); - auto mm = std::make_shared(lifetime_mgr, pool_mgr); +arm_compute::DataType ACLDataType(const std::string& dtype) { + if (dtype == "tensor(float)") { + return arm_compute::DataType::F32; + } + if (dtype == "tensor(float16)") { + return arm_compute::DataType::F16; + } + if (dtype == "tensor(bfloat16)") { + return arm_compute::DataType::BFLOAT16; + } + if (dtype == "tensor(uint8)") { + return arm_compute::DataType::QASYMM8; + } + if (dtype == "tensor(int8)") { + return arm_compute::DataType::QASYMM8_SIGNED; + } + if (dtype == "tensor(int32)") { + return arm_compute::DataType::S32; + } + ORT_THROW("ACL execution provider does not support data type ", dtype); +} - return mm; +int GetIntScalar(const Tensor* tensor) { + ORT_ENFORCE(tensor->Shape().Size() == 1, "Tensor is not a scalar"); + if (tensor->IsDataType()) { + return *tensor->Data(); + } + if (tensor->IsDataType()) { + return *tensor->Data(); + } + ORT_THROW("Unsupported int type: ", DataTypeImpl::ToString(tensor->DataType())); } -arm_compute::Status ACLImportMemory(arm_compute::TensorAllocator* allocator, void* memory, size_t size) { -#ifdef ACL_1902 - return allocator->import_memory(memory, size); -#else +Status LoadQuantizationInfo(const OpKernelInfo& info, arm_compute::Tensor* tensor, + const int scaleIdx, const int zpIdx, bool flipZeroPoint) { + const Tensor* scaleTensor = nullptr; + ORT_RETURN_IF_NOT(info.TryGetConstantInput(scaleIdx, &scaleTensor), "Scale must be constant"); + + const Tensor* zeroPointTensor = nullptr; + ORT_RETURN_IF_NOT(info.TryGetConstantInput(zpIdx, &zeroPointTensor), "Zero point must be constant"); + + const float* scale = scaleTensor->Data(); + const int zeroPoint = GetIntScalar(zeroPointTensor); + tensor->info()->set_quantization_info(arm_compute::QuantizationInfo(*scale, flipZeroPoint ? -zeroPoint : zeroPoint)); + + return Status::OK(); +} + +void GetPackingInfo(std::vector>& state, size_t& packedSize, size_t& alignment) { + alignment = 0; + for (auto& tensor : state) { + alignment = std::max(alignment, tensor->allocator()->alignment()); + } + + packedSize = 0; + for (auto& tensor : state) { + const size_t size = tensor->info()->total_size(); + packedSize += ((size - 1) / alignment + 1) * alignment; + } +} + +Status LoadPackedTensors(std::vector>& state, void* packed, + const size_t packedSize, const size_t alignment) { + auto buffSize = packedSize + alignment; + uint8_t* alignedPtr = (uint8_t*)(alignment == 0 ? packed : std::align(alignment, packedSize, packed, buffSize)); + + uint8_t* currentPtr = alignedPtr; + for (auto& tensor : state) { + ORT_RETURN_IF_ERROR(ACLImportMemory(tensor->allocator(), currentPtr, 0)); + + const size_t size = tensor->info()->total_size(); + currentPtr += ((size - 1) / alignment + 1) * alignment; + } + + return Status::OK(); +} + +Status ACLImportMemory(arm_compute::TensorAllocator* allocator, void* memory, size_t size) { ORT_UNUSED_PARAMETER(size); - return allocator->import_memory(memory); -#endif + arm_compute::Status status = allocator->import_memory(memory); + + if (status) { + return Status::OK(); + } else { + return Status(common::ONNXRUNTIME, common::FAIL, status.error_description()); + } } template @@ -71,12 +188,13 @@ void importDataToTensor(arm_compute::Tensor* tensor, const T* data) { arm_compute::execute_window_loop( aclInpuWindow, [&](const arm_compute::Coordinates& co) { - *reinterpret_cast(aclInputIt.ptr()) = data[index]; + *reinterpret_cast(aclInputIt.ptr()) = data[index]; index++; }, aclInputIt); } template void importDataToTensor(arm_compute::Tensor*, const float*); +template void importDataToTensor(arm_compute::Tensor*, const MLFloat16*); template void importDataFromTensor(arm_compute::Tensor* tensor, T* data) { @@ -89,12 +207,13 @@ void importDataFromTensor(arm_compute::Tensor* tensor, T* data) { arm_compute::execute_window_loop( aclInpuWindow, [&](const arm_compute::Coordinates& co) { - data[index] = *reinterpret_cast(aclInputIt.ptr()); + data[index] = *reinterpret_cast(aclInputIt.ptr()); index++; }, aclInputIt); } template void importDataFromTensor(arm_compute::Tensor*, float*); +template void importDataFromTensor(arm_compute::Tensor*, MLFloat16*); } // namespace acl } // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/acl_common.h b/onnxruntime/core/providers/acl/acl_common.h index 899736c477165..f2e89de15efd9 100644 --- a/onnxruntime/core/providers/acl/acl_common.h +++ b/onnxruntime/core/providers/acl/acl_common.h @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2019-2020, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #pragma once @@ -7,17 +8,37 @@ #include "core/framework/op_kernel.h" // ACL +#include "arm_compute/core/experimental/Types.h" #include "arm_compute/runtime/Tensor.h" #include "arm_compute/runtime/TensorAllocator.h" -#include "arm_compute/runtime/MemoryManagerOnDemand.h" namespace onnxruntime { namespace acl { +struct Workspace { + std::vector> temporary_tensors; + std::vector> prepare_tensors; + std::vector> persistent_tensors; +}; + +void PopulateWorkspace(const arm_compute::experimental::MemoryRequirements& reqs, + Workspace& workspace, arm_compute::MemoryGroup& memory_group, + arm_compute::ITensorPack& run_pack, arm_compute::ITensorPack& prep_pack); + arm_compute::TensorShape ACLTensorShape(const TensorShape& tensorShape, unsigned int extDim = 0); +Status GetArgShape(const NodeArg* tensor, TensorShape& outShape); void ACLPrintTensorShape(const char*, arm_compute::Tensor& t); -std::shared_ptr ACLCreateMemoryManager(); -arm_compute::Status ACLImportMemory(arm_compute::TensorAllocator* allocator, void* memory, size_t size); +arm_compute::DataType ACLDataType(const std::string& dtype); + +int GetIntScalar(const Tensor* tensor); +Status LoadQuantizationInfo(const OpKernelInfo& info, arm_compute::Tensor* tensor, + const int scaleIdx, const int zpIdx, bool flipZeroPoint); + +void GetPackingInfo(std::vector>& state, size_t& packedSize, size_t& alignment); +Status LoadPackedTensors(std::vector>& state, void* packed, + const size_t packedSize, const size_t alignment); + +Status ACLImportMemory(arm_compute::TensorAllocator* allocator, void* memory, size_t size); template void importDataToTensor(arm_compute::Tensor* tensor, const T* data); template diff --git a/onnxruntime/core/providers/acl/acl_execution_provider.cc b/onnxruntime/core/providers/acl/acl_execution_provider.cc index d19dc15e17f6d..8d34e36fe7cd6 100644 --- a/onnxruntime/core/providers/acl/acl_execution_provider.cc +++ b/onnxruntime/core/providers/acl/acl_execution_provider.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2019-2020, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "acl_execution_provider.h" @@ -7,13 +8,19 @@ #include "core/framework/op_kernel.h" #include "core/framework/kernel_registry.h" #include "core/framework/compute_capability.h" +#include "core/providers/acl/math/matmul.h" +#include "core/providers/acl/nn/conv.h" +#include "core/session/inference_session.h" #include "contrib_ops/cpu/cpu_contrib_kernels.h" #include "acl_fwd.h" +#include "scheduler.h" -namespace onnxruntime { +#include "arm_compute/runtime/Scheduler.h" +#include "arm_compute/runtime/PoolManager.h" +#include "arm_compute/runtime/BlobLifetimeManager.h" +#include "arm_compute/runtime/Allocator.h" -constexpr const char* ACL = "Acl"; -constexpr const char* ACL_CPU = "AclCpu"; +namespace onnxruntime { namespace acl { @@ -22,7 +29,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 6, Rel class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 7, 8, Gemm); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 9, 10, Gemm); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 11, Gemm); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 1, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 11, Conv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 11, MLFloat16, Conv); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 7, 9, float, AveragePool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 1, 7, float, MaxPool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 8, 11, float, MaxPool); @@ -39,6 +47,22 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDoma class ONNX_OPERATOR_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 11, Concat); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kMSDomain, 1, float, FusedConv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kAclExecutionProvider, kMSDomain, 1, NhwcConv); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 13, MatMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 13, MLFloat16, MatMul); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kAclExecutionProvider, kMSDomain, 1, FusedMatMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kMSDomain, 1, uint8_t, MatMulIntegerToFloat); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kMSDomain, 1, int8_t, MatMulIntegerToFloat); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 10, uint8_t, QLinearConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 10, int8_t, QLinearConv); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kMSDomain, 1, uint8_t, QLinearConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kMSDomain, 1, int8_t, QLinearConv); Status RegisterACLKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { @@ -48,7 +72,8 @@ Status RegisterACLKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -67,6 +92,22 @@ Status RegisterACLKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { @@ -85,10 +126,22 @@ std::shared_ptr GetAclKernelRegistry() { return kernel_registry; } +std::shared_ptr ACLCreateMemoryManager() { + auto lifetime_mgr = std::make_shared(); + auto pool_mgr = std::make_shared(); + auto mm = std::make_shared(lifetime_mgr, pool_mgr); + + return mm; +} + } // namespace acl -ACLExecutionProvider::ACLExecutionProvider(const ACLExecutionProviderInfo&) - : IExecutionProvider{onnxruntime::kAclExecutionProvider} {} +ACLExecutionProvider::ACLExecutionProvider(const ACLExecutionProviderInfo& info) + : IExecutionProvider{onnxruntime::kAclExecutionProvider}, + info(info), + memory_manager(onnxruntime::acl::ACLCreateMemoryManager()) { + arm_compute::Scheduler::set(std::make_shared(this)); +} ACLExecutionProvider::~ACLExecutionProvider() {} @@ -97,4 +150,47 @@ std::shared_ptr ACLExecutionProvider::GetKernelRegistry() const return kernel_registry; } +std::vector> +ACLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, + const IKernelLookup& kernel_lookup) const { + std::vector> result; + for (const auto& node : graph.Nodes()) { + if (const KernelCreateInfo* kernel_create_info = kernel_lookup.LookUpKernel(node); + kernel_create_info != nullptr) { + Status support_status = Status::OK(); + const std::string op_name = kernel_create_info->kernel_def->OpName(); + + if (op_name == "Conv" || op_name == "NhwcConv" || op_name == "QLinearConv") { + support_status = onnxruntime::acl::ValidateConv(node); + } + if (op_name == "MatMul" || op_name == "FusedMatMul" || op_name == "MatMulIntegerToFloat") { + support_status = onnxruntime::acl::ValidateMatMul(node); + } + + if (support_status.IsOK()) { + std::unique_ptr sub_graph = std::make_unique(); + sub_graph->nodes.push_back(node.Index()); + result.push_back(std::make_unique(std::move(sub_graph))); + } else { + LOGS_DEFAULT(WARNING) << "ACL supports operator " << op_name + << ", but not with these parameters. Using fallback for node: " << node.Name() + << " Reason: " << support_status.ErrorMessage(); + } + } + } + + return result; +} + +Status ACLExecutionProvider::OnRunStart(const onnxruntime::RunOptions&) { + arm_compute::Allocator alloc{}; + memory_manager->populate(alloc, 1); + return Status::OK(); +}; + +Status ACLExecutionProvider::OnRunEnd(bool, const onnxruntime::RunOptions&) { + memory_manager->clear(); + return Status::OK(); +}; + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/acl_execution_provider.h b/onnxruntime/core/providers/acl/acl_execution_provider.h index 126656e0956bb..1c267d8713673 100755 --- a/onnxruntime/core/providers/acl/acl_execution_provider.h +++ b/onnxruntime/core/providers/acl/acl_execution_provider.h @@ -1,20 +1,24 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2019, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #pragma once #include "core/framework/execution_provider.h" #include "core/graph/constants.h" +#include "core/platform/threadpool.h" + +#include "arm_compute/runtime/MemoryManagerOnDemand.h" namespace onnxruntime { // Information needed to construct ACL execution providers. struct ACLExecutionProviderInfo { - bool create_arena{true}; + bool enable_fast_math{false}; - explicit ACLExecutionProviderInfo(bool use_arena) - : create_arena(use_arena) {} + explicit ACLExecutionProviderInfo(bool enable_fast_math) + : enable_fast_math(enable_fast_math) {} ACLExecutionProviderInfo() = default; }; @@ -31,6 +35,26 @@ class ACLExecutionProvider : public IExecutionProvider { } std::shared_ptr GetKernelRegistry() const override; + + std::vector> GetCapability( + const onnxruntime::GraphViewer& graph, + const IKernelLookup& kernel_lookup) const override; + + Status OnRunStart(const onnxruntime::RunOptions&) override; + + Status OnRunEnd(bool, const onnxruntime::RunOptions&) override; + + void SetThreadPool(concurrency::ThreadPool* thread_pool) { + thread_pool_ = thread_pool; + } + + concurrency::ThreadPool* GetThreadPool() const { + return thread_pool_; + } + + const ACLExecutionProviderInfo info; + const std::shared_ptr memory_manager; + concurrency::ThreadPool* thread_pool_ = nullptr; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/acl_provider_factory.cc b/onnxruntime/core/providers/acl/acl_provider_factory.cc index 4eb11b222e576..26a41afeeee36 100755 --- a/onnxruntime/core/providers/acl/acl_provider_factory.cc +++ b/onnxruntime/core/providers/acl/acl_provider_factory.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2019, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "core/providers/acl/acl_provider_factory.h" @@ -11,27 +12,28 @@ namespace onnxruntime { struct ACLProviderFactory : IExecutionProviderFactory { - ACLProviderFactory(bool create_arena) : create_arena_(create_arena) {} + ACLProviderFactory(bool enable_fast_math) : enable_fast_math_(enable_fast_math) {} ~ACLProviderFactory() override {} std::unique_ptr CreateProvider() override; private: - bool create_arena_; + bool enable_fast_math_; }; std::unique_ptr ACLProviderFactory::CreateProvider() { ACLExecutionProviderInfo info; - info.create_arena = create_arena_; + info.enable_fast_math = enable_fast_math_; return std::make_unique(info); } -std::shared_ptr ACLProviderFactoryCreator::Create(int use_arena) { - return std::make_shared(use_arena != 0); +std::shared_ptr ACLProviderFactoryCreator::Create(bool enable_fast_math) { + return std::make_shared(enable_fast_math); } } // namespace onnxruntime -ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_ACL, _In_ OrtSessionOptions* options, int use_arena) { - options->provider_factories.push_back(onnxruntime::ACLProviderFactoryCreator::Create(use_arena)); +ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_ACL, _In_ OrtSessionOptions* options, + bool enable_fast_math) { + options->provider_factories.push_back(onnxruntime::ACLProviderFactoryCreator::Create(enable_fast_math)); return nullptr; } diff --git a/onnxruntime/core/providers/acl/acl_provider_factory_creator.h b/onnxruntime/core/providers/acl/acl_provider_factory_creator.h index 2eee50ee710da..31a596f2d4bbc 100644 --- a/onnxruntime/core/providers/acl/acl_provider_factory_creator.h +++ b/onnxruntime/core/providers/acl/acl_provider_factory_creator.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #pragma once @@ -10,7 +11,7 @@ namespace onnxruntime { struct ACLProviderFactoryCreator { - static std::shared_ptr Create(int use_arena); + static std::shared_ptr Create(bool enable_fast_math); }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/math/gemm.h b/onnxruntime/core/providers/acl/math/gemm.h index f5288d7f231b0..5db2372705184 100644 --- a/onnxruntime/core/providers/acl/math/gemm.h +++ b/onnxruntime/core/providers/acl/math/gemm.h @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2019-2020, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #pragma once @@ -28,7 +29,6 @@ namespace acl { typedef struct { std::shared_ptr layer; std::shared_ptr a, b, c, d; - std::shared_ptr mm_layer; } ACLNEGEMM; typedef std::map::iterator GEMMLayersIterator; @@ -37,6 +37,9 @@ template class Gemm : public onnxruntime::Gemm { public: Gemm(const OpKernelInfo& info) : onnxruntime::Gemm(info) { + provider_ = (const_cast( + static_cast(info.GetExecutionProvider()))); + int64_t temp; ORT_ENFORCE(info.GetAttr("transA", &temp).IsOK()); @@ -49,12 +52,11 @@ class Gemm : public onnxruntime::Gemm { } Status Compute(OpKernelContext* context) const override { -#ifdef ACL_2308 if (this->packed_b_) { // Prepacked RHS not supported, defaulting to cpu execution provider return onnxruntime::Gemm::Compute(context); } -#endif + const auto A = context->Input(0); const auto B = context->Input(1); const auto C = context->Input(2); @@ -96,19 +98,20 @@ class Gemm : public onnxruntime::Gemm { (cShape[1] == 1 && cShape[0] != (long unsigned int)N)) { return onnxruntime::Gemm::Compute(context); } -#ifdef ACL_2308 cShape = arm_compute::TensorShape(N); LOGS_DEFAULT(VERBOSE) << "Bias reshaped to: {" << N << "}"; -#else - cShape = arm_compute::TensorShape(1, N); - LOGS_DEFAULT(VERBOSE) << "Bias reshaped to: {1," << N << "}"; -#endif } int64_t K = helper.K(); - if (A) LOGS_DEFAULT(VERBOSE) << "A " << A->Shape().ToString().c_str(); - if (B) LOGS_DEFAULT(VERBOSE) << "B " << B->Shape().ToString().c_str(); - if (C) LOGS_DEFAULT(VERBOSE) << "C " << C->Shape().ToString().c_str(); + if (A) { + LOGS_DEFAULT(VERBOSE) << "A " << A->Shape().ToString().c_str(); + } + if (B) { + LOGS_DEFAULT(VERBOSE) << "B " << B->Shape().ToString().c_str(); + } + if (C) { + LOGS_DEFAULT(VERBOSE) << "C " << C->Shape().ToString().c_str(); + } LOGS_DEFAULT(VERBOSE) << "D " << D->Shape().ToString().c_str(); LOGS_DEFAULT(VERBOSE) << "M " << (int)M << ", N " << (int)N << ", K " << (int)K; LOGS_DEFAULT(VERBOSE) << "Alfa " << alpha_ << ", Beta " << beta_; @@ -131,10 +134,8 @@ class Gemm : public onnxruntime::Gemm { // dimensions are stored in the opposite order to ACL's tGEMM.d->allocator()->init(arm_compute::TensorInfo(arm_compute::TensorShape(N, M), arm_compute::Format::F32)); - tGEMM.mm_layer = ACLCreateMemoryManager(); - if (FC) { - auto layer = std::make_shared(tGEMM.mm_layer); + auto layer = std::make_shared(provider_->memory_manager); arm_compute::FullyConnectedLayerInfo fc_info; fc_info.transpose_weights = trans_B_ == CblasTrans; layer->configure(tGEMM.a.get(), tGEMM.b.get(), useC ? tGEMM.c.get() : nullptr, tGEMM.d.get(), fc_info); @@ -173,10 +174,7 @@ class Gemm : public onnxruntime::Gemm { ACLPrintTensorShape("c", *pGEMM->c); ACLPrintTensorShape("d", *pGEMM->d); - arm_compute::Allocator alloc_mm{}; - pGEMM->mm_layer->populate(alloc_mm, 1); pGEMM->layer->run(); - pGEMM->mm_layer->clear(); if (D->Shape().Size() != 0 && pGEMM->d->info()->has_padding()) { importDataFromTensor(pGEMM->d.get(), d_data); @@ -195,6 +193,7 @@ class Gemm : public onnxruntime::Gemm { } private: + ACLExecutionProvider* provider_; static thread_local std::map gemmLayers; CBLAS_TRANSPOSE trans_A_; diff --git a/onnxruntime/core/providers/acl/math/matmul.cc b/onnxruntime/core/providers/acl/math/matmul.cc new file mode 100644 index 0000000000000..468b394471c13 --- /dev/null +++ b/onnxruntime/core/providers/acl/math/matmul.cc @@ -0,0 +1,404 @@ +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#include +#include +#include "core/common/status.h" +#include "core/framework/op_kernel_info.h" +#include "core/framework/op_node_proto_helper.h" +#include "core/framework/tensor_shape.h" +#ifdef _WIN32 +#pragma warning(disable : 4244) +#endif +#include +#include + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/util/math.h" +#include "core/util/math_cpuonly.h" + +#include "core/providers/acl/math/matmul.h" +#include "core/providers/acl/acl_common.h" +#include "core/providers/acl/acl_fwd.h" + +// ACL +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/runtime/NEON/functions/NEMatMul.h" +#include "src/cpu/operators/CpuGemm.h" +#include "src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h" +#include "src/cpu/operators/CpuMatMul.h" + +namespace onnxruntime { + +namespace acl { + +TensorShape BroadcastInput(const TensorShape& shape, bool prependDim) { + const auto nd = shape.NumDimensions(); + if (nd == 0) { + ORT_THROW("MatMul by scalar not allowed"); + } + + int64_t batchSize = 1; + if (nd == 1) { + if (prependDim) { + return {1, 1, shape[0]}; + } else { + return {1, shape[0], 1}; + } + } + + for (size_t i = 0; i < nd - 2; i++) { + batchSize *= shape[i]; + } + + return {batchSize, shape[nd - 2], shape[nd - 1]}; +} + +struct MatMulConfig { + bool isQuantized; + float alpha; + bool transA; + bool transB; + TensorShape aShapeBroadcast; + TensorShape bShapeBroadcast; +}; + +Status ParseMatMul(const onnxruntime::Node& node, MatMulConfig& config) { + onnxruntime::ProtoHelperNodeContext ctx(node); + onnxruntime::OpNodeProtoHelper attrs(&ctx); + const auto inputDefs = node.InputDefs(); + + config.isQuantized = node.OpType() == "MatMulIntegerToFloat"; + + config.alpha = 1; + attrs.GetAttr("alpha", &config.alpha); + + int64_t transA = 0; + attrs.GetAttr("transA", &transA); + int64_t transB = 0; + attrs.GetAttr("transB", &transB); + + config.transA = transA; + config.transB = transB; + + const int64_t transBatchA = attrs.GetAttrOrDefault("transBatchA", 0); + const int64_t transBatchB = attrs.GetAttrOrDefault("transBatchB", 0); + + ORT_RETURN_IF(transBatchA, "transBatchA not supported by ACL"); + ORT_RETURN_IF(transBatchB, "transBatchB not supported by ACL"); + + ORT_RETURN_IF(config.isQuantized && inputDefs.size() >= 7, "ACL MatMulIntegerToFloat does not support bias"); + + TensorShape aShapeIn; + ORT_RETURN_IF_ERROR(GetArgShape(inputDefs[0], aShapeIn)); + + TensorShape bShapeIn; + ORT_RETURN_IF_ERROR(GetArgShape(inputDefs[1], bShapeIn)); + + config.aShapeBroadcast = BroadcastInput(aShapeIn, !config.transA); + config.bShapeBroadcast = BroadcastInput(bShapeIn, config.transB); + + ORT_RETURN_IF(!(config.bShapeBroadcast[0] == 1 || (config.aShapeBroadcast[0] == config.bShapeBroadcast[0])), + "ACL does not support broadcasting"); + + ORT_RETURN_IF(config.alpha != 1 && config.bShapeBroadcast[0] > 1, + "ACL does not support alpha scaling with batched B"); + + return Status::OK(); +} + +Status ValidateMatMul(const onnxruntime::Node& node) { + MatMulConfig config; + return ParseMatMul(node, config); +} + +MatMul::MatMul(const OpKernelInfo& info) : onnxruntime::OpKernel(info) { + provider_ = (const_cast( + static_cast(info.GetExecutionProvider()))); + + const auto inputDefs = OpKernel::Node().InputDefs(); + const auto outputDefs = OpKernel::Node().OutputDefs(); + + const Tensor* tmp = nullptr; + const bool aIsConst = info.TryGetConstantInput(0, &tmp); + const bool bIsConst = info.TryGetConstantInput(1, &tmp); + + MatMulConfig config; + ORT_THROW_IF_ERROR(ParseMatMul(OpKernel::Node(), config)); + + ORT_THROW_IF_ERROR(GetArgShape(outputDefs[0], outShape)); + if (outShape.Size() == 0) { + return; + } + + const TensorShape aShape{ + config.aShapeBroadcast[0], + config.aShapeBroadcast[config.transA ? 2 : 1], + config.aShapeBroadcast[config.transA ? 1 : 2]}; + + const TensorShape bShape{ + config.bShapeBroadcast[0], + config.bShapeBroadcast[config.transB ? 2 : 1], + config.bShapeBroadcast[config.transB ? 1 : 2]}; + + const TensorShape outShapeBroadcast{aShape[0], aShape[1], bShape[2]}; + + ORT_ENFORCE(outShape.Size() == outShapeBroadcast.Size(), "Output sizes do not match"); + + arm_compute::DataType aType = ACLDataType(*inputDefs[0]->Type()); + arm_compute::DataType bType = ACLDataType(*inputDefs[1]->Type()); + arm_compute::DataType outType = ACLDataType(*outputDefs[0]->Type()); + + arm_compute::GEMMInfo gemmInfo(false, false, bIsConst); + gemmInfo.set_fast_math(provider_->info.enable_fast_math); + + a = std::make_shared(); + b = std::make_shared(); + out = std::make_shared(); + + a->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(config.aShapeBroadcast), 1, aType)); + b->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(config.bShapeBroadcast), 1, bType)); + out->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(outShapeBroadcast), 1, outType)); + + if (config.isQuantized) { + ORT_THROW_IF_ERROR(LoadQuantizationInfo(info, a.get(), 2, 4, true)); + ORT_THROW_IF_ERROR(LoadQuantizationInfo(info, b.get(), 3, 5, true)); + } + + arm_compute::ITensor* a_to_use = a.get(); + if (config.transA) { + a_transposed = std::make_shared(); + a_transposed->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(aShape), 1, aType)); + a_to_use = a_transposed.get(); + + a_permute = std::make_shared(); + a_permute->configure(a.get(), a_transposed.get(), {1, 0, 2}); + } + + arm_compute::ITensor* b_to_use = b.get(); + if (config.transB) { + if (bIsConst) { + workspace.persistent_tensors.emplace_back(std::make_unique()); + b_transposed = workspace.persistent_tensors.back().get(); + } else { + workspace.temporary_tensors.emplace_back(std::make_unique()); + b_transposed = workspace.temporary_tensors.back().get(); + } + + b_transposed->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(bShape), 1, bType), 128); + b_to_use = b_transposed; + + b_permute = std::make_shared(); + b_permute->configure(b.get(), b_transposed, {1, 0, 2}); + } + + a_to_use->info()->set_are_values_constant(aIsConst); + b_to_use->info()->set_are_values_constant(bIsConst); + + if (config.bShapeBroadcast[0] > 1) { + arm_compute::CpuMatMulSettings settings; + settings.fast_math(provider_->info.enable_fast_math); + + a_to_use->info()->set_are_values_constant(false); + b_to_use->info()->set_are_values_constant(false); + + const auto matmul = std::make_shared(); + matmul->configure(a_to_use->info(), b_to_use->info(), out->info(), {}, settings, {}); + layer = std::move(matmul); + } else if (config.isQuantized) { + const auto gemm = std::make_shared(); + gemm->configure(a_to_use->info(), b_to_use->info(), nullptr, out->info(), gemmInfo); + layer = std::move(gemm); + } else { + const auto gemm = std::make_shared(); + gemm->configure(a_to_use->info(), b_to_use->info(), nullptr, out->info(), config.alpha, 0.f, gemmInfo); + layer = std::move(gemm); + } + + memory_group = arm_compute::MemoryGroup(provider_->memory_manager); + run_pack = {{arm_compute::ACL_SRC_0, a_to_use}, {arm_compute::ACL_SRC_1, b_to_use}, {arm_compute::ACL_DST, out.get()}}; + prep_pack = {{arm_compute::ACL_SRC_1, b_to_use}}; + + PopulateWorkspace(layer->workspace(), workspace, memory_group, run_pack, prep_pack); +} + +Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { + is_packed = false; + if (input_idx != 1 || outShape.Size() == 0) { + return Status::OK(); + } + + const uint8_t* data = (uint8_t*)tensor.DataRaw(); + + ORT_RETURN_IF_ERROR(ACLImportMemory(b->allocator(), (void*)data, 0)); + + if (!workspace.persistent_tensors.empty()) { + size_t packedSize = 0; + size_t alignment = 0; + GetPackingInfo(workspace.persistent_tensors, packedSize, alignment); + auto buffSize = packedSize + alignment; + + pbRaw = IAllocator::MakeUniquePtr(alloc, buffSize, true); + ORT_RETURN_IF_ERROR(LoadPackedTensors(workspace.persistent_tensors, pbRaw.get(), packedSize, alignment)); + + if (prepacked_weights != nullptr) { + prepacked_weights->buffers_.push_back(std::move(pbRaw)); + prepacked_weights->buffer_sizes_.push_back(buffSize); + } + + is_packed = true; + } + + if (b_transposed) { + b_permute->run(); + } + + for (std::unique_ptr& prep_tensor : workspace.prepare_tensors) { + prep_tensor->allocator()->allocate(); + } + + layer->prepare(prep_pack); + + for (std::unique_ptr& prep_tensor : workspace.prepare_tensors) { + prep_tensor->allocator()->free(); + } + + return Status::OK(); +} + +Status MatMul::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + int input_idx, /*out*/ bool& used_shared_buffers) { + used_shared_buffers = false; + if (input_idx != 1) { + return Status::OK(); + } + + if (!workspace.persistent_tensors.empty()) { + size_t packedSize = 0; + size_t alignment = 0; + GetPackingInfo(workspace.persistent_tensors, packedSize, alignment); + + ORT_RETURN_IF_ERROR(LoadPackedTensors(workspace.persistent_tensors, prepacked_buffers[0].get(), packedSize, alignment)); + + used_shared_buffers = true; + } + + return Status::OK(); +} + +Status MatMul::Compute(OpKernelContext* context) const { + provider_->SetThreadPool(context->GetOperatorThreadPool()); + + const Tensor* A = context->Input(0); + const Tensor* B = pbRaw ? nullptr : context->Input(1); + + Tensor* outOrt = context->Output(0, outShape); + + if (outShape.Size() == 0) { + return Status::OK(); + } + + const void* a_data = A->DataRaw(); + const void* b_data = B == nullptr ? nullptr : B->DataRaw(); + void* out_data = outOrt->MutableDataRaw(); + + ORT_RETURN_IF(A->Shape().Size() != 0 && a->info()->has_padding(), "Padded ACL input tensor not supported"); + ORT_RETURN_IF_ERROR(ACLImportMemory(a->allocator(), (void*)a_data, 0)); + + if (b_data != nullptr) { + ORT_RETURN_IF_ERROR(ACLImportMemory(b->allocator(), (void*)b_data, 0)); + } + + ORT_RETURN_IF(outOrt->Shape().Size() != 0 && out->info()->has_padding(), "Padded ACL output tensor not supported"); + ORT_RETURN_IF_ERROR(ACLImportMemory(out->allocator(), (void*)out_data, 0)); + + ORT_RETURN_IF(B != nullptr && workspace.persistent_tensors.size(), "Persistent state requires pre-packing"); + + if (a_transposed) { + a_transposed->allocator()->allocate(); + a_permute->run(); + } + + { + arm_compute::MemoryGroupResourceScope scope_mg(const_cast(memory_group)); + if (b_transposed && B) { + b_permute->run(); + } + + layer->run(const_cast(run_pack)); + } + + a->allocator()->free(); + if (B != nullptr) + b->allocator()->free(); + out->allocator()->free(); + + if (a_transposed) { + a_transposed->allocator()->free(); + } + + return Status::OK(); +} + +ONNX_OPERATOR_KERNEL_EX( + MatMul, + kOnnxDomain, + 13, + kAclExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMul, + kOnnxDomain, + 13, + MLFloat16, + kAclExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_OPERATOR_KERNEL_EX( + FusedMatMul, + kMSDomain, + 1, + kAclExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + FusedMatMul, + kMSDomain, + 1, + MLFloat16, + kAclExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulIntegerToFloat, + kMSDomain, + 1, + uint8_t, + kAclExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulIntegerToFloat, + kMSDomain, + 1, + int8_t, + kAclExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + MatMul); + +} // namespace acl +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/math/matmul.h b/onnxruntime/core/providers/acl/math/matmul.h new file mode 100644 index 0000000000000..b137e33833de9 --- /dev/null +++ b/onnxruntime/core/providers/acl/math/matmul.h @@ -0,0 +1,64 @@ +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#pragma once +#include "core/framework/op_kernel.h" +#include "core/providers/acl/acl_common.h" +#include "core/providers/acl/acl_execution_provider.h" + +// ACL +#include "arm_compute/runtime/Tensor.h" +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/runtime/IOperator.h" +#include "arm_compute/runtime/Tensor.h" +#include "arm_compute/runtime/TensorAllocator.h" +#include "arm_compute/runtime/Allocator.h" +#include "arm_compute/runtime/PoolManager.h" +#include "arm_compute/runtime/BlobLifetimeManager.h" +#include "arm_compute/runtime/MemoryManagerOnDemand.h" + +// NEON +#include "arm_compute/runtime/NEON/functions/NEGEMM.h" +#include "arm_compute/runtime/NEON/functions/NEPermute.h" + +namespace onnxruntime { +namespace acl { + +Status ValidateMatMul(const onnxruntime::Node& node); + +class MatMul : public OpKernel { + public: + explicit MatMul(const OpKernelInfo& info); + + Status PrePack(const Tensor&, int, AllocatorPtr, + bool& is_packed, PrePackedWeights*) override; + + Status UseSharedPrePackedBuffers(std::vector&, + int, bool&) override; + + Status Compute(OpKernelContext* context) const override; + + protected: + ACLExecutionProvider* provider_; + std::shared_ptr a_permute; + std::shared_ptr b_permute; + std::shared_ptr layer; + + arm_compute::MemoryGroup memory_group; + arm_compute::ITensorPack run_pack; + arm_compute::ITensorPack prep_pack; + + Workspace workspace; + + std::shared_ptr a; + std::shared_ptr b; + std::shared_ptr a_transposed; + arm_compute::Tensor* b_transposed = nullptr; + std::shared_ptr out; + arm_compute::Tensor* pb; + + IAllocatorUniquePtr pbRaw; + TensorShape outShape; +}; +} // namespace acl +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/nn/batch_norm.cc b/onnxruntime/core/providers/acl/nn/batch_norm.cc index be0e57c5c0543..192bc34556eef 100755 --- a/onnxruntime/core/providers/acl/nn/batch_norm.cc +++ b/onnxruntime/core/providers/acl/nn/batch_norm.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2020, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "core/common/common.h" @@ -80,7 +81,6 @@ Status BatchNorm::Compute(OpKernelContext* context) const { auto layer = std::make_shared(); -#ifdef ACL_2308 arm_compute::TensorShape in_x_shape; const TensorShape& x_shape = X->Shape(); const auto& dims_vec = x_shape.GetDims(); @@ -94,9 +94,6 @@ Status BatchNorm::Compute(OpKernelContext* context) const { in_x_shape.set(2, onnxruntime::narrow(dims_vec[1])); // C tbatch_norm.in->allocator()->init(arm_compute::TensorInfo(in_x_shape, arm_compute::Format::F32)); -#else - tbatch_norm.in->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(X->Shape()), arm_compute::Format::F32)); -#endif tbatch_norm.out->allocator()->init(arm_compute::TensorInfo(tbatch_norm.in->info()->tensor_shape(), arm_compute::Format::F32)); tbatch_norm.scale->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(S->Shape()), arm_compute::Format::F32)); diff --git a/onnxruntime/core/providers/acl/nn/conv.cc b/onnxruntime/core/providers/acl/nn/conv.cc index 85bd0cfe96279..a62158f1c26ee 100644 --- a/onnxruntime/core/providers/acl/nn/conv.cc +++ b/onnxruntime/core/providers/acl/nn/conv.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2019-2020, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #ifdef _WIN32 @@ -19,31 +20,81 @@ // ACL #include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/ITensorPack.h" +#include "src/cpu/operators/CpuConv2d.h" // NEON -#include "arm_compute/runtime/NEON/functions/NEConvolutionLayer.h" #include "arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h" -#ifdef ACL_1902 -#include "arm_compute/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.h" -#endif -#if defined(ACL_1905) || defined(ACL_1908) -#include "arm_compute/runtime/NEON/functions/assembly/NEDepthwiseConvolutionAssemblyDispatch.h" -#endif - #define CONV_ACL #undef DEPTHWISE_CPU #define PREF_DIM 4 namespace onnxruntime { + namespace acl { -template -thread_local std::map Conv::convLayers; +struct ConvConfig { + bool isQuantized; + bool is_channels_last; + bool isDepthwise; + TensorShape inShapeIn; + TensorShape kShapeIn; + const std::string* inType; + const std::string* kType; +}; + +Status ParseConv(const onnxruntime::Node& node, ConvConfig& config) { + onnxruntime::ProtoHelperNodeContext ctx(node); + onnxruntime::OpNodeProtoHelper attrs(&ctx); + const auto inputDefs = node.InputDefs(); + + config.isQuantized = node.OpType() == "QLinearConv"; + + if (config.isQuantized) { + TensorShape scaleShape; + ORT_RETURN_IF_ERROR(GetArgShape(inputDefs[4], scaleShape)); + ORT_RETURN_IF(scaleShape.Size() > 1, "ACL execution provider does not support per-channel quantization"); + } + + config.is_channels_last = node.OpType() == "NhwcConv"; + if (!config.is_channels_last) { + int64_t cl_ret = 0; + attrs.GetAttr("channels_last", &cl_ret); + config.is_channels_last = (bool)cl_ret; + } -template -arm_compute::TensorShape Conv::ACLReshapeWeightsDepthwise(arm_compute::Tensor* kernel) const { + int64_t group = 1; + attrs.GetAttr("group", &group); + + const NodeArg* kDef = inputDefs[config.isQuantized ? 3 : 1]; + + ORT_RETURN_IF_ERROR(GetArgShape(inputDefs[0], config.inShapeIn)); + ORT_RETURN_IF_ERROR(GetArgShape(kDef, config.kShapeIn)); + + ORT_RETURN_IF(config.kShapeIn.NumDimensions() > 4, "ACL execution provider supports 1D and 2D Conv only"); + + config.inType = inputDefs[0]->Type(); + config.kType = kDef->Type(); + const bool mixedType = config.inType != config.kType; + + config.isDepthwise = group > 1; + if (config.isDepthwise) { + const size_t channels = config.inShapeIn[config.is_channels_last ? config.inShapeIn.NumDimensions() - 1 : 1]; + ORT_RETURN_IF(group != channels, "ACL does not support grouping unless group == channels"); + ORT_RETURN_IF(mixedType, "ACL does not support mixed input types for depthwise Conv"); + } + + return Status::OK(); +} + +Status ValidateConv(const onnxruntime::Node& node) { + ConvConfig config; + return ParseConv(node, config); +} + +arm_compute::TensorShape Conv::ACLReshapeWeightsDepthwise(arm_compute::Tensor* kernel) const { arm_compute::TensorShape shape = arm_compute::TensorShape(kernel->info()->tensor_shape()); shape[2] = shape[2] * shape[3]; shape[3] = 1; @@ -51,43 +102,89 @@ arm_compute::TensorShape Conv::ACLReshapeWeightsDepthwise(arm_compute::Tensor return shape; } -#ifdef CONV_ACL -template -Status Conv::Compute(OpKernelContext* context) const { +Conv::Conv(const OpKernelInfo& info) : onnxruntime::OpKernel(info), conv_attrs_(info) { + provider_ = (const_cast( + static_cast(info.GetExecutionProvider()))); + + ConvConfig config; + ORT_THROW_IF_ERROR(ParseConv(OpKernel::Node(), config)); + isQuantized = config.isQuantized; + is_channels_last = config.is_channels_last; + size_t num_inputs = OpKernel::Node().InputDefs().size(); + has_bias = isQuantized ? (num_inputs == 9) : (num_inputs == 3); - ACLNEConv* pConv; - ConvLayersIterator it = Conv::convLayers.find((OpKernel*)this); - if (it != Conv::convLayers.end()) { - pConv = &it->second; - if (pConv->isDepthwiseCPU == true) { - Status s = onnxruntime::Conv::Compute(context); - return s; - } + const Tensor* tmp = nullptr; + const bool kIsConst = info.TryGetConstantInput(1, &tmp); + ORT_ENFORCE(kIsConst, "ACL does not support Conv with mutable weights"); + + in = std::make_shared(); + k = std::make_shared(); + if (has_bias) + b = std::make_shared(); + out = std::make_shared(); + + const arm_compute::DataLayout data_layout = is_channels_last ? arm_compute::DataLayout::NHWC : arm_compute::DataLayout::NCHW; + + TensorShape inShape = config.inShapeIn; + if (is_channels_last && config.inShapeIn.NumDimensions() < 4) { + inShape = TensorShape({config.inShapeIn[0], config.inShapeIn[1], 1, config.inShapeIn[2]}); } - const Tensor* X = context->Input(0); - const Tensor* W = context->Input(1); - const Tensor* B = num_inputs == 3 ? context->Input(2) : nullptr; + arm_compute::DataType inType = ACLDataType(*config.inType); + in->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(inShape, PREF_DIM), 1, inType, data_layout)); + + arm_compute::DataType kType = ACLDataType(*config.kType); + + TensorShapeVector kShapeVec = config.kShapeIn.AsShapeVector(); + while (kShapeVec.size() < 4) { + kShapeVec.push_back(1); + } + + const TensorShape kShape = is_channels_last ? TensorShape({kShapeVec[0], kShapeVec[2], kShapeVec[3], kShapeVec[1]}) : TensorShape(kShapeVec); + + k->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(kShape), 1, kType, data_layout)); + + TensorShape bShape; + if (has_bias) { + const Tensor* bias = nullptr; + const bool biasIsConst = info.TryGetConstantInput(isQuantized ? 8 : 2, &bias); + ORT_ENFORCE(biasIsConst, "ACL does not support Conv with mutable bias"); + + const auto bDef = OpKernel::Node().InputDefs()[isQuantized ? 8 : 2]; + ORT_THROW_IF_ERROR(GetArgShape(bDef, bShape)); + arm_compute::DataType bType = ACLDataType(*bDef->Type()); + b->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(bShape), 1, bType, data_layout)); + + const void* b_data = bias->DataRaw(); + ORT_THROW_IF_ERROR(ACLImportMemory(b->allocator(), (void*)b_data, 0)); + } + + ORT_THROW_IF_ERROR(GetArgShape(OpKernel::Node().OutputDefs()[0], outShape)); + TensorShape outShapeACL = outShape; + if (is_channels_last && outShape.NumDimensions() < 4) { + outShapeACL = TensorShape({outShape[0], outShape[1], 1, outShape[2]}); + } + + out->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(outShapeACL, PREF_DIM), 1, inType, data_layout)); - const int64_t N = X->Shape()[0]; - const int64_t M = W->Shape()[0]; + if (isQuantized) { + ORT_THROW_IF_ERROR(LoadQuantizationInfo(info, in.get(), 1, 2, false)); + ORT_THROW_IF_ERROR(LoadQuantizationInfo(info, k.get(), 4, 5, false)); + ORT_THROW_IF_ERROR(LoadQuantizationInfo(info, out.get(), 6, 7, false)); + } LOGS_DEFAULT(VERBOSE) << "Conv ACL:"; - LOGS_DEFAULT(VERBOSE) << "X " << X->Shape().ToString().c_str(); - LOGS_DEFAULT(VERBOSE) << "W " << W->Shape().ToString().c_str(); - if (B != nullptr) LOGS_DEFAULT(VERBOSE) << "B " << B->Shape().ToString().c_str(); - - if (X->Shape().NumDimensions() != PREF_DIM) { - LOGS_DEFAULT(WARNING) << "ACL does not have support for tensors with 4 or more dimensions; defaulting to cpu implementation"; - Status s = onnxruntime::Conv::Compute(context); - return s; + LOGS_DEFAULT(VERBOSE) << "X " << inShape.ToString().c_str(); + LOGS_DEFAULT(VERBOSE) << "W " << config.kShapeIn.ToString().c_str(); + if (has_bias) { + LOGS_DEFAULT(VERBOSE) << "B " << bShape.ToString().c_str(); } - ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X, W)); + ORT_THROW_IF_ERROR(conv_attrs_.ValidateInputShape(config.inShapeIn, config.kShapeIn, config.is_channels_last)); TensorShapeVector kernel_shape; - ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape)); + ORT_THROW_IF_ERROR(conv_attrs_.ComputeKernelShape(config.kShapeIn, kernel_shape)); ConvAttributes::ConvPadVector pads(conv_attrs_.pads); if (pads.empty()) { @@ -102,16 +199,13 @@ Status Conv::Compute(OpKernelContext* context) const { strides.resize(kernel_shape.size(), 1); } - TensorShapeVector Y_dims; - Y_dims.insert(Y_dims.begin(), {N, M}); - TensorShape input_shape = X->Shape().Slice(2); -#ifdef ACL_2308 - ORT_RETURN_IF_ERROR(conv_attrs_.InferPadsAndOutputShape(input_shape, kernel_shape, strides, dilations, pads, Y_dims)); -#else - ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShape(input_shape, kernel_shape, strides, dilations, pads, Y_dims)); -#endif - Tensor* Y = context->Output(0, TensorShape(Y_dims)); - LOGS_DEFAULT(VERBOSE) << "Y " << Y->Shape().ToString().c_str(); + TensorShape input_shape = config.inShapeIn.Slice(2); + TensorShapeVector out_shape; + ORT_THROW_IF_ERROR(conv_attrs_.InferPadsAndOutputShape( + input_shape, kernel_shape, strides, dilations, + pads, out_shape)); + + LOGS_DEFAULT(VERBOSE) << "Y " << outShape.ToString().c_str(); arm_compute::ActivationLayerInfo::ActivationFunction acl_activ_func; bool acl_activ_enabled = false; @@ -136,243 +230,274 @@ Status Conv::Compute(OpKernelContext* context) const { ORT_NOT_IMPLEMENTED("Not implemented fused activation: ", activation_type); } - if (it == Conv::convLayers.end()) { - auto mm_layer = ACLCreateMemoryManager(); - - ACLNEConv tconv; - tconv.mm_layer = std::move(mm_layer); + const size_t idx_channel = arm_compute::get_data_layout_dimension_index(data_layout, arm_compute::DataLayoutDimension::CHANNEL); + isDepthwiseCPU = config.isDepthwise; - tconv.in = std::make_shared(); - tconv.k = std::make_shared(); - if (B != nullptr) - tconv.b = std::make_shared(); - tconv.out = std::make_shared(); + std::vector aclStrides(2); + aclStrides[0] = (strides.size() == 2) ? strides[1] : 1; + aclStrides[1] = strides[0]; - tconv.in->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(X->Shape(), PREF_DIM), arm_compute::Format::F32)); - tconv.k->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(W->Shape()), arm_compute::Format::F32)); - if (B != nullptr) { - tconv.b->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(B->Shape()), arm_compute::Format::F32)); - } - tconv.out->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(Y->Shape(), PREF_DIM), arm_compute::Format::F32)); - - const arm_compute::DataLayout data_layout = tconv.in->info()->data_layout(); - const int idx_channel = arm_compute::get_data_layout_dimension_index(data_layout, arm_compute::DataLayoutDimension::CHANNEL); - bool isDepthwise = (conv_attrs_.group > 1 && conv_attrs_.group == tconv.in->info()->tensor_shape()[idx_channel]); - tconv.isDepthwiseCPU = isDepthwise; - - std::vector aclStrides(2); - aclStrides[0] = (strides.size() == 2) ? strides[1] : 1; - aclStrides[1] = strides[0]; - - std::vector aclPads(4); - // The pad order in acl is: pad_left, pad_right, pad_top, pad_bottom - if (pads.size() == 2) { - if (strides.size() == 1) { - aclPads[0] = 0; - aclPads[1] = 0; - aclPads[2] = pads[1]; - aclPads[3] = pads[0]; - } else { - aclPads[0] = pads[1]; - aclPads[1] = pads[0]; - aclPads[2] = pads[1]; - aclPads[3] = pads[0]; - } + std::vector aclPads(4); + // The pad order in acl is: pad_left, pad_right, pad_top, pad_bottom + if (pads.size() == 2) { + if (strides.size() == 1) { + aclPads[0] = 0; + aclPads[1] = 0; + aclPads[2] = pads[0]; + aclPads[3] = pads[1]; } else { aclPads[0] = pads[1]; - aclPads[1] = pads[3]; - aclPads[2] = pads[0]; - aclPads[3] = pads[2]; + aclPads[1] = pads[0]; + aclPads[2] = pads[1]; + aclPads[3] = pads[0]; } + } else { + aclPads[0] = pads[1]; + aclPads[1] = pads[3]; + aclPads[2] = pads[0]; + aclPads[3] = pads[2]; + } - arm_compute::PadStrideInfo aclPadStride = arm_compute::PadStrideInfo(aclStrides[0], aclStrides[1], - aclPads[0], aclPads[1], aclPads[2], aclPads[3], arm_compute::DimensionRoundingType::FLOOR); - unsigned int aclDilation0 = (dilations.size() == 2) ? dilations[1] : 1; - - LOGS_DEFAULT(VERBOSE) << "padding: {" << aclPads[0] << "," << aclPads[1] << "," << aclPads[2] << "," << aclPads[3] << "}"; - LOGS_DEFAULT(VERBOSE) << "strides: {" << aclStrides[0] << "," << aclStrides[1] << "}"; - - if (isDepthwise) { - LOGS_DEFAULT(VERBOSE) << "Depthwise convolution"; -#ifdef DEPTHWISE_CPU - Status s = onnxruntime::Conv::Compute(context); - std::pair ret; - ret = Conv::convLayers.insert(std::pair((OpKernel*)this, tconv)); - return s; -#else - tconv.k->info()->set_tensor_shape(ACLReshapeWeightsDepthwise(tconv.k.get())); - - // in the configure function for NEDepthwiseConvolutionLayer3x3, there is a separation based on the optimization -#ifdef ACL_1902 - bool optimizable = - arm_compute::NEDepthwiseConvolutionLayer3x3Kernel::is_optimized_execution_possible(tconv.in->info()->tensor_shape(), - aclPadStride, - tconv.in->info()->data_type(), - 1 /* depth multiplier */, - tconv.in->info()->data_layout()); -#elif defined(ACL_1905) || defined(ACL_1908) - bool optimizable = - arm_compute::NEDepthwiseConvolutionAssemblyDispatch::is_optimized_supported(tconv.in->info(), - tconv.k->info(), - aclPadStride, - 1 /* depth multiplier */, - arm_compute::Size2D(aclDilation0, dilations[0])); -#elif defined(ACL_2002) - bool optimizable = bool(arm_compute::NEDepthwiseConvolutionLayerOptimized::validate(tconv.in->info(), - tconv.k->info(), - (B != nullptr) ? tconv.b->info() : nullptr, - tconv.out->info(), - aclPadStride, - 1 /* depth multiplier */, - acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(), - arm_compute::Size2D(aclDilation0, dilations[0]))); -#elif defined(ACL_2308) - bool optimizable = bool(arm_compute::NEDepthwiseConvolutionLayer::validate(tconv.in->info(), - tconv.k->info(), - (B != nullptr) ? tconv.b->info() : nullptr, - tconv.out->info(), - aclPadStride, - 1 /* depth multiplier */, - acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(), - arm_compute::Size2D(aclDilation0, dilations[0]))); -#endif + arm_compute::PadStrideInfo aclPadStride = arm_compute::PadStrideInfo( + (unsigned int)aclStrides[0], (unsigned int)aclStrides[1], + (unsigned int)aclPads[0], (unsigned int)aclPads[1], + (unsigned int)aclPads[2], (unsigned int)aclPads[3], arm_compute::DimensionRoundingType::FLOOR); + size_t aclDilation0 = (dilations.size() == 2) ? dilations[1] : 1; + + LOGS_DEFAULT(VERBOSE) << "padding: {" << aclPads[0] << "," << aclPads[1] << "," << aclPads[2] << "," << aclPads[3] << "}"; + LOGS_DEFAULT(VERBOSE) << "strides: {" << aclStrides[0] << "," << aclStrides[1] << "}"; + + if (config.isDepthwise) { + LOGS_DEFAULT(VERBOSE) << "Depthwise convolution"; + k->info()->set_tensor_shape(ACLReshapeWeightsDepthwise(k.get())); + auto dl = std::make_shared(); + dl->configure(in.get(), k.get(), (has_bias) ? b.get() : nullptr, out.get(), + aclPadStride, 1 /* depth multiplier */, + acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(), + arm_compute::Size2D(aclDilation0, dilations[0])); + depthwise_layer = std::move(dl); + isDepthwiseCPU = false; + } else { + LOGS_DEFAULT(VERBOSE) << "ACL 2D convolution"; + auto cl = std::make_shared(); + cl->configure(in->info(), k->info(), (has_bias) ? b->info() : nullptr, out->info(), + aclPadStride, + arm_compute::WeightsInfo(), arm_compute::Size2D(aclDilation0, dilations[0]), + acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(), + provider_->info.enable_fast_math, (unsigned int)conv_attrs_.group); + conv_layer = std::move(cl); + + memory_group = arm_compute::MemoryGroup(provider_->memory_manager); + run_pack = {{arm_compute::ACL_SRC_0, in.get()}, {arm_compute::ACL_SRC_1, k.get()}, {arm_compute::ACL_SRC_2, b.get()}, {arm_compute::ACL_DST, out.get()}}; + prep_pack = {{arm_compute::ACL_SRC_1, k.get()}, {arm_compute::ACL_SRC_2, b.get()}}; + + PopulateWorkspace(conv_layer->workspace(), workspace, memory_group, run_pack, prep_pack); + } - if (optimizable) { - LOGS_DEFAULT(VERBOSE) << "ACL optimized depthwise convolution"; -#if defined(ACL_1902) || defined(ACL_1905) - auto layer = std::make_shared(); -#elif defined(ACL_1908) - auto layer = std::make_shared(); -#elif defined(ACL_2002) || defined(ACL_2308) - auto layer = std::make_shared(); -#endif + ACLPrintTensorShape("X", *in.get()); + ACLPrintTensorShape("Y", *out.get()); +} -#ifdef ACL_1902 - layer->configure(tconv.in.get(), tconv.k.get(), (B != nullptr) ? tconv.b.get() : nullptr, tconv.out.get(), - aclPadStride, 1 /* depth multiplier */, - acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo()); -#elif defined(ACL_1905) || defined(ACL_1908) || defined(ACL_2002) || defined(ACL_2308) - layer->configure(tconv.in.get(), tconv.k.get(), (B != nullptr) ? tconv.b.get() : nullptr, tconv.out.get(), - aclPadStride, 1 /* depth multiplier */, - acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(), - arm_compute::Size2D(aclDilation0, dilations[0])); -#endif - tconv.layer = std::move(layer); - tconv.isDepthwiseCPU = false; - } else { - LOGS_DEFAULT(VERBOSE) << "CPU depthwise convolution"; - Status s = onnxruntime::Conv::Compute(context); - std::pair ret; - ret = Conv::convLayers.insert(std::pair((OpKernel*)this, tconv)); - return s; - } -#endif // DEPTHWISE_CPU - } else { - if (tconv.k->info()->tensor_shape()[0] == 1 && tconv.k->info()->tensor_shape()[1] == 1) { - LOGS_DEFAULT(VERBOSE) << "CPU pointwise convolution"; - Status s = onnxruntime::Conv::Compute(context); - return s; - } else { - if (tconv.k->info()->tensor_shape()[0] == 9 && tconv.k->info()->tensor_shape()[1] == 9) { - LOGS_DEFAULT(WARNING) << "9x9 DirectConvolution does not have an implementation in NCHW layout; defaulting to cpu implementation"; - Status s = onnxruntime::Conv::Compute(context); - return s; - } - LOGS_DEFAULT(VERBOSE) << "ACL 2D convolution"; - auto layer = std::make_shared(mm_layer); - layer->configure(tconv.in.get(), tconv.k.get(), (B != nullptr) ? tconv.b.get() : nullptr, tconv.out.get(), - aclPadStride, - arm_compute::WeightsInfo(), arm_compute::Size2D(aclDilation0, dilations[0]), - acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(), - false, conv_attrs_.group); - tconv.layer = std::move(layer); - } +#ifdef CONV_ACL +Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { + is_packed = false; + if (isQuantized ? (input_idx != 3) : (input_idx != 1)) { + return Status::OK(); + } + + if (!workspace.persistent_tensors.empty()) { + size_t packedSize = 0; + size_t alignment = 0; + GetPackingInfo(workspace.persistent_tensors, packedSize, alignment); + auto buffSize = packedSize + alignment; + + pkRaw = IAllocator::MakeUniquePtr(alloc, buffSize, true); + ORT_RETURN_IF_ERROR(LoadPackedTensors(workspace.persistent_tensors, pkRaw.get(), packedSize, alignment)); + + if (prepacked_weights != nullptr) { + prepacked_weights->buffers_.push_back(std::move(pkRaw)); + prepacked_weights->buffer_sizes_.push_back(buffSize); } - tconv.out->info()->set_format(tconv.in->info()->format()); + is_packed = true; + } - std::pair ret; - ret = Conv::convLayers.insert(std::pair((OpKernel*)this, tconv)); - pConv = &ret.first->second; + bool free_k = false; + const void* k_data = tensor.DataRaw(); + if (is_channels_last) { + TensorShape shape = tensor.Shape(); + if (shape.NumDimensions() < 4) { + shape = TensorShape({shape[0], shape[1], shape[2], 1}); + } - ACLPrintTensorShape("X", *tconv.in.get()); - ACLPrintTensorShape("Y", *tconv.out.get()); + arm_compute::Tensor kIn; + kIn.allocator()->init(arm_compute::TensorInfo(ACLTensorShape(shape), 1, + k->info()->data_type(), arm_compute::DataLayout::NCHW)); + kIn.info()->set_quantization_info(k->info()->quantization_info()); + ORT_RETURN_IF_ERROR(ACLImportMemory(kIn.allocator(), (void*)k_data, 0)); + k->allocator()->allocate(); + free_k = is_packed; + is_packed = true; + + arm_compute::NEPermute perm_layer; + perm_layer.configure(&kIn, k.get(), {2, 0, 1, 3}); + perm_layer.run(); } else { - // TODO: valildate shapes - pConv = &it->second; + ORT_RETURN_IF_ERROR(ACLImportMemory(k->allocator(), (void*)k_data, 0)); } - const T* x_data = X->Data(); - if (X->Shape().Size() != 0 && pConv->in->info()->has_padding()) { - pConv->in->allocator()->allocate(); - importDataToTensor(pConv->in.get(), x_data); - } else { - ACLImportMemory(pConv->in->allocator(), (void*)x_data, X->Shape().Size() * 4); + for (std::unique_ptr& prep_tensor : workspace.prepare_tensors) { + prep_tensor->allocator()->allocate(); } - const T* k_data = W->Data(); - ACLImportMemory(pConv->k->allocator(), (void*)k_data, W->Shape().Size() * 4); + if (conv_layer) { + conv_layer->prepare(prep_pack); + } else { + depthwise_layer->prepare(); + } - if (B != nullptr) { - const T* b_data = B->Data(); - ACLImportMemory(pConv->b->allocator(), (void*)b_data, B->Shape().Size() * 4); + for (std::unique_ptr& prep_tensor : workspace.prepare_tensors) { + prep_tensor->allocator()->free(); } - T* y_data = Y->MutableData(); - if (Y->Shape().Size() != 0 && pConv->out->info()->has_padding()) { - pConv->out->allocator()->allocate(); - } else { - ACLImportMemory(pConv->out->allocator(), (void*)y_data, Y->Shape().Size() * 4); + if (free_k) { + k->allocator()->free(); } - arm_compute::Allocator alloc_mm{}; - pConv->mm_layer->populate(alloc_mm, 1); - pConv->layer->run(); - pConv->mm_layer->clear(); + return Status::OK(); +} - if (Y->Shape().Size() != 0 && pConv->out->info()->has_padding()) { - importDataFromTensor(pConv->out.get(), y_data); +Status Conv::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + int input_idx, /*out*/ bool& used_shared_buffers) { + used_shared_buffers = false; + if (isQuantized ? (input_idx != 3) : (input_idx != 1)) { + return Status::OK(); } - pConv->in->allocator()->free(); - pConv->k->allocator()->free(); - if (B != nullptr) - pConv->b->allocator()->free(); - pConv->out->allocator()->free(); + if (!workspace.persistent_tensors.empty()) { + size_t packedSize = 0; + size_t alignment = 0; + GetPackingInfo(workspace.persistent_tensors, packedSize, alignment); - LOGS_DEFAULT(VERBOSE) << std::endl; + ORT_RETURN_IF_ERROR(LoadPackedTensors(workspace.persistent_tensors, prepacked_buffers[0].get(), + packedSize, alignment)); + + used_shared_buffers = true; + } return Status::OK(); } -#else -template -Status Conv::Compute(OpKernelContext* context) const { - size_t num_inputs = OpKernel::Node().InputDefs().size(); + +Status Conv::Compute(OpKernelContext* context) const { + provider_->SetThreadPool(context->GetOperatorThreadPool()); const Tensor* X = context->Input(0); - const Tensor* W = context->Input(1); - const Tensor* B = num_inputs == 3 ? context->Input(2) : nullptr; - LOGS_DEFAULT(VERBOSE) << "X " << X->Shape().ToString().c_str(); - LOGS_DEFAULT(VERBOSE) << "W " << W->Shape().ToString().c_str(); - if (B != nullptr) - LOGS_DEFAULT(VERBOSE) << "B " << B->Shape().ToString().c_str(); + Tensor* Y = context->Output(0, outShape); + + const void* x_data = X->DataRaw(); + ORT_RETURN_IF(X->Shape().Size() != 0 && in->info()->has_padding(), "Padded ACL input tensor not supported"); + ORT_RETURN_IF_ERROR(ACLImportMemory(in->allocator(), (void*)x_data, 0)); + + void* y_data = Y->MutableDataRaw(); + ORT_RETURN_IF(Y->Shape().Size() != 0 && out->info()->has_padding(), "Padded ACL output tensor not supported"); + ORT_RETURN_IF_ERROR(ACLImportMemory(out->allocator(), (void*)y_data, 0)); + + if (conv_layer) { + arm_compute::MemoryGroupResourceScope scope_mg(const_cast(memory_group)); + conv_layer->run(const_cast(run_pack)); + } else { + depthwise_layer->run(); + } + + in->allocator()->free(); + k->allocator()->free(); + out->allocator()->free(); LOGS_DEFAULT(VERBOSE) << std::endl; - Status s = onnxruntime::Conv::Compute(context); - return s; + return Status::OK(); } #endif ONNX_OPERATOR_KERNEL_EX( Conv, kOnnxDomain, + 11, + kAclExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Conv); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + Conv, + kOnnxDomain, + 11, + MLFloat16, + kAclExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Conv); + +ONNX_OPERATOR_KERNEL_EX( + NhwcConv, + kMSDomain, 1, kAclExecutionProvider, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Conv); + Conv); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + QLinearConv, + kOnnxDomain, + 10, + uint8_t, + kAclExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()) + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), + Conv); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + QLinearConv, + kOnnxDomain, + 10, + int8_t, + kAclExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()) + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), + Conv); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + QLinearConv, + kMSDomain, + 1, + uint8_t, + kAclExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()) + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), + Conv); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + QLinearConv, + kMSDomain, + 1, + int8_t, + kAclExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()) + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), + Conv); } // namespace acl } // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/nn/conv.h b/onnxruntime/core/providers/acl/nn/conv.h index 660d47b4172df..b05ba5363542f 100644 --- a/onnxruntime/core/providers/acl/nn/conv.h +++ b/onnxruntime/core/providers/acl/nn/conv.h @@ -1,17 +1,19 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2019-2020, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #pragma once #include "core/framework/op_kernel.h" #include "core/providers/cpu/nn/conv.h" +#include "core/providers/acl/acl_common.h" #include "core/providers/acl/acl_execution_provider.h" // ACL -#ifdef ACL_2308 #include "arm_compute/runtime/Tensor.h" -#endif #include "arm_compute/core/TensorInfo.h" +#include "arm_compute/runtime/IOperator.h" +#include "arm_compute/runtime/Tensor.h" #include "arm_compute/runtime/TensorAllocator.h" #include "arm_compute/runtime/Allocator.h" #include "arm_compute/runtime/PoolManager.h" @@ -19,45 +21,50 @@ #include "arm_compute/runtime/MemoryManagerOnDemand.h" // NEON -#include "arm_compute/runtime/NEON/functions/NEConvolutionLayer.h" #include "arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h" namespace onnxruntime { namespace acl { -typedef struct -{ - std::shared_ptr layer; - std::shared_ptr mm_layer; - std::shared_ptr in; - std::shared_ptr k; - std::shared_ptr b; - std::shared_ptr out; - bool isDepthwiseCPU; -} ACLNEConv; - -typedef std::map::iterator ConvLayersIterator; +Status ValidateConv(const onnxruntime::Node& node); -template -class Conv : public onnxruntime::Conv { +class Conv : public onnxruntime::OpKernel { public: - explicit Conv(const OpKernelInfo& info) : onnxruntime::Conv(info), conv_attrs_(info) { - provider_ = (const_cast( - static_cast(info.GetExecutionProvider()))); - } + explicit Conv(const OpKernelInfo& info); - ~Conv() { - Conv::convLayers.erase(this); - } + Status PrePack(const Tensor&, int, AllocatorPtr, + bool& is_packed, PrePackedWeights*) override; + + Status UseSharedPrePackedBuffers(std::vector&, + int, bool&) override; Status Compute(OpKernelContext* context) const override; protected: - static thread_local std::map convLayers; ConvAttributes conv_attrs_; ACLExecutionProvider* provider_; std::string activation_type; + std::shared_ptr depthwise_layer; + + std::shared_ptr conv_layer; + arm_compute::MemoryGroup memory_group; + arm_compute::ITensorPack run_pack; + arm_compute::ITensorPack prep_pack; + + Workspace workspace; + + std::shared_ptr in; + std::shared_ptr k; + IAllocatorUniquePtr pkRaw; + std::shared_ptr b; + std::shared_ptr out; + TensorShape outShape; + bool is_channels_last; + bool isQuantized; + bool isDepthwiseCPU; + bool has_bias; + arm_compute::TensorShape ACLReshapeWeightsDepthwise(arm_compute::Tensor* kernel) const; }; } // namespace acl diff --git a/onnxruntime/core/providers/acl/nn/fused_conv.cc b/onnxruntime/core/providers/acl/nn/fused_conv.cc index 3cf18394b5c4c..34e50ebdf6921 100644 --- a/onnxruntime/core/providers/acl/nn/fused_conv.cc +++ b/onnxruntime/core/providers/acl/nn/fused_conv.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2020, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #ifdef _WIN32 @@ -17,11 +18,13 @@ namespace onnxruntime { namespace acl { -class FusedConv final : public acl::Conv { +class FusedConv final : public acl::Conv { public: - explicit FusedConv(const OpKernelInfo& info) : acl::Conv(info) { + explicit FusedConv(const OpKernelInfo& info) : acl::Conv(info) { ORT_ENFORCE(info.GetAttr("activation", &(this->activation_type)).IsOK()); - ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK()); + MLAS_ACTIVATION activation; + activation.ActivationKind = MlasIdentityActivation; + ORT_ENFORCE(GetFusedActivationAttr(info, activation).IsOK()); } }; diff --git a/onnxruntime/core/providers/acl/nn/pool.cc b/onnxruntime/core/providers/acl/nn/pool.cc index 01d9bc0302c3a..cbbecef6bbfac 100644 --- a/onnxruntime/core/providers/acl/nn/pool.cc +++ b/onnxruntime/core/providers/acl/nn/pool.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2019, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include @@ -63,12 +64,7 @@ ACLNEPool PoolOperation(onnxruntime::OpKernelContext* context, if (pool_attrs.global_pooling) { layer->configure(tpool.in.get(), tpool.out.get(), - arm_compute::PoolingLayerInfo(pool_type -#ifdef ACL_2308 - , - arm_compute::DataLayout::NCHW -#endif - )); + arm_compute::PoolingLayerInfo(pool_type, arm_compute::DataLayout::NCHW)); } else { TensorShapeVector aclStrides(2); aclStrides[0] = (strides.size() == 2) ? strides[1] : 1; @@ -95,8 +91,11 @@ ACLNEPool PoolOperation(onnxruntime::OpKernelContext* context, aclPads[3] = pads[2]; } - arm_compute::PadStrideInfo aclPadStride = arm_compute::PadStrideInfo(aclStrides[0], aclStrides[1], - aclPads[0], aclPads[1], aclPads[2], aclPads[3], arm_compute::DimensionRoundingType::FLOOR); + arm_compute::PadStrideInfo aclPadStride = arm_compute::PadStrideInfo( + (unsigned int)aclStrides[0], (unsigned int)aclStrides[1], + (unsigned int)aclPads[0], (unsigned int)aclPads[1], + (unsigned int)aclPads[2], (unsigned int)aclPads[3], + arm_compute::DimensionRoundingType::FLOOR); TensorShapeVector aclKernelShape(2); aclKernelShape[0] = (kernel_shape.size() > 1) ? kernel_shape[1] : 1; @@ -113,9 +112,7 @@ ACLNEPool PoolOperation(onnxruntime::OpKernelContext* context, arm_compute::PoolingLayerInfo pool_info(pool_type, aclSize, -#ifdef ACL_2308 arm_compute::DataLayout::NCHW, -#endif aclPadStride, excludePadding); layer->configure(tpool.in.get(), tpool.out.get(), pool_info); @@ -133,8 +130,8 @@ ACLNEPool PoolOperation(onnxruntime::OpKernelContext* context, aclInpuWindow.use_tensor_dimensions(tpool.in->info()->tensor_shape()); arm_compute::Iterator aclInputIt(tpool.in.get(), aclInpuWindow); - const unsigned int aclWidth = tpool.in->info()->dimension(0); - const unsigned int aclHeight = tpool.in->info()->dimension(1); + const size_t aclWidth = tpool.in->info()->dimension(0); + const size_t aclHeight = tpool.in->info()->dimension(1); // copy input tensor into the larger buffer arm_compute::execute_window_loop( diff --git a/onnxruntime/core/providers/acl/scheduler.cc b/onnxruntime/core/providers/acl/scheduler.cc new file mode 100644 index 0000000000000..e1bab6adb5a1f --- /dev/null +++ b/onnxruntime/core/providers/acl/scheduler.cc @@ -0,0 +1,44 @@ +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#include "core/common/common.h" +#include "scheduler.h" + +using onnxruntime::concurrency::ThreadPool; + +namespace onnxruntime { +namespace acl { + +void ORTScheduler::set_num_threads(unsigned int num_threads) { + ORT_THROW("Not supported"); +} + +unsigned int ORTScheduler::num_threads() const { + // We can't check the size of the thread pool during kernel initialization, + // as required by ACL. Therefore we have to choose a fixed thread count and + // let some cores run multiple workloads if there are fewer than 32 cores. + // This doesn't seem to cause performance issues with fewer cores in practice. + return 32; +} + +void ORTScheduler::schedule(arm_compute::ICPPKernel* kernel, const Hints& hints) { + arm_compute::ITensorPack tensors; + schedule_op(kernel, hints, kernel->window(), tensors); +} + +void ORTScheduler::schedule_op(arm_compute::ICPPKernel* kernel, const Hints& hints, + const arm_compute::Window& window, arm_compute::ITensorPack& tensors) { + schedule_common(kernel, hints, window, tensors); +} + +void ORTScheduler::run_workloads(std::vector& workloads) { + ThreadPool::TrySimpleParallelFor(_provider->GetThreadPool(), workloads.size(), + [&](std::ptrdiff_t id) { + const arm_compute::ThreadInfo info{ + (int)id, (int)workloads.size(), &cpu_info()}; + workloads[id](info); + }); +} + +} // namespace acl +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/scheduler.h b/onnxruntime/core/providers/acl/scheduler.h new file mode 100644 index 0000000000000..c66700a48f3d5 --- /dev/null +++ b/onnxruntime/core/providers/acl/scheduler.h @@ -0,0 +1,33 @@ +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#include "acl_execution_provider.h" + +#include "arm_compute/runtime/IScheduler.h" +#include "arm_compute/core/CPP/ICPPKernel.h" + +namespace onnxruntime { +namespace acl { + +class ORTScheduler : public arm_compute::IScheduler { + public: + ORTScheduler(ACLExecutionProvider* provider) : _provider(provider) { + } + + void set_num_threads(unsigned int num_threads) override; + + unsigned int num_threads() const override; + + void schedule(arm_compute::ICPPKernel* kernel, const Hints& hints) override; + + void schedule_op(arm_compute::ICPPKernel* kernel, const Hints& hints, + const arm_compute::Window& window, arm_compute::ITensorPack& tensors) override; + + void run_workloads(std::vector& workloads) override; + + private: + ACLExecutionProvider* _provider = nullptr; +}; + +} // namespace acl +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/tensor/concat.cc b/onnxruntime/core/providers/acl/tensor/concat.cc index 75eedaac80aea..0cf02ab8762b9 100644 --- a/onnxruntime/core/providers/acl/tensor/concat.cc +++ b/onnxruntime/core/providers/acl/tensor/concat.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2020, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "core/providers/acl/tensor/concat.h" @@ -76,11 +77,7 @@ Status Concat::Compute(OpKernelContext* ctx) const { LOGS_DEFAULT(VERBOSE) << "Concat ACL:"; arm_compute::Tensor output; -#ifdef ACL_2308 std::vector inputs_vector; -#else - std::vector inputs_vector; -#endif for (int i = 0; i < input_count; i++) { arm_compute::Tensor* input = new arm_compute::Tensor(); auto X = input_tensors[i]; @@ -101,11 +98,7 @@ Status Concat::Compute(OpKernelContext* ctx) const { for (int i = 0; i < input_count; i++) { auto X = input_tensors[i]; const T* x_data = X->Data(); -#ifdef ACL_2308 arm_compute::Tensor* in = const_cast(static_cast(inputs_vector[i])); -#else - arm_compute::Tensor* in = static_cast(inputs_vector[i]); -#endif if (X->Shape().Size() != 0 && in->info()->has_padding()) { in->allocator()->allocate(); diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index 4d20061820e71..68460ff7c9b31 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -30,8 +30,8 @@ // to manually do this asm(".linker_option \"-framework\", \"CoreML\""); -using namespace onnxruntime; -using namespace onnxruntime::coreml; +namespace onnxruntime { +namespace coreml { namespace { /** @@ -247,213 +247,6 @@ Status CopyMLMultiArrayBuffer(const void* mlmultiarray_buffer, void* tensor_buff } } // namespace -NS_ASSUME_NONNULL_BEGIN - -// Execution for a CoreML model, it performs -// 1. Compile the model by given path for execution -// 2. Predict using given OnnxTensorFeatureProvider input and copy the output data back ORT -// 3. The compiled model will be removed in dealloc or removed using cleanup function -@interface CoreMLExecution : NSObject { - NSString* coreml_model_path_; - NSString* compiled_model_path_; - const logging::Logger* logger_; - uint32_t coreml_flags_; -} - -- (instancetype)initWithPath:(const std::string&)path - logger:(const logging::Logger&)logger - coreml_flags:(uint32_t)coreml_flags; -- (void)cleanup; -- (void)dealloc; -- (Status)loadModel API_AVAILABLE_COREML3; -- (Status)predict:(const std::unordered_map&)inputs - outputs:(const std::unordered_map&)outputs - getOutputTensorDataFn:(const GetOutputTensorMutableRawDataFn&)get_output_tensor_mutable_raw_data_fn - API_AVAILABLE_COREML3; - -@property(nullable) MLModel* model API_AVAILABLE_COREML3; - -@end - -@implementation CoreMLExecution - -- (instancetype)initWithPath:(const std::string&)path - logger:(const logging::Logger&)logger - coreml_flags:(uint32_t)coreml_flags { - if (self = [super init]) { - coreml_model_path_ = util::Utf8StringToNSString(path.c_str()); - logger_ = &logger; - coreml_flags_ = coreml_flags; - } - return self; -} - -- (void)cleanup { - NSError* error = nil; - if (compiled_model_path_ != nil) { - [[NSFileManager defaultManager] removeItemAtPath:compiled_model_path_ error:&error]; - if (error != nil) { - LOGS(*logger_, ERROR) << "Failed cleaning up the compiled model: " << [compiled_model_path_ UTF8String] - << ", error message: " << [[error localizedDescription] UTF8String]; - } - compiled_model_path_ = nil; - } - -#if !defined(NDEBUG) - std::string path_override = Env::Default().GetEnvironmentVar(util::kOverrideModelOutputDirectoryEnvVar); - if (!path_override.empty()) { - // don't cleanup - coreml_model_path_ = nil; - } -#endif - - if (coreml_model_path_ != nil) { - error = nil; - [[NSFileManager defaultManager] removeItemAtPath:coreml_model_path_ error:&error]; - if (error != nil) { - LOGS(*logger_, ERROR) << "Failed cleaning up the coreml model: " << [coreml_model_path_ UTF8String] - << ", error message: " << [[error localizedDescription] UTF8String]; - } - coreml_model_path_ = nil; - } -} - -- (void)dealloc { - [self cleanup]; -} - -- (Status)loadModel { - NSURL* modelUrl = [NSURL URLWithString:coreml_model_path_]; - if (modelUrl == nil) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create model URL from path"); - } - - // TODO: Update this to version with callback handler as the API used here is deprecated. - // https://developer.apple.com/documentation/coreml/mlmodel/3929553-compilemodelaturl - // As we call loadModel during EP Compile there shouldn't be an issue letting the actual compile run in the - // background. We will have to check for completion in `predict` and block until it is done. - NSError* error = nil; - NSURL* compileUrl = [MLModel compileModelAtURL:modelUrl error:&error]; - - if (error != nil) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Error compiling model: ", - [[error localizedDescription] UTF8String]); - } - - compiled_model_path_ = [compileUrl path]; - - MLModelConfiguration* config = [MLModelConfiguration alloc]; - config.computeUnits = (coreml_flags_ & COREML_FLAG_USE_CPU_ONLY) - ? MLComputeUnitsCPUOnly - : MLComputeUnitsAll; - _model = [MLModel modelWithContentsOfURL:compileUrl configuration:config error:&error]; - - if (error != nil || _model == nil) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create MLModel", - (error != nil) ? MakeString(", error: ", [[error localizedDescription] UTF8String]) : ""); - } - - return Status::OK(); -} - -- (Status)predict:(const std::unordered_map&)inputs - outputs:(const std::unordered_map&)outputs - getOutputTensorDataFn:(const GetOutputTensorMutableRawDataFn&)get_output_tensor_mutable_raw_data_fn { - Status status = Status::OK(); - ORT_TRY { - if (_model == nil) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Model is not loaded"); - } - - id input_features; - InlinedVector> conversion_buffers; - ORT_RETURN_IF_ERROR(CreateInputFeatureProvider(inputs, *logger_, &input_features, conversion_buffers)); - - MLPredictionOptions* options = [[MLPredictionOptions alloc] init]; - NSError* error = nil; - id output_features = [_model predictionFromFeatures:input_features - options:options - error:&error]; - - if (error != nil) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Error executing model: ", - [[error localizedDescription] UTF8String]); - } - - for (const auto& [output_name, output_tensor_info] : outputs) { - MLFeatureValue* output_value = - [output_features featureValueForName:util::Utf8StringToNSString(output_name.c_str())]; - - if (output_value == nil) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "output_features has no value for ", output_name); - } - - MLMultiArray* data = [output_value multiArrayValue]; - - const auto coreml_static_output_shape = [data]() { - InlinedVector result; - result.reserve(data.shape.count); - for (NSNumber* dim in data.shape) { - const auto dim_value = dim.longLongValue; - result.push_back(dim_value); - } - return result; - }(); - - const auto static_output_shape = GetStaticOutputShape(output_tensor_info.shape, coreml_static_output_shape, - *logger_); - - void* output_buffer = get_output_tensor_mutable_raw_data_fn(output_name, output_tensor_info.data_type, - static_output_shape); - - if (const size_t num_elements = data.count; num_elements > 0) { - if (const auto shape_size = ShapeSize(static_output_shape); - shape_size < 0 || num_elements != static_cast(shape_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "CoreML MLMultiArray count (", num_elements, ") and shape size (", shape_size, - ") do not match"); - } - - // support a non-contiguous array, provided only one dimension is not contiguous - int64_t num_blocks = 0; - int64_t block_size = 0; - int64_t stride = 0; - - ORT_RETURN_IF_ERROR(GetMLMultiArrayCopyInfo(data, num_blocks, block_size, stride)); - - __block Status copy_status; - const auto* tensor_info = &output_tensor_info; - // `getBytesWithHandler` replaces deprecated `.dataPointer` on new versions - if (@available(macOS 12.3, iOS 15.4, *)) { - [data getBytesWithHandler:^(const void* bytes, NSInteger size) { - copy_status = CopyMLMultiArrayBuffer(bytes, output_buffer, data, - num_blocks, block_size, stride, tensor_info); - }]; - } else { - copy_status = CopyMLMultiArrayBuffer(data.dataPointer, output_buffer, data, - num_blocks, block_size, stride, tensor_info); - } - - ORT_RETURN_IF_ERROR(copy_status); - } - } - } - ORT_CATCH(const std::exception& e) { - ORT_HANDLE_EXCEPTION([&]() { - status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Exception: ", e.what()); - }); - } - - return status; -} - -@end - -NS_ASSUME_NONNULL_END - -namespace onnxruntime { -namespace coreml { - Status GetMLMultiArrayCopyInfo(const MLMultiArray* _Nonnull array, int64_t& num_blocks, int64_t& block_size, int64_t& stride) { const auto* shape = array.shape; @@ -498,11 +291,14 @@ Status GetMLMultiArrayCopyInfo(const MLMultiArray* _Nonnull array, } // Internal Execution class -// This class will bridge Model (c++) with CoreMLExecution (objective c++) +// This class is part of the model class and handles the calls into CoreML. Specifically, it performs +// 1. Compile the model by given path for execution +// 2. Predict using given OnnxTensorFeatureProvider input and copy the output data back ORT +// 3. The compiled model will be removed in dealloc or removed using cleanup function class Execution { public: Execution(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags); - ~Execution() {}; + ~Execution(); Status LoadModel(); Status Predict(const std::unordered_map& inputs, @@ -510,30 +306,97 @@ Status Predict(const std::unordered_map& inputs, const GetOutputTensorMutableRawDataFn& get_output_tensor_mutable_raw_data_fn); private: - bool model_loaded{false}; - CoreMLExecution* execution_; + void cleanup(); + NSString* coreml_model_path_{nil}; + NSString* compiled_model_path_{nil}; + const logging::Logger& logger_; + uint32_t coreml_flags_{0}; + MLModel* model_{nil}; }; -Execution::Execution(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags) { +Execution::Execution(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags) + : logger_(logger), + coreml_flags_(coreml_flags) { @autoreleasepool { - execution_ = [[CoreMLExecution alloc] initWithPath:path - logger:logger - coreml_flags:coreml_flags]; + coreml_model_path_ = util::Utf8StringToNSString(path.c_str()); + } +} + +Execution::~Execution() { + @autoreleasepool { + cleanup(); + } +} + +void Execution::cleanup() { + NSError* error = nil; + if (compiled_model_path_ != nil) { + [[NSFileManager defaultManager] removeItemAtPath:compiled_model_path_ error:&error]; + if (error != nil) { + LOGS(logger_, ERROR) << "Failed cleaning up the compiled model: " << [compiled_model_path_ UTF8String] + << ", error message: " << [[error localizedDescription] UTF8String]; + } + compiled_model_path_ = nil; + } + +#if !defined(NDEBUG) + std::string path_override = Env::Default().GetEnvironmentVar(util::kOverrideModelOutputDirectoryEnvVar); + if (!path_override.empty()) { + // don't cleanup + coreml_model_path_ = nil; + } +#endif + + if (coreml_model_path_ != nil) { + error = nil; + [[NSFileManager defaultManager] removeItemAtPath:coreml_model_path_ error:&error]; + if (error != nil) { + LOGS(logger_, ERROR) << "Failed cleaning up the coreml model: " << [coreml_model_path_ UTF8String] + << ", error message: " << [[error localizedDescription] UTF8String]; + } + coreml_model_path_ = nil; } } Status Execution::LoadModel() { - if (model_loaded) { + if (model_ != nil) { return Status::OK(); } if (HAS_COREML3_OR_LATER) { - Status status{}; @autoreleasepool { - status = [execution_ loadModel]; + NSError* error = nil; + + NSURL* modelUrl = [NSURL URLWithString:coreml_model_path_]; + if (modelUrl == nil) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create model URL from path"); + } + + // TODO: Update this to version with callback handler as the API used here is deprecated. + // https://developer.apple.com/documentation/coreml/mlmodel/3929553-compilemodelaturl + // As we call loadModel during EP Compile there shouldn't be an issue letting the actual compile run in the + // background. We will have to check for completion in `predict` and block until it is done. + NSURL* compileUrl = [MLModel compileModelAtURL:modelUrl error:&error]; + if (error != nil) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Error compiling model: ", + [[error localizedDescription] UTF8String]); + } + + compiled_model_path_ = [compileUrl path]; + + MLModelConfiguration* config = [MLModelConfiguration alloc]; + config.computeUnits = (coreml_flags_ & COREML_FLAG_USE_CPU_ONLY) + ? MLComputeUnitsCPUOnly + : MLComputeUnitsAll; + model_ = [MLModel modelWithContentsOfURL:compileUrl configuration:config error:&error]; + + if (error != nil || model_ == nil) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create MLModel", + (error != nil) ? MakeString(", error: ", [[error localizedDescription] UTF8String]) : ""); + } + + return Status::OK(); } - model_loaded = status.IsOK(); - return status; } return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Execution::LoadModel requires macos 10.15+ or ios 13+"); @@ -542,13 +405,94 @@ Status Predict(const std::unordered_map& inputs, Status Execution::Predict(const std::unordered_map& inputs, const std::unordered_map& outputs, const GetOutputTensorMutableRawDataFn& get_output_tensor_mutable_raw_data_fn) { - ORT_RETURN_IF_NOT(model_loaded, "Execution::Predict requires Execution::LoadModel"); - if (HAS_COREML3_OR_LATER) { @autoreleasepool { - return [execution_ predict:inputs - outputs:outputs - getOutputTensorDataFn:get_output_tensor_mutable_raw_data_fn]; + Status status = Status::OK(); + ORT_TRY { + if (model_ == nil) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Model is not loaded"); + } + + id input_features; + InlinedVector> conversion_buffers; + ORT_RETURN_IF_ERROR(CreateInputFeatureProvider(inputs, logger_, &input_features, conversion_buffers)); + + MLPredictionOptions* options = [[MLPredictionOptions alloc] init]; + NSError* error = nil; + id output_features = [model_ predictionFromFeatures:input_features + options:options + error:&error]; + + if (error != nil) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Error executing model: ", + [[error localizedDescription] UTF8String]); + } + + for (const auto& [output_name, output_tensor_info] : outputs) { + MLFeatureValue* output_value = + [output_features featureValueForName:util::Utf8StringToNSString(output_name.c_str())]; + + if (output_value == nil) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "output_features has no value for ", output_name); + } + + MLMultiArray* data = [output_value multiArrayValue]; + + const auto coreml_static_output_shape = [data]() { + InlinedVector result; + result.reserve(data.shape.count); + for (NSNumber* dim in data.shape) { + const auto dim_value = dim.longLongValue; + result.push_back(dim_value); + } + return result; + }(); + + const auto static_output_shape = GetStaticOutputShape(output_tensor_info.shape, coreml_static_output_shape, + logger_); + + void* output_buffer = get_output_tensor_mutable_raw_data_fn(output_name, output_tensor_info.data_type, + static_output_shape); + + if (const size_t num_elements = data.count; num_elements > 0) { + if (const auto shape_size = ShapeSize(static_output_shape); + shape_size < 0 || num_elements != static_cast(shape_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "CoreML MLMultiArray count (", num_elements, ") and shape size (", shape_size, + ") do not match"); + } + + // support a non-contiguous array, provided only one dimension is not contiguous + int64_t num_blocks = 0; + int64_t block_size = 0; + int64_t stride = 0; + + ORT_RETURN_IF_ERROR(GetMLMultiArrayCopyInfo(data, num_blocks, block_size, stride)); + + __block Status copy_status; + const auto* tensor_info = &output_tensor_info; + // `getBytesWithHandler` replaces deprecated `.dataPointer` on new versions + if (@available(macOS 12.3, iOS 15.4, *)) { + [data getBytesWithHandler:^(const void* bytes, NSInteger size) { + copy_status = CopyMLMultiArrayBuffer(bytes, output_buffer, data, + num_blocks, block_size, stride, tensor_info); + }]; + } else { + copy_status = CopyMLMultiArrayBuffer(data.dataPointer, output_buffer, data, + num_blocks, block_size, stride, tensor_info); + } + + ORT_RETURN_IF_ERROR(copy_status); + } + } + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Exception: ", e.what()); + }); + } + + return status; } } diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h index 84952ac8620ca..b031a6f0cefa3 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h @@ -23,7 +23,9 @@ struct TreeNodeElementId { } struct hash_fn { std::size_t operator()(const TreeNodeElementId& key) const { - return static_cast(static_cast(key.tree_id) << 32 | static_cast(key.node_id)); + // unordered_map has poor performance on Windows when inserting consecutive keys. + // keys are usually inserted with key.node_id being incremented at each iteration. + return static_cast(static_cast(key.tree_id) | static_cast(key.node_id) << 32); } }; }; @@ -329,8 +331,7 @@ class TreeAggregatorMax : public TreeAggregator& base_values) : TreeAggregator(n_trees, n_targets_or_classes, - post_transform, base_values) {} + const std::vector& base_values) : TreeAggregator(n_trees, n_targets_or_classes, post_transform, base_values) {} // 1 output diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index f2aaa75cadd8d..35f3b12aeba35 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -258,7 +258,7 @@ struct TensorCaster { auto out_data = out.MutableData(); auto in_data = in.Data(); const size_t shape_size = narrow(shape.Size()); - MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size); + MlasConvertHalfToFloatBuffer(in_data, out_data, shape_size); } }; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index b54c572556220..82b29c7b0562e 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -2473,7 +2473,8 @@ static bool RNNNeedFallbackToCPU(const onnxruntime::Node& node, return false; } -static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node, const logging::Logger& logger, +static bool ConvTransposeNeedFallbackToCPU([[maybe_unused]] const onnxruntime::Node& node, + [[maybe_unused]] const logging::Logger& logger, [[maybe_unused]] const GraphViewer& graph_viewer, [[maybe_unused]] const bool prefer_nhwc) { const auto& node_attributes = node.GetAttributes(); diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index 95ba698b707ac..cc76198dc3ae9 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -385,7 +385,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected if (cuda_ep->GetCudnnConv1dPadToNc1d()) { x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1); y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1); - w_dims_cudnn.insert(w_dims.begin() + 2, 1); + w_dims_cudnn.insert(w_dims_cudnn.begin() + 2, 1); pads.insert(pads.begin() + kernel_rank, 0); pads.insert(pads.begin(), 0); kernel_shape.insert(kernel_shape.begin(), 1); diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index bac99d6a81ed2..d4876e1714861 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -7,6 +7,11 @@ #include "conv_transpose.h" #include "core/providers/cuda/tensor/transpose.h" +#if CUDNN_MAJOR < 9 +// if compiled with cuDNN 8 we want to use the legacy cuDNN API +#include "conv_transpose_8.h" +#endif + // To suppress FP static analyzer warnings: // https://msdata.visualstudio.com/Vienna/_workitems/edit/1944928 and // https://msdata.visualstudio.com/Vienna/_workitems/edit/1944950 @@ -38,48 +43,42 @@ REGISTER_KERNEL_TYPED(float, kMSInternalNHWCDomain, true) REGISTER_KERNEL_TYPED(MLFloat16, kMSInternalNHWCDomain, true) #endif -template -Status ConvTranspose::ComputeInternal(OpKernelContext* context) const { - return DoConvTranspose(context, false); -} - +// First input (in this case X) is in case NHWC == true also in NHWC format, the other inputs in NCHW template Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, - PrePackedWeights* prepacked_weights) { + [[maybe_unused]] PrePackedWeights* prepacked_weights) { is_packed = false; // only layout of weight input is adjusted via PrePack - if constexpr (NHWC) { // InputTensors::IN_W - if (input_idx == 1) { + if constexpr (NHWC) { + if (is_nhwc_domain_ && input_idx == 1) { // InputTensors::IN_W + // Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group} auto orig_shape = tensor.Shape(); - const auto rank = orig_shape.NumDimensions(); - - InlinedVector perm; - TensorShapeVector new_dims; - - // Input is { N, C, ...}. Output is { N, M, ...}. 'input channels' is C. 'output channels' is M. - // Transpose the output channels related dimension (M/group) to be last. Leave the input channels as-is. - if (rank == 3) { - // Transpose from {C, M/group, k1} to {C, k1, M/group} - perm = {0, 2, 1}; - new_dims = TensorShapeVector{orig_shape[0], orig_shape[2], orig_shape[1]}; - } else if (rank == 4) { - // Transpose from {C, M/group, kH, kW} to {C, kH, kW, M/group} - perm = {0, 2, 3, 1}; - new_dims = TensorShapeVector{orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[1]}; - } else if (rank == 5) { - // Transpose from {C, M/group, k1, k2, k3} to {C, k1, k2, k3, M/group} - perm = {0, 2, 3, 4, 1}; - new_dims = TensorShapeVector{orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[4], orig_shape[1]}; - } + auto shape_size = orig_shape.GetDims().size(); + + InlinedVector perm; + perm.push_back(0); + for (size_t i = 2; i < shape_size; i++) perm.push_back(i); + perm.push_back(1); + gsl::span permutation(perm.data(), shape_size); - gsl::span permutation(perm.data(), rank); - W_ = Tensor::Create(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); + TensorShapeVector nhwc_dims; + for (size_t i = 0; i < shape_size; i++) { + nhwc_dims.push_back(orig_shape[perm[i]]); + } - ORT_RETURN_IF_ERROR(cuda::Transpose::DoTranspose(GetDeviceProp(), DefaultCudaStream(), DefaultCublasHandle(), - permutation, tensor, *W_)); + W_ = Tensor::Create(tensor.DataType(), TensorShape(nhwc_dims), std::move(alloc)); + auto status = cuda::Transpose::DoTranspose(GetDeviceProp(), + DefaultCudaStream(), + DefaultCublasHandle(), + permutation, tensor, *W_); + if (!status.IsOK()) { + return status; + } CUDA_CALL_THROW(cudaStreamSynchronize(DefaultCudaStream())); is_packed = true; + } else { + W_already_nhwc = true; } } else { ORT_UNUSED_PARAMETER(tensor); @@ -91,236 +90,419 @@ Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, Allo return Status::OK(); } -template -Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const { - typedef typename ToCudaType::MappedType CudaT; +#if CUDNN_MAJOR >= 9 +#if !defined(__CUDACC__) + +template +Status ConvTranspose::CreateCudnnFeExecutionPlan(const onnxruntime::TensorShapeVector& x_dims, + const onnxruntime::TensorShapeVector& w_dims, + const Tensor* B, + const TensorShapeVector& y_dims, + cudnnContext* handle, + const cudnn_frontend::HeurMode_t heur_mode, + const std::vector& pads, + const std::vector& strides, + const std::vector& dilations, + const bool fuse_bias, + const bool fuse_act, + const bool w_in_nhwc, + const bool use_tf32) const { + s_.bias_fused = fuse_bias; + s_.act_fused = fuse_act; + s_.variant_pack.clear(); // clear variant pack, as stored pointers to tensors change + s_.cudnn_fe_graph = std::make_unique(); + cudnn_frontend::DataType_t data_type = CudnnFeTensor::GetDataType(); + s_.cudnn_fe_graph->set_io_data_type(data_type).set_intermediate_data_type(data_type); + if (data_type == cudnn_frontend::DataType_t::HALF) { + s_.cudnn_fe_graph->set_compute_data_type(cudnn_frontend::DataType_t::FLOAT); + } else { + s_.cudnn_fe_graph->set_compute_data_type(data_type); + } - const Tensor* X = context->Input(0); - const TensorShape& x_shape = X->Shape(); - auto x_dims = x_shape.AsShapeVector(); - auto x_data = reinterpret_cast(X->Data()); - - auto x_dimensions = X->Shape().NumDimensions(); - if (x_dimensions < 3 || x_dimensions > 5) { - // TODO: the error message should tell which operator raises it. - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input X must be 3-, 4- or 5-dimensional.", - " X: ", X->Shape().ToString().c_str()); + s_.cudnn_fe_X = s_.cudnn_fe_graph->tensor(CudnnFeTensor(x_dims, "x", data_type, Layout == LAYOUT_NHWC).Get()); + s_.cudnn_fe_W = s_.cudnn_fe_graph->tensor(CudnnFeTensor(w_dims, "w", data_type, w_in_nhwc).Get()); + + auto conv_options = cudnn_frontend::graph::Conv_dgrad_attributes() + .set_pre_padding(std::vector(pads.begin(), + pads.begin() + pads.size() / 2)) + .set_post_padding(std::vector(pads.begin() + pads.size() / 2, pads.end())) + .set_stride(strides) + .set_dilation(dilations); + s_.cudnn_fe_conv_Y = s_.cudnn_fe_graph->conv_dgrad(s_.cudnn_fe_X, s_.cudnn_fe_W, conv_options); + auto cudnn_fe_y_tensor = CudnnFeTensor(y_dims, "y", data_type, Layout == LAYOUT_NHWC).Get(); + + if (B == nullptr) { + s_.cudnn_fe_Y = s_.cudnn_fe_conv_Y; + } else { + int64_t bias_size; + if (B != nullptr) { + bias_size = B->Shape()[0]; + } else { + bias_size = w_dims[0]; + } + + if (fuse_bias) { + onnxruntime::TensorShapeVector b_dims; + for (size_t i = 0; i < x_dims.size(); i++) { + b_dims.push_back(i == 1 ? bias_size : 1); + } + auto bias_tensor = CudnnFeTensor(b_dims, "b", data_type, Layout == LAYOUT_NHWC).Get(); + auto bias_options = cudnn_frontend::graph::Pointwise_attributes().set_mode(cudnn_frontend::PointwiseMode_t::ADD); + s_.cudnn_fe_B = s_.cudnn_fe_graph->tensor(bias_tensor); + s_.cudnn_fe_Y = s_.cudnn_fe_graph->pointwise(s_.cudnn_fe_conv_Y, s_.cudnn_fe_B, bias_options); + } else { + s_.cudnn_fe_Y = s_.cudnn_fe_conv_Y; + + TensorShapeVector b_dims(y_dims.size(), 1); + TensorShapeVector b_strides(y_dims.size(), 1); + b_dims[1] = bias_size; + b_strides[0] = bias_size; + + ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType(), b_strides)); + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims, CudnnTensor::GetDataType(), cudnn_fe_y_tensor.get_stride())); + + /* Creating an own CUDNN Frontend graph for the bias addition. + s_.cudnn_fe_bias_graph = std::make_unique(); + s_.cudnn_fe_bias_graph->set_io_data_type(data_type) + .set_compute_data_type(data_type == cudnn_frontend::DataType_t::HALF ? + cudnn_frontend::DataType_t::FLOAT : data_type) + .set_intermediate_data_type(data_type); + s_.cudnn_fe_bias_X = s_.cudnn_fe_bias_graph->tensor(CudnnFeTensor(y_dims, "x", data_type).Get()); + + s_.cudnn_fe_B = s_.cudnn_fe_bias_graph->tensor(bias_tensor); + s_.cudnn_fe_bias_Y = s_.cudnn_fe_bias_graph->pointwise(s_.cudnn_fe_bias_X, s_.cudnn_fe_B, bias_options); + s_.cudnn_fe_bias_Y->set_output(true); + + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->validate()); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->build_operation_graph(handle)); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->create_execution_plans({heur_mode})); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->check_support(handle)); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->build_plans(handle));*/ + } + } + if (fuse_act && s_.cudnn_fe_act_attr.has_value()) { + auto& activation_attr = s_.cudnn_fe_act_attr.value(); + s_.cudnn_fe_Y = s_.cudnn_fe_graph->pointwise(s_.cudnn_fe_Y, activation_attr); } - // use pre-packed W if available - const Tensor* W = W_ ? W_.get() : context->Input(1); + s_.cudnn_fe_Y->set_dim(cudnn_fe_y_tensor.get_dim()); + s_.cudnn_fe_Y->set_stride(cudnn_fe_y_tensor.get_stride()); + s_.cudnn_fe_Y->set_output(true); + + try { + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->validate()); + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->build_operation_graph(handle)); + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->create_execution_plans({heur_mode})); + } catch (const std::exception& ex) { + std::string message = MakeString("Failed to initialize CUDNN Frontend", ex.what(), + "with the cudnn frontend json:\n", s_.cudnn_fe_graph->print()); + return Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::EP_FAIL, message); + } - const TensorShape& w_shape = W->Shape(); - TensorShapeVector w_dims = w_shape.AsShapeVector(); - auto w_data = reinterpret_cast(W->Data()); + if (!use_tf32) s_.cudnn_fe_graph->deselect_numeric_notes({cudnn_frontend::NumericalNote_t::TENSOR_CORE}); + + try { + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->check_support(handle)); + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->build_plans(handle)); + } catch (const std::exception& ex) { + if (!fuse_bias && !fuse_act && use_tf32) { + std::string message = MakeString("OP not supported by CUDNN Frontend", ex.what(), + "with the cudnn frontend json:\n", s_.cudnn_fe_graph->print()); + return Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::EP_FAIL, message); + } + + // Try fallback. + return CreateCudnnFeExecutionPlan(x_dims, w_dims, B, y_dims, handle, heur_mode, + pads, strides, dilations, false, false, w_in_nhwc, true); + } + + s_.workspace_bytes = s_.cudnn_fe_graph->get_workspace_size(); + return Status::OK(); +} + +#endif + +template +Status ConvTranspose::UpdateState(OpKernelContext* context, bool dynamic_padding) const { + constexpr bool channels_last = Layout == LAYOUT_NHWC; size_t num_inputs = OpKernel::Node().InputDefs().size(); bool has_bias = dynamic_padding ? num_inputs == 4 : num_inputs == 3; - CudaT* y_data = nullptr; + // set X + const Tensor* X = context->Input(0); + const TensorShape& x_shape = X->Shape(); + // X incl. x_dims is in NHWC Format iff. NHWC == true + const auto x_dims = x_shape.AsShapeVector(); + + s_.x_data = reinterpret_cast(X->Data()); + s_.element_size = X->DataType()->Size(); + + // set W + bool w_in_nhwc; + const Tensor* W; + if (!W_) { + W = context->Input(1); + w_in_nhwc = false; + // Dims and memory layout are in NCHW format + } else { + W = W_.get(); + w_in_nhwc = channels_last; + // W got prepacked, therefore if NHWC == true, then dims and memory layout are in NHWC + } + const TensorShape& w_shape = W->Shape(); + onnxruntime::TensorShapeVector w_dims = w_shape.AsShapeVector(); + s_.w_data = reinterpret_cast(W->Data()); + + // set B + // Always in NCHW format + const Tensor* B = nullptr; + if (has_bias) { + B = context->Input(dynamic_padding ? 3 : 2); + s_.b_data = reinterpret_cast(B->Data()); + } else { + s_.b_data = nullptr; + } - const auto* cuda_ep = static_cast(Info().GetExecutionProvider()); + const Tensor* Pads = dynamic_padding ? context->Input(2) : nullptr; - // convert 1D to 2D - if (x_dimensions == 3) { - // we can either add a fake H or W dimension with a value of 1. to be consistent with the Conv behavior we use - // GetCudnnConv1dPadToNc1d to determine which is added. - // see Conv::UpdateState in /onnxruntime/core/providers/cuda/nn/conv.cc for more details. - if (cuda_ep->GetCudnnConv1dPadToNc1d()) { - // add fake H dimension - const auto insert_at = NHWC ? 1 : 2; + bool input_dims_changed = (s_.last_x_dims != x_dims); + bool w_dims_changed = (s_.last_w_dims != w_dims); + if (input_dims_changed || w_dims_changed) { + if (input_dims_changed) + s_.last_x_dims = gsl::make_span(x_dims); - // NCHW: N, C, d1 -> N, C, 1, d1 - // NHWC: N, d1, C -> N, 1, d1, C - x_dims.insert(x_dims.begin() + insert_at, 1); + if (w_dims_changed) { + s_.last_w_dims = gsl::make_span(w_dims); + } - // 'M' is channels dim in CUDA implementation - // NCHW: C, M/g, k1 -> C, M/g, 1, k1 - // NHWC: C, k1, M/g -> C, 1, k1, M/g - w_dims.insert(w_dims.begin() + insert_at, 1); - } else { - // add fake W dimension - const auto insert_at = NHWC ? 2 : 3; + // The following code is from ConvTransposeAttributes::PrepareForCompute - // NCHW: N, C, d1 -> N, C, d1, 1 - // NHWC: N, d1, C -> N, d1, 1, C - x_dims.insert(x_dims.begin() + insert_at, 1); + const int rank = static_cast(X->Shape().NumDimensions()); + TensorShape input_shape = X->Shape().Slice(channels_last ? 1 : 2, channels_last ? rank - 1 : rank); + const int64_t num_input_channels = channels_last ? X->Shape()[rank - 1] : X->Shape()[1]; + const int64_t N = X->Shape()[0]; + const int64_t num_output_channels_multiplier = w_in_nhwc ? w_shape[rank - 1] : w_shape[1]; + const int64_t num_output_channels = num_output_channels_multiplier * conv_transpose_attrs_.group; - // NCHW: C, M/g, k1 -> C, M/g, k1, 1 - // NHWC: C, k1, M/g -> C, k1, 1, M/g - w_dims.insert(w_dims.begin() + insert_at, 1); + if (conv_transpose_attrs_.group <= 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "group count is <= 0", + " group: ", conv_transpose_attrs_.group); } - } - { - std::lock_guard lock(s_.mutex); - // CUDNN_CONFIG_RETURN_IF_ERROR(cudnnSetStream(CudnnHandle(), Stream(context))); - // TODO: add a global cache if need to handle cases for multiple frames running simultaneously with different batch_size - bool input_dims_changed = (s_.last_x_dims.AsShapeVector() != x_dims); - bool w_dims_changed = (s_.last_w_dims.AsShapeVector() != w_dims); - if (input_dims_changed || w_dims_changed) { - if (input_dims_changed) { - s_.last_x_dims = gsl::make_span(x_dims); - } + if (X->Shape().NumDimensions() != w_shape.NumDimensions()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "X num_dims does not match W num_dims.", + " X: ", X->Shape().ToString().c_str(), + " W: ", w_shape.ToString().c_str()); + } - if (w_dims_changed) { - s_.last_w_dims = gsl::make_span(w_dims); - s_.cached_benchmark_results.clear(); - } + if (w_shape[0] != num_input_channels) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "filter number not equal to input channel number.", + " filter_number: ", w_shape[0], + " num_input_channels: ", num_input_channels); + } - ConvTransposeAttributes::Prepare p; - // PrePack moves the M/group dimension of W to the end, with 'M' being interpreted as 'output channels' - const bool transposed_input_channels = false; - ORT_RETURN_IF_ERROR( - conv_transpose_attrs_.PrepareForCompute(context, has_bias, p, dynamic_padding, &w_shape, NHWC, transposed_input_channels)); - - auto y_dims = p.Y->Shape().AsShapeVector(); - if (x_dimensions == 3) { - if (cuda_ep->GetCudnnConv1dPadToNc1d()) { - // add fake H dimension of 1 - // NCHW: N, M, d1 -> N, M, 1, d1 or - // NHWC: N, d1, M -> N, 1, d1, M - y_dims.insert(y_dims.begin() + (NHWC ? 1 : 2), 1); - p.kernel_shape.insert(p.kernel_shape.begin(), 1); - p.pads.insert(p.pads.begin(), 0); - p.pads.insert(p.pads.begin() + 2, 0); - p.strides.insert(p.strides.begin(), 1); - p.dilations.insert(p.dilations.begin(), 1); - } else { - // add fake W dimension of 1 - // NCHW: N, M, d1 -> N, M, d1, 1 or - // NHWC: N, d1, M -> N, d1, 1, M - y_dims.insert(y_dims.begin() + (NHWC ? 2 : 3), 1); - p.kernel_shape.push_back(1); - p.pads.insert(p.pads.begin() + 1, 0); - p.pads.push_back(0); - p.strides.push_back(1); - p.dilations.push_back(1); - } - } + // it looks like num_output_channels is really k*group similar to how in the conv case + // num_input_channels is k*group. hence removing the check for num_output_channels here. - s_.y_dims = gsl::make_span(y_dims); + if (num_input_channels % conv_transpose_attrs_.group != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input channels is not divisible by group.", + " num_input_channels: ", num_input_channels, + " group: ", conv_transpose_attrs_.group); + } - if (w_dims_changed) { - if constexpr (NHWC) { - ORT_RETURN_IF_ERROR(s_.w_desc.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), - static_cast(w_dims[0]), static_cast(w_dims[3]), - static_cast(w_dims[1]), static_cast(w_dims[2]))); - } else { - ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType())); - } - } + TensorShapeVector kernel_shape; + ORT_RETURN_IF_ERROR(conv_transpose_attrs_.ComputeKernelShape(w_shape, kernel_shape, w_in_nhwc)); - // Special case when there is a dim value of 0 in the shape. - // Return only after we have cached the following for subsequent runs : - // 1) `w_dims` in the `w_desc` - // 2) `y_dims` in s_.y_dims - if (p.Y->Shape().Size() == 0) { - return Status::OK(); - } + const size_t kernel_rank = kernel_shape.size(); - if constexpr (NHWC) { - ORT_RETURN_IF_ERROR(s_.x_tensor.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), - static_cast(x_dims[0]), static_cast(x_dims[3]), - static_cast(x_dims[1]), static_cast(x_dims[2]))); - ORT_RETURN_IF_ERROR(s_.y_tensor.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), - static_cast(y_dims[0]), static_cast(y_dims[3]), - static_cast(y_dims[1]), static_cast(y_dims[2]))); - } else { - ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims, CudnnTensor::GetDataType())); + TensorShapeVector local_output_padding(conv_transpose_attrs_.output_padding); + if (local_output_padding.empty()) { + local_output_padding.resize(kernel_shape.size(), 0); + } + ConvPadVector pads; + pads.reserve(2 * (input_shape.NumDimensions())); + if (dynamic_padding) { + for (int64_t i = 0; i < Pads->Shape().SizeFromDimension(0); ++i) { + pads.push_back(Pads->Data()[i]); } + } else { + pads.assign(conv_transpose_attrs_.pads.begin(), conv_transpose_attrs_.pads.end()); + } + if (pads.empty()) { + pads.resize(kernel_shape.size() * 2, 0); + } + TensorShapeVector dilations(conv_transpose_attrs_.dilations); + if (dilations.empty()) { + dilations.resize(kernel_shape.size(), 1); + } + TensorShapeVector strides(conv_transpose_attrs_.strides); + if (strides.empty()) { + strides.resize(kernel_shape.size(), 1); + } - cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION; - ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations, - gsl::narrow_cast(conv_transpose_attrs_.group), mode, - CudnnTensor::GetDataType(), - UseTF32())); - - if (has_bias) { - const auto& b_shape = p.B->Shape(); - ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D"); - TensorShapeVector b_dims(2 + p.kernel_shape.size()); - b_dims[0] = 1; // N - b_dims[NHWC ? 3 : 1] = b_shape[0]; // C - for (size_t i = 0; i < p.kernel_shape.size(); i++) { - b_dims[(NHWC ? 1 : 2) + i] = 1; - } - - ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType(), NHWC)); - } + TensorShapeVector y_dims; - y_data = reinterpret_cast(p.Y->MutableData()); - - if (!s_.cached_benchmark_results.contains(x_dims)) { - IAllocatorUniquePtr algo_search_workspace = - GetScratchBuffer(AlgoSearchWorkspaceSize, context->GetComputeStream()); - - // set math type to tensor core before algorithm search - if constexpr (std::is_same::value) { - CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); - } else if constexpr (std::is_same::value) { - if (!UseTF32()) { - CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH)); - } - } - - cudnnConvolutionBwdDataAlgoPerf_t perf; - int algo_count = 1; - CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardDataAlgorithmEx( - GetCudnnHandle(context), s_.w_desc, w_data, s_.x_tensor, x_data, s_.conv_desc, s_.y_tensor, y_data, 1, - &algo_count, &perf, algo_search_workspace.get(), AlgoSearchWorkspaceSize)); - s_.cached_benchmark_results.insert(x_dims, {perf.algo, perf.memory, perf.mathType}); - } + conv_transpose_attrs_.ComputePadsAndOutputShape(input_shape, num_output_channels, kernel_shape, + strides, dilations, local_output_padding, N, &pads, &y_dims, channels_last); - const auto& perf = s_.cached_benchmark_results.at(x_dims); - CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, perf.mathType)); - s_.algo = perf.algo; - s_.workspace_bytes = perf.memory; - } + s_.y_dims = gsl::make_span(y_dims); + s_.Y = context->Output(0, s_.y_dims); - // The following block will be executed in case there has been no change in the shapes of the - // input and the filter compared to the previous run - if (!y_data) { - auto y_dims = s_.y_dims.AsShapeVector(); - if (x_dimensions == 3) { - if (cuda_ep->GetCudnnConv1dPadToNc1d()) { - // erase the fake H dimension - y_dims.erase(y_dims.begin() + (NHWC ? 1 : 2)); - } else { - // erase the fake W dimension - y_dims.erase(y_dims.begin() + (NHWC ? 2 : 3)); - } - } + s_.y_data = reinterpret_cast(s_.Y->MutableData()); + const CUDAExecutionProvider* cuda_ep = + static_cast(this->Info().GetExecutionProvider()); - Tensor* Y = context->Output(0, TensorShape(y_dims)); - y_data = reinterpret_cast(Y->MutableData()); + TensorShapeVector x_dims_cudnn{x_dims.begin(), x_dims.end()}; + TensorShapeVector y_dims_cudnn{y_dims.begin(), y_dims.end()}; + TensorShapeVector w_dims_cudnn{w_dims.begin(), w_dims.end()}; - // Bail out early if one of the output dimensions is zero. - if (Y->Shape().Size() == 0) { - return Status::OK(); + if constexpr (channels_last) { + x_dims_cudnn.insert(x_dims_cudnn.begin() + 1, *(x_dims_cudnn.end() - 1)); + y_dims_cudnn.insert(y_dims_cudnn.begin() + 1, *(y_dims_cudnn.end() - 1)); + x_dims_cudnn.erase(x_dims_cudnn.end() - 1); + y_dims_cudnn.erase(y_dims_cudnn.end() - 1); + + if (w_in_nhwc) { + w_dims_cudnn.insert(w_dims_cudnn.begin() + 1, *(w_dims_cudnn.end() - 1)); + w_dims_cudnn.erase(w_dims_cudnn.end() - 1); } } - const auto alpha = Consts::One; - const auto beta = Consts::Zero; + if (kernel_rank < 2) { + // TODO: Explore padding the provided input shape [N, C, D] to [N, C, 1, D] + // especially for EXHAUSTIVE algo search which may result in a better algo selection. + // ORTModule uses different algo search options (HEURISTIC, and use max workspace size) compared to + // inference build (EXHAUSTIVE, 32M workspace size). We observed better perf when we pad input shape + // [N,C,D] to [N,C,1,D], expecially on A100, and especially for ConvGrad. + // PyTorch also pads to [N,C,1,D]. For inference build, we still pad it to [N, C, D, 1] as this seems + // to be the sweet spot for all algo search options: EXHAUSTIVE, HEURISTIC, and DEFAULT. + // See PR #7348 and #7702 for more context. + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1); + y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1); + w_dims_cudnn.insert(w_dims_cudnn.begin() + 2, 1); + pads.insert(pads.begin() + kernel_rank, 0); + pads.insert(pads.begin(), 0); + kernel_shape.insert(kernel_shape.begin(), 1); + strides.insert(strides.begin(), 1); + dilations.insert(dilations.begin(), 1); + } else { + x_dims_cudnn.push_back(1); + y_dims_cudnn.push_back(1); + w_dims_cudnn.push_back(1); + pads.insert(pads.begin() + kernel_rank, 0); + pads.insert(pads.end(), 0); + kernel_shape.push_back(1); + strides.push_back(1); + dilations.push_back(1); + } + } + + // We must delay returning early until here so that the weight dims have been cached properly + if (s_.Y->Shape().Size() == 0) { + return Status::OK(); + } - IAllocatorUniquePtr workspace = GetScratchBuffer(s_.workspace_bytes, context->GetComputeStream()); + auto handle = GetCudnnHandle(context); + + int cudnn_conv_algo = cuda_ep->GetCudnnConvAlgo(); +#if !defined(__CUDACC__) + cudnn_frontend::HeurMode_t heur_mode; + switch (cudnn_conv_algo) { + case 0: + heur_mode = cudnn_frontend::HeurMode_t::B; + break; + case 1: + heur_mode = cudnn_frontend::HeurMode_t::A; + break; + case 2: + heur_mode = cudnn_frontend::HeurMode_t::FALLBACK; + break; + default: + heur_mode = cudnn_frontend::HeurMode_t::A; + break; + } - CUDNN_RETURN_IF_ERROR(cudnnConvolutionBackwardData(GetCudnnHandle(context), &alpha, s_.w_desc, w_data, s_.x_tensor, - x_data, s_.conv_desc, s_.algo, workspace.get(), - s_.workspace_bytes, &beta, s_.y_tensor, y_data)); + auto use_tf32 = cuda_ep->UseTF32(); + const auto fuse_bias = cuda_ep->IsFuseConvBias() || is_fused_node_; + const auto fuse_act = is_fused_node_; + + ORT_RETURN_IF_ERROR(CreateCudnnFeExecutionPlan(x_dims_cudnn, w_dims_cudnn, B, y_dims_cudnn, handle, heur_mode, + std::vector(pads.begin(), + pads.end()), + std::vector(strides.begin(), + strides.end()), + std::vector(dilations.begin(), + dilations.end()), + fuse_bias, fuse_act, w_in_nhwc, use_tf32)); +#endif + } else { + // set Y + s_.Y = context->Output(0, s_.y_dims); + if (s_.Y->Shape().Size() == 0) { + return Status::OK(); + } + s_.y_data = reinterpret_cast(s_.Y->MutableData()); + } + return Status::OK(); +} - if (has_bias) { - const Tensor* B = dynamic_padding ? context->Input(3) : context->Input(2); - auto b_data = reinterpret_cast(B->Data()); - CUDNN_RETURN_IF_ERROR( - cudnnAddTensor(GetCudnnHandle(context), &alpha, s_.b_tensor, b_data, &alpha, s_.y_tensor, y_data)); +template +Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const { + std::lock_guard lock(s_.mutex); + ORT_RETURN_IF_ERROR(UpdateState(context, dynamic_padding)); + if (s_.Y->Shape().Size() == 0) { + return Status::OK(); + } + const auto alpha = onnxruntime::cuda::Consts::One; + auto cudnn_handle = GetCudnnHandle(context); +#if !defined(__CUDACC__) + s_.variant_pack.insert_or_assign(s_.cudnn_fe_X, const_cast(s_.x_data)); + s_.variant_pack.insert_or_assign(s_.cudnn_fe_W, const_cast(s_.w_data)); + s_.variant_pack.insert_or_assign(s_.cudnn_fe_Y, s_.y_data); + if (s_.bias_fused && s_.b_data != nullptr) { + s_.variant_pack.insert_or_assign(s_.cudnn_fe_B, const_cast(s_.b_data)); + } + if (s_.bias_fused && s_.z_data != nullptr) { + s_.variant_pack.insert_or_assign(s_.cudnn_fe_Z, const_cast(s_.z_data)); + if (Layout == LAYOUT_NCHW && s_.z_data == s_.y_data) { + // memset Z if it's required for a succesful fusion + CUDA_RETURN_IF_ERROR(cudaMemset(s_.y_data, 0, s_.Y->SizeInBytes())); } } + auto ws = GetWorkSpace(context->GetComputeStream()); + + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_graph->execute(cudnn_handle, + s_.variant_pack, + ws.get())); + + if (!s_.bias_fused && s_.z_data != nullptr) { + CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnn_handle, &alpha, s_.z_tensor, s_.z_data, + &alpha, s_.y_tensor, s_.y_data)); + } + if (!s_.bias_fused && s_.b_data != nullptr) { + CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnn_handle, &alpha, s_.b_tensor, s_.b_data, + &alpha, s_.y_tensor, s_.y_data)); + + /* For the standalone bias addition graph. + s_.variant_pack_bias.insert_or_assign(s_.cudnn_fe_bias_X, s_.y_data); + s_.variant_pack_bias.insert_or_assign(s_.cudnn_fe_bias_Y, s_.y_data); + s_.variant_pack_bias.insert_or_assign(s_.cudnn_fe_B, const_cast(s_.b_data)); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->execute(cudnn_handle, + s_.variant_pack_bias, + GetWorkSpace(context->GetComputeStream()).get()));*/ + } +#endif return Status::OK(); } +#endif + +template +Status ConvTranspose::ComputeInternal(OpKernelContext* context) const { + return DoConvTranspose(context, false); +} } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.h b/onnxruntime/core/providers/cuda/nn/conv_transpose.h index 71ad3ee6e2147..1a6957164d22f 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.h +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.h @@ -18,6 +18,8 @@ namespace cuda { template class ConvTranspose : public CudaKernel { public: + using CudaT = typename ToCudaType::MappedType; + ConvTranspose(const OpKernelInfo& info) : CudaKernel(info), conv_transpose_attrs_(info) {}; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) override; @@ -29,6 +31,33 @@ class ConvTranspose : public CudaKernel { mutable CudnnConvState s_; std::unique_ptr W_; + + bool is_nhwc_domain_; // prepack is only needed for the Conv in kMSInternalNHWCDomain + bool is_fused_node_ = false; // ensures the node is fused although the session option is not set + bool W_already_nhwc = false; // In case NHWC == true and Conv is not in kMSInternalNHWCDomain + + protected: + inline IAllocatorUniquePtr GetWorkSpace(onnxruntime::Stream* stream) const { + return GetScratchBuffer(s_.workspace_bytes, stream); + } + + Status UpdateState(OpKernelContext* context, bool bias_expected) const; + +#if !defined(__CUDACC__) && CUDNN_MAJOR >= 9 + Status CreateCudnnFeExecutionPlan(const onnxruntime::TensorShapeVector& x_dims, + const onnxruntime::TensorShapeVector& w_dims, + const Tensor* B, + const TensorShapeVector& y_dims, + cudnnContext* handle, + const cudnn_frontend::HeurMode_t heur_mode, + const std::vector& pads, + const std::vector& strides, + const std::vector& dilations, + const bool fuse_bias, + const bool fuse_act, + const bool w_in_nhwc, + const bool use_tf32) const; +#endif }; } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h b/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h new file mode 100644 index 0000000000000..b46d41b887e41 --- /dev/null +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h @@ -0,0 +1,266 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. +// Licensed under the MIT License. + +#include + +#include "conv_transpose.h" +#include + +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/cudnn_common.h" +#include "core/providers/cuda/nn/conv.h" +#include "core/providers/cpu/nn/conv_transpose_attributes.h" + +#include "core/providers/cuda/tensor/transpose.h" + +// To suppress FP static analyzer warnings: +// https://msdata.visualstudio.com/Vienna/_workitems/edit/1944928 and +// https://msdata.visualstudio.com/Vienna/_workitems/edit/1944950 +#ifdef _WIN32 +#pragma warning(push) +#pragma warning(disable : 26110) +#pragma warning(disable : 26117) +#endif + +namespace onnxruntime { +namespace cuda { + +template +Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const { + const Tensor* X = context->Input(0); + const TensorShape& x_shape = X->Shape(); + auto x_dims = x_shape.AsShapeVector(); + auto x_data = reinterpret_cast(X->Data()); + + auto x_dimensions = X->Shape().NumDimensions(); + if (x_dimensions < 3 || x_dimensions > 5) { + // TODO: the error message should tell which operator raises it. + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input X must be 3-, 4- or 5-dimensional.", + " X: ", X->Shape().ToString().c_str()); + } + + // use pre-packed W if available + const Tensor* W = W_ ? W_.get() : context->Input(1); + + const TensorShape& w_shape = W->Shape(); + TensorShapeVector w_dims = w_shape.AsShapeVector(); + auto w_data = reinterpret_cast(W->Data()); + + size_t num_inputs = OpKernel::Node().InputDefs().size(); + bool has_bias = dynamic_padding ? num_inputs == 4 : num_inputs == 3; + + CudaT* y_data = nullptr; + + const auto* cuda_ep = static_cast(Info().GetExecutionProvider()); + + // convert 1D to 2D + if (x_dimensions == 3) { + // we can either add a fake H or W dimension with a value of 1. to be consistent with the Conv behavior we use + // GetCudnnConv1dPadToNc1d to determine which is added. + // see Conv::UpdateState in /onnxruntime/core/providers/cuda/nn/conv.cc for more details. + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + // add fake H dimension + const auto insert_at = NHWC ? 1 : 2; + + // NCHW: N, C, d1 -> N, C, 1, d1 + // NHWC: N, d1, C -> N, 1, d1, C + x_dims.insert(x_dims.begin() + insert_at, 1); + + // 'M' is channels dim in CUDA implementation + // NCHW: C, M/g, k1 -> C, M/g, 1, k1 + // NHWC: C, k1, M/g -> C, 1, k1, M/g + w_dims.insert(w_dims.begin() + insert_at, 1); + } else { + // add fake W dimension + const auto insert_at = NHWC ? 2 : 3; + + // NCHW: N, C, d1 -> N, C, d1, 1 + // NHWC: N, d1, C -> N, d1, 1, C + x_dims.insert(x_dims.begin() + insert_at, 1); + + // NCHW: C, M/g, k1 -> C, M/g, k1, 1 + // NHWC: C, k1, M/g -> C, k1, 1, M/g + w_dims.insert(w_dims.begin() + insert_at, 1); + } + } + + { + std::lock_guard lock(s_.mutex); + // CUDNN_CONFIG_RETURN_IF_ERROR(cudnnSetStream(CudnnHandle(), Stream(context))); + // TODO: add a global cache if need to handle cases for multiple frames running simultaneously with + // different batch_size + bool input_dims_changed = (s_.last_x_dims.AsShapeVector() != x_dims); + bool w_dims_changed = (s_.last_w_dims.AsShapeVector() != w_dims); + if (input_dims_changed || w_dims_changed) { + if (input_dims_changed) { + s_.last_x_dims = gsl::make_span(x_dims); + } + + if (w_dims_changed) { + s_.last_w_dims = gsl::make_span(w_dims); + s_.cached_benchmark_results.clear(); + } + + ConvTransposeAttributes::Prepare p; + // PrePack moves the M/group dimension of W to the end, with 'M' being interpreted as 'output channels' + const bool transposed_input_channels = false; + ORT_RETURN_IF_ERROR( + conv_transpose_attrs_.PrepareForCompute(context, has_bias, p, dynamic_padding, + &w_shape, NHWC, transposed_input_channels)); + + auto y_dims = p.Y->Shape().AsShapeVector(); + if (x_dimensions == 3) { + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + // add fake H dimension of 1 + // NCHW: N, M, d1 -> N, M, 1, d1 or + // NHWC: N, d1, M -> N, 1, d1, M + y_dims.insert(y_dims.begin() + (NHWC ? 1 : 2), 1); + p.kernel_shape.insert(p.kernel_shape.begin(), 1); + p.pads.insert(p.pads.begin(), 0); + p.pads.insert(p.pads.begin() + 2, 0); + p.strides.insert(p.strides.begin(), 1); + p.dilations.insert(p.dilations.begin(), 1); + } else { + // add fake W dimension of 1 + // NCHW: N, M, d1 -> N, M, d1, 1 or + // NHWC: N, d1, M -> N, d1, 1, M + y_dims.insert(y_dims.begin() + (NHWC ? 2 : 3), 1); + p.kernel_shape.push_back(1); + p.pads.insert(p.pads.begin() + 1, 0); + p.pads.push_back(0); + p.strides.push_back(1); + p.dilations.push_back(1); + } + } + + s_.y_dims = gsl::make_span(y_dims); + + if (w_dims_changed) { + if constexpr (NHWC) { + ORT_RETURN_IF_ERROR(s_.w_desc.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), + static_cast(w_dims[0]), static_cast(w_dims[3]), + static_cast(w_dims[1]), static_cast(w_dims[2]))); + } else { + ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType())); + } + } + + // Special case when there is a dim value of 0 in the shape. + // Return only after we have cached the following for subsequent runs : + // 1) `w_dims` in the `w_desc` + // 2) `y_dims` in s_.y_dims + if (p.Y->Shape().Size() == 0) { + return Status::OK(); + } + + if constexpr (NHWC) { + ORT_RETURN_IF_ERROR(s_.x_tensor.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), + static_cast(x_dims[0]), static_cast(x_dims[3]), + static_cast(x_dims[1]), static_cast(x_dims[2]))); + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), + static_cast(y_dims[0]), static_cast(y_dims[3]), + static_cast(y_dims[1]), static_cast(y_dims[2]))); + } else { + ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims, CudnnTensor::GetDataType())); + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims, CudnnTensor::GetDataType())); + } + + cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION; + ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations, + gsl::narrow_cast(conv_transpose_attrs_.group), mode, + CudnnTensor::GetDataType(), + UseTF32())); + + if (has_bias) { + const auto& b_shape = p.B->Shape(); + ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D"); + TensorShapeVector b_dims(2 + p.kernel_shape.size()); + b_dims[0] = 1; // N + b_dims[NHWC ? 3 : 1] = b_shape[0]; // C + for (size_t i = 0; i < p.kernel_shape.size(); i++) { + b_dims[(NHWC ? 1 : 2) + i] = 1; + } + + ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType(), NHWC)); + } + + y_data = reinterpret_cast(p.Y->MutableData()); + + if (!s_.cached_benchmark_results.contains(x_dims)) { + IAllocatorUniquePtr algo_search_workspace = + GetScratchBuffer(AlgoSearchWorkspaceSize, context->GetComputeStream()); + + // set math type to tensor core before algorithm search + if constexpr (std::is_same::value) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); + } else if constexpr (std::is_same::value) { + if (!UseTF32()) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH)); + } + } + + cudnnConvolutionBwdDataAlgoPerf_t perf; + int algo_count = 1; + CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardDataAlgorithmEx( + GetCudnnHandle(context), s_.w_desc, w_data, s_.x_tensor, x_data, s_.conv_desc, s_.y_tensor, y_data, 1, + &algo_count, &perf, algo_search_workspace.get(), AlgoSearchWorkspaceSize)); + s_.cached_benchmark_results.insert(x_dims, {perf.algo, perf.memory, perf.mathType}); + } + + const auto& perf = s_.cached_benchmark_results.at(x_dims); + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, perf.mathType)); + s_.algo = perf.algo; + s_.workspace_bytes = perf.memory; + } + + // The following block will be executed in case there has been no change in the shapes of the + // input and the filter compared to the previous run + if (!y_data) { + auto y_dims = s_.y_dims.AsShapeVector(); + if (x_dimensions == 3) { + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + // erase the fake H dimension + y_dims.erase(y_dims.begin() + (NHWC ? 1 : 2)); + } else { + // erase the fake W dimension + y_dims.erase(y_dims.begin() + (NHWC ? 2 : 3)); + } + } + + Tensor* Y = context->Output(0, TensorShape(y_dims)); + y_data = reinterpret_cast(Y->MutableData()); + + // Bail out early if one of the output dimensions is zero. + if (Y->Shape().Size() == 0) { + return Status::OK(); + } + } + + const auto alpha = Consts::One; + const auto beta = Consts::Zero; + + IAllocatorUniquePtr workspace = GetScratchBuffer(s_.workspace_bytes, context->GetComputeStream()); + + CUDNN_RETURN_IF_ERROR(cudnnConvolutionBackwardData(GetCudnnHandle(context), &alpha, s_.w_desc, w_data, + s_.x_tensor, x_data, s_.conv_desc, s_.algo, workspace.get(), + s_.workspace_bytes, &beta, s_.y_tensor, y_data)); + + if (has_bias) { + const Tensor* B = dynamic_padding ? context->Input(3) : context->Input(2); + auto b_data = reinterpret_cast(B->Data()); + CUDNN_RETURN_IF_ERROR( + cudnnAddTensor(GetCudnnHandle(context), &alpha, s_.b_tensor, b_data, &alpha, s_.y_tensor, y_data)); + } + } + + return Status::OK(); +} + +} // namespace cuda +} // namespace onnxruntime + +#ifdef _WIN32 +#pragma warning(pop) +#endif diff --git a/onnxruntime/core/providers/js/operators/unsqueeze.h b/onnxruntime/core/providers/js/operators/unsqueeze.h index 7cbfdc38b742d..f15a3008895aa 100644 --- a/onnxruntime/core/providers/js/operators/unsqueeze.h +++ b/onnxruntime/core/providers/js/operators/unsqueeze.h @@ -26,8 +26,9 @@ class Unsqueeze final : public JsKernel, public UnsqueezeBase { if (num_inputs == 2) { // axes is an input const Tensor* axes_tensor = context->Input(1); ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); - ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, - "An axes tensor must be a vector tensor."); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 0 || + axes_tensor->Shape().NumDimensions() == 1, + "An axes tensor must be a scalar or a vector tensor."); auto nDims = static_cast(axes_tensor->Shape()[0]); const auto* data = axes_tensor->Data(); axes.assign(data, data + nDims); diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index be41b125e4440..4fca4037301fb 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -107,12 +108,15 @@ BackendManager::BackendManager(const GlobalContext& global_context, subgraph_context_, ep_ctx_handle_); } catch (const OnnxRuntimeException& ex) { + std::string exception_str = ex.what(); + bool eligible_for_cpu_fallback = device_type.find("NPU") != std::string::npos && + !GetGlobalContext().disable_cpu_fallback && + !ep_ctx_handle_.IsValidOVEPCtxGraph(); #if defined(OPENVINO_DISABLE_NPU_FALLBACK) - ORT_THROW(ex.what()); + eligible_for_cpu_fallback = false; #else - if (device_type.find("NPU") != std::string::npos && - !GetGlobalContext().disable_cpu_fallback) { - LOGS_DEFAULT(WARNING) << ex.what(); + if (eligible_for_cpu_fallback) { + LOGS_DEFAULT(VERBOSE) << exception_str; LOGS_DEFAULT(WARNING) << "Model compilation failed at OV NPU." << "Falling back to OV CPU for execution"; GetGlobalContext().device_type = "CPU"; @@ -125,10 +129,32 @@ BackendManager::BackendManager(const GlobalContext& global_context, } catch (std::string const& msg) { ORT_THROW(msg); } - } else { - ORT_THROW(ex.what()); } #endif + if (!eligible_for_cpu_fallback) { + if (device_type.find("NPU") != std::string::npos && + exception_str.find("intel_npu") != std::string::npos) { + // Handle NPU device related errors +#ifndef NDEBUG + ORT_THROW(exception_str + "\nModel needs to be recompiled\n"); +#else + std::string error_message = "UNKNOWN NPU ERROR"; + std::string error_code = "code 0x0"; + std::regex error_message_pattern(R"(\bZE_\w*\b)"); + std::regex error_code_pattern("code 0x[0-9a-fA-F]+"); + std::smatch matches; + if (std::regex_search(exception_str, matches, error_message_pattern)) { + error_message = matches[0]; + } + if (std::regex_search(exception_str, matches, error_code_pattern)) { + error_code = matches[0]; + } + throw std::runtime_error(error_message + ", " + error_code + "\nModel needs to be recompiled\n"); +#endif + } else { + ORT_THROW(exception_str); + } + } } } if (global_context_.export_ep_ctx_blob && !ep_ctx_handle_.IsValidOVEPCtxGraph()) { diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 8d340e2daf4b5..1f9c61780f27a 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -48,14 +48,6 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr // Set the inference_num_threads property of the CPU SetNumThreads(device_config); -#ifndef NDEBUG - if (IsDebugEnabled()) { - std::string file_name = subgraph_context.subgraph_name + "_static.onnx"; - std::fstream outfile(file_name, std::ios::out | std::ios::trunc | std::ios::binary); - model_proto->SerializeToOstream(outfile); - } -#endif - try { std::string dev_prec = global_context.device_type + "_" + global_context_.precision_str; @@ -180,6 +172,11 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { device_property = std::make_pair("NPU_COMPILER_TYPE", env_npu_compiler_type); } device_config.emplace(ov::device::properties("NPU", device_property)); +#if (OPENVINO_VERSION_MAJOR >= 2024) && (OPENVINO_VERSION_MINOR > 3) + if (global_context_.export_ep_ctx_blob) { + global_context_.ie_core.Get().set_property("NPU", ov::intel_npu::bypass_umd_caching(true)); + } +#endif } } @@ -295,16 +292,104 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque ORT_THROW(msg); } } else { - OVTensorPtr graph_input_blob; - try { - graph_input_blob = infer_request->GetTensor(input_name); - } catch (const char* msg) { - ORT_THROW(msg); + if ((global_context_.device_type.find("CPU") != std::string::npos || + global_context_.device_type.find("GPU") != std::string::npos)) { + OVTensorPtr graph_input_blob; + try { + graph_input_blob = infer_request->GetTensor(input_name); + } catch (const char* msg) { + ORT_THROW(msg); + } + FillInputBlob(std::move(graph_input_blob), batch_slice_idx, std::move(input_name), context, subgraph_context_); + } else { + auto tensor = context.GetInput(subgraph_context_.input_names.at(input_name)); + auto allocator_name = tensor.GetTensorMemoryInfo().GetAllocatorName(); + ov_tensor_data_t ov_tensor_key; + ort_tensor_key_t ort_tensor_key{tensor.GetTensorRawData(), allocator_name}; + if (const auto& it = ort_ov_tensor_map.find(ort_tensor_key); it != ort_ov_tensor_map.end()) { + ov_tensor_key = it->second; + } else { + // Does this make sense for both types of allocators? + auto input = graph_input_info.at(input_idx); + if (allocator_name == OpenVINO_RT_NPU) { + ov_tensor_key.copy_needed = false; + ov_tensor_key.tensor_ptr = std::make_shared(input.get_element_type(), input.get_shape(), + (void*)tensor.GetTensorRawData()); + } else { + ov_tensor_key.copy_needed = true; + ov_tensor_key.tensor_ptr = std::make_shared(input.get_element_type(), input.get_shape()); + } + ort_ov_tensor_map.emplace(ort_tensor_key, ov_tensor_key); + + if (ov_tensor_key.copy_needed) { + const char* ort_tensor_data = tensor.GetTensorData(); + size_t tensor_data_size = ov_tensor_key.tensor_ptr->get_byte_size(); + auto ort_batch_memory_offset = ort_tensor_data + tensor_data_size * batch_slice_idx; + std::memcpy(ov_tensor_key.tensor_ptr->data(), ort_batch_memory_offset, tensor_data_size); + } + + try { + infer_request->SetTensor(input_name, ov_tensor_key.tensor_ptr); + } catch (const char* msg) { + ORT_THROW(msg); + } + } } - FillInputBlob(std::move(graph_input_blob), batch_slice_idx, std::move(input_name), context, subgraph_context_); } input_idx++; } + if (global_context_.device_type.find("NPU") != std::string::npos) { + // Set the output blob as remote blob + auto graph_output_info = exe_network_.Get().outputs(); + auto output_idx = 0; + for (auto output_info_iter = graph_output_info.begin(); + output_info_iter != graph_output_info.end(); ++output_info_iter) { + auto output_names = output_info_iter->get_names(); + std::string onnx_output_name; + std::string output_name; + // using the output name retrieved from ONNX original to match with the output names returned by OV tensors + for (auto it = subgraph_context_.output_names.begin(); it != subgraph_context_.output_names.end(); ++it) { + onnx_output_name = it->first; + if (output_names.find(onnx_output_name) != output_names.end()) { + // Assigning the output_name + output_name = it->first; + break; + } + } + size_t batch_size = 1; + Ort::UnownedValue tensor = GetOutputTensor(context, + batch_size, + infer_request, + output_name, + subgraph_context_.output_names); + auto allocator_name = tensor.GetTensorMemoryInfo().GetAllocatorName(); + + ov_tensor_data_t ov_tensor_data; + ort_tensor_key_t ort_tensor_key{tensor.GetTensorRawData(), allocator_name}; + if (const auto& it = ort_ov_tensor_map.find(ort_tensor_key); it != ort_ov_tensor_map.end()) { + ov_tensor_data = it->second; + } else { + auto output = graph_output_info.at(output_idx); + if (allocator_name == OpenVINO_RT_NPU) { + ov_tensor_data.copy_needed = false; + ov_tensor_data.tensor_ptr = std::make_shared(output.get_element_type(), output.get_shape(), + (void*)tensor.GetTensorRawData()); + } else { + ov_tensor_data.copy_needed = true; + ov_tensor_data.tensor_ptr = std::make_shared(output.get_element_type(), output.get_shape()); + } + ort_ov_tensor_map.emplace(ort_tensor_key, ov_tensor_data); + + try { + infer_request->SetTensor(output_name, ov_tensor_data.tensor_ptr); + } catch (const char* msg) { + ORT_THROW(msg); + } + } + output_idx++; + } + } + // Start Async inference infer_request->StartAsync(); } catch (const char* msg) { @@ -454,20 +539,42 @@ void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRe " doesn't exist in the " "list of OpenVINO output tensor names"); } - try { - graph_output_blob = infer_request->GetTensor(output_name); - } catch (const char* msg) { - ORT_THROW(msg); - } - size_t batch_size = 1; - Ort::UnownedValue output_tensor = - GetOutputTensor(context, batch_size, infer_request, std::move(output_name), subgraph_context_.output_names); - auto mem_info = output_tensor.GetTensorMemoryInfo(); - if (mem_info.GetAllocatorName() == OpenVINO_GPU) { - return; + if ((global_context_.device_type.find("CPU") != std::string::npos || + global_context_.device_type.find("GPU") != std::string::npos)) { + try { + graph_output_blob = infer_request->GetTensor(output_name); + } catch (const char* msg) { + ORT_THROW(msg); + } + size_t batch_size = 1; + Ort::UnownedValue output_tensor = + GetOutputTensor(context, batch_size, infer_request, std::move(output_name), subgraph_context_.output_names); + auto mem_info = output_tensor.GetTensorMemoryInfo(); + if (mem_info.GetAllocatorName() == OpenVINO_GPU) { + return; + } else { + size_t batch_slice = 0; + FillOutputBlob(std::move(graph_output_blob), output_tensor, batch_slice); + } } else { - size_t batch_slice = 0; - FillOutputBlob(std::move(graph_output_blob), output_tensor, batch_slice); + size_t batch_size = 1; + Ort::UnownedValue output_tensor = + GetOutputTensor(context, batch_size, infer_request, std::move(output_name), subgraph_context_.output_names); + auto allocator_name = output_tensor.GetTensorMemoryInfo().GetAllocatorName(); + ov_tensor_data_t ov_tensor_data; + ort_tensor_key_t ort_tensor_key{output_tensor.GetTensorRawData(), allocator_name}; + if (const auto& it = ort_ov_tensor_map.find(ort_tensor_key); it != ort_ov_tensor_map.end()) { + ov_tensor_data = it->second; + } else { + ORT_THROW(log_tag + "Expected all outputs to have associated OV::Tensor's"); + } + + if (ov_tensor_data.copy_needed) { + auto ort_tensor_data = output_tensor.GetTensorMutableData(); + size_t tensor_data_size = ov_tensor_data.tensor_ptr->get_byte_size(); + auto ort_batch_memory_offset = ort_tensor_data /*+ tensor_data_size * batch_size*/; + std::memcpy(ort_batch_memory_offset, ov_tensor_data.tensor_ptr->data(), tensor_data_size); + } } } diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h index cd242a06b27d4..cd69e88f994b9 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.h +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h @@ -11,6 +11,7 @@ #include #include #include +#include #include "core/session/onnxruntime_cxx_api.h" #include "core/providers/openvino/contexts.h" @@ -20,6 +21,11 @@ namespace onnxruntime { namespace openvino_ep { +struct ov_tensor_data_t { + OVTensorPtr tensor_ptr; + bool copy_needed; +}; + class InferRequestsQueue; class BasicBackend : public IBackend { public: @@ -60,6 +66,9 @@ class BasicBackend : public IBackend { #if defined IO_BUFFER_ENABLED OVRemoteContextPtr remote_context_; #endif + + using ort_tensor_key_t = std::pair; + std::map ort_ov_tensor_map; }; class InferRequestsQueue { diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 29c45916795d3..08144651319cf 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -10,6 +10,9 @@ #include "core/providers/openvino/onnx_ctx_model_helper.h" #include "core/providers/openvino/ov_versions/capability.h" #include "openvino/core/version.hpp" +#ifdef USE_OVEP_NPU_MEMORY +#include "core/providers/openvino/ov_allocator.h" +#endif #define MEMCPY_S(dest, src, destsz, srcsz) memcpy(dest, src, std::min(destsz, srcsz)) @@ -180,4 +183,18 @@ common::Status OpenVINOExecutionProvider::Compile( return Status::OK(); } +#ifdef USE_OVEP_NPU_MEMORY +std::vector OpenVINOExecutionProvider::CreatePreferredAllocators() { + AllocatorCreationInfo npu_allocator_info{ + [this](OrtDevice::DeviceId device_id) { + return std::make_unique(global_context_->ie_core.Get(), OrtDevice::NPU, device_id, OpenVINO_RT_NPU); + }, + 0, + }; + + // fill in allocator + return std::vector{CreateAllocator(npu_allocator_info)}; +} +#endif + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.h b/onnxruntime/core/providers/openvino/openvino_execution_provider.h index 030e5bba71b67..8b1c62c607f6e 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.h +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.h @@ -189,7 +189,9 @@ class OpenVINOExecutionProvider : public IExecutionProvider { const void* GetExecutionHandle() const noexcept override { return nullptr; } - +#ifdef USE_OVEP_NPU_MEMORY + std::vector CreatePreferredAllocators() override; +#endif private: std::unique_ptr global_context_; openvino_ep::EPCtxHandler ep_ctx_handle_{}; diff --git a/onnxruntime/core/providers/openvino/ov_allocator.cc b/onnxruntime/core/providers/openvino/ov_allocator.cc new file mode 100644 index 0000000000000..6700244b754d8 --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_allocator.cc @@ -0,0 +1,55 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License +#ifdef USE_OVEP_NPU_MEMORY +#include "core/providers/openvino/ov_allocator.h" +#include "core/providers/openvino/ov_interface.h" +#include "openvino/runtime/intel_npu/level_zero/level_zero.hpp" +#include "openvino/runtime/intel_npu/properties.hpp" + +namespace onnxruntime { + +using namespace openvino_ep; + +constexpr size_t default_alignment = 4096; + +static inline size_t align_up(size_t size, size_t pow2_alignment) { + return (size + pow2_alignment - 1) & ~(pow2_alignment - 1); +} + +OVRTAllocator::OVRTAllocator(ov::Core& core, OrtDevice::DeviceType device_type, OrtDevice::DeviceId device_id, const char* name) : IAllocator(OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(device_type, OrtDevice::MemType::DEFAULT, device_id), device_id, OrtMemTypeCPUInput)), core_(core) { + if (device_type == OrtDevice::NPU) { + remote_ctx_ = core_.get_default_context("NPU").as(); + } else { + ORT_THROW("Invalid device type"); + } +} + +void* OVRTAllocator::Alloc(size_t size) { + try { + size_t alloc_size = align_up(size + sizeof(ov::Tensor*) + default_alignment, default_alignment); + ov::Tensor* tensor = new ov::Tensor(remote_ctx_.create_host_tensor(ov::element::Type_t::u8, + {alloc_size})); + uintptr_t data_ptr = reinterpret_cast(tensor->data()); + + ov::Tensor** ptr = reinterpret_cast(align_up(data_ptr + sizeof(ov::Tensor*), default_alignment)); + ptr[-1] = tensor; + + return reinterpret_cast(ptr); + + } catch (const ov::Exception& e) { + ORT_THROW(std::string("Alloc failed: ") + e.what()); + } + return nullptr; +} + +void OVRTAllocator::Free(void* p) { + try { + ov::Tensor** ptr = reinterpret_cast(p); + delete ptr[-1]; + } catch (const ov::Exception& e) { + ORT_THROW(std::string("Free failed: ") + e.what()); + } +} + +} // namespace onnxruntime +#endif diff --git a/onnxruntime/core/providers/openvino/ov_allocator.h b/onnxruntime/core/providers/openvino/ov_allocator.h new file mode 100644 index 0000000000000..083cfc4d5aed3 --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_allocator.h @@ -0,0 +1,24 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License +#ifdef USE_OVEP_NPU_MEMORY +#pragma once + +#include "core/common/inlined_containers.h" +#include "core/framework/allocator.h" +#include "openvino/runtime/remote_context.hpp" + +namespace onnxruntime { + +class OVRTAllocator : public IAllocator { + public: + OVRTAllocator(ov::Core& core, OrtDevice::DeviceType device_type, OrtDevice::DeviceId device_id, const char* name); + void* Alloc(size_t size) override; + void Free(void* p) override; + + private: + ov::Core& core_; + ov::RemoteContext remote_ctx_; +}; + +} // namespace onnxruntime +#endif diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index fa22e0f3cb03d..f4da4ea3e3244 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -10,6 +10,7 @@ #include #include "openvino/openvino.hpp" +#include "openvino/runtime/intel_npu/properties.hpp" #include "openvino/pass/convert_fp32_to_fp16.hpp" #include "openvino/frontend/manager.hpp" diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc index 35a6b7bf40637..5c4608dff9bb1 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc @@ -87,9 +87,9 @@ Status LayerNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[BIAS_IDX], logger, input_names)); } -#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR == 17 || QNN_API_VERSION_MINOR == 18) +#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR == 17 || QNN_API_VERSION_MINOR == 18 || QNN_API_VERSION_MINOR == 19) if (!has_bias_input && IsNpuBackend(qnn_model_wrapper.GetQnnBackendType())) { - // Bias is implicit. QNN SDK 2.24/2.25 (QNN API version 2.17/2.18) has a validation bug for implicit bias inputs, + // Bias is implicit. QNN SDK 2.24/2.25/2.26 (QNN API version 2.17/2.18/2.19) has a validation bug for implicit bias inputs, // so provide an explicit bias of all 0 (quantized int32). TensorInfo x_input_info = {}; ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[X_IDX], x_input_info)); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index 3c029fda9cd52..2c7f3c8b22ddd 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -308,8 +308,10 @@ bool QnnModelWrapper::GetOnnxShape(const NodeArg& node_arg, std::vectordim()) { + if (!dim.has_dim_value()) { + return false; // Do not support dynamic shapes. + } shape.push_back(SafeInt(dim.dim_value())); } diff --git a/onnxruntime/core/providers/shared/utils/utils.cc b/onnxruntime/core/providers/shared/utils/utils.cc index 2088618538de5..5b2f2c1fa1b2e 100644 --- a/onnxruntime/core/providers/shared/utils/utils.cc +++ b/onnxruntime/core/providers/shared/utils/utils.cc @@ -192,6 +192,15 @@ std::vector NodeAttrHelper::Get(const std::string& key, const std::vect return def_val; } +std::vector NodeAttrHelper::Get(const std::string& key, const std::vector& def_val) const { + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + const auto& values = entry->second.strings(); + return std::vector{values.cbegin(), values.cend()}; + } + + return def_val; +} + std::vector NodeAttrHelper::Get(const std::string& key, const std::vector& def_val) const { if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { const auto& values = entry->second.floats(); diff --git a/onnxruntime/core/providers/shared/utils/utils.h b/onnxruntime/core/providers/shared/utils/utils.h index 5813dcc48d72b..ddbae42534711 100644 --- a/onnxruntime/core/providers/shared/utils/utils.h +++ b/onnxruntime/core/providers/shared/utils/utils.h @@ -57,6 +57,7 @@ class NodeAttrHelper { std::vector Get(const std::string& key, const std::vector& def_val) const; const std::string& Get(const std::string& key, const std::string& def_val) const; + std::vector Get(const std::string& key, const std::vector& def_val) const; // Convert the i() or ints() of the attribute from int64_t to int32_t int32_t Get(const std::string& key, int32_t def_val) const; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 43ab08f62af4e..288af2ed27e5f 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2255,16 +2255,29 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect network_flags |= fp16_enable_ || int8_enable_ ? 0 : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); #endif network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); - auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); + auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif + +#if (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 1) || NV_TENSORRT_MAJOR > 10 + auto is_model_supported = trt_parser->supportsModelV2(string_buf.data(), string_buf.size(), model_path_); + + // Note: Calling getNbSubgraphs or getSubgraphNodes before calling supportsModelV2 results in undefined behavior. + auto num_subgraphs = trt_parser->getNbSubgraphs(); + parser_nodes_list.reserve(num_subgraphs); + + for (int64_t i = 0; i < num_subgraphs; ++i) { + int64_t subgraph_len = 0; + int64_t* nodes = trt_parser->getSubgraphNodes(i, subgraph_len); + parser_nodes_list.emplace_back(); + parser_nodes_list.back().first.reserve(subgraph_len); + for (int64_t j = 0; j < subgraph_len; ++j) { + parser_nodes_list.back().first.push_back(nodes[j]); + } + parser_nodes_list.back().second = is_model_supported ? true : false; + } +#else trt_parser->supportsModel(string_buf.data(), string_buf.size(), parser_nodes_list, model_path_); -#if defined(_MSC_VER) -#pragma warning(pop) #endif SubGraphCollection_t next_nodes_list; @@ -2272,6 +2285,24 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect next_nodes_list = GetSupportedList(parser_nodes_list, iterations, max_iterations, *graph_viewer, early_termination); for (size_t i = 0, end = next_nodes_list.size(); i < end; ++i) { for (size_t j = 0, end = next_nodes_list[i].first.size(); j < end; ++j) { + /* + * Convert the supported node list returning from onnx-tensorrt parser to the node list recognized by ORT TRT. + * + * TRT EP reconstructs the graph based on the nodes in group.first and feeds this graph (converts to model proto and to string buffer) to onnx-tensorrt parser. + * The node index in the list returning from onnx-tensorrt parser might not be the same as the node index in group.first. Therefore, TRT EP needs a node index mapping table here. + * + * The order of iterating the nodes in group.first and calling graph_build.AddNode() determines the node order in the newly constructed graph (see Graph::AllocateNode() in graph.cc), + * however, once the graph is converted to model proto, the node proto order in model proto (ex: onnx-tensorrt calls model.graph().node() to iterate NodeProto in ModelProto) is decided by topo sort. + * + * The topo sort list (i.e. subgraph_node_index) acts as the node index mapping table: + * subgraph_node_index[node index from onnx-tensorrt parser] = index in group.first + * + * In the past, TRT EP uses ORT's default reversed DFS topo sort which might end up with the sorting result not sequence of 0, 1, ... n-1, ex: the subgraph_node_index = [0,2,1,3,4]. + * With the change of using ORT's priority-based topo sort (node with lower node index outputs first) the sorting result is the sequence of 0, 1, ... n-1 for most of the cases, + * therefore subgraph_node_index as a mapping table is not needed anymore. + * + * TODO: Remove the subgraph_node_index + */ next_nodes_list[i].first[j] = group.first[subgraph_node_index[next_nodes_list[i].first[j]]]; } nodes_list_output.push_back(next_nodes_list[i]); diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc index 4b2b7610cf7ea..872d022e85264 100644 --- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc +++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc @@ -9,9 +9,44 @@ #include "core/providers/shared_library/provider_api.h" namespace vaip { using namespace onnxruntime; + +static gsl::span process_ext_address(const ONNX_NAMESPACE::TensorProto& tensor) { + auto tensor_proto = const_cast(&tensor); + auto file = std::string(); + uintptr_t offset = 0; + size_t size = 0; + if (tensor_proto->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL) { + auto external_data = tensor_proto->mutable_external_data(); + auto external_data_size = external_data->size(); + for (auto i = 0; i < external_data_size; ++i) { + auto& data = external_data->at(i); + char* end = nullptr; + if (*data.mutable_key() == "location") { + file = *data.mutable_value(); + } else if (*data.mutable_key() == "offset") { + offset = (uintptr_t)std::strtoull(data.mutable_value()->data(), &end, 10); + } else if (*data.mutable_key() == "length") { + size = (size_t)std::strtoull(data.mutable_value()->data(), &end, 10); + } else if (*data.mutable_key() == "checksum") { + // checksum = (size_t)std::strtoull(data.mutable_value()->data(), &end, 10); + } + } + if (file == "*/_ORT_MEM_ADDR_/*") { + auto addr = reinterpret_cast(offset); + return {addr, size}; + } + } + return {}; +} + gsl::span tensor_proto_as_raw(const onnxruntime::Graph& graph, const ONNX_NAMESPACE::TensorProto& tensor) { auto& mut_tensor = const_cast(tensor); if (!tensor.has_raw_data()) { + auto maybe_external_memory_address = process_ext_address(tensor); + if (!maybe_external_memory_address.empty()) { + return maybe_external_memory_address; + } + std::vector unpacked_tensor; auto path = graph.ModelPath(); auto s = onnxruntime::utils::UnpackInitializerData(tensor, path, unpacked_tensor); diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index d3c1d06818db2..c4a633fcc92bb 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -45,12 +45,12 @@ bool GetShape(const NodeArg& node_arg, std::vector& shape, const loggin return true; } -bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, - const WebnnDeviceType device_type, const logging::Logger& logger) { +bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const WebnnDeviceType device_type, + const emscripten::val& wnn_limits, const logging::Logger& logger) { const auto& op_builders = GetOpBuilders(); if (Contains(op_builders, node.OpType())) { const auto* op_builder = op_builders.at(node.OpType()); - return op_builder->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, device_type, logger); + return op_builder->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, device_type, wnn_limits, logger); } else { return false; } @@ -86,6 +86,7 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, cons std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, const emscripten::val& wnn_builder, const WebnnDeviceType device_type, + const emscripten::val& wnn_limits, const logging::Logger& logger) { std::vector> supported_node_groups; @@ -105,7 +106,7 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v // Firstly check if platform supports the WebNN op. if (CheckSingleOp(node->OpType(), wnn_builder, device_type)) { LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() << "] is supported by browser"; - supported = IsNodeSupported(*node, graph_viewer, device_type, logger); + supported = IsNodeSupported(*node, graph_viewer, device_type, wnn_limits, logger); } LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() @@ -130,10 +131,54 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v return supported_node_groups; } -bool IsSupportedDataType(const int32_t data_type, - const std::unordered_set& supported_data_types) { - return std::find(supported_data_types.begin(), supported_data_types.end(), data_type) != - supported_data_types.end(); +bool AreInputDataTypesSame(const std::string& op_type, + gsl::span input_types, + const logging::Logger& logger) { + for (size_t i = 1; i < input_types.size(); i++) { + if (input_types[0] != input_types[i]) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input data types should be the same, but [" + << input_types[0] << "] does not match " + << input_types[i] << "]."; + return false; + } + } + return true; +} + +bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types) { + auto it = onnx_to_webnn_data_type_map.find(static_cast(onnx_data_type)); + if (it == onnx_to_webnn_data_type_map.end()) + return false; + + std::string webnn_data_type = it->second; + + // Check if WebNN supports the data type. + emscripten::val is_supported = webnn_supported_data_types.call("includes", + emscripten::val(webnn_data_type)); + return is_supported.as(); +} + +// Check if the input or output data type of ONNX node is supported by the WebNN operator. +bool IsDataTypeSupportedByOp(const std::string& onnx_op_type, + const int32_t onnx_data_type, + const emscripten::val& wnn_limits, + const std::string& webnn_input_output_name, + const std::string& onnx_input_output_name, + const logging::Logger& logger) { + std::string webnn_op_type; + if (!GetWebNNOpType(onnx_op_type, webnn_op_type)) + return false; + + if (!IsSupportedDataType(onnx_data_type, wnn_limits[webnn_op_type][webnn_input_output_name]["dataTypes"])) { + LOGS(logger, VERBOSE) << "[" << onnx_op_type + << "] " << onnx_input_output_name + << " type: [" << onnx_data_type + << "] is not supported for now"; + return false; + } + + return true; } bool GetBidirectionalBroadcastShape(std::vector& shape_a, diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 4d723a3c59ee2..dd4a8acc662ef 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -148,6 +148,7 @@ bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, c std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, const emscripten::val& wnn_builder, const WebnnDeviceType device_type, + const emscripten::val& wnn_limits, const logging::Logger& logger); static const InlinedHashMap op_map = { {"Abs", "abs"}, @@ -183,6 +184,7 @@ static const InlinedHashMap op_map = { {"GlobalLpPool", "l2Pool2d"}, {"Greater", "greater"}, {"GreaterOrEqual", "greaterOrEqual"}, + {"Gru", "gru"}, {"HardSigmoid", "hardSigmoid"}, {"HardSwish", "hardSwish"}, {"Identity", "identity"}, @@ -204,6 +206,7 @@ static const InlinedHashMap op_map = { {"Pad", "pad"}, {"Pow", "pow"}, {"PRelu", "prelu"}, + {"QuantizeLinear", "quantizeLinear"}, {"Reciprocal", "reciprocal"}, {"ReduceL1", "reduceL1"}, {"ReduceL2", "reduceL2"}, @@ -249,20 +252,38 @@ inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn return true; } -static const std::unordered_set webnn_supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_BOOL, - ONNX_NAMESPACE::TensorProto_DataType_INT8, - ONNX_NAMESPACE::TensorProto_DataType_UINT8, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_INT32, - ONNX_NAMESPACE::TensorProto_DataType_INT64, - ONNX_NAMESPACE::TensorProto_DataType_UINT32, - ONNX_NAMESPACE::TensorProto_DataType_UINT64, +inline bool GetWebNNOpType(const std::string& op_type, std::string& webnn_op_type) { + auto it = op_map.find(op_type); + // Returns false if the op_type is not listed in the op_map. + if (it == op_map.end()) { + return false; + } + webnn_op_type = it->second; + return true; +} + +static const InlinedHashMap onnx_to_webnn_data_type_map = { + {ONNX_NAMESPACE::TensorProto_DataType_BOOL, "uint8"}, + {ONNX_NAMESPACE::TensorProto_DataType_INT8, "int8"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT8, "uint8"}, + {ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, "float16"}, + {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, "float32"}, + {ONNX_NAMESPACE::TensorProto_DataType_INT32, "int32"}, + {ONNX_NAMESPACE::TensorProto_DataType_INT64, "int64"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT32, "uint32"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT64, "uint64"}, }; -bool IsSupportedDataType(const int32_t data_type, - const std::unordered_set& supported_data_types); +bool AreInputDataTypesSame(const std::string& op_type, + gsl::span input_types, + const logging::Logger& logger); +bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types); +bool IsDataTypeSupportedByOp(const std::string& onnx_op_type, + const int32_t onnx_data_type, + const emscripten::val& wnn_limits, + const std::string& webnn_input_output_name, + const std::string& onnx_input_output_name, + const logging::Logger& logger); bool GetBidirectionalBroadcastShape(std::vector& shape_a, std::vector& shape_b, diff --git a/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc index 626aaf5c71b74..781ddcb896155 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc @@ -21,8 +21,6 @@ class ActivationOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -94,44 +92,6 @@ bool ActivationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initi return true; } -bool ActivationOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types; - // WebNN relu op supports float32, float16, int32, int8 input data types. - if (op_type == "Relu") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - ONNX_NAMESPACE::TensorProto_DataType_INT32, - ONNX_NAMESPACE::TensorProto_DataType_INT8, - }; - // WebNN CPU backend does not support int32 data type for relu. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32); - } - } else { // Others only support float32 and float16. - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc index 05f3a742a3775..d61ae1a1f6be7 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc @@ -22,8 +22,6 @@ class ArgMaxMinOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -77,31 +75,6 @@ bool ArgMaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initia return true; } -bool ArgMaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support int64, uint64 input data types for argMax and argMin. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT64); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateArgMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc index fa535889299ea..8da255a288f17 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc @@ -38,9 +38,9 @@ bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { ORT_RETURN_IF_NOT( - IsOpSupported(model_builder.GetInitializerTensors(), node, model_builder.GetWebnnDeviceType(), logger), - "Unsupported operator ", - node.OpType()); + IsOpSupported(model_builder.GetInitializerTensors(), node, model_builder.GetWebnnDeviceType(), + model_builder.GetOpSupportLimits(), logger), + "Unsupported operator ", node.OpType()); ORT_RETURN_IF_ERROR(AddToModelBuilderImpl(model_builder, node, logger)); LOGS(logger, VERBOSE) << "Operator name: [" << node.Name() << "] type: [" << node.OpType() << "] was added"; @@ -50,8 +50,12 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& // Operator support related. bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const { - if (!HasSupportedInputs(node, device_type, logger)) + const WebnnDeviceType device_type, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + if (!HasSupportedInputs(node, wnn_limits, logger)) + return false; + + if (!HasSupportedOutputsImpl(node, wnn_limits, logger)) return false; // We do not support external initializers for now. @@ -64,7 +68,7 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons return IsOpSupportedImpl(initializers, node, device_type, logger); } -bool BaseOpBuilder::HasSupportedInputs(const Node& node, const WebnnDeviceType device_type, +bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]"); for (const auto* input : node.InputDefs()) { @@ -73,39 +77,33 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const WebnnDeviceType d } } - // WebNN CPU backend (TFLite) will enable float16 input data type soon, - // temporarily fallback float16 input data type for WebNN CPU. - if (device_type == WebnnDeviceType::CPU) { - const auto& input = *node.InputDefs()[0]; - - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) - return false; - } - - return HasSupportedInputsImpl(node, device_type, logger); + return HasSupportedInputsImpl(node, wnn_limits, logger); } bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, - const WebnnDeviceType /* device_type */, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { // We only check the type of input 0 by default, specific op builder can override this. const auto& input = *node.InputDefs()[0]; - + const auto& op_type = node.OpType(); int32_t input_type; if (!GetType(input, input_type, logger)) return false; - if (!IsSupportedDataType(input_type, webnn_supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << node.OpType() - << "] Input type: [" << input_type - << "] is not supported for now"; + return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "Input", logger); +} + +bool BaseOpBuilder::HasSupportedOutputsImpl(const Node& node, + const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + // We only check the type of output 0 by default, specific op builder can override this. + const auto& output = *node.OutputDefs()[0]; + const auto& op_type = node.OpType(); + int32_t output_type; + if (!GetType(output, output_type, logger)) return false; - } - return true; + return IsDataTypeSupportedByOp(op_type, output_type, wnn_limits, "output", "Output", logger); } bool BaseOpBuilder::HasSupportedOpSet(const Node& node, diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h index 85e38b668cee4..584455f62cb4e 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h @@ -28,7 +28,8 @@ class BaseOpBuilder : public IOpBuilder { // Operator support related. public: bool IsOpSupported(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const override; + const WebnnDeviceType device_type, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; protected: virtual bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& /* node */, @@ -36,8 +37,10 @@ class BaseOpBuilder : public IOpBuilder { return true; } - virtual bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + virtual bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; + virtual bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const; // ONNX Runtime only *guarantees* support for models stamped // with opset version 7 or above for opset domain 'ai.onnx'. @@ -50,7 +53,7 @@ class BaseOpBuilder : public IOpBuilder { private: bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; - bool HasSupportedInputs(const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const; + bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; }; } // namespace webnn diff --git a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc index 555de68cd60fe..af82a01b14de5 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc @@ -22,7 +22,7 @@ class BinaryOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -86,7 +86,7 @@ bool BinaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return true; } -bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, +bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -97,36 +97,14 @@ bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDevice !GetType(*input_defs[1], input1_type, logger)) return false; - std::unordered_set supported_data_types; - // WebNN prelu op only supports float32, float16, int32, int8 input data types. - if (op_type == "Prelu") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - ONNX_NAMESPACE::TensorProto_DataType_INT32, - ONNX_NAMESPACE::TensorProto_DataType_INT8, - }; - // WebNN CPU backend doesn't support int32 for prelu. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32); - } - } else { - supported_data_types = webnn_supported_data_types; - } - if (!IsSupportedDataType(input0_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; + std::array input_types{input0_type, input1_type}; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - if (input0_type != input1_type) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; - return false; - } - - return true; + std::string webnn_input_name = op_type == "PRelu" ? "input" : "a"; + std::string onnx_input_name = op_type == "PRelu" || op_type == "Pow" ? "X" : "A"; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, webnn_input_name, onnx_input_name, logger); } void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc index a08e1681a8464..3c4fc822f3d01 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc @@ -21,8 +21,8 @@ class CastOpBuilder : public BaseOpBuilder { // Operator support related. private: - bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; }; // Add operator related. @@ -80,26 +80,22 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } // Operator support related. +bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + int32_t input_type; -bool CastOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, - const Node& node, - const WebnnDeviceType device_type, - const logging::Logger& logger) const { - NodeAttrHelper helper(node); - // Check cast output type. - const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED); - - // WebNN CPU backend doesn't support casting to uint64 data type. - if (device_type == WebnnDeviceType::CPU && to_type == ONNX_NAMESPACE::TensorProto_DataType_UINT64) { - LOGS(logger, VERBOSE) << "Cast to uint64 is not supported for WebNN CPU backend."; + if (!GetType(*input_defs[0], input_type, logger)) return false; - } - if (!IsSupportedDataType(to_type, webnn_supported_data_types)) { - LOGS(logger, VERBOSE) << "WebNN doesn't support casting to type " << to_type << "."; + + if (!IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "input", logger)) return false; - } - return true; + NodeAttrHelper helper(node); + // Check cast to type. + const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED); + return IsDataTypeSupportedByOp(op_type, to_type, wnn_limits, "output", "to", logger); } void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc index b5c3206072d50..374143c886849 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc @@ -25,8 +25,6 @@ class ClipOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -94,33 +92,6 @@ bool ClipOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, }; } -bool ClipOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support int32, uint32, int64, uint64 input data types for clamp. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT64); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc index dedc76b80e978..48dd6f3beb020 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc @@ -19,6 +19,10 @@ class ConcatOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; }; // Add operator related. @@ -52,6 +56,30 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } +bool ConcatOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + int32_t input0_type; + + if (!GetType(*input_defs[0], input0_type, logger)) + return false; + + for (size_t i = 1; i < input_defs.size(); i++) { + int32_t input_type; + if (!GetType(*input_defs[i], input_type, logger)) { + return false; + } + + std::array input_types{input0_type, input_type}; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { + return false; + } + } + + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "inputs", "inputs", logger); +} + void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 76a8a178678df..f03e5b90ff6db 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -29,7 +29,7 @@ class ConvOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -311,12 +311,12 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (input_defs.size() >= 3) { x_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); } else { - x_zero_point = model_builder.GetZeroConstant("uint8"); + x_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8); } if (input_defs.size() >= 4) { w_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name()); } else { - w_zero_point = model_builder.GetZeroConstant("uint8"); + w_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8); } output = model_builder.GetBuilder().call("conv2dInteger", input, x_zero_point, filter, w_zero_point, options); @@ -397,7 +397,7 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } -bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, +bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -415,35 +415,18 @@ bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTy return false; } - std::unordered_set supported_data_types; - if (op_type == "Conv" || op_type == "ConvTranspose") { - // WebNN conv2d and convTranspose2d only support float32 and float16 input data types. - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - } else if (op_type == "ConvInteger") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_INT8, - ONNX_NAMESPACE::TensorProto_DataType_UINT8, - }; + InlinedVector input_types = {input0_type, input1_type}; + if (has_input2) { + input_types.push_back(input2_type); } - if (!IsSupportedDataType(input0_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; - return false; + if (has_input3) { + input_types.push_back(input3_type); } - - if (input0_type != input1_type || - (has_input2 && input0_type != input2_type) || - (has_input3 && input0_type != input3_type)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - return true; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); } void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/dequantizeLinear_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/dequantizeLinear_op_builder.cc deleted file mode 100644 index 93a12a696cce1..0000000000000 --- a/onnxruntime/core/providers/webnn/builders/impl/dequantizeLinear_op_builder.cc +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Copyright (c) Intel Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/common/safeint.h" -#include "core/optimizer/initializer.h" -#include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" -#include "core/providers/webnn/builders/helper.h" -#include "core/providers/webnn/builders/model_builder.h" -#include "core/providers/webnn/builders/op_builder_factory.h" - -#include "core/providers/webnn/builders/impl/base_op_builder.h" - -namespace onnxruntime { -namespace webnn { - -class DequantizeLinearOpBuilder : public BaseOpBuilder { - // Add operator related. - private: - Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& logger) const override ORT_MUST_USE_RESULT; -}; - -Status DequantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, - const Node& node, - const logging::Logger& logger) const { - const auto& input_defs = node.InputDefs(); - emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); - emscripten::val scale = model_builder.GetOperand(input_defs[1]->Name()); - emscripten::val zero_point = emscripten::val::null(); - if (input_defs.size() == 3) { - zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); - } else { - zero_point = model_builder.GetZeroConstant("uint8"); - } - emscripten::val output; - std::vector input_shape; - std::vector scale_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape"); - ORT_RETURN_IF_NOT(GetShape(*input_defs[1], scale_shape, logger), "Cannot get scale shape"); - NodeAttrHelper helper(node); - int32_t axis = helper.Get("axis", 1); - // axis is valid for input shape greater than 1D. - if (input_shape.size() > 1) { - axis = static_cast(HandleNegativeAxis(axis, input_shape.size())); - } - // Insert ones before and after the axis dimension for broadcasting of 1D scale tensor. - if (1 == scale_shape.size() && 1 < input_shape.size()) { - std::vector target_shape{static_cast(input_shape[axis])}; - target_shape.insert(target_shape.begin(), axis, 1); - target_shape.insert(target_shape.end(), input_shape.size() - axis - 1, 1); - emscripten::val reshape_scale_options = emscripten::val::object(); - reshape_scale_options.set("label", node.Name() + "_reshape_scale"); - scale = model_builder.GetBuilder().call("reshape", - scale, - emscripten::val::array(target_shape), - reshape_scale_options); - emscripten::val reshape_zero_point_options = emscripten::val::object(); - reshape_zero_point_options.set("label", node.Name() + "_reshape_zero_point"); - zero_point = model_builder.GetBuilder().call("reshape", - zero_point, - emscripten::val::array(target_shape), - reshape_zero_point_options); - } - emscripten::val options = emscripten::val::object(); - options.set("label", node.Name()); - output = model_builder.GetBuilder().call("dequantizeLinear", input, scale, zero_point, options); - - model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); - - return Status::OK(); -} - -void CreateDequantizeLinearOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { - op_registrations.builders.push_back(std::make_unique()); - op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); -} - -} // namespace webnn -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc index 23233539d34c7..ae9fe3e3f3bd1 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc @@ -22,7 +22,7 @@ class GatherOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -69,29 +69,19 @@ bool GatherOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ return true; } -bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, +bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input = *node.InputDefs()[0]; + const auto& indices = *node.InputDefs()[1]; const auto& op_type = node.OpType(); int32_t input_type; - if (!GetType(input, input_type, logger)) + int32_t indices_type; + if (!GetType(input, input_type, logger) || + !GetType(indices, indices_type, logger)) return false; - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint32, uint64 input data types for gather. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; + return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger) && + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index bd452b118fe3e..1477530ce1894 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -25,7 +25,7 @@ class GemmOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -113,12 +113,12 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (input_defs.size() >= 3) { a_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); } else { - a_zero_point = model_builder.GetZeroConstant("uint8"); + a_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8); } if (input_defs.size() >= 4) { b_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name()); } else { - b_zero_point = model_builder.GetZeroConstant("uint8"); + b_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8); } output = model_builder.GetBuilder().call("matmulInteger", a, @@ -215,7 +215,7 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializer return true; } -bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, +bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -233,35 +233,18 @@ bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTy return false; } - std::unordered_set supported_data_types; - if (op_type == "Gemm" || op_type == "MatMul") { - // WebNN gemm and matmul only support float32 and float16 input data types. - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - } else if (op_type == "MatMulInteger") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_INT8, - ONNX_NAMESPACE::TensorProto_DataType_UINT8, - }; + InlinedVector input_types = {input0_type, input1_type}; + if (has_input2) { + input_types.push_back(input2_type); } - if (!IsSupportedDataType(input0_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; - return false; + if (has_input3) { + input_types.push_back(input3_type); } - - if (input0_type != input1_type || - (has_input2 && input0_type != input2_type) || - (has_input3 && input0_type != input3_type)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - return true; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "A", logger); } void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc new file mode 100644 index 0000000000000..c92fe7366d494 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc @@ -0,0 +1,234 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class GruOpBuilder : public BaseOpBuilder { + // Add operator related. + public: + void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; + + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; +}; + +void GruOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + if (node.InputDefs().size() > 4 && node.InputDefs()[4]->Exists()) { + model_builder.AddInitializerToSkip(node.InputDefs()[4]->Name()); // sequence_lens + model_builder.AddInputToSkip(node.InputDefs()[4]->Name()); + } +} + +Status GruOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const { + NodeAttrHelper helper(node); + uint32_t hidden_size = helper.Get("hidden_size", 1); + + const auto& input_defs = node.InputDefs(); + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input's shape"); + uint32_t steps = static_cast(input_shape[0]); + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val weight = model_builder.GetOperand(input_defs[1]->Name()); + emscripten::val recurrent_weight = model_builder.GetOperand(input_defs[2]->Name()); + + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + options.set("layout", emscripten::val("zrn")); + + if (input_defs.size() > 3 && input_defs[3]->Exists()) { + emscripten::val bias = model_builder.GetOperand(input_defs[3]->Name()); + emscripten::val split_options = emscripten::val::object(); + split_options.set("label", node.Name() + "_split"); + split_options.set("axis", 1); + // Split it to bias and recurrentBias. + emscripten::val splitted_biases = + model_builder.GetBuilder().call("split", bias, /*splits*/ 2, split_options); + options.set("bias", splitted_biases[0]); + options.set("recurrentBias", splitted_biases[1]); + } + + if (input_defs.size() > 5 && input_defs[5]->Exists()) { + options.set("initialHiddenState", model_builder.GetOperand(input_defs[5]->Name())); + } + + bool linear_before_reset = !!helper.Get("linear_before_reset ", 0); + options.set("resetAfter", linear_before_reset); + + const auto& output_defs = node.OutputDefs(); + bool has_Y = output_defs.size() > 0 && output_defs[0]->Exists(); + bool has_Y_h = output_defs.size() > 1 && output_defs[1]->Exists(); + options.set("returnSequence", has_Y); + + std::string direction = helper.Get("direction", "forward"); + if (direction == "forward") { + options.set("direction", emscripten::val("forward")); + } else if (direction == "reverse") { + options.set("direction", emscripten::val("backward")); + } else if (direction == "bidirectional") { + options.set("direction", emscripten::val("both")); + } + + if (helper.HasAttr("activations")) { + const auto activations = helper.Get("activations", std::vector{"Sigmoid", "Tanh"}); + emscripten::val recurrent_network_activations = emscripten::val::array(); + for (size_t i = 0; i < 2; ++i) { + const std::string& activation = activations[i]; + if (activation == "Relu") { + recurrent_network_activations.call("push", emscripten::val("relu")); + } else if (activation == "Sigmoid") { + recurrent_network_activations.call("push", emscripten::val("sigmoid")); + } else if (activation == "Tanh") { + recurrent_network_activations.call("push", emscripten::val("tanh")); + } + } + + options.set("activations", recurrent_network_activations); + } + + emscripten::val outputs = model_builder.GetBuilder().call("gru", input, weight, recurrent_weight, + steps, hidden_size, options); + + if (has_Y) { + model_builder.AddOperand(output_defs[0]->Name(), outputs[1]); + } + if (has_Y_h) { + model_builder.AddOperand(output_defs[1]->Name(), outputs[0]); + } + + return Status::OK(); +} + +bool GruOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + if (input_defs.size() < 3) { + LOGS(logger, ERROR) << "GRU: input size must greater than or equal to 3"; + return false; + } + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger) || input_shape.empty()) { + LOGS(logger, ERROR) << "Cannot get input's shape"; + return false; + } + int32_t steps = static_cast(input_shape[0]); + + if (input_defs.size() > 4 && input_defs[4]->Exists()) { + if (!Contains(initializers, input_defs[4]->Name())) { + LOGS(logger, ERROR) << "GRU: sequence_lens must be constant"; + return false; + } + + const auto& sequence_lens_tensor = *initializers.at(input_defs[4]->Name()); + std::vector sequence_lens; + if (!ReadIntArrayFrom1DTensor(sequence_lens_tensor, sequence_lens, logger)) { + LOGS(logger, ERROR) << "Cannot read sequence lens tensor"; + return false; + } + if (!std::all_of(sequence_lens.begin(), sequence_lens.end(), + [steps](int32_t lens) -> bool { return steps == lens; })) { + LOGS(logger, ERROR) << "GRU: every sequence length must be equal to input shape[0]"; + return false; + } + } + + NodeAttrHelper helper(node); + if (helper.HasAttr("activations")) { + const auto activations = helper.Get("activations", std::vector{"Sigmoid", "Tanh"}); + + if (activations.size() >= 4) { + if (activations[0] != activations[2] || activations[1] != activations[3]) { + LOGS(logger, ERROR) << "GRU: forward and reverse directions must have the same activations"; + return false; + } + } + + const InlinedHashSet supported_activations = {"Relu", "Tanh", "Sigmoid"}; + if (!std::all_of(activations.begin(), activations.end(), + [&supported_activations](const std::string& activation) -> bool { + return supported_activations.contains(activation); + })) { + LOGS(logger, ERROR) << "GRU: activations must be one of Relu, Tanh, Sigmoid"; + return false; + } + } + + if (helper.Get("clip", std::numeric_limits::max()) != std::numeric_limits::max()) { + LOGS(logger, ERROR) << "GRU: clip is not supported"; + return false; + } + + if (helper.Get("layout", 0) != 0) { + LOGS(logger, ERROR) << "GRU: batchwise (layout == 1) is not supported"; + return false; + } + + return true; +} + +bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + int32_t input0_type = 0; // input data type + int32_t input1_type = 0; // weight data type + int32_t input2_type = 0; // recurrentWeight data type + int32_t input3_type = 0; // bias data type + int32_t input4_type = 0; // recurrentBias data type + int32_t input5_type = 0; // initialHiddenState data type + bool has_input3 = input_defs.size() > 3 && input_defs[3]->Exists(); + bool has_input4 = input_defs.size() > 4 && input_defs[4]->Exists(); + bool has_input5 = input_defs.size() > 5 && input_defs[5]->Exists(); + + if (!GetType(*input_defs[0], input0_type, logger) || + !GetType(*input_defs[1], input1_type, logger) || + !GetType(*input_defs[2], input2_type, logger) || + (has_input3 && !GetType(*input_defs[3], input3_type, logger)) || + (has_input4 && !GetType(*input_defs[4], input4_type, logger)) || + (has_input5 && !GetType(*input_defs[5], input5_type, logger))) { + return false; + } + + InlinedVector input_types = {input0_type, input1_type, input2_type}; + if (has_input3) { + input_types.push_back(input3_type); + } + if (has_input4) { + input_types.push_back(input4_type); + } + if (has_input5) { + input_types.push_back(input5_type); + } + if (!AreInputDataTypesSame(op_type, input_types, logger)) { + return false; + } + + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); +} + +void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc index 23f3a938fee5e..ea7f70b4598e6 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc @@ -21,7 +21,7 @@ class LogicalOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -29,9 +29,14 @@ class LogicalOpBuilder : public BaseOpBuilder { Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { + const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); - emscripten::val input0 = model_builder.GetOperand(node.InputDefs()[0]->Name()); - emscripten::val input1 = model_builder.GetOperand(node.InputDefs()[1]->Name()); + emscripten::val input0 = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val input1 = emscripten::val::undefined(); + if (input_defs.size() > 1) { + input1 = model_builder.GetOperand(input_defs[1]->Name()); + } + emscripten::val output = emscripten::val::object(); emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); @@ -45,6 +50,8 @@ Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons output = model_builder.GetBuilder().call("lesser", input0, input1, options); } else if (op_type == "LessOrEqual") { output = model_builder.GetBuilder().call("lesserOrEqual", input0, input1, options); + } else if (op_type == "Not") { + output = model_builder.GetBuilder().call("logicalNot", input0, options); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "LogicalOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); @@ -61,7 +68,7 @@ bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali const auto& name = node.Name(); const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); - if (input_defs.size() < 2) { + if (input_defs.size() < 2 && op_type != "Not") { LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 2 inputs, actual: " << input_defs.size(); return false; @@ -69,31 +76,27 @@ bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali return true; } -bool LogicalOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, +bool LogicalOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; int32_t input1_type; - if (!GetType(*input_defs[0], input0_type, logger) || - !GetType(*input_defs[1], input1_type, logger)) - return false; - - if (!IsSupportedDataType(input0_type, webnn_supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; + if (!GetType(*input_defs[0], input0_type, logger)) return false; - } - if (input0_type != input1_type) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; - return false; + if (op_type != "Not") { + if (!GetType(*input_defs[1], input1_type, logger)) + return false; + std::array input_types{input0_type, input1_type}; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { + return false; + } } - return true; + std::string onnx_input_name = op_type == "Not" ? "X" : "A"; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", onnx_input_name, logger); } void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { @@ -107,6 +110,7 @@ void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& "GreaterOrEqual", "Less", "LessOrEqual", + "Not", }; op_registrations.builders.push_back(std::make_unique()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc index 5d88afda7b6a7..e111ca412c6e9 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc @@ -22,7 +22,7 @@ class MaxMinOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -87,31 +87,28 @@ bool MaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ return true; } -bool MaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, +bool MaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; - int32_t input1_type; - if (!GetType(*input_defs[0], input0_type, logger) || - !GetType(*input_defs[1], input1_type, logger)) + if (!GetType(*input_defs[0], input0_type, logger)) return false; - if (!IsSupportedDataType(input0_type, webnn_supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; - return false; - } + for (size_t i = 1; i < input_defs.size(); i++) { + int32_t input_type; + if (!GetType(*input_defs[i], input_type, logger)) { + return false; + } - if (input0_type != input1_type) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; - return false; + std::array input_types{input0_type, input_type}; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { + return false; + } } - return true; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "data_0", logger); } void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index 4d068baf35e72..a3c6b8fdcea9b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -25,7 +25,7 @@ class NormalizationOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -182,7 +182,7 @@ bool NormalizationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initi return true; } -bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, +bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -203,30 +203,21 @@ bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const Webn return false; } - // WebNN batchNormalization, instanceNormalization, layerNormalization - // only support float32 and float16 input data types. - std::unordered_set supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - - if (!IsSupportedDataType(input0_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; - return false; + std::vector input_types = {input0_type, input1_type}; + if (has_input2) { + input_types.push_back(input2_type); } - - if (input0_type != input1_type || - (has_input2 && input0_type != input2_type) || - (has_input3 && input0_type != input3_type) || - (has_input4 && input0_type != input4_type)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; + if (has_input3) { + input_types.push_back(input3_type); + } + if (has_input4) { + input_types.push_back(input4_type); + } + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - return true; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); } void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc index 071155a2fb372..d8373a45e4423 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc @@ -28,8 +28,6 @@ class PadOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -196,31 +194,6 @@ bool PadOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } // namespace webnn -bool PadOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint32, uint64 input data types for pad. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc new file mode 100644 index 0000000000000..13dee667f6fd9 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc @@ -0,0 +1,152 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/optimizer/initializer.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "core/providers/webnn/builders/impl/base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class QDQOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; +}; + +Status QDQOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + const auto& op_type = node.OpType(); + const auto& input_defs = node.InputDefs(); + const auto& output_defs = node.OutputDefs(); + + std::vector input_shape; + std::vector scale_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape"); + ORT_RETURN_IF_NOT(GetShape(*input_defs[1], scale_shape, logger), "Cannot get scale shape"); + int32_t input_type = 0; + int32_t output_type = 0; + int32_t zero_point_type = 0; + ORT_RETURN_IF_NOT(GetType(*input_defs[0], input_type, logger), "Cannot get input data type"); + ORT_RETURN_IF_NOT(GetType(*output_defs[0], output_type, logger), "Cannot get output data type"); + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val scale = model_builder.GetOperand(input_defs[1]->Name()); + + emscripten::val zero_point = emscripten::val::null(); + if (input_defs.size() == 3 && input_defs[2]->Exists()) { + zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); + } else { + // DequantizeLinear: x_zero_point's data type equals to input data type + // QuantizeLinear: x_zero_point's data type equals to output data type + zero_point_type = op_type == "DequantizeLinear" ? input_type : output_type; + zero_point = model_builder.GetZeroConstant(zero_point_type); + } + + emscripten::val output; + NodeAttrHelper helper(node); + int32_t axis = helper.Get("axis", 1); + int32_t block_size = helper.Get("block_size", 0); + // axis is valid for input shape greater than 1D. + if (input_shape.size() > 1) { + axis = static_cast(HandleNegativeAxis(axis, input_shape.size())); + } + // Insert ones before and after the axis dimension for broadcasting of 1D scale tensor. + if (1 == scale_shape.size() && 1 < input_shape.size()) { + std::vector target_shape{static_cast(input_shape[axis])}; + target_shape.insert(target_shape.begin(), axis, 1); + target_shape.insert(target_shape.end(), input_shape.size() - axis - 1, 1); + emscripten::val reshape_scale_options = emscripten::val::object(); + reshape_scale_options.set("label", node.Name() + "_reshape_scale"); + scale = model_builder.GetBuilder().call("reshape", + scale, + emscripten::val::array(target_shape), + reshape_scale_options); + emscripten::val reshape_zero_point_options = emscripten::val::object(); + reshape_zero_point_options.set("label", node.Name() + "_reshape_zero_point"); + zero_point = model_builder.GetBuilder().call("reshape", + zero_point, + emscripten::val::array(target_shape), + reshape_zero_point_options); + } + + // If block_size is specified, we need to expand the scale and zero_point tensors. + if (block_size > 1) { + emscripten::val concat_scale_inputs = emscripten::val::array(); + emscripten::val concat_zero_point_inputs = emscripten::val::array(); + for (int i = 0; i < block_size; i++) { + concat_scale_inputs.call("push", scale); + concat_zero_point_inputs.call("push", zero_point); + } + + emscripten::val concat_scale_options = emscripten::val::object(); + concat_scale_options.set("label", node.Name() + "_concat_scale"); + scale = model_builder.GetBuilder().call("concat", concat_scale_inputs, axis, concat_scale_options); + + emscripten::val concat_zero_point_options = emscripten::val::object(); + concat_zero_point_options.set("label", node.Name() + "_concat_zero_point"); + zero_point = model_builder.GetBuilder().call( + "concat", concat_zero_point_inputs, axis, concat_zero_point_options); + } + + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + std::string webnn_op_type; + ORT_RETURN_IF_NOT(GetWebNNOpType(op_type, webnn_op_type), "Cannot get WebNN op type"); + output = model_builder.GetBuilder().call(webnn_op_type.c_str(), input, scale, zero_point, options); + + model_builder.AddOperand(output_defs[0]->Name(), std::move(output)); + + return Status::OK(); +} + +bool QDQOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + int32_t input0_type = 0; // input data type + int32_t input1_type = 0; // x_scale data type + int32_t input2_type = 0; // x_zero_point data type + bool has_input2 = input_defs.size() > 2 && input_defs[2]->Exists(); + + if (!GetType(*input_defs[0], input0_type, logger) || + !GetType(*input_defs[1], input1_type, logger) || + (has_input2 && !GetType(*input_defs[2], input2_type, logger))) { + return false; + } + + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "x", logger) && + IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "scale", "x_scale", logger) && + (!has_input2 || IsDataTypeSupportedByOp(op_type, input2_type, wnn_limits, "zeroPoint", "x_zero_point", logger)); +} + +void CreateQDQOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + if (op_registrations.op_builder_map.count(op_type) > 0) + return; + + static std::vector op_types = + { + "DequantizeLinear", + "QuantizeLinear", + }; + + op_registrations.builders.push_back(std::make_unique()); + for (const auto& type : op_types) { + op_registrations.op_builder_map.emplace(type, op_registrations.builders.back().get()); + } +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc index 3e6d4d9820e9a..93ad933d71c34 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc @@ -31,8 +31,6 @@ class ReductionOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -147,56 +145,6 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializ return true; } -bool ReductionOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types; - if (op_type == "ReduceL1" || op_type == "ReduceProd" || - op_type == "ReduceSum" || op_type == "ReduceSumSquare") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - ONNX_NAMESPACE::TensorProto_DataType_INT32, - ONNX_NAMESPACE::TensorProto_DataType_UINT32, - ONNX_NAMESPACE::TensorProto_DataType_INT64, - ONNX_NAMESPACE::TensorProto_DataType_UINT64, - }; - - if (device_type == WebnnDeviceType::CPU) { - // WebNN CPU backend doesn't support uint32 and uint64 for reduceL1, - // reduceProd, reduceSum and reduceSumSquare. - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - } else if (op_type == "ReduceL2" || op_type == "ReduceLogSum" || - op_type == "ReduceLogSumExp" || op_type == "ReduceMean") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - } else { // ReduceMax and ReduceMin - supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint32, uint64 for reduceMax and reduceMin. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - } - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc index 2218c858951d3..9dc79f4f52f46 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc @@ -35,8 +35,6 @@ class ResizeOpBuilder : public BaseOpBuilder { // Resize opset 10- is very different than Resize opset 11+, with many key attributes missing. // We only support Resize opset 11+ here. int GetMinSupportedOpSet(const Node& /* node */) const override { return 11; } - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const override; }; // Helper functions @@ -275,30 +273,6 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return true; } -bool ResizeOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - // WebNN resample2d op only supports float32 and float16 input data types. - std::unordered_set supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc index 0eb7dafdffe4d..6b56d2c740f40 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc @@ -18,11 +18,6 @@ class ShapeOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - - // Operator support related. - private: - bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const override; }; Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -69,28 +64,6 @@ Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -// Operator support related. - -bool ShapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, - const Node& node, - const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - int32_t output_type = ONNX_NAMESPACE::TensorProto_DataType_INT64; - if (!IsSupportedDataType(output_type, webnn_supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << node.OpType() - << "] Output type: [" << output_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateShapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc index bef13841c646c..3f0d633ac888b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc @@ -29,8 +29,6 @@ class SliceOpBuilder : public BaseOpBuilder { const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; // TODO: Support Slice opset < 10, which uses attributes for starts and ends. int GetMinSupportedOpSet(const Node& /* node */) const override { return 10; } - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -166,30 +164,6 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } -bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint64 input data type for slice. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc index 798cfabae65db..b1b737b114998 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc @@ -24,8 +24,6 @@ class SoftmaxOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const override; }; Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -63,30 +61,6 @@ bool SoftmaxOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali return true; } -bool SoftmaxOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - // WebNN softmax only supports float32 and float16 input data types. - std::unordered_set supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc index 2ed8330bf25be..4b6cf312074ba 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc @@ -18,7 +18,7 @@ class TernaryOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -46,7 +46,7 @@ Status TernaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons return Status::OK(); } -bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, +bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -59,27 +59,14 @@ bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDevic !GetType(*input_defs[2], input2_type, logger)) return false; - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint64 X, Y data type for where. - if (device_type == WebnnDeviceType::CPU && op_type == "Where") { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } // ONNX's condition data type is bool which is same as WebNN. // Only need to check X, Y data types. - if (!IsSupportedDataType(input1_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input1_type - << "] is not supported for now"; - return false; - } - - if (input1_type != input2_type) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input X, Y data types should be the same."; + std::array input_types{input1_type, input2_type}; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - return true; + return IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "trueValue", "X", logger); } void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc index 03c88ad9db88a..3a5e39f7f7a56 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc @@ -18,8 +18,6 @@ class TransposeOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -50,31 +48,6 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -bool TransposeOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint32, uint64 input data types for transpose. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc index 061404c8a9ce0..8e64e98445f03 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc @@ -18,8 +18,6 @@ class UnaryOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const override; }; // Add operator related. @@ -51,8 +49,6 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const output = model_builder.GetBuilder().call("log", input, options); } else if (op_type == "Neg") { output = model_builder.GetBuilder().call("neg", input, options); - } else if (op_type == "Not") { - output = model_builder.GetBuilder().call("logicalNot", input, options); } else if (op_type == "Reciprocal") { output = model_builder.GetBuilder().call("reciprocal", input, options); } else if (op_type == "Sin") { @@ -70,44 +66,6 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const return Status::OK(); } -bool UnaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types; - if (op_type == "Identity") { - supported_data_types = webnn_supported_data_types; - } else if (op_type == "Abs" || op_type == "Neg") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - ONNX_NAMESPACE::TensorProto_DataType_INT32, - ONNX_NAMESPACE::TensorProto_DataType_INT8, - }; - } else if (op_type == "Not") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_BOOL, - }; - } else { // Others only support float32, float16 input data types. - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - } - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; @@ -123,7 +81,6 @@ void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op "Identity", "Log", "Neg", - "Not", "Reciprocal", "Sin", "Sqrt", diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 44bec1fb6fd48..f9f8264b234bb 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -21,12 +21,13 @@ namespace webnn { ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, const emscripten::val& context, const DataLayout preferred_layout, - const WebnnDeviceType wnn_device_type) + const WebnnDeviceType wnn_device_type, const emscripten::val& wnn_limits) : graph_viewer_(graph_viewer), logger_(logger), wnn_context_(context), preferred_layout_(preferred_layout), - wnn_device_type_(wnn_device_type) { + wnn_device_type_(wnn_device_type), + wnn_limits_(wnn_limits) { // Create WebNN MLGraphBuilder for each ModelBuilder, because MLGraphBuilder.build() // is only allowed to be called once. wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(context); @@ -102,7 +103,7 @@ Status ModelBuilder::RegisterInitializers() { desc.set("dimensions", emscripten::val::array(dims)); auto data_type = tensor.data_type(); emscripten::val operand = emscripten::val::object(); - if (IsSupportedDataType(data_type, webnn_supported_data_types)) { + if (IsSupportedDataType(data_type, wnn_limits_["constant"]["dataTypes"])) { ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type"); auto num_elements = SafeInt(Product(shape)); emscripten::val view = emscripten::val::undefined(); @@ -353,27 +354,48 @@ void ModelBuilder::AddOperand(const std::string& name, const emscripten::val& op // https://webmachinelearning.github.io/webnn/#api-mlgraphbuilder-constant-value-type // BTW, the spec is discussing if the builer.constant(value, type) should be dropped at // https://github.com/webmachinelearning/webnn/issues/475. Fix me according to the spec decision. -const emscripten::val& ModelBuilder::GetZeroConstant(const std::string& data_type) { - std::string name = "webnn_zero_constant_" + data_type; +const emscripten::val& ModelBuilder::GetZeroConstant(const int32_t& data_type) { + std::string name = "webnn_zero_constant_" + std::to_string(data_type); // If the operand does not exist, create it. if (wnn_operands_.find(name) == wnn_operands_.end()) { emscripten::val desc = emscripten::val::object(); emscripten::val dims = emscripten::val::array(); desc.set("dimensions", dims); emscripten::val zero_buffer = emscripten::val::undefined(); - if (data_type == "uint8") { - if (!SetWebnnDataType(desc, ONNX_NAMESPACE::TensorProto_DataType_UINT8)) { - ORT_THROW("Unsupported data type: " + data_type); - } - zero_buffer = emscripten::val::global("Uint8Array").new_(1); - } else if (data_type == "float32") { - if (!SetWebnnDataType(desc, ONNX_NAMESPACE::TensorProto_DataType_FLOAT)) { - ORT_THROW("Unsupported data type: " + data_type); - } - zero_buffer = emscripten::val::global("Float32Array").new_(1); - } else { - ORT_THROW("Unsupported data type: " + data_type); + if (!SetWebnnDataType(desc, data_type)) { + ORT_THROW("Unsupported data type: " + std::to_string(data_type)); } + + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: + zero_buffer = emscripten::val::global("Uint8Array").new_(1); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + zero_buffer = emscripten::val::global("Int8Array").new_(1); + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + zero_buffer = emscripten::val::global("Uint16Array").new_(1); + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + zero_buffer = emscripten::val::global("Float32Array").new_(1); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + zero_buffer = emscripten::val::global("Int32Array").new_(1); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + zero_buffer = emscripten::val::global("BigInt64Array").new_(1); + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: + zero_buffer = emscripten::val::global("Uint32Array").new_(1); + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: + zero_buffer = emscripten::val::global("BigUint64Array").new_(1); + break; + default: + break; + } + emscripten::val zero_constant = wnn_builder_.call("constant", desc, zero_buffer); wnn_operands_.insert(std::make_pair(name, zero_constant)); } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index 2d686070cdcc1..13937933a0a9c 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -23,7 +23,7 @@ class ModelBuilder { public: ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, const emscripten::val& context, const DataLayout preferred_layout, - const WebnnDeviceType wnn_device_type); + const WebnnDeviceType wnn_device_type, const emscripten::val& wnn_limits); ~ModelBuilder() = default; Status Compile(std::unique_ptr& model) ORT_MUST_USE_RESULT; @@ -35,8 +35,10 @@ class ModelBuilder { const emscripten::val& GetBuilder() const { return wnn_builder_; } const emscripten::val& GetContext() const { return wnn_context_; } const emscripten::val& GetOperand(const std::string& name) const { return wnn_operands_.at(name); } + const emscripten::val& GetOpSupportLimits() const { return wnn_limits_; } + void AddOperand(const std::string& name, const emscripten::val& operand); - const emscripten::val& GetZeroConstant(const std::string& data_type); + const emscripten::val& GetZeroConstant(const int32_t& data_type); // Use the buffers to persist WebNN allocated data like transposed weight. // It ensures the validity during inference session. std::vector> mem_persist_buffers_; @@ -66,6 +68,7 @@ class ModelBuilder { emscripten::val wnn_builder_ = emscripten::val::undefined(); DataLayout preferred_layout_; WebnnDeviceType wnn_device_type_; + emscripten::val wnn_limits_ = emscripten::val::undefined(); InlinedHashMap wnn_operands_; std::vector input_names_; std::vector output_names_; diff --git a/onnxruntime/core/providers/webnn/builders/op_builder.h b/onnxruntime/core/providers/webnn/builders/op_builder.h index 6ecc5d1068963..bb69a6a545597 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder.h +++ b/onnxruntime/core/providers/webnn/builders/op_builder.h @@ -29,7 +29,8 @@ class IOpBuilder { public: // Check if an operator is supported. virtual bool IsOpSupported(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const = 0; + const WebnnDeviceType device_type, const emscripten::val& wnn_limits, + const logging::Logger& logger) const = 0; }; } // namespace webnn diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 862cf5ded15bc..93a2b232a7d51 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -25,7 +25,6 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateUnaryOpBuilder("Identity", op_registrations); CreateUnaryOpBuilder("Log", op_registrations); CreateUnaryOpBuilder("Neg", op_registrations); - CreateUnaryOpBuilder("Not", op_registrations); CreateUnaryOpBuilder("Reciprocal", op_registrations); CreateUnaryOpBuilder("Sin", op_registrations); CreateUnaryOpBuilder("Sqrt", op_registrations); @@ -85,9 +84,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateDropoutOpBuilder("Dropout", op_registrations); } - { // Quantize/Dequantize + { // DequantizeLinear/QuantizeLinear/DynamicQuantizeLinear + CreateQDQOpBuilder("DequantizeLinear", op_registrations); + CreateQDQOpBuilder("QuantizeLinear", op_registrations); CreateDynamicQuantizeLinearOpBuilder("DynamicQuantizeLinear", op_registrations); - CreateDequantizeLinearOpBuilder("DequantizeLinear", op_registrations); } { // Expand @@ -108,12 +108,17 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateGemmOpBuilder("MatMulInteger", op_registrations); } + { // GRU + CreateGruOpBuilder("GRU", op_registrations); + } + { // Logical CreateLogicalOpBuilder("Equal", op_registrations); CreateLogicalOpBuilder("Greater", op_registrations); CreateLogicalOpBuilder("GreaterOrEqual", op_registrations); CreateLogicalOpBuilder("Less", op_registrations); CreateLogicalOpBuilder("LessOrEqual", op_registrations); + CreateLogicalOpBuilder("Not", op_registrations); } { // Max/Min diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h index e11938d8fa406..61fe6d936e9d1 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h @@ -28,16 +28,17 @@ void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_ void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateDropoutOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateDynamicQuantizeLinearOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); -void CreateDequantizeLinearOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreatePoolOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateQDQOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index b918daf838c99..b729623c5d3d8 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -21,10 +21,8 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f : IExecutionProvider{onnxruntime::kWebNNExecutionProvider} { // WebNN EP uses NHWC layout for CPU XNNPACK backend and NCHW for GPU DML backend. if (webnn_device_flags.compare("cpu") == 0) { - preferred_layout_ = DataLayout::NHWC; wnn_device_type_ = webnn::WebnnDeviceType::CPU; } else { - preferred_layout_ = DataLayout::NCHW; if (webnn_device_flags.compare("gpu") == 0) { wnn_device_type_ = webnn::WebnnDeviceType::GPU; } else if (webnn_device_flags.compare("npu") == 0) { @@ -38,6 +36,17 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f if (!wnn_context_.as()) { ORT_THROW("Failed to create WebNN context."); } + + // Retrieve the level of support for different WebNN operators. + // This varies across implementations and is obtained via the WebNN's opSupportLimits() function. + // https://www.w3.org/TR/webnn/#api-mlcontext-opsupportlimits + wnn_limits_ = wnn_context_.call("opSupportLimits"); + + if (wnn_limits_["preferredInputLayout"].as().compare("nhwc") == 0) { + preferred_layout_ = DataLayout::NHWC; + } else { + preferred_layout_ = DataLayout::NCHW; + } } WebNNExecutionProvider::~WebNNExecutionProvider() {} @@ -82,7 +91,7 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view ORT_THROW("Failed to create WebNN builder."); } - const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, logger); + const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, wnn_limits_, logger); wnn_builder = emscripten::val::undefined(); if (node_groups.empty()) { @@ -213,7 +222,7 @@ common::Status WebNNExecutionProvider::Compile(const std::vector model; ORT_RETURN_IF_ERROR(builder.Compile(model)); @@ -295,11 +304,6 @@ common::Status WebNNExecutionProvider::Compile(const std::vector>& clip_min_max, struct xnn_operator*& p, const OpQuantParam& quant_param, @@ -42,7 +41,6 @@ Status CreateXnnpackKernel(const PoolAttributes& pool_attrs, input_padding_bottom, input_padding_left, pooling_height, pooling_width, stride_height, stride_width, - C, C, C, // channels, input_pixel_stride, output_pixel_stride foutput_min, foutput_max, flags, &p); } else if (avgpool_type == OpComputeType::op_compute_type_qu8) { const float output_scale = quant_param[1].first[0]; @@ -53,7 +51,6 @@ Status CreateXnnpackKernel(const PoolAttributes& pool_attrs, input_padding_bottom, input_padding_left, pooling_height, pooling_width, stride_height, stride_width, - C, C, C, // channels, input_pixel_stride, output_pixel_stride quant_param[0].second, quant_param[0].first[0], quant_param[1].second, @@ -209,7 +206,7 @@ AveragePool::AveragePool(const OpKernelInfo& info) ORT_THROW("unsupported AveragePool in XnnpackEP, we have FLOAT|UINT8, but got ", stype); } struct xnn_operator* p; - auto ret = CreateXnnpackKernel(pool_attrs_, C, clip_min_max_, p, + auto ret = CreateXnnpackKernel(pool_attrs_, clip_min_max_, p, quant_param, avgpool_type_); ORT_ENFORCE(ret.IsOK(), ret.ErrorMessage()); op0_.reset(p); @@ -222,6 +219,7 @@ Status AveragePool::Compute(OpKernelContext* context) const { int64_t N = X_shape[0]; int64_t H = X_shape[1]; int64_t W = X_shape[2]; + int64_t C = X_shape[3]; // set the N dim to the correct value TensorShapeVector output_dims{output_dims_}; @@ -247,7 +245,7 @@ Status AveragePool::Compute(OpKernelContext* context) const { ? xnn_reshape_average_pooling2d_nhwc_f32 : xnn_reshape_average_pooling2d_nhwc_qu8; - auto status = reshape_fn(op0_.get(), N, H, W, + auto status = reshape_fn(op0_.get(), N, H, W, C, C, C, &workspace_size, &workspace_alignment, /*output_height_out=*/nullptr, /*output_width_out=*/nullptr, threadpool); diff --git a/onnxruntime/core/providers/xnnpack/nn/max_pool.cc b/onnxruntime/core/providers/xnnpack/nn/max_pool.cc index 2ef9f97f77b14..0f0b827974f66 100644 --- a/onnxruntime/core/providers/xnnpack/nn/max_pool.cc +++ b/onnxruntime/core/providers/xnnpack/nn/max_pool.cc @@ -172,7 +172,6 @@ MaxPool::MaxPool(const OpKernelInfo& info) pooling_height, pooling_width, stride_height, stride_width, dilation_height, dilation_width, - C, C, C, // channels, input_pixel_stride, output_pixel_stride foutput_min, foutput_max, flags, &p); } else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { maxpool_type_ = OpComputeType::op_compute_type_qu8; @@ -183,7 +182,6 @@ MaxPool::MaxPool(const OpKernelInfo& info) pooling_height, pooling_width, stride_height, stride_width, dilation_height, dilation_width, - C, C, C, // channels, input_pixel_stride, output_pixel_stride output_min, output_max, flags, &p); } else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_INT8) { maxpool_type_ = OpComputeType::op_compute_type_qs8; @@ -194,7 +192,6 @@ MaxPool::MaxPool(const OpKernelInfo& info) pooling_height, pooling_width, stride_height, stride_width, dilation_height, dilation_width, - C, C, C, // channels, input_pixel_stride, output_pixel_stride output_min, output_max, flags, &p); } else { auto stype = DataTypeImpl::ToString(DataTypeImpl::TypeFromProto(*X_arg.TypeAsProto())); @@ -213,6 +210,7 @@ Status MaxPool::Compute(OpKernelContext* context) const { int64_t N = X_shape[0]; int64_t H = X_shape[1]; int64_t W = X_shape[2]; + int64_t C = X_shape[3]; // set the N dim to the correct value TensorShapeVector output_dims{output_dims_}; @@ -234,6 +232,7 @@ Status MaxPool::Compute(OpKernelContext* context) const { } auto status = reshape_fn(op0_.get(), N, H, W, + C, C, C, // channels, input_pixel_stride, output_pixel_stride /*output_height_out=*/nullptr, /*output_width_out=*/nullptr, threadpool); if (status != xnn_status_success) { diff --git a/onnxruntime/core/providers/xnnpack/tensor/resize.cc b/onnxruntime/core/providers/xnnpack/tensor/resize.cc index cf874796ba169..db5648d5d6e54 100644 --- a/onnxruntime/core/providers/xnnpack/tensor/resize.cc +++ b/onnxruntime/core/providers/xnnpack/tensor/resize.cc @@ -214,8 +214,6 @@ Resize::Resize(const OpKernelInfo& info) : UpsampleBase(info), XnnpackKernel{inf } } - int64_t channels = x_shape->dim(3).dim_value(); - uint32_t flags = 0; ORT_ENFORCE(mode_ == UpsampleMode::LINEAR, "only support bilinear resize"); if (coordinate_transform_mode_ == ResizeCoordinateTransformationMode::ALIGN_CORNERS) { @@ -227,12 +225,14 @@ Resize::Resize(const OpKernelInfo& info) : UpsampleBase(info), XnnpackKernel{inf xnn_status xstatus = xnn_status_invalid_state; struct xnn_operator* p = nullptr; + auto out_h = output_dims_[1]; + auto out_w = output_dims_[2]; if (op_type_ == OpComputeType::op_compute_type_fp32) { - xstatus = xnn_create_resize_bilinear2d_nhwc_f32(channels, channels, channels, flags, &p); + xstatus = xnn_create_resize_bilinear2d_nhwc_f32(out_h, out_w, flags, &p); } else if (op_type_ == OpComputeType::op_compute_type_qu8) { - xstatus = xnn_create_resize_bilinear2d_nhwc_u8(channels, channels, channels, flags, &p); + xstatus = xnn_create_resize_bilinear2d_nhwc_u8(out_h, out_w, flags, &p); } else { - xstatus = xnn_create_resize_bilinear2d_nhwc_s8(channels, channels, channels, flags, &p); + xstatus = xnn_create_resize_bilinear2d_nhwc_s8(out_h, out_w, flags, &p); } ORT_ENFORCE(xstatus == xnn_status_success, "xnn_create_resize_bilinear2d_nhwc_", OpTypeToString(op_type_), " failed. Status:", @@ -248,6 +248,7 @@ Status Resize::ComputeInternal(OpKernelContext* ctx, const Tensor* input, auto N = X_shape[0]; auto H = X_shape[1]; auto W = X_shape[2]; + auto C = X_shape[3]; Tensor* output = ctx->Output(0, TensorShape(output_dims)); pthreadpool_t threadpool = GetThreadPool(); @@ -266,7 +267,7 @@ Status Resize::ComputeInternal(OpKernelContext* ctx, const Tensor* input, reshape_fn = xnn_reshape_resize_bilinear2d_nhwc_s8; } - auto status = reshape_fn(op0_.get(), N, H, W, output_dims[1], output_dims[2], + auto status = reshape_fn(op0_.get(), N, H, W, C, C, C, &workspace_size, &workspace_alignment, threadpool); if (status != xnn_status_success) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_resize_bilinear2d_nhwc_", OpTypeToString(op_type_), diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_kernel.h b/onnxruntime/core/providers/xnnpack/xnnpack_kernel.h index 0978a88288114..31512586be19d 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_kernel.h +++ b/onnxruntime/core/providers/xnnpack/xnnpack_kernel.h @@ -48,7 +48,7 @@ class XnnpackKernel : public OpKernel { // auto_code_cache.reset(&code_cache_); #endif // status = xnn_init_weights_cache(&weights_cache_); - xnn_weights_cache_t weights_cache = nullptr; + xnn_weights_cache_t weights_cache_provider = nullptr; status = xnn_create_weights_cache(&weights_cache, 0); ORT_ENFORCE(status == xnn_status_success, "Failed to create XNNPACK weights cache"); auto_weights_cache.reset(weights_cache); @@ -57,7 +57,7 @@ class XnnpackKernel : public OpKernel { } // std::unique_ptr auto_code_cache; - std::unique_ptr auto_weights_cache; + std::unique_ptr auto_weights_cache; // private: // #if defined(XNN_CACHE_ENABLE) && XNN_PLATFORM_JIT diff --git a/onnxruntime/python/onnxruntime_pybind_schema.cc b/onnxruntime/python/onnxruntime_pybind_schema.cc index c5757095e2e1e..1319e8f6fe959 100644 --- a/onnxruntime/python/onnxruntime_pybind_schema.cc +++ b/onnxruntime/python/onnxruntime_pybind_schema.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "python/onnxruntime_pybind_state_common.h" @@ -54,7 +55,7 @@ void addGlobalSchemaFunctions(pybind11::module& m) { onnxruntime::VitisAIProviderFactoryCreator::Create(ProviderOptions{}), #endif #ifdef USE_ACL - onnxruntime::ACLProviderFactoryCreator::Create(0), + onnxruntime::ACLProviderFactoryCreator::Create(false), #endif #ifdef USE_ARMNN onnxruntime::ArmNNProviderFactoryCreator::Create(0), diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 47b8d75f22aea..e8bf61612c89b 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "python/onnxruntime_pybind_exceptions.h" @@ -1141,8 +1142,25 @@ std::unique_ptr CreateExecutionProviderInstance( #endif } else if (type == kAclExecutionProvider) { #ifdef USE_ACL - return onnxruntime::ACLProviderFactoryCreator::Create( - session_options.enable_cpu_mem_arena) + bool enable_fast_math = false; + auto it = provider_options_map.find(type); + if (it != provider_options_map.end()) { + for (auto option : it->second) { + if (option.first == "enable_fast_math") { + std::set supported_values = {"true", "True", "false", "False"}; + if (supported_values.find(option.second) != supported_values.end()) { + enable_fast_math = (option.second == "true") || (option.second == "True"); + } else { + ORT_THROW( + "Invalid value for enable_fast_math. " + "Select from 'true' or 'false'\n"); + } + } else { + ORT_THROW("Unrecognized option: ", option.first); + } + } + } + return onnxruntime::ACLProviderFactoryCreator::Create(enable_fast_math) ->CreateProvider(); #endif } else if (type == kArmNNExecutionProvider) { diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index 4d6e411defae3..08e5e4f7b18fa 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #pragma once @@ -439,7 +440,7 @@ std::shared_ptr CreateExecutionProviderFactory_Dnnl(c std::shared_ptr CreateExecutionProviderFactory_Tvm(const tvm::TvmEPOptions& info); std::shared_ptr CreateExecutionProviderFactory_Tvm(const char* params); #endif -std::shared_ptr CreateExecutionProviderFactory_ACL(int use_arena); +std::shared_ptr CreateExecutionProviderFactory_ACL(bool enable_fast_math); std::shared_ptr CreateExecutionProviderFactory_ArmNN(int use_arena); std::shared_ptr CreateExecutionProviderFactory_DML(int device_id); std::shared_ptr CreateExecutionProviderFactory_Nnapi( diff --git a/onnxruntime/python/tools/transformers/fusion_gelu.py b/onnxruntime/python/tools/transformers/fusion_gelu.py index 8626ca1482104..6be5140c070d0 100644 --- a/onnxruntime/python/tools/transformers/fusion_gelu.py +++ b/onnxruntime/python/tools/transformers/fusion_gelu.py @@ -98,10 +98,13 @@ def fuse_1(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) return self.nodes_to_remove.extend(subgraph_nodes) - fused_node = helper.make_node("Gelu", inputs=[subgraph_input], outputs=[subgraph_output]) + fused_node = helper.make_node( + "Gelu", inputs=[subgraph_input], outputs=[subgraph_output], name=self.model.create_node_name("Gelu") + ) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) self.node_name_to_graph_name[fused_node.name] = self.this_graph_name + self.increase_counter("Gelu") return True def fuse_2(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]: @@ -172,10 +175,13 @@ def fuse_2(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) return self.nodes_to_remove.extend(subgraph_nodes) - fused_node = helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[mul.output[0]]) + fused_node = helper.make_node( + "Gelu", inputs=[root_node.output[0]], outputs=[mul.output[0]], name=self.model.create_node_name("Gelu") + ) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) self.node_name_to_graph_name[fused_node.name] = self.this_graph_name + self.increase_counter("Gelu") return True def fuse_3(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]: @@ -243,8 +249,11 @@ def fuse_3(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) return self.nodes_to_remove.extend(subgraph_nodes) - fused_node = helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[last_mul.output[0]]) + fused_node = helper.make_node( + "Gelu", inputs=[root_node.output[0]], outputs=[last_mul.output[0]], name=self.model.create_node_name("Gelu") + ) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) self.node_name_to_graph_name[fused_node.name] = self.this_graph_name + self.increase_counter("Gelu") return True diff --git a/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py b/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py index 1ec5edf686c63..a10b61fdc3f08 100644 --- a/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py @@ -24,14 +24,15 @@ def __init__( model: OnnxModel, fused_op_type: str = "SkipLayerNormalization", search_op_types: str = "LayerNormalization", + shape_infer: bool = True, ): super().__init__(model, fused_op_type, search_op_types) - # Update shape inference is needed since other fusions might add new edge which does not have shape info yet. - self.shape_infer_helper = self.model.infer_runtime_shape({"batch_size": 4, "seq_len": 7}, update=True) - - if self.shape_infer_helper is None: - # TODO(tianleiwu): support subgraph in shape inference or add broadcasting in SkipLayerNormalization op. - logger.warning("symbolic shape inference disabled or failed.") + if shape_infer: + # Update shape inference is needed since other fusions might add new edge which does not have shape info yet. + self.shape_infer_helper = self.model.infer_runtime_shape({"batch_size": 4, "seq_len": 7}, update=True) + if self.shape_infer_helper is None: + # TODO(tianleiwu): support subgraph in shape inference or add broadcasting in SkipLayerNormalization op. + logger.warning("symbolic shape inference disabled or failed.") def fuse(self, node, input_name_to_nodes, output_name_to_node): add = self.model.get_parent(node, 0, output_name_to_node) @@ -56,18 +57,19 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): # Root Mean Square Layer Normalization simplified = node.op_type == "SimplifiedLayerNormalization" - if self.shape_infer_helper is not None: - # TODO(tianleiwu): support broadcasting Skip shape (1, sequence_length, hidden_size) or (sequence_length, hidden_size) - if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]): - logger.debug( - "skip SkipLayerNormalization fusion since shape of inputs (%s, %s) are not same", - add.input[0], - add.input[1], - ) + if hasattr(self, "shape_infer_helper"): + if self.shape_infer_helper is not None: + # TODO(tianleiwu): support broadcasting Skip shape (1, sequence_length, hidden_size) or (sequence_length, hidden_size) + if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]): + logger.debug( + "skip SkipLayerNormalization fusion since shape of inputs (%s, %s) are not same", + add.input[0], + add.input[1], + ) + return + else: + logger.debug("skip SkipLayerNormalization fusion since symbolic shape inference failed") return - else: - logger.debug("skip SkipLayerNormalization fusion since symbolic shape inference failed") - return gather_path = self.model.match_parent_path(add, ["Gather"], [None]) if gather_path is not None and self.model.find_graph_input(gather_path[0].input[1]) is None: diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index a8fc6e661933e..fe80a08829263 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -63,9 +63,10 @@ def infer_runtime_shape(self, dynamic_axis_mapping={}, update=False): # noqa: B return None - def input_name_to_nodes(self): + def input_name_to_nodes(self, exclude_subgraphs=False): input_name_to_nodes = {} - for node in self.nodes(): + nodes_to_search = self.nodes() if not exclude_subgraphs else self.model.graph.node + for node in nodes_to_search: for input_name in node.input: if input_name: # could be empty when it is optional if input_name not in input_name_to_nodes: @@ -74,9 +75,10 @@ def input_name_to_nodes(self): input_name_to_nodes[input_name].append(node) return input_name_to_nodes - def output_name_to_node(self): + def output_name_to_node(self, exclude_subgraphs=False): output_name_to_node = {} - for node in self.nodes(): + nodes_to_search = self.nodes() if not exclude_subgraphs else self.model.graph.node + for node in nodes_to_search: for output_name in node.output: if output_name: # could be empty when it is optional output_name_to_node[output_name] = node @@ -906,6 +908,31 @@ def remove_unused_constant(self): if len(unused_nodes) > 0: logger.debug(f"Removed unused constant nodes: {len(unused_nodes)}") + def _get_subgraph_inputs_of_node(self, node): + """ + Get inputs to all nodes in all subgraphs of a node + """ + # Note: This function only handles one-level subgraphs of child nodes. + subgraph_nodes_inputs = set() + for attr in node.attribute: + if attr.type == AttributeProto.GRAPH: + child_nodes = attr.g.node + for child_node in child_nodes: + subgraph_nodes_inputs.update(child_node.input) + return subgraph_nodes_inputs + + def _get_subgraph_nodes_and_inputs(self, ops_with_graph_attrs): + """ + Get input names to all nodes in all subgraphs where subgraphs are + graph attributes of a node in the main graph + """ + subgraph_nodes = list(filter(lambda node: node.op_type in ops_with_graph_attrs, self.model.graph.node)) + subgraph_nodes_inputs = set() + for parent_node in subgraph_nodes: + subgraph_inputs_of_parent_node = self._get_subgraph_inputs_of_node(parent_node) + subgraph_nodes_inputs.update(subgraph_inputs_of_parent_node) + return subgraph_nodes, subgraph_nodes_inputs + def prune_graph(self, outputs=None, allow_remove_graph_inputs=True): """ Prune graph to keep only required outputs. It removes unnecessary nodes that are not linked @@ -918,13 +945,9 @@ def prune_graph(self, outputs=None, allow_remove_graph_inputs=True): allow_remove_graph_inputs (bool): allow remove graph inputs. """ - if len(self.graphs()) > 1: - # TODO(tianleiwu): handle subgraph - logger.debug("Skip prune_graph since graph has subgraph") - return - keep_outputs = [output.name for output in self.model.graph.output] if outputs is None else outputs + input_name_to_nodes_for_main_graph = self.input_name_to_nodes(exclude_subgraphs=True) output_name_to_node = self.output_name_to_node() def get_first_output(node): @@ -932,6 +955,29 @@ def get_first_output(node): return node.output[0] return next(iter([o for o in node.output if o]), None) + if len(self.graphs()) > 1: + # Get input names for all nodes in all subgraphs + subgraph_nodes, subgraph_nodes_inputs = self._get_subgraph_nodes_and_inputs( + ops_with_graph_attrs={"Loop", "Scan", "If"} + ) + if len(subgraph_nodes) == 0: + # TODO: support other ops such as `BeamSearch` that have subgraphs as op attributes + logger.debug("Skip prune_graph since graph has subgraph") + return + + # For graphs with subgraphs, add dangling outputs from parent graph nodes to list of outputs to keep + for node in self.model.graph.node: + # TODO: This for-loop logic currently assumes that Loop/Scan/If nodes will not be + # pruned because their subgraphs are needed for computations. This might not be + # true in all cases. + if node in subgraph_nodes: + continue + + # Check if node output is an input of a subgraph node and not an input to a node in the main graph + for output in node.output: + if output in subgraph_nodes_inputs and output not in input_name_to_nodes_for_main_graph: + keep_outputs += [output] + # Keep track of nodes to keep. The key is first output of node, and the value is the node. output_to_node = {} @@ -956,7 +1002,7 @@ def get_first_output(node): first_output = get_first_output(node) kept_node = output_to_node.get(first_output) - # Need double check the node since fused node might reuse output name of some nodes to be removed. + # Need to double check the node since fused node might reuse output name of some nodes to be removed. # It is slow to compare whole node, so we compare op_type first to avoid comparing node in most cases. if kept_node and kept_node.op_type == node.op_type and kept_node == node: nodes_to_keep.append(node) @@ -997,16 +1043,15 @@ def get_first_output(node): def update_graph(self, verbose=False, allow_remove_graph_inputs=False): graph = self.model.graph - remaining_input_names = [] + remaining_input_names = set() for node in graph.node: if node.op_type in ["Loop", "Scan", "If"]: - # TODO: handle inner graph - logger.debug(f"Skip update_graph since graph has operator: {node.op_type}") - return + # Add input names of nodes in subgraphs + subgraph_inputs_of_node = self._get_subgraph_inputs_of_node(node) + remaining_input_names.update(subgraph_inputs_of_node) + if node.op_type != "Constant": - for input_name in node.input: - if input_name not in remaining_input_names: - remaining_input_names.append(input_name) + remaining_input_names.update(node.input) if verbose: logger.debug(f"remaining input names: {remaining_input_names}") diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py index ad51c1cce0ec4..26464fc32817d 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py @@ -115,8 +115,8 @@ def fuse_simplified_layer_norm(self): fusion = FusionSimplifiedLayerNormalization(self) fusion.apply() - def fuse_skip_layer_norm(self): - fusion = FusionSkipLayerNormalization(self) + def fuse_skip_layer_norm(self, shape_infer=True): + fusion = FusionSkipLayerNormalization(self, shape_infer=shape_infer) fusion.apply() def fuse_skip_simplified_layer_norm(self): @@ -344,7 +344,7 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo self.fuse_reshape() if (options is None) or options.enable_skip_layer_norm: - self.fuse_skip_layer_norm() + self.fuse_skip_layer_norm(options.enable_shape_inference) self.fuse_skip_simplified_layer_norm() if (options is None) or options.enable_rotary_embeddings: diff --git a/onnxruntime/python/tools/transformers/onnx_model_clip.py b/onnxruntime/python/tools/transformers/onnx_model_clip.py index 32bddc3ca16a0..388d058c7856c 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_clip.py +++ b/onnxruntime/python/tools/transformers/onnx_model_clip.py @@ -24,6 +24,7 @@ def get_fused_operator_statistics(self): op_count = {} ops = [ "Attention", + "Gelu", "LayerNormalization", "QuickGelu", "SkipLayerNormalization", diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 548f24e8ac69e..fa7c6bce7c23e 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -262,8 +262,8 @@ void RunTest(const TestOptions& opts, } // namespace -TEST(MatMulNBits, Float32) { - // onnxruntime::profiling::Profiler::Profiler::Instance().StartProfiling("profile.json"); +template +void TestMatMulNBitsTyped() { for (auto M : {1, 2, 100}) { for (auto N : {/*2560, */ 1, 2, 32, 288}) { for (auto K : {/*2560, */ 16, 32, 64, 128, 256, 1024, 93, 1234}) { @@ -276,30 +276,53 @@ TEST(MatMulNBits, Float32) { if (base_opts.accuracy_level == 4) { base_opts.output_abs_error = 0.1f; + } else { + if constexpr (std::is_same::value) { + base_opts.output_abs_error = 0.01f; + } } { TestOptions opts = base_opts; - RunTest(opts); + RunTest(opts); } { TestOptions opts = base_opts; opts.has_zero_point = true; - RunTest(opts); + RunTest(opts); } #if !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) { TestOptions opts = base_opts; opts.has_g_idx = true; - RunTest(opts); + RunTest(opts); + } + + { + TestOptions opts = base_opts; + opts.has_g_idx = true; + opts.has_bias = true; + if constexpr (std::is_same::value) { + if (opts.accuracy_level == 0 || opts.accuracy_level == 1) { + // CI failure (not able to repro on either local machines): + // M:100, N:288, K:1234, block_size:16, accuracy_level:0, has_zero_point:0, zp_is_4bit:1, has_g_idx:1, has_bias:1 + // The difference between cur_expected[i] and cur_actual[i] is 1.0401010513305664e-05, which exceeds tolerance, + // tolerance evaluates to 1.006456386676291e-05. + opts.output_abs_error = 0.0001f; + } + } + // only enabled for CPU EP for now + std::vector> explicit_eps; + explicit_eps.emplace_back(DefaultCpuExecutionProvider()); + RunTest(opts, std::move(explicit_eps)); } { TestOptions opts = base_opts; opts.has_zero_point = true, opts.zp_is_4bit = false; - RunTest(opts); + RunTest(opts); } #endif // !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) @@ -311,7 +334,7 @@ TEST(MatMulNBits, Float32) { std::vector> explicit_eps; explicit_eps.emplace_back(DefaultCpuExecutionProvider()); - RunTest(opts, std::move(explicit_eps)); + RunTest(opts, std::move(explicit_eps)); } } } @@ -320,6 +343,21 @@ TEST(MatMulNBits, Float32) { } } +TEST(MatMulNBits, Float32) { + // onnxruntime::profiling::Profiler::Profiler::Instance().StartProfiling("profile.json"); + TestMatMulNBitsTyped(); +} + +#ifdef MLAS_TARGET_AMD64_IX86 +#if !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) +// Actual and expected difference is over 0.01 with DmlExecutionProvider. +// Skip the tests instead of raising the tolerance to make is pass. +TEST(MatMulNBits, Float16) { + TestMatMulNBitsTyped(); +} +#endif +#endif + #if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) namespace { @@ -367,7 +405,7 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura } } // namespace -TEST(MatMulNBits, Float16) { +TEST(MatMulNBits, Float16Cuda) { #if defined(USE_CUDA) || defined(USE_ROCM) auto has_gidx_options = {true, false}; #else diff --git a/onnxruntime/test/fuzzing/ort_libfuzzer/OrtLibfuzzer.cpp b/onnxruntime/test/fuzzing/ort_libfuzzer/OrtLibfuzzer.cpp new file mode 100644 index 0000000000000..406aca722bb67 --- /dev/null +++ b/onnxruntime/test/fuzzing/ort_libfuzzer/OrtLibfuzzer.cpp @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "OnnxPrediction.h" +#include "onnxruntime_session_options_config_keys.h" +#include "src/libfuzzer/libfuzzer_macro.h" +#include "fuzzer/FuzzedDataProvider.h" + +Ort::Env env; + +void predict(onnx::ModelProto& msg, unsigned int seed, Ort::Env& env) { + // Create object for prediction + // + OnnxPrediction predict(msg, env); + + // Give predict a function to generate the data + // to run prediction on. + // + predict.SetupInput(GenerateDataForInputTypeTensor, seed); + + // Run the prediction on the data + // + predict.RunInference(); +} + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + FuzzedDataProvider data_provider(data, size); + onnx::ModelProto msg; + try { + if (!msg.ParseFromArray(data, static_cast(size))) { + return 0; // Ignore invalid inputs + } + predict(msg, data_provider.ConsumeIntegral(), env); + } catch (const std::exception& e) { + // Optionally log or suppress the exception + // std::cerr << "Caught exception: " << e.what() << std::endl; + } catch (...) { + // Handle any other exceptions + // std::cerr << "Caught unknown exception." << std::endl; + } + return 0; +} diff --git a/onnxruntime/test/fuzzing/ort_libfuzzer/OrtProtoLibfuzzer.cpp b/onnxruntime/test/fuzzing/ort_libfuzzer/OrtProtoLibfuzzer.cpp new file mode 100644 index 0000000000000..607d9cfd9c755 --- /dev/null +++ b/onnxruntime/test/fuzzing/ort_libfuzzer/OrtProtoLibfuzzer.cpp @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "src/mutator.h" +#include "OnnxPrediction.h" +#include "onnxruntime_session_options_config_keys.h" +#include "src/libfuzzer/libfuzzer_macro.h" +#include "onnx/onnx_pb.h" + +#include + +Ort::Env env; + +std::string wstring_to_string(const std::wstring& wstr) { + std::wstring_convert> converter; + return converter.to_bytes(wstr); +} + +void predict(onnx::ModelProto& msg, unsigned int seed, Ort::Env& env) { + // Create object for prediction + // + OnnxPrediction predict(msg, env); + + // Give predict a function to generate the data + // to run prediction on. + // + predict.SetupInput(GenerateDataForInputTypeTensor, seed); + + // Run the prediction on the data + // + predict.RunInference(); + + // View the output + // + predict.PrintOutputValues(); +} + +template +using PostProcessor = + protobuf_mutator::libfuzzer::PostProcessorRegistration; + +// Helper function to generate random strings +std::string generate_random_string(size_t length, std::mt19937& rng) { + const std::string characters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; + std::uniform_int_distribution<> dist(0, characters.size() - 1); + std::string result; + for (size_t i = 0; i < length; ++i) { + result += characters[dist(rng)]; + } + return result; +} + +// Helper function to generate random float +float generate_random_float(std::mt19937& rng) { + std::uniform_real_distribution dist(0.0f, 1.0f); + return dist(rng); +} + +// PostProcessor for ONNX ModelProto with random values +static PostProcessor reg1 = { + [](onnx::ModelProto* model_proto, unsigned int seed) { + std::mt19937 rng(seed); + + // Set model's IR version + model_proto->set_ir_version(7); + + model_proto->set_producer_name("onnx"); + model_proto->set_producer_version("7.0"); + model_proto->set_domain("example.com"); + + // Add a dummy opset import + auto* opset_import = model_proto->add_opset_import(); + opset_import->set_version(10); + + // Access the graph from the model + auto* graph = model_proto->mutable_graph(); + + // Set a random name for the graph + graph->set_name(generate_random_string(10, rng)); + }}; + +DEFINE_PROTO_FUZZER(const onnx::ModelProto& msg) { + try { + auto seed = static_cast(std::chrono::system_clock::now().time_since_epoch().count()); + onnx::ModelProto msg_proto = msg; + predict(msg_proto, seed, env); + } catch (const std::exception& e) { + // Optionally log or suppress the exception + // std::cerr << "Caught exception: " << e.what() << std::endl; + } catch (...) { + // Handle any other exceptions + // std::cerr << "Caught unknown exception." << std::endl; + } +} diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 0397bba90438b..924616f49ab25 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include @@ -655,7 +656,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); } if (enable_acl) { #ifdef USE_ACL - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ACL(sf, enable_cpu_mem_arena ? 1 : 0)); + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ACL(sf, false)); #else fprintf(stderr, "ACL is not supported in this build"); return -1; diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 6ae66e35e7853..3aec0d5a67e94 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1469,6 +1469,54 @@ TEST_F(GraphTransformationTests, FusePadWithConv) { } } +TEST_F(GraphTransformationTests, FusePadWithNoPadsConv) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-pad-nopadsconv.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::vector expected_pads; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "Pad") { + const auto* pads_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); + Initializer pads{*pads_proto, graph.ModelPath()}; + gsl::span pads_values = pads.DataAsSpan(); + expected_pads.resize(pads_values.size() - 4); + + for (uint32_t pads_index = 2, index = 0; pads_index < pads_values.size() / 2; pads_index++, index++) { + expected_pads[index] = pads_values[pads_index]; + expected_pads[index + (expected_pads.size() / 2)] = pads_values[pads_index + (pads_values.size() / 2)]; + } + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Pad"], 0); + ASSERT_EQ(op_to_count["Conv"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Conv") { + auto child_pads = node.GetMutableAttributes()["pads"].mutable_ints(); + ASSERT_EQ(child_pads->size(), static_cast(expected_pads.size())) + << "fusion should produce the same size of pads integer as the Conv node"; + for (uint32_t index = 0; index < expected_pads.size(); index++) { + ASSERT_EQ(expected_pads[index], child_pads->Get(index)) + << "fusion does not produce correct padding value"; + } + } + } +} + TEST_F(GraphTransformationTests, FusePadWithMaxPool) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-pad-maxpool.onnx"; diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 7d06bbadbd645..c1c48d4945a4d 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2023 NVIDIA Corporation. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "command_args_parser.h" @@ -66,6 +67,7 @@ namespace perftest { "\t-i: Specify EP specific runtime options as key value pairs. Different runtime options available are: \n" "\t [Usage]: -e -i '| |'\n" "\n" + "\t [ACL only] [enable_fast_math]: Options: 'true', 'false', default: 'false', \n" "\t [DML only] [performance_preference]: DML device performance preference, options: 'default', 'minimum_power', 'high_performance', \n" "\t [DML only] [device_filter]: DML device filter, options: 'any', 'gpu', 'npu', \n" "\t [DML only] [disable_metacommands]: Options: 'true', 'false', \n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 837aeb3c37acd..3ed5eaee5a5f7 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2023 NVIDIA Corporation. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "ort_test_session.h" @@ -34,10 +35,18 @@ std::chrono::duration OnnxRuntimeTestSession::Run() { // Randomly pick one OrtValueArray from test_inputs_. (NOT ThreadSafe) const std::uniform_int_distribution::param_type p(0, static_cast(test_inputs_.size() - 1)); const size_t id = static_cast(dist_(rand_engine_, p)); + auto& input = test_inputs_.at(id); auto start = std::chrono::high_resolution_clock::now(); - auto output_values = session_.Run(Ort::RunOptions{nullptr}, input_names_.data(), input.data(), input_names_.size(), - output_names_raw_ptr.data(), output_names_raw_ptr.size()); + + if (!use_device_mem) { + auto output_values = session_.Run(Ort::RunOptions{nullptr}, input_names_.data(), input.data(), input_names_.size(), + output_names_raw_ptr.data(), output_names_raw_ptr.size()); + } else { + session_.Run(Ort::RunOptions{nullptr}, input_names_.data(), input.data(), input_names_.size(), + output_names_raw_ptr.data(), outputs_.data(), output_names_raw_ptr.size()); + } + auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration duration_seconds = end - start; return duration_seconds; @@ -511,9 +520,42 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #endif } else if (provider_name_ == onnxruntime::kAclExecutionProvider) { #ifdef USE_ACL +#if defined(_MSC_VER) + std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); +#else + std::string ov_string = performance_test_config.run_config.ep_runtime_config_string; +#endif // defined(_MSC_VER) + std::istringstream ss(ov_string); + std::string token; + bool enable_fast_math = false; + while (ss >> token) { + if (token == "") { + continue; + } + auto pos = token.find("|"); + if (pos == std::string::npos || pos == 0 || pos == token.length()) { + ORT_THROW("[ERROR] [ACL] Use a '|' to separate the key and value for the run-time option you are trying to use.\n"); + } + + auto key = token.substr(0, pos); + auto value = token.substr(pos + 1); + + if (key == "enable_fast_math") { + std::set ov_supported_values = {"true", "True", "false", "False"}; + if (ov_supported_values.find(value) != ov_supported_values.end()) { + enable_fast_math = (value == "true") || (value == "True"); + } else { + ORT_THROW( + "[ERROR] [ACL] You have selcted an invalid value for the key 'enable_fast_math'. " + "Select from 'true' or 'false' \n"); + } + } else { + ORT_THROW( + "[ERROR] [ACL] Unrecognized option: ", key); + } + } Ort::ThrowOnError( - OrtSessionOptionsAppendExecutionProvider_ACL(session_options, - performance_test_config.run_config.enable_cpu_mem_arena ? 1 : 0)); + OrtSessionOptionsAppendExecutionProvider_ACL(session_options, enable_fast_math)); #else ORT_THROW("Acl is not supported in this build\n"); #endif @@ -815,6 +857,10 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); "[ERROR] [OpenVINO] The value for the key 'export_ep_ctx_blob' " "should be a boolean i.e. true or false. Default value is false.\n"); } + } else if (key == "use_device_mem") { + if (value == "true" || value == "True") { + use_device_mem = true; + } } else { ORT_THROW("[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO. ['device_type', 'device_id', 'enable_npu_fast_compile', 'num_of_threads', 'cache_dir', 'num_streams', 'enable_opencl_throttling', 'disable_dynamic_shapes'] \n"); } @@ -858,6 +904,27 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); input_names_str_[i] = m.GetInputName(i); input_names_[i] = input_names_str_[i].c_str(); } + + if (use_device_mem) { + Ort::MemoryInfo memory_info = Ort::MemoryInfo("OpenVINO_RT_NPU", OrtArenaAllocator, 0, OrtMemTypeCPUOutput); + custom_allocator_ = std::make_unique(session_, memory_info); + for (size_t i = 0; i < output_names_raw_ptr.size(); i++) { + Ort::TypeInfo type_info = session_.GetOutputTypeInfo(i); + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + + std::vector output_shape = tensor_info.GetShape(); + + // free dimensions are treated as 1 if not overridden + for (int64_t& dim : output_shape) { + if (dim == -1) { + dim = 1; + } + } + + outputs_.push_back(Ort::Value::CreateTensor(*custom_allocator_, (const int64_t*)output_shape.data(), + output_shape.size(), tensor_info.GetElementType())); + } + } } template @@ -944,9 +1011,11 @@ bool OnnxRuntimeTestSession::PopulateGeneratedInputTestData(int32_t seed) { // iterate over all input nodes for (size_t i = 0; i < static_cast(input_length_); i++) { Ort::TypeInfo type_info = session_.GetInputTypeInfo(i); - Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); if (type_info.GetONNXType() == ONNX_TYPE_TENSOR) { auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + if (!use_device_mem) { + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + } std::vector input_node_dim = tensor_info.GetShape(); // free dimensions are treated as 1 if not overridden @@ -955,12 +1024,18 @@ bool OnnxRuntimeTestSession::PopulateGeneratedInputTestData(int32_t seed) { dim = 1; } } - - auto allocator = Ort::AllocatorWithDefaultOptions(); - Ort::Value input_tensor = Ort::Value::CreateTensor(allocator, (const int64_t*)input_node_dim.data(), - input_node_dim.size(), tensor_info.GetElementType()); - InitializeTensorWithSeed(seed, input_tensor); - PreLoadTestData(0, i, std::move(input_tensor)); + if (use_device_mem) { + Ort::Value input_tensor = Ort::Value::CreateTensor(*custom_allocator_, (const int64_t*)input_node_dim.data(), + input_node_dim.size(), tensor_info.GetElementType()); + InitializeTensorWithSeed(seed, input_tensor); + PreLoadTestData(0, i, std::move(input_tensor)); + } else { + auto allocator = Ort::AllocatorWithDefaultOptions(); + Ort::Value input_tensor = Ort::Value::CreateTensor(allocator, (const int64_t*)input_node_dim.data(), + input_node_dim.size(), tensor_info.GetElementType()); + InitializeTensorWithSeed(seed, input_tensor); + PreLoadTestData(0, i, std::move(input_tensor)); + } } } return true; diff --git a/onnxruntime/test/perftest/ort_test_session.h b/onnxruntime/test/perftest/ort_test_session.h index f1a4220ab325e..e33041a2a0958 100644 --- a/onnxruntime/test/perftest/ort_test_session.h +++ b/onnxruntime/test/perftest/ort_test_session.h @@ -38,6 +38,8 @@ class OnnxRuntimeTestSession : public TestSession { std::mt19937 rand_engine_; std::uniform_int_distribution dist_; std::vector> test_inputs_; + std::unique_ptr custom_allocator_; + std::vector outputs_; std::vector output_names_; // The same size with output_names_. // TODO: implement a customized allocator, then we can remove output_names_ to simplify this code @@ -46,6 +48,7 @@ class OnnxRuntimeTestSession : public TestSession { std::vector input_names_str_; const int input_length_; std::string provider_name_; + bool use_device_mem = false; }; } // namespace perftest diff --git a/onnxruntime/test/platform/apple/apple_package_test/apple_package_test.xcodeproj/project.pbxproj b/onnxruntime/test/platform/apple/apple_package_test/apple_package_test.xcodeproj/project.pbxproj index eb7345be3770b..e9536bb59bbae 100644 --- a/onnxruntime/test/platform/apple/apple_package_test/apple_package_test.xcodeproj/project.pbxproj +++ b/onnxruntime/test/platform/apple/apple_package_test/apple_package_test.xcodeproj/project.pbxproj @@ -456,7 +456,7 @@ GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; IPHONEOS_DEPLOYMENT_TARGET = 13.0; - MACOSX_DEPLOYMENT_TARGET = 11.0; + MACOSX_DEPLOYMENT_TARGET = 13.3; MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; MTL_FAST_MATH = YES; ONLY_ACTIVE_ARCH = YES; @@ -510,7 +510,7 @@ GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; IPHONEOS_DEPLOYMENT_TARGET = 13.0; - MACOSX_DEPLOYMENT_TARGET = 11.0; + MACOSX_DEPLOYMENT_TARGET = 13.3; MTL_ENABLE_DEBUG_INFO = NO; MTL_FAST_MATH = YES; SDKROOT = iphoneos; @@ -527,7 +527,7 @@ CODE_SIGNING_STYLE = Automatic; CODE_SIGN_ENTITLEMENTS = ios_package_test/ios_package_test.entitlements; INFOPLIST_FILE = ios_package_test/Info.plist; - IPHONEOS_DEPLOYMENT_TARGET = 14.0; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LD_RUNPATH_SEARCH_PATHS = ( "$(inherited)", "@executable_path/Frameworks", @@ -550,7 +550,7 @@ CODE_SIGNING_STYLE = Automatic; CODE_SIGN_ENTITLEMENTS = ios_package_test/ios_package_test.entitlements; INFOPLIST_FILE = ios_package_test/Info.plist; - IPHONEOS_DEPLOYMENT_TARGET = 14.0; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LD_RUNPATH_SEARCH_PATHS = ( "$(inherited)", "@executable_path/Frameworks", @@ -571,7 +571,7 @@ CODE_SIGN_STYLE = Automatic; CURRENT_PROJECT_VERSION = 1; GENERATE_INFOPLIST_FILE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 14.0; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LD_RUNPATH_SEARCH_PATHS = ( "$(inherited)", "@executable_path/Frameworks", @@ -593,7 +593,7 @@ CODE_SIGN_STYLE = Automatic; CURRENT_PROJECT_VERSION = 1; GENERATE_INFOPLIST_FILE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 14.0; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LD_RUNPATH_SEARCH_PATHS = ( "$(inherited)", "@executable_path/Frameworks", @@ -631,7 +631,7 @@ "@executable_path/../Frameworks", ); LOCALIZATION_PREFERS_STRING_CATALOGS = YES; - MACOSX_DEPLOYMENT_TARGET = 11.0; + MACOSX_DEPLOYMENT_TARGET = 13.3; MARKETING_VERSION = 1.0; PRODUCT_BUNDLE_IDENTIFIER = "ai.onnxruntime.tests.macos-package-test"; PRODUCT_NAME = "$(TARGET_NAME)"; @@ -663,7 +663,7 @@ "@executable_path/../Frameworks", ); LOCALIZATION_PREFERS_STRING_CATALOGS = YES; - MACOSX_DEPLOYMENT_TARGET = 11.0; + MACOSX_DEPLOYMENT_TARGET = 13.3; MARKETING_VERSION = 1.0; PRODUCT_BUNDLE_IDENTIFIER = "ai.onnxruntime.tests.macos-package-test"; PRODUCT_NAME = "$(TARGET_NAME)"; @@ -684,7 +684,7 @@ GCC_C_LANGUAGE_STANDARD = gnu17; GENERATE_INFOPLIST_FILE = YES; LOCALIZATION_PREFERS_STRING_CATALOGS = YES; - MACOSX_DEPLOYMENT_TARGET = 11.0; + MACOSX_DEPLOYMENT_TARGET = 13.3; MARKETING_VERSION = 1.0; PRODUCT_BUNDLE_IDENTIFIER = "ai.onnxruntime.tests.macos-package-testUITests"; PRODUCT_NAME = "$(TARGET_NAME)"; @@ -706,7 +706,7 @@ GCC_C_LANGUAGE_STANDARD = gnu17; GENERATE_INFOPLIST_FILE = YES; LOCALIZATION_PREFERS_STRING_CATALOGS = YES; - MACOSX_DEPLOYMENT_TARGET = 11.0; + MACOSX_DEPLOYMENT_TARGET = 13.3; MARKETING_VERSION = 1.0; PRODUCT_BUNDLE_IDENTIFIER = "ai.onnxruntime.tests.macos-package-testUITests"; PRODUCT_NAME = "$(TARGET_NAME)"; diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index a5015c18cee63..177647ab5be6b 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include @@ -254,7 +255,7 @@ TEST_P(ModelTest, Run) { #endif #ifdef USE_ACL else if (provider_name == "acl") { - ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_ACL(ortso, 0)); + ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_ACL(ortso, false)); } #endif #ifdef USE_ARMNN diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc index 9f7be524daa34..67fb35d26e6dc 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc @@ -196,7 +196,7 @@ TEST(InternalTestingEP, TestMixOfStaticAndCompiledKernels) { // Error message should come from the Conv implementation with the statically registered kernel ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session.Run(feeds, output_names, &fetches), - "Non-zero status code returned while running Conv node. Name:'Conv' " + "Non-zero status code returned while running Conv node. Name:'_token_2' " "Status Message: TODO: add NHWC implementation here."); } @@ -242,7 +242,7 @@ TEST(InternalTestingEP, TestNhwcConversionOfStaticKernels) { std::vector fetches; ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session.Run(feeds, output_names, &fetches), - "Non-zero status code returned while running Conv node. Name:'Conv' " + "Non-zero status code returned while running Conv node. Name:'_token_2' " "Status Message: TODO: add NHWC implementation here."); }; diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 9d19c36dc94b2..c4367aeb52edc 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -948,6 +948,63 @@ TEST_F(QnnHTPBackendTests, Float32ModelWithFP16PrecisionTest) { 0.008f); } +// Test that QNN EP only handles nodes with static shapes and rejects nodes with dynamic shape I/O. +TEST_F(QnnHTPBackendTests, EPRejectsDynamicShapesF32) { + // Local function that builds a model in which the last two nodes use dynamic shapes. + auto model_build_fn = [](ModelTestBuilder& builder) { + NodeArg* input1 = builder.MakeInput(std::vector{1, 2, 8, 8}, + GetFloatDataInRange(0.0f, 1.0f, 128)); + NodeArg* input2 = builder.MakeInput(std::vector{3}, std::vector{1, 2, 49}); + + // Add a Conv with known shapes. QNN EP should support it. + NodeArg* weight = builder.MakeInitializer(std::vector{2, 2, 2, 2}, + GetFloatDataInRange(-0.3f, 0.3f, 16)); + NodeArg* bias = builder.MakeInitializer(std::vector{2}, {0.0f, 1.0f}); + + auto* conv_output = builder.MakeIntermediate(); + builder.AddNode("Conv", {input1, weight, bias}, {conv_output}); + + // Add a Reshape to a dynamic shape. QNN EP should reject this node. + auto* reshape_output = builder.MakeIntermediate(); + builder.AddNode("Reshape", {conv_output, input2}, {reshape_output}); + + // Add a Softmax. QNN EP should reject this node because its input has a dynamic shape. + NodeArg* output = builder.MakeOutput(); + builder.AddNode("Softmax", {reshape_output}, {output}); + }; + + // Local function that checks that the nodes with dynamic shape I/O were assigned to CPU EP. + std::function ep_graph_checker = [](const Graph& graph) { + for (const Node& node : graph.Nodes()) { + const std::string& ep_name = node.GetExecutionProviderType(); + const std::string& op_type = node.OpType(); + if (op_type == "Reshape" || op_type == "Softmax") { + EXPECT_EQ(ep_name, kCpuExecutionProvider); + } else { + EXPECT_EQ(ep_name, kQnnExecutionProvider); + } + } + }; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + provider_options["enable_htp_fp16_precision"] = "1"; // QNN EP will use fp16 precision. + // CPU EP will use fp32, so we can relax accuracy requirements. + + RunQnnModelTest(model_build_fn, + provider_options, + /*opset*/ 19, + ExpectedEPNodeAssignment::Some, + /*abs_err*/ 1e-4f, + logging::Severity::kERROR, + /*verify_output*/ true, + &ep_graph_checker); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index afaa5a341d5e9..8a4f7f2a1f6b5 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -98,10 +98,12 @@ void TryEnableQNNSaver(ProviderOptions& qnn_options) { void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions provider_options, int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, - float fp32_abs_err, logging::Severity log_severity, bool verify_outputs) { + float fp32_abs_err, logging::Severity log_severity, bool verify_outputs, + std::function* ep_graph_checker) { EPVerificationParams verification_params; verification_params.ep_node_assignment = expected_ep_assignment; verification_params.fp32_abs_err = fp32_abs_err; + verification_params.graph_verifier = ep_graph_checker; // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index 3a6753e9b6131..bb77c92668853 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -1033,12 +1033,16 @@ inline GetTestQDQModelFn BuildQDQOpTestCase( * \param expected_ep_assignment How many nodes are expected to be assigned to QNN (All, Some, or None). * \param fp32_abs_err The acceptable error between CPU EP and QNN EP. * \param log_severity The logger's minimum severity level. + * \param verify_outputs True to verify that the outputs match (within tolerance). + * \param ep_graph_checker Function called on the Graph generated for the EP's session. Used to check node + * EP assignment. */ void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions provider_options, int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, float fp32_abs_err = 1e-5f, logging::Severity log_severity = logging::Severity::kERROR, - bool verify_outputs = true); + bool verify_outputs = true, + std::function* ep_graph_checker = nullptr); enum class BackendSupport { SUPPORT_UNKNOWN, diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index feabd648f8385..24151932a6681 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -1694,7 +1694,7 @@ def test_register_custom_e_ps_library(self): available_eps = C.get_available_providers() # skip amd gpu build - if "RocmExecutionProvider" in available_eps: + if "ROCMExecutionProvider" in available_eps: return if sys.platform.startswith("win"): diff --git a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py index c04929a3b603e..46ab905977f48 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py +++ b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py @@ -223,6 +223,7 @@ def create_group_query_attention_graph_prompt( rotary=False, rotary_interleaved=False, packed=False, + interactive=False, softcap=0.0, use_smooth_softmax=False, ): @@ -1224,7 +1225,7 @@ def parity_check_gqa_prompt( config, causal=True, local=False, - past_format=Formats.BSNH, + past_format=Formats.BNSH, rotary=False, rotary_interleaved=False, packed=False, @@ -1422,7 +1423,7 @@ def parity_check_gqa_prompt_no_buff( config, causal=True, local=False, - past_format=Formats.BSNH, + past_format=Formats.BNSH, rotary=False, rotary_interleaved=False, packed=False, @@ -1597,7 +1598,7 @@ def parity_check_gqa_past( config, causal=True, local=False, - past_format=Formats.BSNH, + past_format=Formats.BNSH, rotary=False, rotary_interleaved=False, packed=False, @@ -1667,7 +1668,6 @@ def parity_check_gqa_past( if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) - # cache_seqlens = torch.tensor([config.past_sequence_length], device="cuda").repeat(config.batch_size) cache_seqlens = torch.randint( 0, config.kv_sequence_length - config.sequence_length + 1, @@ -1696,7 +1696,6 @@ def parity_check_gqa_past( "b 1 (s h) d -> b s h d", s=config.sequence_length, ) - # q_ro = q k_ro = rotary_embedding(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) else: cos, sin = None, None @@ -1730,6 +1729,8 @@ def parity_check_gqa_past( k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) + cache_seqlens += config.sequence_length - 1 + # Flash function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) @@ -1783,15 +1784,14 @@ def parity_check_gqa_past( numpy.testing.assert_allclose( present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg ) - numpy.testing.assert_allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg) def parity_check_gqa_past_no_buff( config, - causal=False, + causal=True, local=False, - past_format=Formats.BSNH, + past_format=Formats.BNSH, rotary=False, rotary_interleaved=False, packed=False, @@ -1864,7 +1864,6 @@ def parity_check_gqa_past_no_buff( v_cache_ref = v_cache_ref.transpose(1, 2) k_cache_ref = torch.cat((k_cache_ref, new_k), 1) v_cache_ref = torch.cat((v_cache_ref, new_v), 1) - # cache_seqlens = torch.tensor([config.past_sequence_length], device="cuda").repeat(config.batch_size) cache_seqlens = torch.randint( 0, config.kv_sequence_length, @@ -1896,7 +1895,6 @@ def parity_check_gqa_past_no_buff( "b 1 (s h) d -> b s h d", s=config.sequence_length, ) - # q_ro = q k_ro = rotary_embedding(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) else: cos, sin = None, None @@ -1930,6 +1928,8 @@ def parity_check_gqa_past_no_buff( k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) + cache_seqlens += config.sequence_length - 1 + # Flash function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) @@ -1976,6 +1976,23 @@ def parity_check_gqa_past_no_buff( f" with {config}, causal={causal}, local={local}, past_format={past_format}," f" rotary={rotary}, rotary_interleaved={rotary_interleaved}, packed={packed}, softcap={softcap}" ) + for b in range(config.batch_size): + numpy.testing.assert_allclose( + present_k[b, :, : (cache_seqlens + 1)[b]], + k_cache_ref[b, :, : (cache_seqlens + 1)[b]].detach().cpu().numpy(), + rtol=rtol, + atol=atol, + equal_nan=True, + err_msg=err_msg, + ) + numpy.testing.assert_allclose( + present_v[b, :, : (cache_seqlens + 1)[b]], + v_cache_ref[b, :, : (cache_seqlens + 1)[b]].detach().cpu().numpy(), + rtol=rtol, + atol=atol, + equal_nan=True, + err_msg=err_msg, + ) numpy.testing.assert_allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg) @@ -2229,6 +2246,86 @@ def gqa_past_flash_attention_test_cases(): ) +def gqa_interactive_one_batch_flash_attention_test_cases(): + batches = [1] + seqs = ( + [(2, 128), (128, 129), (32, 128), (256, 2048)] + if pipeline_mode + else [ + (1, 128), + (32, 128), + (128, 2048), + (1235, 5000), + (40, 800), + (1, 256), + (2, 799), + (41, 2048), + # (1, 128 * 512), + # (16, 128 * 512), + # (128, 128), + ] + ) + num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + random.seed(69) + + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for local in [False, True]: + for rotary, rotary_interleaved in rotary_options_for_current_os(): + for packed in [False, True]: + config = Config(b, s, s2, -1, n, n2, h) + yield ( + str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}", + config, + local, + rotary, + rotary_interleaved, + packed, + ) + + +def gqa_interactive_one_batch_memory_efficient_attention_test_cases(): + batches = [1] + seqs = ( + [(2, 128), (128, 129), (32, 128), (256, 2048)] + if pipeline_mode + else [ + (1, 128), + (32, 128), + (128, 2048), + (1235, 5000), + (40, 800), + (1, 256), + (2, 799), + (41, 2048), + # (1, 128 * 512), + # (16, 128 * 512), + # (128, 128), + ] + ) + num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + random.seed(69) + + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for rotary, rotary_interleaved in rotary_options_for_current_os(): + for packed in [False, True]: + config = Config(b, s, s2, -1, n, n2, h) + yield ( + str(config) + f"{rotary}_{rotary_interleaved}_{packed}", + config, + rotary, + rotary_interleaved, + packed, + ) + + class TestGQA(unittest.TestCase): @parameterized.expand(gqa_no_past_memory_efficient_test_cases()) def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap): @@ -2350,6 +2447,60 @@ def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interle use_smooth_softmax=True, ) + @parameterized.expand(gqa_interactive_one_batch_flash_attention_test_cases()) + def test_gqa_interactive_one_batch_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): + if not has_flash_attention(): + return + print("------- FLASH ATTENTION (INTERACTIVE) -------") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + + parity_check_gqa_past( + config, + local=local, + past_format=Formats.BNSH, + rtol=5e-3, + atol=5e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + parity_check_gqa_past_no_buff( + config, + local=local, + past_format=Formats.BNSH, + rtol=5e-3, + atol=5e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + + @parameterized.expand(gqa_interactive_one_batch_memory_efficient_attention_test_cases()) + def test_gqa_interactive_one_batch_memory_efficient_attention(self, _, config, rotary, rotary_interleaved, packed): + if not has_memory_efficient(): + return + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + print("-------- MEMORY EFFICIENT (INTERACTIVE) --------") + + parity_check_gqa_past( + config, + past_format=Formats.BNSH, + rtol=5e-3, + atol=5e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + parity_check_gqa_past_no_buff( + config, + past_format=Formats.BNSH, + rtol=5e-3, + atol=5e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index cc9d7ff51a5c6..dc21d4e4a5890 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -121,8 +121,12 @@ def rotate_tensor( else: x_rot = torch.cat((real, imag), dim=-1) else: - cos_x = cos[:, 0:seq_len, :, :] - sin_x = sin[:, 0:seq_len, :, :] + batch_size = x.shape[0] + cos_x = torch.zeros((batch_size, seq_len, 1, cos.shape[3]), device=x.device) + sin_x = torch.zeros((batch_size, seq_len, 1, sin.shape[3]), device=x.device) + for b in range(x.shape[0]): + cos_x[b] = cos[0, pos[b] : pos[b] + seq_len, :, :] + sin_x[b] = sin[0, pos[b] : pos[b] + seq_len, :, :] real = cos_x * x1 - sin_x * x2 imag = sin_x * x1 + cos_x * x2 if interleaved: @@ -716,7 +720,6 @@ def gqa_prompt_func( ort_inputs["sin_cache"] = sin.detach().cpu().numpy() io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) - # TODO: do we need io binding for cpu input? io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_input( "past_key", "cpu", 0, numpy.float32, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() @@ -788,6 +791,7 @@ def gqa_past_func( softcap=0.0, use_smooth_softmax=False, ): + assert seqlens_k is not None onnx_model_str = create_group_query_attention_graph_past( config, past_kv_format, @@ -819,12 +823,12 @@ def gqa_past_func( sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) io_binding = ort_session.io_binding() - if new_k is not None: + if new_k is not None and new_v is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() io_binding.bind_cpu_input("key", ort_inputs["key"]) io_binding.bind_cpu_input("value", ort_inputs["value"]) - if cos is not None: + if cos is not None and sin is not None: ort_inputs["cos_cache"] = cos.detach().cpu().numpy() ort_inputs["sin_cache"] = sin.detach().cpu().numpy() io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) @@ -867,12 +871,12 @@ def gqa_past_func( sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) io_binding = ort_session.io_binding() - if new_k is not None: + if new_k is not None and new_v is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() io_binding.bind_cpu_input("key", ort_inputs["key"]) io_binding.bind_cpu_input("value", ort_inputs["value"]) - if cos is not None: + if cos is not None and sin is not None: ort_inputs["cos_cache"] = cos.detach().cpu().numpy() ort_inputs["sin_cache"] = sin.detach().cpu().numpy() io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) @@ -1518,7 +1522,6 @@ def parity_check_gqa_past( if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) - # cache_seqlens = torch.tensor([config.past_sequence_length], device="cpu").repeat(config.batch_size) cache_seqlens = torch.randint( 0, config.kv_sequence_length - config.sequence_length + 1, @@ -1576,6 +1579,8 @@ def parity_check_gqa_past( k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) + cache_seqlens += config.sequence_length - 1 + # ORT function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) @@ -1739,7 +1744,6 @@ def parity_check_gqa_past_no_buff( v_cache_ref = v_cache_ref.transpose(1, 2) k_cache_ref = torch.cat((k_cache_ref, new_k), 1) v_cache_ref = torch.cat((v_cache_ref, new_v), 1) - # cache_seqlens = torch.tensor([config.past_sequence_length], device="cpu").repeat(config.batch_size) cache_seqlens = torch.randint( 0, config.kv_sequence_length, @@ -1800,6 +1804,8 @@ def parity_check_gqa_past_no_buff( k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) + cache_seqlens += config.sequence_length - 1 + # Flash function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) @@ -2000,6 +2006,61 @@ def test_gqa_past(self): ) self.assertTrue(all_close) + def test_gqa_interactive_one_batch(self): + print("-------- TEST GQA INTERACTIVE ---------") + batches = [1] + seqs = ( + [(2, 128), (128, 129), (32, 128), (256, 2048)] + if pipeline_mode + else [ + (1, 128), + (1, 339), + (1, 1024), + (1, 5000), + (1, 800), + (1, 256), + (1, 799), + (1, 2048), + # (1, 128 * 512), + # (16, 128 * 512), + # (128, 128), + ] + ) + num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 64, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + random.seed(69) + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for local in [False, True]: + for rotary, rotary_interleaved in [(False, False), (True, False), (True, True)]: + for packed in [False, True]: + config = Config(b, s, s2, -1, n, n2, h) + past_kv_format = Formats.BNSH + all_close = parity_check_gqa_past( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + self.assertTrue(all_close) + all_close = parity_check_gqa_past_no_buff( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + self.assertTrue(all_close) + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_sparse_attention.py b/onnxruntime/test/python/transformers/test_sparse_attention.py index 6a08d2101b100..5dbb9a277e45a 100644 --- a/onnxruntime/test/python/transformers/test_sparse_attention.py +++ b/onnxruntime/test/python/transformers/test_sparse_attention.py @@ -890,7 +890,7 @@ def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rot dtype=dtype, is_packed_qkv=packed_qkv, do_rotary=do_rotary, - rotary_interleaved=sequence_length <= 128, + rotary_interleaved=do_rotary and sequence_length <= 128, max_cache_sequence_length=None if sequence_length >= 128 else 128, ) yield config @@ -929,7 +929,7 @@ def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rot dtype=dtype, is_packed_qkv=packed_qkv, do_rotary=do_rotary, - rotary_interleaved=sequence_length <= 128, + rotary_interleaved=do_rotary and sequence_length <= 128, max_cache_sequence_length=None if sequence_length >= 128 else 128, # test smaller kv cache buffer. ) yield config @@ -940,7 +940,6 @@ def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rot class TestSparseAttention(unittest.TestCase): - @unittest.skipUnless(has_cuda_support(), "cuda not available") def test_sparse_attention_cuda(self): major, minor = torch.cuda.get_device_capability() @@ -1056,7 +1055,7 @@ def run_relevance_past(self, sm: int, device, do_rotary: bool): vert_stride=4, softmax_scale=None, do_rotary=do_rotary, - rotary_interleaved=(past_seq_len % 2 == 1), + rotary_interleaved=do_rotary and (past_seq_len % 2 == 1), device=device, is_packed_qkv=packed_qkv, max_rotary_sequence_length=None if past_seq_len >= 128 else 128, # test smaller rotary buffer. diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-pad-nopadsconv.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-pad-nopadsconv.onnx new file mode 100644 index 0000000000000..145847cdc47fa Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/fuse-pad-nopadsconv.onnx differ diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 1feba20e32bbb..6451f8ec6dce8 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include @@ -207,11 +208,11 @@ std::unique_ptr DefaultRknpuExecutionProvider() { #endif } -std::unique_ptr DefaultAclExecutionProvider(bool enable_arena) { +std::unique_ptr DefaultAclExecutionProvider(bool enable_fast_math) { #ifdef USE_ACL - return ACLProviderFactoryCreator::Create(enable_arena)->CreateProvider(); + return ACLProviderFactoryCreator::Create(enable_fast_math)->CreateProvider(); #else - ORT_UNUSED_PARAMETER(enable_arena); + ORT_UNUSED_PARAMETER(enable_fast_math); return nullptr; #endif } diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h index 606dfc068d399..b3a619022f79b 100644 --- a/onnxruntime/test/util/include/default_providers.h +++ b/onnxruntime/test/util/include/default_providers.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #pragma once #include "core/common/optional.h" @@ -53,7 +54,7 @@ std::unique_ptr DefaultOpenVINOExecutionProvider(); std::unique_ptr DefaultNnapiExecutionProvider(); std::unique_ptr DefaultVSINPUExecutionProvider(); std::unique_ptr DefaultRknpuExecutionProvider(); -std::unique_ptr DefaultAclExecutionProvider(bool enable_arena = true); +std::unique_ptr DefaultAclExecutionProvider(bool enable_fast_math = false); std::unique_ptr DefaultArmNNExecutionProvider(bool enable_arena = true); std::unique_ptr DefaultRocmExecutionProvider(bool test_tunable_op = false); std::unique_ptr DefaultCoreMLExecutionProvider(bool use_mlprogram = false); diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 587d035541c45..8535f1e8c85a0 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 # Copyright (c) Microsoft Corporation. All rights reserved. +# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates # Licensed under the MIT License. import argparse @@ -460,6 +461,12 @@ def convert_arg_line_to_args(self, arg_line): action="store_true", help="Disable memory leak checker from Windows build. By default it is enabled in Windows Debug build. This option is Windows only.", ) + # Dependency search with vcpkg + parser.add_argument( + "--use_vcpkg", + action="store_true", + help="Use vcpkg to search dependencies. Requires CMAKE_TOOLCHAIN_FILE for vcpkg.cmake", + ) # WebAssembly build parser.add_argument("--build_wasm", action="store_true", help="Build for WebAssembly") @@ -645,9 +652,7 @@ def convert_arg_line_to_args(self, arg_line): parser.add_argument("--enable_transformers_tool_test", action="store_true", help="Enable transformers tool test") parser.add_argument( "--use_acl", - nargs="?", - const="ACL_1905", - choices=["ACL_1902", "ACL_1905", "ACL_1908", "ACL_2002", "ACL_2308"], + action="store_true", help="Build with ACL for ARM architectures.", ) parser.add_argument("--acl_home", help="Path to ACL home dir") @@ -999,6 +1004,7 @@ def generate_build_tree( # of them to get the best compatibility. "-DPython_EXECUTABLE=" + sys.executable, "-DPYTHON_EXECUTABLE=" + sys.executable, + "-Donnxruntime_USE_VCPKG=" + ("ON" if args.use_vcpkg else "OFF"), "-Donnxruntime_USE_MIMALLOC=" + ("ON" if args.use_mimalloc else "OFF"), "-Donnxruntime_ENABLE_PYTHON=" + ("ON" if args.enable_pybind else "OFF"), "-Donnxruntime_BUILD_CSHARP=" + ("ON" if args.build_csharp else "OFF"), @@ -1045,11 +1051,6 @@ def generate_build_tree( "-Donnxruntime_USE_TELEMETRY=" + ("ON" if args.use_telemetry else "OFF"), "-Donnxruntime_ENABLE_LTO=" + ("ON" if args.enable_lto else "OFF"), "-Donnxruntime_USE_ACL=" + ("ON" if args.use_acl else "OFF"), - "-Donnxruntime_USE_ACL_1902=" + ("ON" if args.use_acl == "ACL_1902" else "OFF"), - "-Donnxruntime_USE_ACL_1905=" + ("ON" if args.use_acl == "ACL_1905" else "OFF"), - "-Donnxruntime_USE_ACL_1908=" + ("ON" if args.use_acl == "ACL_1908" else "OFF"), - "-Donnxruntime_USE_ACL_2002=" + ("ON" if args.use_acl == "ACL_2002" else "OFF"), - "-Donnxruntime_USE_ACL_2308=" + ("ON" if args.use_acl == "ACL_2308" else "OFF"), "-Donnxruntime_USE_ARMNN=" + ("ON" if args.use_armnn else "OFF"), "-Donnxruntime_ARMNN_RELU_USE_CPU=" + ("OFF" if args.armnn_relu else "ON"), "-Donnxruntime_ARMNN_BN_USE_CPU=" + ("OFF" if args.armnn_bn else "ON"), diff --git a/tools/ci_build/github/apple/default_full_apple_framework_build_settings.json b/tools/ci_build/github/apple/default_full_apple_framework_build_settings.json index 84d7e355ed5b4..f2d034dbfc000 100644 --- a/tools/ci_build/github/apple/default_full_apple_framework_build_settings.json +++ b/tools/ci_build/github/apple/default_full_apple_framework_build_settings.json @@ -24,7 +24,7 @@ ], "macosx": [ "--macos=MacOSX", - "--apple_deploy_target=11.0" + "--apple_deploy_target=13.3" ], "iphoneos": [ "--ios", diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index cc4727a889d44..7bc1cd669bbff 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.25.0.240728 + default: 2.26.0.240828 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml index 41ff365a65d49..a0ae0441f2824 100644 --- a/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml @@ -70,9 +70,9 @@ stages: versionSpec: $(pythonVersion) - task: JavaToolInstaller@0 - displayName: Use jdk 11 + displayName: Use jdk 17 inputs: - versionSpec: '11' + versionSpec: '17' jdkArchitectureOption: 'x64' jdkSourceOption: 'PreInstalled' @@ -131,9 +131,9 @@ stages: versionSpec: $(pythonVersion) - task: JavaToolInstaller@0 - displayName: Use jdk 11 + displayName: Use jdk 17 inputs: - versionSpec: '11' + versionSpec: '17' jdkArchitectureOption: 'x64' jdkSourceOption: 'PreInstalled' @@ -195,9 +195,9 @@ stages: versionSpec: $(pythonVersion) - task: JavaToolInstaller@0 - displayName: Use jdk 11 + displayName: Use jdk 17 inputs: - versionSpec: '11' + versionSpec: '17' jdkArchitectureOption: 'x64' jdkSourceOption: 'PreInstalled' diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 24342ee977481..4bcbc12574b4d 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -62,7 +62,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.25.0.240728 + default: 2.26.0.240828 resources: repositories: diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml index ddfb4297afc4a..e172611d898bf 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml @@ -8,7 +8,7 @@ parameters: - name: TrtVersion displayName: TensorRT Version type: string - default: 10.2.cuda_12_5_cudnn_9 + default: 10.3.cuda_12_5_cudnn_9 values: - 8.6.cuda_11_8_cudnn_8 - 8.6.cuda_12_3_cudnn_9 diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index 686f91526023c..9532ed9bea2f3 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.25.0.240728 + default: 2.26.0.240828 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml index 7b77281b0efe2..50f3862761320 100644 --- a/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml @@ -1,4 +1,5 @@ ##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### +### please do rerun set-trigger-rules.py ### trigger: branches: include: diff --git a/tools/ci_build/github/azure-pipelines/mac-coreml-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-coreml-ci-pipeline.yml index c16adc6221ed0..a1059bb3c314e 100644 --- a/tools/ci_build/github/azure-pipelines/mac-coreml-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/mac-coreml-ci-pipeline.yml @@ -33,9 +33,9 @@ jobs: workspace: clean: all pool: - vmImage: 'macOS-latest' + vmImage: 'macOS-13' variables: - MACOSX_DEPLOYMENT_TARGET: '11.0' + MACOSX_DEPLOYMENT_TARGET: '13.3' TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] CCACHE_DIR: '$(Pipeline.Workspace)/ccache' timeoutInMinutes: 120 @@ -44,8 +44,6 @@ jobs: displayName: Install coreutils and ninja - template: templates/use-xcode-version.yml - parameters: - xcodeVersion: 14.2 - template: templates/mac-build-step-with-cache.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml index 48d48156fe913..c61beb63b8b40 100644 --- a/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml @@ -31,7 +31,7 @@ pr: jobs: - job: iOS_CI_on_Mac pool: - vmImage: 'macOS-latest' + vmImage: 'macOS-13' variables: PROTO_CACHE_DIR: $(Pipeline.Workspace)/proto_ccache ORT_CACHE_DIR: $(Pipeline.Workspace)/ort_ccache @@ -39,8 +39,7 @@ jobs: timeoutInMinutes: 150 steps: - template: templates/use-xcode-version.yml - parameters: - xcodeVersion: 14.2 + - template: templates/mac-build-step-with-cache.yml parameters: WithCache: true @@ -53,7 +52,7 @@ jobs: python3 $(Build.SourcesDirectory)/tools/ci_build/build.py \ --skip_submodule_sync \ --build_dir $(Build.BinariesDirectory)/iOS \ - --build_shared \ + --build_shared_lib \ --use_coreml \ --use_xnnpack \ --ios \ diff --git a/tools/ci_build/github/azure-pipelines/mac-ios-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-ios-packaging-pipeline.yml index abd1004a830e0..c3176ec54fd79 100644 --- a/tools/ci_build/github/azure-pipelines/mac-ios-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/mac-ios-packaging-pipeline.yml @@ -54,7 +54,7 @@ stages: displayName: "Set common variables" pool: - vmImage: "macOS-latest" + vmImage: "macOS-13" timeoutInMinutes: 5 diff --git a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml index 53923e0b4432a..4518a168879a2 100644 --- a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml +++ b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml @@ -11,7 +11,7 @@ stages: clean: all timeoutInMinutes: 120 pool: - vmImage: 'macOS-latest' + vmImage: 'macOS-13' variables: - name: OnnxRuntimeBuildDirectory diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/test_macos.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/test_macos.yml index c977e17aada9d..07d21333270a8 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/test_macos.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/test_macos.yml @@ -11,7 +11,7 @@ stages: workspace: clean: all pool: - vmImage: 'macOS-latest' + vmImage: 'macOS-13' variables: - name: OnnxRuntimeBuildDirectory diff --git a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml index e13f1c20b37ce..3853bdbd1eb88 100644 --- a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml @@ -425,7 +425,7 @@ stages: - job: IosDynamicFramework timeoutInMinutes: 120 pool: - vmImage: "macOS-latest" + vmImage: "macOS-13" steps: - task: UsePythonVersion@0 @@ -435,8 +435,6 @@ stages: architecture: "x64" - template: templates/use-xcode-version.yml - parameters: - xcodeVersion: 14.2 - script: | pip install -r tools/ci_build/github/apple/ios_packaging/requirements.txt @@ -463,7 +461,7 @@ stages: - job: IosMinimalTrainingBuild timeoutInMinutes: 120 pool: - vmImage: "macOS-latest" + vmImage: "macOS-13" steps: - task: UsePythonVersion@0 @@ -473,8 +471,6 @@ stages: architecture: "x64" - template: templates/use-xcode-version.yml - parameters: - xcodeVersion: 14.2 - script: | pip install -r tools/ci_build/github/apple/ios_packaging/requirements.txt diff --git a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml index dced61670e34a..de2677ebc6594 100644 --- a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml @@ -29,7 +29,7 @@ stages: parameters: job_name: Test_MAC_Wheels machine_pool: - vmImage: 'macOS-latest' + vmImage: 'macOS-13' itemPattern: '*/*mac*x86_64.whl' - template: templates/py-package-smoking-test.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index 2b970688c5fac..8107c1a890973 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -64,7 +64,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.25.0.240728 + default: 2.26.0.240828 trigger: none diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index 25d50f4255cb1..98b5e47c0e2d7 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -2,7 +2,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.25.0.240728 + default: 2.26.0.240828 - name: build_config displayName: Build Configuration diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml index 6971488478ed0..7fa548240f865 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml @@ -19,9 +19,7 @@ jobs: workspace: clean: all pool: - # We need macOS-12 to run the Android emulator for now. - # https://github.com/actions/runner-images/issues/7671 - vmImage: 'macOS-12' + vmImage: 'macOS-13' variables: - name: runCodesignValidationInjection value: false @@ -39,9 +37,9 @@ jobs: targetPath: '$(Build.BinariesDirectory)/final-android-aar' - task: JavaToolInstaller@0 - displayName: Use jdk 11 + displayName: Use jdk 17 inputs: - versionSpec: '11' + versionSpec: '17' jdkArchitectureOption: 'x64' jdkSourceOption: 'PreInstalled' diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 8b4fe66465bb1..3e90a401d4deb 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -83,14 +83,12 @@ stages: workspace: clean: all pool: - vmImage: 'macOS-latest' + vmImage: 'macOS-13' timeoutInMinutes: 300 steps: - template: set-version-number-variables-step.yml - template: use-xcode-version.yml - parameters: - xcodeVersion: 14.2 - template: download-deps.yml @@ -782,7 +780,7 @@ stages: - template: ../nuget/templates/test_macos.yml parameters: - AgentPool : macOS-latest + AgentPool : macOS-13 ArtifactSuffix: 'CPU' - template: ../nodejs/templates/test_win.yml @@ -818,4 +816,4 @@ stages: OS: MacOS BuildId: ${{ parameters.BuildId }} SpecificArtifact: ${{ parameters.SpecificArtifact }} - PoolName: 'macOS-latest' + PoolName: 'macOS-13' diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index 3c74c70ed102d..cbba1cb8ba8bd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.181 + version: 1.0.184 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.181 + version: 1.0.184 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml index e727ec4f7ef5c..4aedd2f8564c1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.25.0.240728' + default: '2.26.0.240828' steps: - script: | diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml index 912cac6fbb99e..eff49302eb33d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.25.0.240728' + default: '2.26.0.240828' steps: - powershell: | diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index a56eb37faef84..2ab432e94fcbd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -31,10 +31,6 @@ parameters: type: boolean default: false -- name: BuildTraining - type: boolean - default: true - - name: WithCache type: boolean default: false @@ -116,19 +112,6 @@ jobs: DisplayName: 'Build and test (browser) (simd + threads)' WithCache: ${{ parameters.WithCache }} - - ${{ if eq(parameters.BuildTraining, true) }}: - - template: build-linux-wasm-step.yml - parameters: - Today: $(Today) - ${{ if eq(parameters.BuildStaticLib, true)}}: - AdditionalKey: wasm_training | ${{ parameters.BuildConfig }} | static - ${{ else }}: - AdditionalKey: wasm_training | ${{ parameters.BuildConfig }} - CacheDir: $(ORT_CACHE_DIR)/wasm_training - Arguments: '$(CommonBuildArgs) --build_dir $(Build.BinariesDirectory)/wasm_training --enable_training_apis --target onnxruntime_webassembly --skip_tests' - DisplayName: 'Build (training + simd + threads)' - WithCache: ${{ parameters.WithCache }} - - ${{ if eq(parameters.BuildJsep, true) }}: - template: build-linux-wasm-step.yml parameters: @@ -150,10 +133,6 @@ jobs: cp $(Build.BinariesDirectory)/wasm_inferencing_jsep/${{ parameters.BuildConfig }}/ort-wasm-simd-threaded.jsep.wasm $(Build.ArtifactStagingDirectory) cp $(Build.BinariesDirectory)/wasm_inferencing_jsep/${{ parameters.BuildConfig }}/ort-wasm-simd-threaded.jsep.mjs $(Build.ArtifactStagingDirectory) fi - if [ -d $(Build.BinariesDirectory)/wasm_training ]; then - cp $(Build.BinariesDirectory)/wasm_training/${{ parameters.BuildConfig }}/ort-training-wasm-simd-threaded.wasm $(Build.ArtifactStagingDirectory) - cp $(Build.BinariesDirectory)/wasm_training/${{ parameters.BuildConfig }}/ort-training-wasm-simd-threaded.mjs $(Build.ArtifactStagingDirectory) - fi displayName: 'Create Artifacts' - ${{ if eq(parameters.SkipPublish, false) }}: - task: PublishPipelineArtifact@0 diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml index 945fbb7c4a094..1b1080034948e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml @@ -69,9 +69,9 @@ stages: - job: MacOS_C_API_Package_Publish pool: ${{ if eq(parameters.DoESRP, true)}}: - vmImage: 'macOS-12' + vmImage: 'macOS-13' ${{ else }}: - vmImage: 'macOS-latest' + vmImage: 'macOS-13' steps: - checkout: none - template: flex-downloadPipelineArtifact.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml index 01ec3b5a2f8ca..3b661d9eb2dc6 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml @@ -31,13 +31,13 @@ jobs: workspace: clean: all variables: - MACOSX_DEPLOYMENT_TARGET: '11.0' + MACOSX_DEPLOYMENT_TARGET: '13.3' ALLOW_RELEASED_ONNX_OPSET_ONLY: ${{ parameters.AllowReleasedOpsetOnly }} TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] PROTO_CACHE_DIR: $(Pipeline.Workspace)/ccache_proto ORT_CACHE_DIR: $(Pipeline.Workspace)/ccache_ort pool: - vmImage: 'macOS-latest' + vmImage: 'macOS-13' timeoutInMinutes: 300 steps: - checkout: self @@ -61,8 +61,6 @@ jobs: - template: set-version-number-variables-step.yml - template: use-xcode-version.yml - parameters: - xcodeVersion: 14.2 - template: download-deps.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml index 86ae1a80bd4f0..451e5ad2d44c7 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml @@ -68,7 +68,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.25.0.240728 + default: 2.26.0.240828 stages: - ${{ if eq(parameters.enable_windows_cpu, true) }}: @@ -398,9 +398,9 @@ stages: workspace: clean: all pool: - vmImage: 'macOS-latest' + vmImage: 'macOS-13' variables: - MACOSX_DEPLOYMENT_TARGET: '11.0' + MACOSX_DEPLOYMENT_TARGET: '13.3' strategy: matrix: Python38: @@ -424,8 +424,6 @@ stages: versionSpec: $(PythonVersion) - template: use-xcode-version.yml - parameters: - xcodeVersion: 14.2 - template: download-deps.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index c3a2b7be7ebd2..837aa2760bd2c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.25.0.240728 + default: 2.26.0.240828 - name: PYTHON_VERSION type: string @@ -129,7 +129,7 @@ jobs: - task: PublishBuildArtifacts@1 displayName: 'Publish Artifact: ONNXRuntime python wheel' inputs: - ArtifactName: onnxruntime_qnn + ArtifactName: onnxruntime_qnn_arm64 - script: | 7z x *.whl diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index 775244943484c..419bc5a2024d3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.24.0.240626 + default: 2.26.0.240828 - name: ENV_SETUP_SCRIPT type: string @@ -135,7 +135,7 @@ jobs: - task: PublishBuildArtifacts@1 displayName: 'Publish Artifact: ONNXRuntime python wheel' inputs: - ArtifactName: onnxruntime_qnn + ArtifactName: onnxruntime_qnn_arm64ec - script: | 7z x *.whl diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml index c8431d4fa3497..b6bcad80a556e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.25.0.240728 + default: 2.26.0.240828 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index c7fd26712329c..13ecd356d53e3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -1,5 +1,5 @@ parameters: - QnnSdk: '2.25.0.240728' + QnnSdk: '2.26.0.240828' build_config: 'RelWithDebInfo' IsReleaseBuild: false DoEsrp: false diff --git a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml index 46dc867113a2e..5fea265d59392 100644 --- a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml @@ -41,7 +41,7 @@ stages: - job: Build_Ios_Pod_For_React_Native pool: - vmImage: 'macOS-12' + vmImage: 'macOS-13' timeoutInMinutes: 90 @@ -52,7 +52,6 @@ stages: steps: - template: use-xcode-version.yml - - task: UsePythonVersion@0 displayName: Use python 3.9 inputs: @@ -99,9 +98,7 @@ stages: jobs: - job: ReactNative_CI pool: - # We need macOS-12 to run the Android emulator for now. - # https://github.com/actions/runner-images/issues/7671 - vmImage: 'macOS-12' + vmImage: 'macOS-13' variables: runCodesignValidationInjection: false timeoutInMinutes: 90 @@ -109,7 +106,7 @@ stages: - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 displayName: Clean Agent Directories condition: always() - + - template: use-xcode-version.yml - task: UsePythonVersion@0 displayName: Use python 3.9 inputs: @@ -118,9 +115,9 @@ stages: architecture: "x64" - task: JavaToolInstaller@0 - displayName: Use jdk 11 + displayName: Use jdk 17 inputs: - versionSpec: '11' + versionSpec: '17' jdkArchitectureOption: 'x64' jdkSourceOption: 'PreInstalled' @@ -304,7 +301,7 @@ stages: scheme: 'OnnxruntimeModuleTest' packageApp: false destinationPlatformOption: 'iOS' - destinationSimulators: 'iPhone 13,OS=latest' + destinationSimulators: 'iPhone 14,OS=16.4' workingDirectory: '$(Build.SourcesDirectory)/js/react_native/ios' xcprettyArgs: '--output build/reports/test-results.xml' publishJUnitResults: true @@ -319,6 +316,11 @@ stages: condition: succeededOrFailed() displayName: Publish React Native iOS Instrumented Test Results + - script: | + xcrun simctl list devices + displayName: List iOS Simulators + continueOnError: true + - script: | JEST_JUNIT_OUTPUT_FILE=$(Build.SourcesDirectory)/js/react_native/e2e/ios-test-results.xml \ detox test --record-logs all \ diff --git a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml index d43f277739d99..e27de27036130 100644 --- a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml @@ -15,10 +15,10 @@ stages: displayName: "Build iOS package for variant: ${{ parameters.packageVariant}}" pool: - vmImage: "macOS-latest" + vmImage: "macOS-13" variables: - xcodeVersion: "14.2" + xcodeVersion: "14.3.1" ortPodVersion: $[stageDependencies.IosPackaging_SetCommonVariables.j.outputs['SetCommonVariables.ORT_POD_VERSION']] ${{ if eq(parameters.packageVariant, 'Full') }}: @@ -62,8 +62,6 @@ stages: architecture: "x64" - template: ../use-xcode-version.yml - parameters: - xcodeVersion: ${{ variables.xcodeVersion }} - template: ../install-appcenter.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/use-xcode-version.yml b/tools/ci_build/github/azure-pipelines/templates/use-xcode-version.yml index ec4398fe31fc5..2cf698aefa8bd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/use-xcode-version.yml +++ b/tools/ci_build/github/azure-pipelines/templates/use-xcode-version.yml @@ -3,7 +3,7 @@ parameters: - name: xcodeVersion type: string - default: "14.2" + default: "14.3.1" steps: - bash: | diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index c1fde93d8e640..0e8a7eb94379b 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -214,11 +214,7 @@ jobs: workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'E2E package consuming test' condition: and(succeeded(), eq('${{ parameters.BuildConfig }}', 'Release')) - - script: | - npm run test:training:e2e - workingDirectory: '$(Build.SourcesDirectory)\js\web' - displayName: 'E2E training package test' - condition: and(succeeded(), eq('${{ parameters.BuildConfig }}', 'Release')) + - task: CopyFiles@2 inputs: sourceFolder: $(Build.SourcesDirectory)\js\common diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 948b18802beb1..0829b33684372 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.25.0.240728 + default: 2.26.0.240828 jobs: - job: 'build' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index 9e08603992260..100175c1c15e3 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.25.0.240728 + default: 2.26.0.240828 jobs: - job: 'build' diff --git a/tools/ci_build/set-trigger-rules.py b/tools/ci_build/set-trigger-rules.py index 583e5b05ed6d8..fb6aa44cdf31a 100644 --- a/tools/ci_build/set-trigger-rules.py +++ b/tools/ci_build/set-trigger-rules.py @@ -24,6 +24,7 @@ "linux-migraphx-ci-pipeline.yml", "linux-openvino-ci-pipeline.yml", "linux-qnn-ci-pipeline.yml", + "linux-rocm-ci-pipeline.yml", "mac-ci-pipeline.yml", "mac-coreml-ci-pipeline.yml", "mac-ios-ci-pipeline.yml",