From e7107f41de9cbd4325150b6bafcbc7485059afbb Mon Sep 17 00:00:00 2001 From: Sophie Schoenmeyer <107952697+sophies927@users.noreply.github.com> Date: Tue, 10 Sep 2024 10:44:08 -0700 Subject: [PATCH 01/39] Decrease API docs artifact retention days (#22003) ### Description When API docs workflows fail, we typically don't catch the issue until the most recently generated artifact expires. The current artifact retention is 60 days, so by decreasing to 30 days, we can ensure that we're resolving the workflow failures more quickly. ### Motivation and Context --- .github/workflows/publish-c-apidocs.yml | 2 +- .github/workflows/publish-csharp-apidocs.yml | 2 +- .github/workflows/publish-java-apidocs.yml | 2 +- .github/workflows/publish-js-apidocs.yml | 2 +- .github/workflows/publish-objectivec-apidocs.yml | 2 +- .github/workflows/publish-python-apidocs.yml | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/publish-c-apidocs.yml b/.github/workflows/publish-c-apidocs.yml index b097cdbd9a55c..6c4dc43847d0b 100644 --- a/.github/workflows/publish-c-apidocs.yml +++ b/.github/workflows/publish-c-apidocs.yml @@ -49,4 +49,4 @@ jobs: with: name: onnxruntime-c-apidocs path: _site - retention-days: 60 + retention-days: 30 diff --git a/.github/workflows/publish-csharp-apidocs.yml b/.github/workflows/publish-csharp-apidocs.yml index 02aa053365790..862a7a70e33a2 100644 --- a/.github/workflows/publish-csharp-apidocs.yml +++ b/.github/workflows/publish-csharp-apidocs.yml @@ -59,4 +59,4 @@ jobs: with: name: onnxruntime-csharp-apidocs path: _site - retention-days: 60 + retention-days: 30 diff --git a/.github/workflows/publish-java-apidocs.yml b/.github/workflows/publish-java-apidocs.yml index 3e553049a186e..9e42dca708a17 100644 --- a/.github/workflows/publish-java-apidocs.yml +++ b/.github/workflows/publish-java-apidocs.yml @@ -47,4 +47,4 @@ jobs: with: name: onnxruntime-java-apidocs path: _site - retention-days: 60 + retention-days: 30 diff --git a/.github/workflows/publish-js-apidocs.yml b/.github/workflows/publish-js-apidocs.yml index db021106a6554..cec4a52d39c93 100644 --- a/.github/workflows/publish-js-apidocs.yml +++ b/.github/workflows/publish-js-apidocs.yml @@ -47,4 +47,4 @@ jobs: with: name: onnxruntime-node-apidocs path: _site - retention-days: 60 + retention-days: 30 diff --git a/.github/workflows/publish-objectivec-apidocs.yml b/.github/workflows/publish-objectivec-apidocs.yml index ebacd38f1f882..a8b81c8d5cf84 100644 --- a/.github/workflows/publish-objectivec-apidocs.yml +++ b/.github/workflows/publish-objectivec-apidocs.yml @@ -48,4 +48,4 @@ jobs: with: name: onnxruntime-objectivec-apidocs path: ./_site - retention-days: 60 + retention-days: 30 diff --git a/.github/workflows/publish-python-apidocs.yml b/.github/workflows/publish-python-apidocs.yml index e98d22450c5b0..8b2f72d80bacf 100644 --- a/.github/workflows/publish-python-apidocs.yml +++ b/.github/workflows/publish-python-apidocs.yml @@ -53,4 +53,4 @@ jobs: with: name: onnxruntime-python-apidocs path: _site - retention-days: 60 + retention-days: 30 From 31ae11788aad76b85333e9880446f8ac5e0a6fd4 Mon Sep 17 00:00:00 2001 From: George Wu Date: Tue, 10 Sep 2024 14:03:06 -0700 Subject: [PATCH 02/39] [QNN EP] Update QNN SDK to 2.26 (#22037) * update default QNN SDK version to 2.26 * enable layernorm implicit bias workaround for QNN 2.26 * update artifact names for py win arm64 and arm64ec to re-enable ort-qnn-nightly arm64 python packages --- .../providers/qnn/builder/opbuilder/layer_norm_op_builder.cc | 4 ++-- .../android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml | 2 +- .../azure-pipelines/c-api-noopenmp-packaging-pipelines.yml | 2 +- .../ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml | 2 +- .../ci_build/github/azure-pipelines/py-packaging-pipeline.yml | 2 +- .../azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml | 2 +- .../azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml | 2 +- .../azure-pipelines/templates/jobs/download_win_qnn_sdk.yml | 2 +- .../github/azure-pipelines/templates/py-packaging-stage.yml | 2 +- .../github/azure-pipelines/templates/py-win-arm64-qnn.yml | 4 ++-- .../github/azure-pipelines/templates/py-win-arm64ec-qnn.yml | 4 ++-- .../github/azure-pipelines/templates/py-win-x64-qnn.yml | 2 +- .../ci_build/github/azure-pipelines/templates/qnn-ep-win.yml | 2 +- .../github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml | 2 +- tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml | 2 +- 15 files changed, 18 insertions(+), 18 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc index 35a6b7bf40637..5c4608dff9bb1 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc @@ -87,9 +87,9 @@ Status LayerNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[BIAS_IDX], logger, input_names)); } -#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR == 17 || QNN_API_VERSION_MINOR == 18) +#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR == 17 || QNN_API_VERSION_MINOR == 18 || QNN_API_VERSION_MINOR == 19) if (!has_bias_input && IsNpuBackend(qnn_model_wrapper.GetQnnBackendType())) { - // Bias is implicit. QNN SDK 2.24/2.25 (QNN API version 2.17/2.18) has a validation bug for implicit bias inputs, + // Bias is implicit. QNN SDK 2.24/2.25/2.26 (QNN API version 2.17/2.18/2.19) has a validation bug for implicit bias inputs, // so provide an explicit bias of all 0 (quantized int32). TensorInfo x_input_info = {}; ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[X_IDX], x_input_info)); diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index cc4727a889d44..7bc1cd669bbff 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.25.0.240728 + default: 2.26.0.240828 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 24342ee977481..4bcbc12574b4d 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -62,7 +62,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.25.0.240728 + default: 2.26.0.240828 resources: repositories: diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index 686f91526023c..9532ed9bea2f3 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.25.0.240728 + default: 2.26.0.240828 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index 2b970688c5fac..8107c1a890973 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -64,7 +64,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.25.0.240728 + default: 2.26.0.240828 trigger: none diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index 25d50f4255cb1..98b5e47c0e2d7 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -2,7 +2,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.25.0.240728 + default: 2.26.0.240828 - name: build_config displayName: Build Configuration diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml index e727ec4f7ef5c..4aedd2f8564c1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.25.0.240728' + default: '2.26.0.240828' steps: - script: | diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml index 912cac6fbb99e..eff49302eb33d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.25.0.240728' + default: '2.26.0.240828' steps: - powershell: | diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml index 86ae1a80bd4f0..2701852f4601d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml @@ -68,7 +68,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.25.0.240728 + default: 2.26.0.240828 stages: - ${{ if eq(parameters.enable_windows_cpu, true) }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index c3a2b7be7ebd2..837aa2760bd2c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.25.0.240728 + default: 2.26.0.240828 - name: PYTHON_VERSION type: string @@ -129,7 +129,7 @@ jobs: - task: PublishBuildArtifacts@1 displayName: 'Publish Artifact: ONNXRuntime python wheel' inputs: - ArtifactName: onnxruntime_qnn + ArtifactName: onnxruntime_qnn_arm64 - script: | 7z x *.whl diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index 775244943484c..419bc5a2024d3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.24.0.240626 + default: 2.26.0.240828 - name: ENV_SETUP_SCRIPT type: string @@ -135,7 +135,7 @@ jobs: - task: PublishBuildArtifacts@1 displayName: 'Publish Artifact: ONNXRuntime python wheel' inputs: - ArtifactName: onnxruntime_qnn + ArtifactName: onnxruntime_qnn_arm64ec - script: | 7z x *.whl diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml index c8431d4fa3497..b6bcad80a556e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.25.0.240728 + default: 2.26.0.240828 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index c7fd26712329c..13ecd356d53e3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -1,5 +1,5 @@ parameters: - QnnSdk: '2.25.0.240728' + QnnSdk: '2.26.0.240828' build_config: 'RelWithDebInfo' IsReleaseBuild: false DoEsrp: false diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 948b18802beb1..0829b33684372 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.25.0.240728 + default: 2.26.0.240828 jobs: - job: 'build' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index 9e08603992260..100175c1c15e3 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.25.0.240728 + default: 2.26.0.240828 jobs: - job: 'build' From 4a5d66c15f2daf2f0d654fa98dfd0334322b5341 Mon Sep 17 00:00:00 2001 From: jingyanwangms <47403504+jingyanwangms@users.noreply.github.com> Date: Tue, 10 Sep 2024 15:26:16 -0700 Subject: [PATCH 03/39] Default value 10.2->10.3 in linux-gpu-tensorrt-daily-perf-pipeline.yml (#21823) ### Description Fix default value 10.2->10.3 in linux-gpu-tensorrt-daily-perf-pipeline.yml ### Motivation and Context --- .../azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml index ddfb4297afc4a..e172611d898bf 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml @@ -8,7 +8,7 @@ parameters: - name: TrtVersion displayName: TensorRT Version type: string - default: 10.2.cuda_12_5_cudnn_9 + default: 10.3.cuda_12_5_cudnn_9 values: - 8.6.cuda_11_8_cudnn_8 - 8.6.cuda_12_3_cudnn_9 From 19954decafe6e66d73d9dd276a9dfdb1392d1a4b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 10 Sep 2024 23:05:44 +0000 Subject: [PATCH 04/39] Bump body-parser from 1.20.2 to 1.20.3 in /js/web (#22044) --- js/web/package-lock.json | 330 ++++++++++++++++++++++++++++----------- 1 file changed, 243 insertions(+), 87 deletions(-) diff --git a/js/web/package-lock.json b/js/web/package-lock.json index d37cf6bd90887..9db48f74a94a4 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -456,9 +456,9 @@ } }, "node_modules/body-parser": { - "version": "1.20.2", - "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.2.tgz", - "integrity": "sha512-ml9pReCu3M61kGlqoTm2umSXTlRTuGTx0bfYj+uIUKKYycG5NtSbeetV3faSU6R7ajOPw0g/J1PvK4qNy7s5bA==", + "version": "1.20.3", + "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.3.tgz", + "integrity": "sha512-7rAxByjUMqQ3/bHJy7D6OGXvx/MMc4IqBn/X0fcM1QUcAItpZrBEYhWGem+tzXH90c+G01ypMcYJBO9Y30203g==", "dev": true, "dependencies": { "bytes": "3.1.2", @@ -469,7 +469,7 @@ "http-errors": "2.0.0", "iconv-lite": "0.4.24", "on-finished": "2.4.1", - "qs": "6.11.0", + "qs": "6.13.0", "raw-body": "2.5.2", "type-is": "~1.6.18", "unpipe": "1.0.0" @@ -603,13 +603,19 @@ } }, "node_modules/call-bind": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.2.tgz", - "integrity": "sha512-7O+FbCihrB5WGbFYesctwmTKae6rOiIzmz1icreWJ+0aA7LJfuqhEso2T9ncpcFtzMQtzXf2QGGueWJGTYsqrA==", + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.7.tgz", + "integrity": "sha512-GHTSNSYICQ7scH7sZ+M2rFopRoLh8t2bLSW6BbgrtLsahOIB5iyAVJf9GjWK3cYTDaMj4XdBpM1cA6pIS0Kv2w==", "dev": true, "dependencies": { - "function-bind": "^1.1.1", - "get-intrinsic": "^1.0.2" + "es-define-property": "^1.0.0", + "es-errors": "^1.3.0", + "function-bind": "^1.1.2", + "get-intrinsic": "^1.2.4", + "set-function-length": "^1.2.1" + }, + "engines": { + "node": ">= 0.4" }, "funding": { "url": "https://github.com/sponsors/ljharb" @@ -959,6 +965,23 @@ "node": ">=10" } }, + "node_modules/define-data-property": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/define-data-property/-/define-data-property-1.1.4.tgz", + "integrity": "sha512-rBMvIzlpA8v6E+SJZoo++HAYqsLrkg7MSfIinMPFhmkorw7X+dOXVJQs+QT69zGkzMyfDnIMN2Wid1+NbL3T+A==", + "dev": true, + "dependencies": { + "es-define-property": "^1.0.0", + "es-errors": "^1.3.0", + "gopd": "^1.0.1" + }, + "engines": { + "node": ">= 0.4" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, "node_modules/define-properties": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/define-properties/-/define-properties-1.2.0.tgz", @@ -1137,6 +1160,27 @@ "node": ">=6" } }, + "node_modules/es-define-property": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.0.tgz", + "integrity": "sha512-jxayLKShrEqqzJ0eumQbVhTYQM27CfT1T35+gCgDFoL82JLsXqTJ76zv6A0YLOgEnLUMvLzsDsGIrl8NFpT2gQ==", + "dev": true, + "dependencies": { + "get-intrinsic": "^1.2.4" + }, + "engines": { + "node": ">= 0.4" + } + }, + "node_modules/es-errors": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz", + "integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==", + "dev": true, + "engines": { + "node": ">= 0.4" + } + }, "node_modules/es6-error": { "version": "4.1.1", "resolved": "https://registry.npmjs.org/es6-error/-/es6-error-4.1.1.tgz", @@ -1447,14 +1491,19 @@ } }, "node_modules/get-intrinsic": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.0.tgz", - "integrity": "sha512-L049y6nFOuom5wGyRc3/gdTLO94dySVKRACj1RmJZBQXlbTMhtNIgkWkUHq+jYmZvKf14EW1EoJnnjbmoHij0Q==", + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.4.tgz", + "integrity": "sha512-5uYhsJH8VJBTv7oslg4BznJYhDoRI6waYCxMmCdnTrcCrHA/fCFKoTFz2JKKE0HdDFUF7/oQuhzumXJK7paBRQ==", "dev": true, "dependencies": { - "function-bind": "^1.1.1", - "has": "^1.0.3", - "has-symbols": "^1.0.3" + "es-errors": "^1.3.0", + "function-bind": "^1.1.2", + "has-proto": "^1.0.1", + "has-symbols": "^1.0.3", + "hasown": "^2.0.0" + }, + "engines": { + "node": ">= 0.4" }, "funding": { "url": "https://github.com/sponsors/ljharb" @@ -1598,6 +1647,18 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/gopd": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.0.1.tgz", + "integrity": "sha512-d65bNlIadxvpb/A2abVdlqKqV563juRnZ1Wtk6s1sIR8uNsXR70xqIzVqxVf1eTqDunwT2MkczEeaezCKTZhwA==", + "dev": true, + "dependencies": { + "get-intrinsic": "^1.1.3" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, "node_modules/got": { "version": "11.8.6", "resolved": "https://registry.npmjs.org/got/-/got-11.8.6.tgz", @@ -1640,18 +1701,6 @@ "resolved": "https://registry.npmjs.org/guid-typescript/-/guid-typescript-1.0.9.tgz", "integrity": "sha512-Y8T4vYhEfwJOTbouREvG+3XDsjr8E3kIr7uf+JZ0BYloFsttiHU0WfvANVsR7TxNUJa/WpCnw/Ino/p+DeBhBQ==" }, - "node_modules/has": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/has/-/has-1.0.3.tgz", - "integrity": "sha512-f2dvO0VU6Oej7RkWJGrehjbzMAjFp5/VKPp5tTpWIV4JHHZK1/BxbFRtf/siA2SWTe09caDmVtYYzWEIbBS4zw==", - "dev": true, - "dependencies": { - "function-bind": "^1.1.1" - }, - "engines": { - "node": ">= 0.4.0" - } - }, "node_modules/has-flag": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-3.0.0.tgz", @@ -1662,13 +1711,24 @@ } }, "node_modules/has-property-descriptors": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/has-property-descriptors/-/has-property-descriptors-1.0.0.tgz", - "integrity": "sha512-62DVLZGoiEBDHQyqG4w9xCuZ7eJEwNmJRWw2VY84Oedb7WFcA27fiEVe8oUQx9hAUJ4ekurquucTGwsyO1XGdQ==", + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/has-property-descriptors/-/has-property-descriptors-1.0.2.tgz", + "integrity": "sha512-55JNKuIW+vq4Ke1BjOTjM2YctQIvCT7GFzHwmfZPGo5wnrgkid0YQtnAleFSqumZm4az3n2BS+erby5ipJdgrg==", "dev": true, - "optional": true, "dependencies": { - "get-intrinsic": "^1.1.1" + "es-define-property": "^1.0.0" + }, + "funding": { + "url": "https://github.com/sponsors/ljharb" + } + }, + "node_modules/has-proto": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/has-proto/-/has-proto-1.0.3.tgz", + "integrity": "sha512-SJ1amZAJUiZS+PhsVLf5tGydlaVB8EdFpaSO4gmiUKUOxk8qzn5AIy4ZeJUmh22znIdk/uMAUT2pl3FxzVUH+Q==", + "dev": true, + "engines": { + "node": ">= 0.4" }, "funding": { "url": "https://github.com/sponsors/ljharb" @@ -1686,6 +1746,18 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/hasown": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz", + "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==", + "dev": true, + "dependencies": { + "function-bind": "^1.1.2" + }, + "engines": { + "node": ">= 0.4" + } + }, "node_modules/http-cache-semantics": { "version": "4.1.1", "resolved": "https://registry.npmjs.org/http-cache-semantics/-/http-cache-semantics-4.1.1.tgz", @@ -2502,10 +2574,13 @@ } }, "node_modules/object-inspect": { - "version": "1.12.3", - "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.12.3.tgz", - "integrity": "sha512-geUvdk7c+eizMNUDkRpW1wJwgfOiOeHbxBR/hLXK1aT6zmVSO0jsQcs7fj6MGw89jC/cjGfLcNOrtMYtGqm81g==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.2.tgz", + "integrity": "sha512-IRZSRuzJiynemAXPYtPe5BoI/RESNYR7TYm50MC5Mqbd3Jmw5y790sErYw3V6SryFJD64b74qQQs9wn5Bg/k3g==", "dev": true, + "engines": { + "node": ">= 0.4" + }, "funding": { "url": "https://github.com/sponsors/ljharb" } @@ -2717,12 +2792,12 @@ } }, "node_modules/qs": { - "version": "6.11.0", - "resolved": "https://registry.npmjs.org/qs/-/qs-6.11.0.tgz", - "integrity": "sha512-MvjoMCJwEarSbUYk5O+nmoSzSutSsTwF85zcHPQ9OrlFoZOYIjaqBAJIqIXjptyD5vThxGq52Xu/MaJzRkIk4Q==", + "version": "6.13.0", + "resolved": "https://registry.npmjs.org/qs/-/qs-6.13.0.tgz", + "integrity": "sha512-+38qI9SOr8tfZ4QmJNplMUxqjbe7LKvvZgWdExBOmd+egZTtjLB67Gu0HRX3u/XOq7UU2Nx6nsjvS16Z9uwfpg==", "dev": true, "dependencies": { - "side-channel": "^1.0.4" + "side-channel": "^1.0.6" }, "engines": { "node": ">=0.6" @@ -2967,6 +3042,23 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/set-function-length": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/set-function-length/-/set-function-length-1.2.2.tgz", + "integrity": "sha512-pgRc4hJ4/sNjWCSS9AmnS40x3bNMDTknHgL5UaMBTMyJnU90EgWh1Rz+MC9eFu4BuN/UwZjKQuY/1v3rM7HMfg==", + "dev": true, + "dependencies": { + "define-data-property": "^1.1.4", + "es-errors": "^1.3.0", + "function-bind": "^1.1.2", + "get-intrinsic": "^1.2.4", + "gopd": "^1.0.1", + "has-property-descriptors": "^1.0.2" + }, + "engines": { + "node": ">= 0.4" + } + }, "node_modules/setprototypeof": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.2.0.tgz", @@ -2995,14 +3087,18 @@ } }, "node_modules/side-channel": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.0.4.tgz", - "integrity": "sha512-q5XPytqFEIKHkGdiMIrY10mvLRvnQh42/+GoBlFW3b2LXLE2xxJpZFdm94we0BaoV3RwJyGqg5wS7epxTv0Zvw==", + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.0.6.tgz", + "integrity": "sha512-fDW/EZ6Q9RiO8eFG8Hj+7u/oW+XrPTIChwCOM2+th2A6OblDtYYIpve9m+KvI9Z4C9qSEXlaGR6bTEYHReuglA==", "dev": true, "dependencies": { - "call-bind": "^1.0.0", - "get-intrinsic": "^1.0.2", - "object-inspect": "^1.9.0" + "call-bind": "^1.0.7", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.4", + "object-inspect": "^1.13.1" + }, + "engines": { + "node": ">= 0.4" }, "funding": { "url": "https://github.com/sponsors/ljharb" @@ -3876,9 +3972,9 @@ "dev": true }, "body-parser": { - "version": "1.20.2", - "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.2.tgz", - "integrity": "sha512-ml9pReCu3M61kGlqoTm2umSXTlRTuGTx0bfYj+uIUKKYycG5NtSbeetV3faSU6R7ajOPw0g/J1PvK4qNy7s5bA==", + "version": "1.20.3", + "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-1.20.3.tgz", + "integrity": "sha512-7rAxByjUMqQ3/bHJy7D6OGXvx/MMc4IqBn/X0fcM1QUcAItpZrBEYhWGem+tzXH90c+G01ypMcYJBO9Y30203g==", "dev": true, "requires": { "bytes": "3.1.2", @@ -3889,7 +3985,7 @@ "http-errors": "2.0.0", "iconv-lite": "0.4.24", "on-finished": "2.4.1", - "qs": "6.11.0", + "qs": "6.13.0", "raw-body": "2.5.2", "type-is": "~1.6.18", "unpipe": "1.0.0" @@ -4005,13 +4101,16 @@ } }, "call-bind": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.2.tgz", - "integrity": "sha512-7O+FbCihrB5WGbFYesctwmTKae6rOiIzmz1icreWJ+0aA7LJfuqhEso2T9ncpcFtzMQtzXf2QGGueWJGTYsqrA==", + "version": "1.0.7", + "resolved": "https://registry.npmjs.org/call-bind/-/call-bind-1.0.7.tgz", + "integrity": "sha512-GHTSNSYICQ7scH7sZ+M2rFopRoLh8t2bLSW6BbgrtLsahOIB5iyAVJf9GjWK3cYTDaMj4XdBpM1cA6pIS0Kv2w==", "dev": true, "requires": { - "function-bind": "^1.1.1", - "get-intrinsic": "^1.0.2" + "es-define-property": "^1.0.0", + "es-errors": "^1.3.0", + "function-bind": "^1.1.2", + "get-intrinsic": "^1.2.4", + "set-function-length": "^1.2.1" } }, "chai": { @@ -4286,6 +4385,17 @@ "integrity": "sha512-4tvttepXG1VaYGrRibk5EwJd1t4udunSOVMdLSAL6mId1ix438oPwPZMALY41FCijukO1L0twNcGsdzS7dHgDg==", "dev": true }, + "define-data-property": { + "version": "1.1.4", + "resolved": "https://registry.npmjs.org/define-data-property/-/define-data-property-1.1.4.tgz", + "integrity": "sha512-rBMvIzlpA8v6E+SJZoo++HAYqsLrkg7MSfIinMPFhmkorw7X+dOXVJQs+QT69zGkzMyfDnIMN2Wid1+NbL3T+A==", + "dev": true, + "requires": { + "es-define-property": "^1.0.0", + "es-errors": "^1.3.0", + "gopd": "^1.0.1" + } + }, "define-properties": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/define-properties/-/define-properties-1.2.0.tgz", @@ -4429,6 +4539,21 @@ "integrity": "sha512-+h1lkLKhZMTYjog1VEpJNG7NZJWcuc2DDk/qsqSTRRCOXiLjeQ1d1/udrUGhqMxUgAlwKNZ0cf2uqan5GLuS2A==", "dev": true }, + "es-define-property": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.0.tgz", + "integrity": "sha512-jxayLKShrEqqzJ0eumQbVhTYQM27CfT1T35+gCgDFoL82JLsXqTJ76zv6A0YLOgEnLUMvLzsDsGIrl8NFpT2gQ==", + "dev": true, + "requires": { + "get-intrinsic": "^1.2.4" + } + }, + "es-errors": { + "version": "1.3.0", + "resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz", + "integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==", + "dev": true + }, "es6-error": { "version": "4.1.1", "resolved": "https://registry.npmjs.org/es6-error/-/es6-error-4.1.1.tgz", @@ -4678,14 +4803,16 @@ "dev": true }, "get-intrinsic": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.0.tgz", - "integrity": "sha512-L049y6nFOuom5wGyRc3/gdTLO94dySVKRACj1RmJZBQXlbTMhtNIgkWkUHq+jYmZvKf14EW1EoJnnjbmoHij0Q==", + "version": "1.2.4", + "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.4.tgz", + "integrity": "sha512-5uYhsJH8VJBTv7oslg4BznJYhDoRI6waYCxMmCdnTrcCrHA/fCFKoTFz2JKKE0HdDFUF7/oQuhzumXJK7paBRQ==", "dev": true, "requires": { - "function-bind": "^1.1.1", - "has": "^1.0.3", - "has-symbols": "^1.0.3" + "es-errors": "^1.3.0", + "function-bind": "^1.1.2", + "has-proto": "^1.0.1", + "has-symbols": "^1.0.3", + "hasown": "^2.0.0" } }, "get-stream": { @@ -4791,6 +4918,15 @@ "slash": "^4.0.0" } }, + "gopd": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.0.1.tgz", + "integrity": "sha512-d65bNlIadxvpb/A2abVdlqKqV563juRnZ1Wtk6s1sIR8uNsXR70xqIzVqxVf1eTqDunwT2MkczEeaezCKTZhwA==", + "dev": true, + "requires": { + "get-intrinsic": "^1.1.3" + } + }, "got": { "version": "11.8.6", "resolved": "https://registry.npmjs.org/got/-/got-11.8.6.tgz", @@ -4827,15 +4963,6 @@ "resolved": "https://registry.npmjs.org/guid-typescript/-/guid-typescript-1.0.9.tgz", "integrity": "sha512-Y8T4vYhEfwJOTbouREvG+3XDsjr8E3kIr7uf+JZ0BYloFsttiHU0WfvANVsR7TxNUJa/WpCnw/Ino/p+DeBhBQ==" }, - "has": { - "version": "1.0.3", - "resolved": "https://registry.npmjs.org/has/-/has-1.0.3.tgz", - "integrity": "sha512-f2dvO0VU6Oej7RkWJGrehjbzMAjFp5/VKPp5tTpWIV4JHHZK1/BxbFRtf/siA2SWTe09caDmVtYYzWEIbBS4zw==", - "dev": true, - "requires": { - "function-bind": "^1.1.1" - } - }, "has-flag": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-3.0.0.tgz", @@ -4843,21 +4970,35 @@ "dev": true }, "has-property-descriptors": { - "version": "1.0.0", - "resolved": "https://registry.npmjs.org/has-property-descriptors/-/has-property-descriptors-1.0.0.tgz", - "integrity": "sha512-62DVLZGoiEBDHQyqG4w9xCuZ7eJEwNmJRWw2VY84Oedb7WFcA27fiEVe8oUQx9hAUJ4ekurquucTGwsyO1XGdQ==", + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/has-property-descriptors/-/has-property-descriptors-1.0.2.tgz", + "integrity": "sha512-55JNKuIW+vq4Ke1BjOTjM2YctQIvCT7GFzHwmfZPGo5wnrgkid0YQtnAleFSqumZm4az3n2BS+erby5ipJdgrg==", "dev": true, - "optional": true, "requires": { - "get-intrinsic": "^1.1.1" + "es-define-property": "^1.0.0" } }, + "has-proto": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/has-proto/-/has-proto-1.0.3.tgz", + "integrity": "sha512-SJ1amZAJUiZS+PhsVLf5tGydlaVB8EdFpaSO4gmiUKUOxk8qzn5AIy4ZeJUmh22znIdk/uMAUT2pl3FxzVUH+Q==", + "dev": true + }, "has-symbols": { "version": "1.0.3", "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.0.3.tgz", "integrity": "sha512-l3LCuF6MgDNwTDKkdYGEihYjt5pRPbEg46rtlmnSPlUbgmB8LOIrKJbYYFBSbnPaJexMKtiPO8hmeRjRz2Td+A==", "dev": true }, + "hasown": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz", + "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==", + "dev": true, + "requires": { + "function-bind": "^1.1.2" + } + }, "http-cache-semantics": { "version": "4.1.1", "resolved": "https://registry.npmjs.org/http-cache-semantics/-/http-cache-semantics-4.1.1.tgz", @@ -5505,9 +5646,9 @@ "dev": true }, "object-inspect": { - "version": "1.12.3", - "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.12.3.tgz", - "integrity": "sha512-geUvdk7c+eizMNUDkRpW1wJwgfOiOeHbxBR/hLXK1aT6zmVSO0jsQcs7fj6MGw89jC/cjGfLcNOrtMYtGqm81g==", + "version": "1.13.2", + "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.2.tgz", + "integrity": "sha512-IRZSRuzJiynemAXPYtPe5BoI/RESNYR7TYm50MC5Mqbd3Jmw5y790sErYw3V6SryFJD64b74qQQs9wn5Bg/k3g==", "dev": true }, "object-keys": { @@ -5666,12 +5807,12 @@ "dev": true }, "qs": { - "version": "6.11.0", - "resolved": "https://registry.npmjs.org/qs/-/qs-6.11.0.tgz", - "integrity": "sha512-MvjoMCJwEarSbUYk5O+nmoSzSutSsTwF85zcHPQ9OrlFoZOYIjaqBAJIqIXjptyD5vThxGq52Xu/MaJzRkIk4Q==", + "version": "6.13.0", + "resolved": "https://registry.npmjs.org/qs/-/qs-6.13.0.tgz", + "integrity": "sha512-+38qI9SOr8tfZ4QmJNplMUxqjbe7LKvvZgWdExBOmd+egZTtjLB67Gu0HRX3u/XOq7UU2Nx6nsjvS16Z9uwfpg==", "dev": true, "requires": { - "side-channel": "^1.0.4" + "side-channel": "^1.0.6" } }, "queue-microtask": { @@ -5832,6 +5973,20 @@ } } }, + "set-function-length": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/set-function-length/-/set-function-length-1.2.2.tgz", + "integrity": "sha512-pgRc4hJ4/sNjWCSS9AmnS40x3bNMDTknHgL5UaMBTMyJnU90EgWh1Rz+MC9eFu4BuN/UwZjKQuY/1v3rM7HMfg==", + "dev": true, + "requires": { + "define-data-property": "^1.1.4", + "es-errors": "^1.3.0", + "function-bind": "^1.1.2", + "get-intrinsic": "^1.2.4", + "gopd": "^1.0.1", + "has-property-descriptors": "^1.0.2" + } + }, "setprototypeof": { "version": "1.2.0", "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.2.0.tgz", @@ -5854,14 +6009,15 @@ "dev": true }, "side-channel": { - "version": "1.0.4", - "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.0.4.tgz", - "integrity": "sha512-q5XPytqFEIKHkGdiMIrY10mvLRvnQh42/+GoBlFW3b2LXLE2xxJpZFdm94we0BaoV3RwJyGqg5wS7epxTv0Zvw==", + "version": "1.0.6", + "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.0.6.tgz", + "integrity": "sha512-fDW/EZ6Q9RiO8eFG8Hj+7u/oW+XrPTIChwCOM2+th2A6OblDtYYIpve9m+KvI9Z4C9qSEXlaGR6bTEYHReuglA==", "dev": true, "requires": { - "call-bind": "^1.0.0", - "get-intrinsic": "^1.0.2", - "object-inspect": "^1.9.0" + "call-bind": "^1.0.7", + "es-errors": "^1.3.0", + "get-intrinsic": "^1.2.4", + "object-inspect": "^1.13.1" } }, "signal-exit": { From c5418f35d41ba4c1e189ae338e286294c6de6b60 Mon Sep 17 00:00:00 2001 From: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Date: Tue, 10 Sep 2024 16:18:05 -0700 Subject: [PATCH 05/39] Add fusions for re-designed Phi-3 vision and Phi-3.5 vision ONNX models (#22026) ### Description This PR adds the optimizer logic to fuse the newly designed exported ONNX models for Phi-3 vision and Phi-3.5 vision. ### Motivation and Context After the re-designed export of Phi-3 vision and Phi-3.5 vision, the ONNX models for the vision component and embedding component contain `If` and `Loop` ops to handle multi-image support. --- .../python/tools/transformers/fusion_gelu.py | 15 +++- .../transformers/fusion_skiplayernorm.py | 36 +++++---- .../python/tools/transformers/onnx_model.py | 79 +++++++++++++++---- .../tools/transformers/onnx_model_bert.py | 6 +- .../tools/transformers/onnx_model_clip.py | 1 + 5 files changed, 97 insertions(+), 40 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_gelu.py b/onnxruntime/python/tools/transformers/fusion_gelu.py index 8626ca1482104..6be5140c070d0 100644 --- a/onnxruntime/python/tools/transformers/fusion_gelu.py +++ b/onnxruntime/python/tools/transformers/fusion_gelu.py @@ -98,10 +98,13 @@ def fuse_1(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) return self.nodes_to_remove.extend(subgraph_nodes) - fused_node = helper.make_node("Gelu", inputs=[subgraph_input], outputs=[subgraph_output]) + fused_node = helper.make_node( + "Gelu", inputs=[subgraph_input], outputs=[subgraph_output], name=self.model.create_node_name("Gelu") + ) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) self.node_name_to_graph_name[fused_node.name] = self.this_graph_name + self.increase_counter("Gelu") return True def fuse_2(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]: @@ -172,10 +175,13 @@ def fuse_2(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) return self.nodes_to_remove.extend(subgraph_nodes) - fused_node = helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[mul.output[0]]) + fused_node = helper.make_node( + "Gelu", inputs=[root_node.output[0]], outputs=[mul.output[0]], name=self.model.create_node_name("Gelu") + ) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) self.node_name_to_graph_name[fused_node.name] = self.this_graph_name + self.increase_counter("Gelu") return True def fuse_3(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]: @@ -243,8 +249,11 @@ def fuse_3(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) return self.nodes_to_remove.extend(subgraph_nodes) - fused_node = helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[last_mul.output[0]]) + fused_node = helper.make_node( + "Gelu", inputs=[root_node.output[0]], outputs=[last_mul.output[0]], name=self.model.create_node_name("Gelu") + ) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) self.node_name_to_graph_name[fused_node.name] = self.this_graph_name + self.increase_counter("Gelu") return True diff --git a/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py b/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py index 1ec5edf686c63..a10b61fdc3f08 100644 --- a/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py @@ -24,14 +24,15 @@ def __init__( model: OnnxModel, fused_op_type: str = "SkipLayerNormalization", search_op_types: str = "LayerNormalization", + shape_infer: bool = True, ): super().__init__(model, fused_op_type, search_op_types) - # Update shape inference is needed since other fusions might add new edge which does not have shape info yet. - self.shape_infer_helper = self.model.infer_runtime_shape({"batch_size": 4, "seq_len": 7}, update=True) - - if self.shape_infer_helper is None: - # TODO(tianleiwu): support subgraph in shape inference or add broadcasting in SkipLayerNormalization op. - logger.warning("symbolic shape inference disabled or failed.") + if shape_infer: + # Update shape inference is needed since other fusions might add new edge which does not have shape info yet. + self.shape_infer_helper = self.model.infer_runtime_shape({"batch_size": 4, "seq_len": 7}, update=True) + if self.shape_infer_helper is None: + # TODO(tianleiwu): support subgraph in shape inference or add broadcasting in SkipLayerNormalization op. + logger.warning("symbolic shape inference disabled or failed.") def fuse(self, node, input_name_to_nodes, output_name_to_node): add = self.model.get_parent(node, 0, output_name_to_node) @@ -56,18 +57,19 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): # Root Mean Square Layer Normalization simplified = node.op_type == "SimplifiedLayerNormalization" - if self.shape_infer_helper is not None: - # TODO(tianleiwu): support broadcasting Skip shape (1, sequence_length, hidden_size) or (sequence_length, hidden_size) - if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]): - logger.debug( - "skip SkipLayerNormalization fusion since shape of inputs (%s, %s) are not same", - add.input[0], - add.input[1], - ) + if hasattr(self, "shape_infer_helper"): + if self.shape_infer_helper is not None: + # TODO(tianleiwu): support broadcasting Skip shape (1, sequence_length, hidden_size) or (sequence_length, hidden_size) + if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]): + logger.debug( + "skip SkipLayerNormalization fusion since shape of inputs (%s, %s) are not same", + add.input[0], + add.input[1], + ) + return + else: + logger.debug("skip SkipLayerNormalization fusion since symbolic shape inference failed") return - else: - logger.debug("skip SkipLayerNormalization fusion since symbolic shape inference failed") - return gather_path = self.model.match_parent_path(add, ["Gather"], [None]) if gather_path is not None and self.model.find_graph_input(gather_path[0].input[1]) is None: diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index a8fc6e661933e..fe80a08829263 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -63,9 +63,10 @@ def infer_runtime_shape(self, dynamic_axis_mapping={}, update=False): # noqa: B return None - def input_name_to_nodes(self): + def input_name_to_nodes(self, exclude_subgraphs=False): input_name_to_nodes = {} - for node in self.nodes(): + nodes_to_search = self.nodes() if not exclude_subgraphs else self.model.graph.node + for node in nodes_to_search: for input_name in node.input: if input_name: # could be empty when it is optional if input_name not in input_name_to_nodes: @@ -74,9 +75,10 @@ def input_name_to_nodes(self): input_name_to_nodes[input_name].append(node) return input_name_to_nodes - def output_name_to_node(self): + def output_name_to_node(self, exclude_subgraphs=False): output_name_to_node = {} - for node in self.nodes(): + nodes_to_search = self.nodes() if not exclude_subgraphs else self.model.graph.node + for node in nodes_to_search: for output_name in node.output: if output_name: # could be empty when it is optional output_name_to_node[output_name] = node @@ -906,6 +908,31 @@ def remove_unused_constant(self): if len(unused_nodes) > 0: logger.debug(f"Removed unused constant nodes: {len(unused_nodes)}") + def _get_subgraph_inputs_of_node(self, node): + """ + Get inputs to all nodes in all subgraphs of a node + """ + # Note: This function only handles one-level subgraphs of child nodes. + subgraph_nodes_inputs = set() + for attr in node.attribute: + if attr.type == AttributeProto.GRAPH: + child_nodes = attr.g.node + for child_node in child_nodes: + subgraph_nodes_inputs.update(child_node.input) + return subgraph_nodes_inputs + + def _get_subgraph_nodes_and_inputs(self, ops_with_graph_attrs): + """ + Get input names to all nodes in all subgraphs where subgraphs are + graph attributes of a node in the main graph + """ + subgraph_nodes = list(filter(lambda node: node.op_type in ops_with_graph_attrs, self.model.graph.node)) + subgraph_nodes_inputs = set() + for parent_node in subgraph_nodes: + subgraph_inputs_of_parent_node = self._get_subgraph_inputs_of_node(parent_node) + subgraph_nodes_inputs.update(subgraph_inputs_of_parent_node) + return subgraph_nodes, subgraph_nodes_inputs + def prune_graph(self, outputs=None, allow_remove_graph_inputs=True): """ Prune graph to keep only required outputs. It removes unnecessary nodes that are not linked @@ -918,13 +945,9 @@ def prune_graph(self, outputs=None, allow_remove_graph_inputs=True): allow_remove_graph_inputs (bool): allow remove graph inputs. """ - if len(self.graphs()) > 1: - # TODO(tianleiwu): handle subgraph - logger.debug("Skip prune_graph since graph has subgraph") - return - keep_outputs = [output.name for output in self.model.graph.output] if outputs is None else outputs + input_name_to_nodes_for_main_graph = self.input_name_to_nodes(exclude_subgraphs=True) output_name_to_node = self.output_name_to_node() def get_first_output(node): @@ -932,6 +955,29 @@ def get_first_output(node): return node.output[0] return next(iter([o for o in node.output if o]), None) + if len(self.graphs()) > 1: + # Get input names for all nodes in all subgraphs + subgraph_nodes, subgraph_nodes_inputs = self._get_subgraph_nodes_and_inputs( + ops_with_graph_attrs={"Loop", "Scan", "If"} + ) + if len(subgraph_nodes) == 0: + # TODO: support other ops such as `BeamSearch` that have subgraphs as op attributes + logger.debug("Skip prune_graph since graph has subgraph") + return + + # For graphs with subgraphs, add dangling outputs from parent graph nodes to list of outputs to keep + for node in self.model.graph.node: + # TODO: This for-loop logic currently assumes that Loop/Scan/If nodes will not be + # pruned because their subgraphs are needed for computations. This might not be + # true in all cases. + if node in subgraph_nodes: + continue + + # Check if node output is an input of a subgraph node and not an input to a node in the main graph + for output in node.output: + if output in subgraph_nodes_inputs and output not in input_name_to_nodes_for_main_graph: + keep_outputs += [output] + # Keep track of nodes to keep. The key is first output of node, and the value is the node. output_to_node = {} @@ -956,7 +1002,7 @@ def get_first_output(node): first_output = get_first_output(node) kept_node = output_to_node.get(first_output) - # Need double check the node since fused node might reuse output name of some nodes to be removed. + # Need to double check the node since fused node might reuse output name of some nodes to be removed. # It is slow to compare whole node, so we compare op_type first to avoid comparing node in most cases. if kept_node and kept_node.op_type == node.op_type and kept_node == node: nodes_to_keep.append(node) @@ -997,16 +1043,15 @@ def get_first_output(node): def update_graph(self, verbose=False, allow_remove_graph_inputs=False): graph = self.model.graph - remaining_input_names = [] + remaining_input_names = set() for node in graph.node: if node.op_type in ["Loop", "Scan", "If"]: - # TODO: handle inner graph - logger.debug(f"Skip update_graph since graph has operator: {node.op_type}") - return + # Add input names of nodes in subgraphs + subgraph_inputs_of_node = self._get_subgraph_inputs_of_node(node) + remaining_input_names.update(subgraph_inputs_of_node) + if node.op_type != "Constant": - for input_name in node.input: - if input_name not in remaining_input_names: - remaining_input_names.append(input_name) + remaining_input_names.update(node.input) if verbose: logger.debug(f"remaining input names: {remaining_input_names}") diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py index ad51c1cce0ec4..26464fc32817d 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py @@ -115,8 +115,8 @@ def fuse_simplified_layer_norm(self): fusion = FusionSimplifiedLayerNormalization(self) fusion.apply() - def fuse_skip_layer_norm(self): - fusion = FusionSkipLayerNormalization(self) + def fuse_skip_layer_norm(self, shape_infer=True): + fusion = FusionSkipLayerNormalization(self, shape_infer=shape_infer) fusion.apply() def fuse_skip_simplified_layer_norm(self): @@ -344,7 +344,7 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo self.fuse_reshape() if (options is None) or options.enable_skip_layer_norm: - self.fuse_skip_layer_norm() + self.fuse_skip_layer_norm(options.enable_shape_inference) self.fuse_skip_simplified_layer_norm() if (options is None) or options.enable_rotary_embeddings: diff --git a/onnxruntime/python/tools/transformers/onnx_model_clip.py b/onnxruntime/python/tools/transformers/onnx_model_clip.py index 32bddc3ca16a0..388d058c7856c 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_clip.py +++ b/onnxruntime/python/tools/transformers/onnx_model_clip.py @@ -24,6 +24,7 @@ def get_fused_operator_statistics(self): op_count = {} ops = [ "Attention", + "Gelu", "LayerNormalization", "QuickGelu", "SkipLayerNormalization", From f633caa0b1c8fdeafd2545c1de80f6572f9d5366 Mon Sep 17 00:00:00 2001 From: PARK DongHa Date: Wed, 11 Sep 2024 08:39:27 +0900 Subject: [PATCH 06/39] Create CMake option `onnxruntime_USE_VCPKG` (#21348) ### Changes 1. CMake option `onnxruntime_USE_VCPKG`. It will be used in the vcpkg port * Unit test may fail because this option leads to a mixture of unexpected external library versions. Especially ONNX, Protobuf, and Flatbuffers version can be different 2. Overhaul of `onnxruntime_external_deps.cmake` * Make `FetchContent_Declare` to try `find_package`. See https://cmake.org/cmake/help/latest/guide/using-dependencies/index.html * Relocated `FetchContent_Declare` and `FetchContent_MakeAvailable`(or `onnxruntime_fetchcontent_makeavailable`) to closer lines. It was too hard to navigate the entire file to search related sections... * Alias `IMPORTED` targets like build targets (e.g. `ONNX::onnx` --> `onnx`) ```cmake # The script uses `find_package` with the changes. # In this case, use vcpkg to search dependencies # See https://cmake.org/cmake/help/latest/guide/using-dependencies/index.html include(external/onnxruntime_external_deps.cmake) ``` 3. Create CMakePresets.json and presets to [run vcpkg in manifest mode](https://learn.microsoft.com/en-us/vcpkg/concepts/manifest-mode) * Currently, it's NOT for training build * Main triplets are `x64-windows` and `x64-osx` ```pwsh Push-Location "cmake" cmake --preset "x64-windows-vcpkg" cmake --build --preset "x64-windows-vcpkg-debug" Pop-Location ``` ```bash pushd "cmake" cmake --preset "x64-osx-vcpkg" cmake --build --preset "x64-osx-vcpkg-debug" popd ``` 4. Updated tools/ci_build/build.py * `--use_vcpkg` option: it needs `CMAKE_TOOLCHAIN_FILE` with [vcpkg.cmake toolchain script](https://github.com/microsoft/vcpkg/blob/master/scripts/buildsystems/vcpkg.cmake) * `--compile_no_warning_as_error` is recommended because library version differences will cause unexpected compiler warnings ```bash python ./tools/ci_build/build.py \ --compile_no_warning_as_error \ --use_vcpkg \ --cmake_extra_defines "CMAKE_TOOLCHAIN_FILE:FILEPATH=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake" \ --cmake_extra_defines "VCPKG_TARGET_TRIPLET=..." ``` 5. Created Job `Vcpkg` for Windows and macOS * Show how to setup and use vcpkg. Similar to the CMakePresets.json usage ### Motivation and Context * Help #7150 * Help https://github.com/microsoft/vcpkg/pull/36850 * https://github.com/luncliff/vcpkg-registry/pull/212 * https://github.com/microsoft/vcpkg/pull/39881 * https://github.com/luncliff/vcpkg-registry/pull/215 * https://github.com/luncliff/vcpkg-registry/pull/216 * https://github.com/luncliff/vcpkg-registry/pull/227 * https://cmake.org/cmake/help/latest/guide/using-dependencies/index.html * https://github.com/microsoft/vcpkg/blob/master/scripts/buildsystems/vcpkg.cmake ### Future Works? More feature coverage with the vcpkg supported libraries * CUDA feature support * Training feature support --- .github/workflows/mac.yml | 84 +++++ .github/workflows/windows.yml | 87 +++++ cmake/CMakeLists.txt | 3 + cmake/CMakePresets.json | 192 ++++++++++ cmake/external/abseil-cpp.cmake | 4 +- .../external/onnxruntime_external_deps.cmake | 340 ++++++++++-------- cmake/onnxruntime.cmake | 6 + cmake/onnxruntime_unittests.cmake | 4 + cmake/vcpkg-configuration.json | 8 + cmake/vcpkg.json | 78 ++++ .../framework/kernel_type_str_resolver.cc | 4 - .../core/framework/ort_value_name_idx_map.h | 4 - tools/ci_build/build.py | 7 + 13 files changed, 665 insertions(+), 156 deletions(-) create mode 100644 cmake/CMakePresets.json create mode 100644 cmake/vcpkg-configuration.json create mode 100644 cmake/vcpkg.json diff --git a/.github/workflows/mac.yml b/.github/workflows/mac.yml index 3d94d30947c76..6efa8a5592337 100644 --- a/.github/workflows/mac.yml +++ b/.github/workflows/mac.yml @@ -58,6 +58,90 @@ jobs: --use_xnnpack \ --use_binskim_compliant_compile_flags + Vcpkg: + runs-on: macos-13 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ${{ env.python_version }} + + - name: "Run vcpkg(x64-osx)" + uses: lukka/run-vcpkg@v11 + with: + vcpkgDirectory: "${{ runner.temp }}/vcpkg" + vcpkgGitCommitId: "1de2026f28ead93ff1773e6e680387643e914ea1" # 2024.07.12 + runVcpkgInstall: true + vcpkgJsonGlob: "cmake/vcpkg.json" + vcpkgConfigurationJsonGlob: "cmake/vcpkg-configuration.json" + env: + VCPKG_INSTALLED_DIR: "${{ github.workspace }}/.build" + VCPKG_DEFAULT_TRIPLET: "x64-osx" + # VCPKG_BINARY_SOURCES: "default" # https://learn.microsoft.com/en-us/vcpkg/reference/binarycaching + + - name: "Run compile_schema.py" + run: | + # Runner's host triplet should be x64-osx or arm64-osx + export FLATC_DIR="${{ github.workspace }}/.build/${{ runner.arch }}-osx/tools/flatbuffers" + export PATH="$FLATC_DIR:$PATH" + flatc --version + python onnxruntime/core/flatbuffers/schema/compile_schema.py --flatc "$(which flatc)" + + - name: "Detect protoc" + id: protoc-detect + run: | + export PROTOC_DIR="${{ github.workspace }}/.build/${{ runner.arch }}-osx/tools/protobuf" + export PATH="$PROTOC_DIR:$PATH" + protoc --version + echo "protoc_path=$(which protoc)" >> "$GITHUB_OUTPUT" + + - name: "Run build.py(x64-osx)" + run: | + python ./tools/ci_build/build.py \ + --build_dir "build/x64-osx" \ + --skip_submodule_sync \ + --skip_tests \ + --compile_no_warning_as_error \ + --parallel \ + --path_to_protoc_exe "${{ steps.protoc-detect.outputs.protoc_path }}" \ + --osx_arch x86_64 \ + --use_vcpkg \ + --cmake_extra_defines "CMAKE_TOOLCHAIN_FILE:FILEPATH=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake" \ + --cmake_extra_defines "VCPKG_TARGET_TRIPLET=x64-osx" \ + --cmake_extra_defines "VCPKG_INSTALLED_DIR:PATH=${{ github.workspace }}/.build" \ + --cmake_extra_defines "VCPKG_INSTALL_OPTIONS=--x-feature=tests" + shell: bash + + - name: "Run vcpkg(arm64-osx)" + uses: lukka/run-vcpkg@v11 + with: + vcpkgDirectory: "${{ runner.temp }}/vcpkg" + vcpkgGitCommitId: "1de2026f28ead93ff1773e6e680387643e914ea1" # 2024.07.12 + runVcpkgInstall: true + vcpkgJsonGlob: "cmake/vcpkg.json" + vcpkgConfigurationJsonGlob: "cmake/vcpkg-configuration.json" + env: + VCPKG_INSTALLED_DIR: "${{ github.workspace }}/.build" + VCPKG_DEFAULT_TRIPLET: "arm64-osx" + # VCPKG_BINARY_SOURCES: "default" # https://learn.microsoft.com/en-us/vcpkg/reference/binarycaching + + - name: "Run build.py(arm64-osx)" + run: | + python ./tools/ci_build/build.py \ + --build_dir "build/arm64-osx" \ + --skip_submodule_sync \ + --skip_tests \ + --compile_no_warning_as_error \ + --parallel \ + --path_to_protoc_exe "${{ steps.protoc-detect.outputs.protoc_path }}" \ + --osx_arch arm64 \ + --use_vcpkg \ + --cmake_extra_defines "CMAKE_TOOLCHAIN_FILE:FILEPATH=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake" \ + --cmake_extra_defines "VCPKG_TARGET_TRIPLET=arm64-osx" \ + --cmake_extra_defines "VCPKG_INSTALLED_DIR:PATH=${{ github.workspace }}/.build" \ + --cmake_extra_defines "VCPKG_INSTALL_OPTIONS=--x-feature=tests" + shell: bash + Objective-C-StaticAnalysis: runs-on: macos-14 diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index b77e48942ec44..d276877b7ad47 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -42,3 +42,90 @@ jobs: # The build machine doesn't have a GPU. So the value of CMAKE_CUDA_ARCHITECTURES doesn't matter. - name: Build code run: python tools\ci_build\build.py --windows_sdk_version 10.0.22621.0 --enable_training --build_java --config Debug --build_dir D:\b --skip_submodule_sync --build_csharp --update --build --parallel --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_pybind --use_cuda --cuda_home=${{ github.workspace }}\cuda_sdk\v12.2 --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75 + + Vcpkg: + runs-on: "windows-latest" + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.11.x' + architecture: 'x64' + + - name: "Run vcpkg(x64-windows)" + uses: lukka/run-vcpkg@v11 + with: + vcpkgDirectory: "C:/vcpkg" # use VCPKG_INSTALLATION_ROOT of the image + doNotUpdateVcpkg: true + runVcpkgInstall: true + vcpkgJsonGlob: "cmake/vcpkg.json" + vcpkgConfigurationJsonGlob: "cmake/vcpkg-configuration.json" + env: + VCPKG_INSTALLED_DIR: "${{ github.workspace }}/.build" + VCPKG_DEFAULT_TRIPLET: "x64-windows" + # VCPKG_BINARY_SOURCES: "default" # https://learn.microsoft.com/en-us/vcpkg/reference/binarycaching + + - name: "Run compile_schema.py" + run: | + # Runner's host triplet should be x64-windows or arm64-windows + $FLATC_DIR="${{ github.workspace }}/.build/${{ runner.arch }}-windows/tools/flatbuffers" + $env:PATH="$FLATC_DIR;$env:PATH" + flatc --version + $FLATC_PATH = Join-Path "$FLATC_DIR" "flatc.exe" + python onnxruntime/core/flatbuffers/schema/compile_schema.py --flatc "$FLATC_PATH" + shell: pwsh + + - name: "Detect protoc" + id: protoc-detect + run: | + $PROTOC_DIR="${{ github.workspace }}/.build/${{ runner.arch }}-windows/tools/protobuf" + $env:PATH="$PROTOC_DIR;$env:PATH" + protoc --version + $PROTOC_PATH = Join-Path "$PROTOC_DIR" "protoc.exe" + "protoc_path=$PROTOC_PATH" >> $env:GITHUB_OUTPUT + shell: pwsh + + - name: "Run build.py(x64-windows)" + run: | + python tools\ci_build\build.py ` + --build_dir "cmake_build/x64-windows" ` + --skip_submodule_sync ` + --skip_tests ` + --compile_no_warning_as_error ` + --parallel ` + --path_to_protoc_exe "${{ steps.protoc-detect.outputs.protoc_path }}" ` + --use_vcpkg ` + --cmake_extra_defines "CMAKE_TOOLCHAIN_FILE:FILEPATH=C:/vcpkg/scripts/buildsystems/vcpkg.cmake" ` + --cmake_extra_defines "VCPKG_TARGET_TRIPLET=x64-windows" ` + --cmake_extra_defines "VCPKG_INSTALLED_DIR:PATH=${{ github.workspace }}/.build" ` + --cmake_extra_defines "VCPKG_INSTALL_OPTIONS=--x-feature=tests" + shell: pwsh + + - name: "Run vcpkg(arm64-windows)" + uses: lukka/run-vcpkg@v11 + with: + vcpkgDirectory: "C:/vcpkg" # use VCPKG_INSTALLATION_ROOT of the image + doNotUpdateVcpkg: true + runVcpkgInstall: true + vcpkgJsonGlob: "cmake/vcpkg.json" + vcpkgConfigurationJsonGlob: "cmake/vcpkg-configuration.json" + env: + VCPKG_INSTALLED_DIR: "${{ github.workspace }}/.build" + VCPKG_DEFAULT_TRIPLET: "arm64-windows" + # VCPKG_BINARY_SOURCES: "default" # https://learn.microsoft.com/en-us/vcpkg/reference/binarycaching + + - name: "Run build.py(arm64-windows)" + run: | + python tools\ci_build\build.py ` + --build_dir "cmake_build/arm64-windows" --arm64 ` + --skip_submodule_sync ` + --skip_tests ` + --compile_no_warning_as_error ` + --parallel ` + --path_to_protoc_exe "${{ steps.protoc-detect.outputs.protoc_path }}" ` + --use_vcpkg ` + --cmake_extra_defines "CMAKE_TOOLCHAIN_FILE:FILEPATH=C:/vcpkg/scripts/buildsystems/vcpkg.cmake" ` + --cmake_extra_defines "VCPKG_TARGET_TRIPLET=arm64-windows" ` + --cmake_extra_defines "VCPKG_INSTALLED_DIR:PATH=${{ github.workspace }}/.build" ` + --cmake_extra_defines "VCPKG_INSTALL_OPTIONS=--x-feature=tests" + shell: pwsh diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 2c8fb4824d94a..fb3b75fda4eaf 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -38,6 +38,7 @@ include(CheckLanguage) include(CMakeDependentOption) include(FetchContent) include(CheckFunctionExists) +include(GNUInstallDirs) # onnxruntime_providers_* require CMAKE_INSTALL_* variables # TODO: update this once all system adapt c++20 if(CMAKE_SYSTEM_NAME STREQUAL "Darwin") @@ -69,6 +70,7 @@ if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_ endif() # Options +option(onnxruntime_USE_VCPKG "Build with the vcpkg package manager" OFF) option(onnxruntime_RUN_ONNX_TESTS "Enable ONNX Compatibility Testing" OFF) option(onnxruntime_GENERATE_TEST_REPORTS "Enable test report generation" OFF) option(onnxruntime_ENABLE_STATIC_ANALYSIS "Enable static analysis" OFF) @@ -595,6 +597,7 @@ get_filename_component(ORTTRAINING_ROOT "${ORTTRAINING_ROOT}" ABSOLUTE) get_filename_component(REPO_ROOT "${REPO_ROOT}" ABSOLUTE) set(ONNXRUNTIME_INCLUDE_DIR ${REPO_ROOT}/include/onnxruntime) +list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/external) include(external/onnxruntime_external_deps.cmake) set(ORT_WARNING_FLAGS) diff --git a/cmake/CMakePresets.json b/cmake/CMakePresets.json new file mode 100644 index 0000000000000..1b7aa11975a3e --- /dev/null +++ b/cmake/CMakePresets.json @@ -0,0 +1,192 @@ +{ + "version": 6, + "cmakeMinimumRequired": { + "major": 3, + "minor": 25, + "patch": 0 + }, + "configurePresets": [ + { + "name": "vcpkg-manifest", + "hidden": true, + "toolchainFile": "$env{VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake", + "cacheVariables": { + "VCPKG_INSTALLED_DIR": "${sourceParentDir}/.build" + }, + "environment": { + "VCPKG_FEATURE_FLGAS": "manifests,versions" + } + }, + { + "name": "msvc-static-runtime", + "hidden": true, + "cacheVariables": { + "ONNX_USE_MSVC_STATIC_RUNTIME": true, + "protobuf_MSVC_STATIC_RUNTIME": true, + "gtest_force_shared_crt": false, + "CMAKE_MSVC_RUNTIME_LIBRARY": "MultiThreaded$<$:Debug>" + } + }, + { + "name": "unit-test", + "hidden": true, + "cacheVariables": { + "onnxruntime_RUN_ONNX_TESTS": true, + "onnxruntime_BUILD_BENCHMARKS": true, + "onnxruntime_BUILD_UNIT_TESTS": true, + "onnxruntime_GENERATE_TEST_REPORTS": true + } + }, + { + "name": "x64-windows", + "inherits": [ + "msvc-static-runtime", + "unit-test" + ], + "generator": "Visual Studio 17 2022", + "architecture": "x64", + "binaryDir": "${sourceParentDir}/cmake_build/x64-windows", + "installDir": "${sourceParentDir}/cmake_build/out", + "cacheVariables": { + "onnxruntime_USE_XNNPACK": true, + "onnxruntime_USE_DML": true, + "onnxruntime_BUILD_SHARED_LIB": true, + "CMAKE_CONFIGURATION_TYPES": "Debug;Release" + }, + "vendor": { + "microsoft.com/VisualStudioSettings/CMake/1.0": { + "intelliSenseMode": "windows-msvc-x64" + } + }, + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Windows" + } + }, + { + "name": "x64-windows-vcpkg", + "inherits": [ + "unit-test", + "vcpkg-manifest" + ], + "generator": "Visual Studio 17 2022", + "architecture": "x64", + "binaryDir": "${sourceParentDir}/cmake_build/x64-windows", + "installDir": "${sourceParentDir}/cmake_build/out", + "cacheVariables": { + "onnxruntime_USE_VCPKG": true, + "onnxruntime_USE_XNNPACK": false, + "onnxruntime_USE_DML": false, + "onnxruntime_BUILD_SHARED_LIB": true, + "CMAKE_MSVC_RUNTIME_LIBRARY": "MultiThreaded$<$:Debug>DLL", + "CMAKE_CONFIGURATION_TYPES": "Debug;Release", + "VCPKG_INSTALL_OPTIONS": "--x-feature=tests", + "VCPKG_TARGET_TRIPLET": "x64-windows" + } + }, + { + "name": "x64-osx", + "inherits": [ + "unit-test" + ], + "generator": "Xcode", + "binaryDir": "${sourceParentDir}/cmake_build/x64-osx", + "installDir": "${sourceParentDir}/cmake_build/out", + "cacheVariables": { + "CMAKE_OSX_ARCHITECTURES": "x86_64", + "onnxruntime_BUILD_SHARED_LIB": true, + "onnxruntime_USE_XNNPACK": false, + "onnxruntime_USE_COREML": true, + "onnxruntime_BUILD_OBJC": true, + "onnxruntime_BUILD_APPLE_FRAMEWORK": true, + "CMAKE_CONFIGURATION_TYPES": "Debug;Release" + }, + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Darwin" + } + }, + { + "name": "x64-osx-vcpkg", + "inherits": [ + "x64-osx", + "vcpkg-manifest" + ], + "cacheVariables": { + "onnxruntime_USE_VCPKG": true, + "onnxruntime_USE_XNNPACK": false, + "onnxruntime_USE_COREML": false, + "onnxruntime_BUILD_OBJC": false, + "onnxruntime_BUILD_APPLE_FRAMEWORK": false, + "VCPKG_INSTALL_OPTIONS": "--x-feature=tests", + "VCPKG_TARGET_TRIPLET": "x64-osx" + } + } + ], + "buildPresets": [ + { + "name": "x64-windows-debug", + "configurePreset": "x64-windows", + "configuration": "Debug" + }, + { + "name": "x64-windows-vcpkg-debug", + "configurePreset": "x64-windows-vcpkg", + "configuration": "Debug" + }, + { + "name": "x64-osx-debug", + "configurePreset": "x64-osx", + "configuration": "Debug" + }, + { + "name": "x64-osx-vcpkg-debug", + "configurePreset": "x64-osx-vcpkg", + "configuration": "Debug" + } + ], + "testPresets": [ + { + "name": "x64-windows-debug", + "configurePreset": "x64-windows", + "configuration": "Debug", + "output": { + "verbosity": "default", + "outputJUnitFile": "TEST-x64-windows-debug.xml", + "outputLogFile": "TEST-x64-windows-debug.log", + "outputOnFailure": true + }, + "execution": { + "noTestsAction": "error", + "stopOnFailure": false + } + }, + { + "name": "x64-windows-vcpkg-debug", + "inherits": "x64-windows-debug", + "configurePreset": "x64-windows-vcpkg" + }, + { + "name": "x64-osx-debug", + "configurePreset": "x64-osx", + "configuration": "Debug", + "output": { + "verbosity": "default", + "outputJUnitFile": "TEST-x64-osx-debug.xml", + "outputLogFile": "TEST-x64-osx-debug.log", + "outputOnFailure": true + }, + "execution": { + "noTestsAction": "error", + "stopOnFailure": false + } + }, + { + "name": "x64-osx-vcpkg-debug", + "inherits": "x64-osx-debug", + "configurePreset": "x64-osx-vcpkg" + } + ] +} diff --git a/cmake/external/abseil-cpp.cmake b/cmake/external/abseil-cpp.cmake index 3223724693a49..dda7c5ff19ba4 100644 --- a/cmake/external/abseil-cpp.cmake +++ b/cmake/external/abseil-cpp.cmake @@ -27,7 +27,7 @@ FetchContent_Declare( URL ${DEP_URL_abseil_cpp} URL_HASH SHA1=${DEP_SHA1_abseil_cpp} PATCH_COMMAND ${ABSL_PATCH_COMMAND} - FIND_PACKAGE_ARGS 20240116 NAMES absl + FIND_PACKAGE_ARGS NAMES absl ) onnxruntime_fetchcontent_makeavailable(abseil_cpp) @@ -142,4 +142,4 @@ absl::throw_delegate absl::memory absl::charset absl::endian -absl::config) \ No newline at end of file +absl::config) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 4e52707474052..43f18abbe9522 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -1,4 +1,4 @@ -message("Loading Dependencies URLs ...") +message(STATUS "Loading Dependencies URLs ...") include(external/helper_functions.cmake) @@ -27,7 +27,9 @@ foreach(ONNXRUNTIME_DEP IN LISTS ONNXRUNTIME_DEPS_LIST) endif() endforeach() -message("Loading Dependencies ...") +message(STATUS "Loading Dependencies ...") +include(FetchContent) + # ABSL should be included before protobuf because protobuf may use absl include(external/abseil-cpp.cmake) @@ -39,6 +41,7 @@ FetchContent_Declare( URL_HASH SHA1=${DEP_SHA1_re2} FIND_PACKAGE_ARGS NAMES re2 ) +onnxruntime_fetchcontent_makeavailable(re2) if (onnxruntime_BUILD_UNIT_TESTS) # WebAssembly threading support in Node.js is still an experimental feature and @@ -65,6 +68,7 @@ if (onnxruntime_BUILD_UNIT_TESTS) URL_HASH SHA1=${DEP_SHA1_googletest} FIND_PACKAGE_ARGS 1.14.0...<2.0.0 NAMES GTest ) + FetchContent_MakeAvailable(googletest) endif() if (onnxruntime_BUILD_BENCHMARKS) @@ -77,50 +81,41 @@ if (onnxruntime_BUILD_BENCHMARKS) google_benchmark URL ${DEP_URL_google_benchmark} URL_HASH SHA1=${DEP_SHA1_google_benchmark} + FIND_PACKAGE_ARGS NAMES benchmark ) + onnxruntime_fetchcontent_makeavailable(google_benchmark) endif() if (NOT WIN32) - FetchContent_Declare( + FetchContent_Declare( google_nsync URL ${DEP_URL_google_nsync} URL_HASH SHA1=${DEP_SHA1_google_nsync} - FIND_PACKAGE_ARGS NAMES nsync - ) -endif() -list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/external) - -FetchContent_Declare( - mimalloc - URL ${DEP_URL_mimalloc} - URL_HASH SHA1=${DEP_SHA1_mimalloc} -) - + FIND_PACKAGE_ARGS NAMES nsync unofficial-nsync + ) + #nsync tests failed on Mac Build + set(NSYNC_ENABLE_TESTS OFF CACHE BOOL "" FORCE) + onnxruntime_fetchcontent_makeavailable(google_nsync) -# Flatbuffers -# We do not need to build flatc for iOS or Android Cross Compile -if (CMAKE_SYSTEM_NAME STREQUAL "iOS" OR CMAKE_SYSTEM_NAME STREQUAL "Android" OR CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - set(FLATBUFFERS_BUILD_FLATC OFF CACHE BOOL "FLATBUFFERS_BUILD_FLATC" FORCE) -endif() -set(FLATBUFFERS_BUILD_TESTS OFF CACHE BOOL "FLATBUFFERS_BUILD_TESTS" FORCE) -set(FLATBUFFERS_INSTALL OFF CACHE BOOL "FLATBUFFERS_INSTALL" FORCE) -set(FLATBUFFERS_BUILD_FLATHASH OFF CACHE BOOL "FLATBUFFERS_BUILD_FLATHASH" FORCE) -set(FLATBUFFERS_BUILD_FLATLIB ON CACHE BOOL "FLATBUFFERS_BUILD_FLATLIB" FORCE) -if(Patch_FOUND) - set(ONNXRUNTIME_FLATBUFFERS_PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/flatbuffers/flatbuffers.patch) -else() - set(ONNXRUNTIME_FLATBUFFERS_PATCH_COMMAND "") + if (google_nsync_SOURCE_DIR) + add_library(nsync::nsync_cpp ALIAS nsync_cpp) + target_include_directories(nsync_cpp PUBLIC ${google_nsync_SOURCE_DIR}/public) + endif() + if(TARGET unofficial::nsync::nsync_cpp AND NOT TARGET nsync::nsync_cpp) + message(STATUS "Aliasing unofficial::nsync::nsync_cpp to nsync::nsync_cpp") + add_library(nsync::nsync_cpp ALIAS unofficial::nsync::nsync_cpp) + endif() endif() -#flatbuffers 1.11.0 does not have flatbuffers::IsOutRange, therefore we require 1.12.0+ -FetchContent_Declare( - flatbuffers - URL ${DEP_URL_flatbuffers} - URL_HASH SHA1=${DEP_SHA1_flatbuffers} - PATCH_COMMAND ${ONNXRUNTIME_FLATBUFFERS_PATCH_COMMAND} - FIND_PACKAGE_ARGS 23.5.9 NAMES Flatbuffers -) - +if(onnxruntime_USE_MIMALLOC) + FetchContent_Declare( + mimalloc + URL ${DEP_URL_mimalloc} + URL_HASH SHA1=${DEP_SHA1_mimalloc} + FIND_PACKAGE_ARGS NAMES mimalloc + ) + FetchContent_MakeAvailable(mimalloc) +endif() #Protobuf depends on utf8_range FetchContent_Declare( @@ -133,6 +128,10 @@ FetchContent_Declare( set(utf8_range_ENABLE_TESTS OFF CACHE BOOL "Build test suite" FORCE) set(utf8_range_ENABLE_INSTALL OFF CACHE BOOL "Configure installation" FORCE) +# The next line will generate an error message "fatal: not a git repository", but it is ok. It is from flatbuffers +onnxruntime_fetchcontent_makeavailable(utf8_range) +# protobuf's cmake/utf8_range.cmake has the following line +include_directories(${utf8_range_SOURCE_DIR}) # Download a protoc binary from Internet if needed if(NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) @@ -146,12 +145,12 @@ if(NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_mac_universal} URL_HASH SHA1=${DEP_SHA1_protoc_mac_universal}) FetchContent_Populate(protoc_binary) if(protoc_binary_SOURCE_DIR) - message("Use prebuilt protoc") + message(STATUS "Use prebuilt protoc") set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc) set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) endif() elseif (CMAKE_CROSSCOMPILING) - message("CMAKE_HOST_SYSTEM_NAME: ${CMAKE_HOST_SYSTEM_NAME}") + message(STATUS "CMAKE_HOST_SYSTEM_NAME: ${CMAKE_HOST_SYSTEM_NAME}") if(CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") if(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "AMD64") FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_win64} URL_HASH SHA1=${DEP_SHA1_protoc_win64}) @@ -162,7 +161,7 @@ if(NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) endif() if(protoc_binary_SOURCE_DIR) - message("Use prebuilt protoc") + message(STATUS "Use prebuilt protoc") set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc.exe) set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) endif() @@ -179,7 +178,7 @@ if(NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) endif() if(protoc_binary_SOURCE_DIR) - message("Use prebuilt protoc") + message(STATUS "Use prebuilt protoc") set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc) set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) endif() @@ -217,7 +216,7 @@ FetchContent_Declare( URL ${DEP_URL_protobuf} URL_HASH SHA1=${DEP_SHA1_protobuf} PATCH_COMMAND ${ONNXRUNTIME_PROTOBUF_PATCH_COMMAND} - FIND_PACKAGE_ARGS 3.21.12 NAMES Protobuf + FIND_PACKAGE_ARGS NAMES Protobuf protobuf ) set(protobuf_BUILD_TESTS OFF CACHE BOOL "Build protobuf tests" FORCE) @@ -239,6 +238,51 @@ endif() include(protobuf_function) #protobuf end +onnxruntime_fetchcontent_makeavailable(Protobuf) +if(Protobuf_FOUND) + message(STATUS "Protobuf version: ${Protobuf_VERSION}") +else() + # Adjust warning flags + if (TARGET libprotoc) + if (NOT MSVC) + target_compile_options(libprotoc PRIVATE "-w") + endif() + endif() + if (TARGET protoc) + add_executable(protobuf::protoc ALIAS protoc) + if (UNIX AND onnxruntime_ENABLE_LTO) + #https://github.com/protocolbuffers/protobuf/issues/5923 + target_link_options(protoc PRIVATE "-Wl,--no-as-needed") + endif() + if (NOT MSVC) + target_compile_options(protoc PRIVATE "-w") + endif() + get_target_property(PROTOC_OSX_ARCH protoc OSX_ARCHITECTURES) + if (PROTOC_OSX_ARCH) + if (${CMAKE_HOST_SYSTEM_PROCESSOR} IN_LIST PROTOC_OSX_ARCH) + message(STATUS "protoc can run") + else() + list(APPEND PROTOC_OSX_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR}) + set_target_properties(protoc PROPERTIES OSX_ARCHITECTURES "${CMAKE_HOST_SYSTEM_PROCESSOR}") + set_target_properties(libprotoc PROPERTIES OSX_ARCHITECTURES "${PROTOC_OSX_ARCH}") + set_target_properties(libprotobuf PROPERTIES OSX_ARCHITECTURES "${PROTOC_OSX_ARCH}") + endif() + endif() + endif() + if (TARGET libprotobuf AND NOT MSVC) + target_compile_options(libprotobuf PRIVATE "-w") + endif() + if (TARGET libprotobuf-lite AND NOT MSVC) + target_compile_options(libprotobuf-lite PRIVATE "-w") + endif() +endif() +if (onnxruntime_USE_FULL_PROTOBUF) + set(PROTOBUF_LIB protobuf::libprotobuf) +else() + set(PROTOBUF_LIB protobuf::libprotobuf-lite) +endif() + +# date set(ENABLE_DATE_TESTING OFF CACHE BOOL "" FORCE) set(USE_SYSTEM_TZ_DB ON CACHE BOOL "" FORCE) @@ -254,7 +298,16 @@ FetchContent_Declare( mp11 URL ${DEP_URL_mp11} URL_HASH SHA1=${DEP_SHA1_mp11} + FIND_PACKAGE_ARGS NAMES Boost ) +onnxruntime_fetchcontent_makeavailable(mp11) +if(NOT TARGET Boost::mp11) + if(onnxruntime_USE_VCPKG) + find_package(Boost REQUIRED) + endif() + message(STATUS "Aliasing Boost::headers to Boost::mp11") + add_library(Boost::mp11 ALIAS Boost::headers) +endif() set(JSON_BuildTests OFF CACHE INTERNAL "") set(JSON_Install OFF CACHE INTERNAL "") @@ -265,6 +318,7 @@ FetchContent_Declare( URL_HASH SHA1=${DEP_SHA1_json} FIND_PACKAGE_ARGS 3.10 NAMES nlohmann_json ) +onnxruntime_fetchcontent_makeavailable(nlohmann_json) #TODO: include clog first if (onnxruntime_ENABLE_CPUINFO) @@ -301,20 +355,6 @@ else() set(CPUINFO_SUPPORTED FALSE) endif() -# xnnpack depends on clog -# Android build should use the system's log library instead of clog -if ((CPUINFO_SUPPORTED OR onnxruntime_USE_XNNPACK) AND NOT ANDROID) - set(CLOG_BUILD_TESTS OFF CACHE BOOL "" FORCE) - FetchContent_Declare( - pytorch_clog - URL ${DEP_URL_pytorch_cpuinfo} - URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo} - SOURCE_SUBDIR deps/clog - ) - set(ONNXRUNTIME_CLOG_PROJ pytorch_clog) - set(ONNXRUNTIME_CLOG_TARGET_NAME clog) -endif() - if (CPUINFO_SUPPORTED) if (CMAKE_SYSTEM_NAME STREQUAL "iOS") set(IOS ON CACHE INTERNAL "") @@ -333,7 +373,7 @@ if (CPUINFO_SUPPORTED) set(CPUINFO_BUILD_MOCK_TESTS OFF CACHE INTERNAL "") set(CPUINFO_BUILD_BENCHMARKS OFF CACHE INTERNAL "") if(onnxruntime_target_platform STREQUAL "ARM64EC") - message("Applying a patch for Windows ARM64EC in cpuinfo") + message(STATUS "Applying a patch for Windows ARM64EC in cpuinfo") FetchContent_Declare( pytorch_cpuinfo URL ${DEP_URL_pytorch_cpuinfo} @@ -350,20 +390,33 @@ if (CPUINFO_SUPPORTED) ) endif() set(ONNXRUNTIME_CPUINFO_PROJ pytorch_cpuinfo) + onnxruntime_fetchcontent_makeavailable(${ONNXRUNTIME_CPUINFO_PROJ}) + if(TARGET cpuinfo::cpuinfo AND NOT TARGET cpuinfo) + message(STATUS "Aliasing cpuinfo::cpuinfo to cpuinfo") + add_library(cpuinfo ALIAS cpuinfo::cpuinfo) + endif() endif() - -if (onnxruntime_BUILD_BENCHMARKS) - onnxruntime_fetchcontent_makeavailable(google_benchmark) -endif() - -if (NOT WIN32) - #nsync tests failed on Mac Build - set(NSYNC_ENABLE_TESTS OFF CACHE BOOL "" FORCE) - onnxruntime_fetchcontent_makeavailable(google_nsync) - if (google_nsync_SOURCE_DIR) - add_library(nsync::nsync_cpp ALIAS nsync_cpp) - target_include_directories(nsync_cpp PUBLIC ${google_nsync_SOURCE_DIR}/public) +# xnnpack depends on clog +# Android build should use the system's log library instead of clog +if ((CPUINFO_SUPPORTED OR onnxruntime_USE_XNNPACK) AND NOT ANDROID) + set(CLOG_BUILD_TESTS OFF CACHE BOOL "" FORCE) + FetchContent_Declare( + pytorch_clog + URL ${DEP_URL_pytorch_cpuinfo} + URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo} + SOURCE_SUBDIR deps/clog + FIND_PACKAGE_ARGS NAMES cpuinfo + ) + set(ONNXRUNTIME_CLOG_PROJ pytorch_clog) + onnxruntime_fetchcontent_makeavailable(${ONNXRUNTIME_CLOG_PROJ}) + set(ONNXRUNTIME_CLOG_TARGET_NAME clog) + # if cpuinfo is from find_package, use it with imported name + if(TARGET cpuinfo::clog) + set(ONNXRUNTIME_CLOG_TARGET_NAME cpuinfo::clog) + elseif(onnxruntime_USE_VCPKG) + # however, later cpuinfo versions may not contain clog. use cpuinfo + set(ONNXRUNTIME_CLOG_TARGET_NAME cpuinfo::cpuinfo) endif() endif() @@ -383,21 +436,51 @@ else() FIND_PACKAGE_ARGS 4.0 NAMES Microsoft.GSL ) endif() +set(GSL_TARGET "Microsoft.GSL::GSL") +set(GSL_INCLUDE_DIR "$") +onnxruntime_fetchcontent_makeavailable(GSL) +find_path(safeint_SOURCE_DIR NAMES "SafeInt.hpp") +if(NOT safeint_SOURCE_DIR) + unset(safeint_SOURCE_DIR) + FetchContent_Declare( + safeint + URL ${DEP_URL_safeint} + URL_HASH SHA1=${DEP_SHA1_safeint} + ) + + # use fetch content rather than makeavailable because safeint only includes unconditional test targets + FetchContent_Populate(safeint) +endif() +add_library(safeint_interface INTERFACE) +target_include_directories(safeint_interface INTERFACE ${safeint_SOURCE_DIR}) + + +# Flatbuffers +# We do not need to build flatc for iOS or Android Cross Compile +if (CMAKE_SYSTEM_NAME STREQUAL "iOS" OR CMAKE_SYSTEM_NAME STREQUAL "Android" OR CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + set(FLATBUFFERS_BUILD_FLATC OFF CACHE BOOL "FLATBUFFERS_BUILD_FLATC" FORCE) +endif() +set(FLATBUFFERS_BUILD_TESTS OFF CACHE BOOL "FLATBUFFERS_BUILD_TESTS" FORCE) +set(FLATBUFFERS_INSTALL OFF CACHE BOOL "FLATBUFFERS_INSTALL" FORCE) +set(FLATBUFFERS_BUILD_FLATHASH OFF CACHE BOOL "FLATBUFFERS_BUILD_FLATHASH" FORCE) +set(FLATBUFFERS_BUILD_FLATLIB ON CACHE BOOL "FLATBUFFERS_BUILD_FLATLIB" FORCE) +if(Patch_FOUND) + set(ONNXRUNTIME_FLATBUFFERS_PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/flatbuffers/flatbuffers.patch) +else() + set(ONNXRUNTIME_FLATBUFFERS_PATCH_COMMAND "") +endif() + +#flatbuffers 1.11.0 does not have flatbuffers::IsOutRange, therefore we require 1.12.0+ FetchContent_Declare( - safeint - URL ${DEP_URL_safeint} - URL_HASH SHA1=${DEP_SHA1_safeint} + flatbuffers + URL ${DEP_URL_flatbuffers} + URL_HASH SHA1=${DEP_SHA1_flatbuffers} + PATCH_COMMAND ${ONNXRUNTIME_FLATBUFFERS_PATCH_COMMAND} + FIND_PACKAGE_ARGS 23.5.9 NAMES Flatbuffers flatbuffers ) -# use fetch content rather than makeavailable because safeint only includes unconditional test targets -FetchContent_Populate(safeint) -# The next line will generate an error message "fatal: not a git repository", but it is ok. It is from flatbuffers -onnxruntime_fetchcontent_makeavailable(utf8_range) -# protobuf's cmake/utf8_range.cmake has the following line -include_directories(${utf8_range_SOURCE_DIR}) - -onnxruntime_fetchcontent_makeavailable(Protobuf nlohmann_json mp11 re2 GSL flatbuffers ${ONNXRUNTIME_CPUINFO_PROJ} ${ONNXRUNTIME_CLOG_PROJ}) +onnxruntime_fetchcontent_makeavailable(flatbuffers) if(NOT flatbuffers_FOUND) if(NOT TARGET flatbuffers::flatbuffers) add_library(flatbuffers::flatbuffers ALIAS flatbuffers) @@ -424,54 +507,6 @@ namespace std { using ::getenv; } endif() endif() -if (onnxruntime_BUILD_UNIT_TESTS) - onnxruntime_fetchcontent_makeavailable(googletest) -endif() - -if(Protobuf_FOUND) - message("Protobuf version: ${Protobuf_VERSION}") -else() - # Adjust warning flags - if (TARGET libprotoc) - if (NOT MSVC) - target_compile_options(libprotoc PRIVATE "-w") - endif() - endif() - if (TARGET protoc) - add_executable(protobuf::protoc ALIAS protoc) - if (UNIX AND onnxruntime_ENABLE_LTO) - #https://github.com/protocolbuffers/protobuf/issues/5923 - target_link_options(protoc PRIVATE "-Wl,--no-as-needed") - endif() - if (NOT MSVC) - target_compile_options(protoc PRIVATE "-w") - endif() - get_target_property(PROTOC_OSX_ARCH protoc OSX_ARCHITECTURES) - if (PROTOC_OSX_ARCH) - if (${CMAKE_HOST_SYSTEM_PROCESSOR} IN_LIST PROTOC_OSX_ARCH) - message("protoc can run") - else() - list(APPEND PROTOC_OSX_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR}) - set_target_properties(protoc PROPERTIES OSX_ARCHITECTURES "${CMAKE_HOST_SYSTEM_PROCESSOR}") - set_target_properties(libprotoc PROPERTIES OSX_ARCHITECTURES "${PROTOC_OSX_ARCH}") - set_target_properties(libprotobuf PROPERTIES OSX_ARCHITECTURES "${PROTOC_OSX_ARCH}") - endif() - endif() - endif() - if (TARGET libprotobuf AND NOT MSVC) - target_compile_options(libprotobuf PRIVATE "-w") - endif() - if (TARGET libprotobuf-lite AND NOT MSVC) - target_compile_options(libprotobuf-lite PRIVATE "-w") - endif() -endif() -if (onnxruntime_USE_FULL_PROTOBUF) - set(PROTOBUF_LIB protobuf::libprotobuf) -else() - set(PROTOBUF_LIB protobuf::libprotobuf-lite) -endif() - - # ONNX if (NOT onnxruntime_USE_FULL_PROTOBUF) set(ONNX_USE_LITE_PROTO ON CACHE BOOL "" FORCE) @@ -490,27 +525,36 @@ FetchContent_Declare( URL ${DEP_URL_onnx} URL_HASH SHA1=${DEP_SHA1_onnx} PATCH_COMMAND ${ONNXRUNTIME_ONNX_PATCH_COMMAND} + FIND_PACKAGE_ARGS NAMES ONNX onnx ) - - - - - - -include(eigen) -include(wil) - if (NOT onnxruntime_MINIMAL_BUILD) - onnxruntime_fetchcontent_makeavailable(onnx) + onnxruntime_fetchcontent_makeavailable(onnx) else() include(onnx_minimal) endif() -set(GSL_TARGET "Microsoft.GSL::GSL") -set(GSL_INCLUDE_DIR "$") +if(TARGET ONNX::onnx AND NOT TARGET onnx) + message(STATUS "Aliasing ONNX::onnx to onnx") + add_library(onnx ALIAS ONNX::onnx) +endif() +if(TARGET ONNX::onnx_proto AND NOT TARGET onnx_proto) + message(STATUS "Aliasing ONNX::onnx_proto to onnx_proto") + add_library(onnx_proto ALIAS ONNX::onnx_proto) +endif() -add_library(safeint_interface INTERFACE) -target_include_directories(safeint_interface INTERFACE ${safeint_SOURCE_DIR}) +find_package(Eigen3 CONFIG) +if(Eigen3_FOUND) + get_target_property(eigen_INCLUDE_DIRS Eigen3::Eigen INTERFACE_INCLUDE_DIRECTORIES) +else() + include(eigen) # FetchContent +endif() + +if(onnxruntime_USE_VCPKG) + find_package(wil CONFIG REQUIRED) + set(WIL_TARGET "WIL::WIL") +else() + include(wil) # FetchContent +endif() # XNNPACK EP if (onnxruntime_USE_XNNPACK) @@ -539,9 +583,11 @@ set(onnxruntime_EXTERNAL_LIBRARIES ${onnxruntime_EXTERNAL_LIBRARIES_XNNPACK} ${W # The other libs do not have the problem. All the sources are already there. We can compile them in any order. set(onnxruntime_EXTERNAL_DEPENDENCIES onnx_proto flatbuffers::flatbuffers) -target_compile_definitions(onnx PUBLIC $ PRIVATE "__ONNX_DISABLE_STATIC_REGISTRATION") -if (NOT onnxruntime_USE_FULL_PROTOBUF) - target_compile_definitions(onnx PUBLIC "__ONNX_NO_DOC_STRINGS") +if(NOT (onnx_FOUND OR ONNX_FOUND)) # building ONNX from source + target_compile_definitions(onnx PUBLIC $ PRIVATE "__ONNX_DISABLE_STATIC_REGISTRATION") + if (NOT onnxruntime_USE_FULL_PROTOBUF) + target_compile_definitions(onnx PUBLIC "__ONNX_NO_DOC_STRINGS") + endif() endif() if (onnxruntime_RUN_ONNX_TESTS) @@ -550,11 +596,12 @@ endif() if(onnxruntime_ENABLE_ATEN) - message("Aten fallback is enabled.") + message(STATUS "Aten fallback is enabled.") FetchContent_Declare( dlpack URL ${DEP_URL_dlpack} URL_HASH SHA1=${DEP_SHA1_dlpack} + FIND_PACKAGE_ARGS NAMES dlpack ) # We can't use onnxruntime_fetchcontent_makeavailable since some part of the the dlpack code is Linux only. # For example, dlpackcpp.h uses posix_memalign. @@ -568,6 +615,7 @@ if(onnxruntime_ENABLE_TRAINING OR (onnxruntime_ENABLE_TRAINING_APIS AND onnxrunt cxxopts URL ${DEP_URL_cxxopts} URL_HASH SHA1=${DEP_SHA1_cxxopts} + FIND_PACKAGE_ARGS NAMES cxxopts ) set(CXXOPTS_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE) set(CXXOPTS_BUILD_TESTS OFF CACHE BOOL "" FORCE) @@ -585,7 +633,7 @@ if (onnxruntime_USE_COREML) FetchContent_Populate(coremltools) endif() -message("Finished fetching external dependencies") +message(STATUS "Finished fetching external dependencies") set(onnxruntime_LINK_DIRS ) diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 927b4ac84b037..7e992fb33077c 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -332,6 +332,9 @@ if(onnxruntime_BUILD_APPLE_FRAMEWORK) # If it's an onnxruntime library, extract .o files from the original cmake build path to a separate directory for # each library to avoid any clashes with filenames (e.g. utils.o) foreach(_LIB ${onnxruntime_INTERNAL_LIBRARIES} ) + if(NOT TARGET ${_LIB}) # if we didn't build from source. it may not a target + continue() + endif() GET_TARGET_PROPERTY(_LIB_TYPE ${_LIB} TYPE) if(_LIB_TYPE STREQUAL "STATIC_LIBRARY") set(CUR_STATIC_LIB_OBJ_DIR ${STATIC_LIB_TEMP_DIR}/$) @@ -362,6 +365,9 @@ if(onnxruntime_BUILD_APPLE_FRAMEWORK) # for external libraries we create a symlink to the .a file foreach(_LIB ${onnxruntime_EXTERNAL_LIBRARIES}) + if(NOT TARGET ${_LIB}) # if we didn't build from source. it may not a target + continue() + endif() GET_TARGET_PROPERTY(_LIB_TYPE ${_LIB} TYPE) if(_LIB_TYPE STREQUAL "STATIC_LIBRARY") add_custom_command(TARGET onnxruntime POST_BUILD diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 58dd08f15f4e2..4b880c4437dfd 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -877,6 +877,7 @@ AddTest( DEPENDS ${all_dependencies} TEST_ARGS ${test_all_args} ) +target_include_directories(onnxruntime_test_all PRIVATE ${ONNXRUNTIME_ROOT}/core/flatbuffers/schema) # ort.fbs.h if (MSVC) # The warning means the type of two integral values around a binary operator is narrow than their result. @@ -982,6 +983,9 @@ target_compile_definitions(onnx_test_data_proto PRIVATE "-DONNX_API=") onnxruntime_add_include_to_target(onnx_test_data_proto onnx_proto) target_include_directories(onnx_test_data_proto PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) set_target_properties(onnx_test_data_proto PROPERTIES FOLDER "ONNXRuntimeTest") +if(NOT DEFINED onnx_SOURCE_DIR) + find_path(onnx_SOURCE_DIR NAMES "onnx/onnx-ml.proto3" "onnx/onnx-ml.proto" REQUIRED) +endif() onnxruntime_protobuf_generate(APPEND_PATH IMPORT_DIRS ${onnx_SOURCE_DIR} TARGET onnx_test_data_proto) # diff --git a/cmake/vcpkg-configuration.json b/cmake/vcpkg-configuration.json new file mode 100644 index 0000000000000..f3525977c7bb9 --- /dev/null +++ b/cmake/vcpkg-configuration.json @@ -0,0 +1,8 @@ +{ + "default-registry": { + "kind": "git", + "repository": "https://github.com/Microsoft/vcpkg", + "baseline": "3508985146f1b1d248c67ead13f8f54be5b4f5da" + }, + "registries": [] +} diff --git a/cmake/vcpkg.json b/cmake/vcpkg.json new file mode 100644 index 0000000000000..159b8654c1cb1 --- /dev/null +++ b/cmake/vcpkg.json @@ -0,0 +1,78 @@ +{ + "$schema": "https://raw.githubusercontent.com/microsoft/vcpkg-tool/main/docs/vcpkg.schema.json", + "name": "onnxruntime", + "version-date": "2024-09-10", + "description": "ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator", + "homepage": "https://onnxruntime.ai/", + "license": "MIT", + "dependencies": [ + "abseil", + { + "name": "boost-config", + "version>=": "1.82.0" + }, + { + "name": "boost-mp11", + "version>=": "1.82.0" + }, + "cpuinfo", + "cxxopts", + "date", + "dlpack", + { + "name": "flatbuffers", + "host": true, + "version>=": "23.5.26" + }, + { + "name": "flatbuffers", + "version>=": "23.5.26" + }, + "ms-gsl", + "nlohmann-json", + { + "name": "nsync", + "platform": "!windows" + }, + { + "name": "nsync", + "platform": "!windows", + "version>=": "1.26.0" + }, + "optional-lite", + { + "name": "protobuf", + "version>=": "3.21.12" + }, + { + "name": "protobuf", + "host": true, + "version>=": "3.21.12" + }, + "re2", + "safeint", + "utf8-range", + { + "name": "vcpkg-cmake", + "host": true + }, + { + "name": "vcpkg-cmake-config", + "host": true + }, + "wil", + { + "name": "zlib", + "platform": "windows" + } + ], + "features": { + "tests": { + "description": "Build ONNXRuntime unit tests", + "dependencies": [ + "benchmark", + "gtest" + ] + } + } +} diff --git a/onnxruntime/core/framework/kernel_type_str_resolver.cc b/onnxruntime/core/framework/kernel_type_str_resolver.cc index d05e02eb3ab32..3142f94f289b3 100644 --- a/onnxruntime/core/framework/kernel_type_str_resolver.cc +++ b/onnxruntime/core/framework/kernel_type_str_resolver.cc @@ -46,12 +46,8 @@ Status KernelTypeStrResolver::ResolveKernelTypeStr(const Node& node, std::string ORT_RETURN_IF(op_it == op_kernel_type_str_map_.end(), "Failed to find op_id: ", op_id); const auto& type_str_map = op_it->second; -#ifdef DISABLE_ABSEIL // TODO(edgchen1) maybe we can use transparent hash/eq to enable lookup with string_view const auto type_str_it = type_str_map.find(std::string(kernel_type_str)); -#else - const auto type_str_it = type_str_map.find(kernel_type_str); -#endif ORT_RETURN_IF(type_str_it == type_str_map.end(), "Failed to find args for kernel type string '", kernel_type_str, diff --git a/onnxruntime/core/framework/ort_value_name_idx_map.h b/onnxruntime/core/framework/ort_value_name_idx_map.h index 1b5f6bcee9bd0..76e7e369514d4 100644 --- a/onnxruntime/core/framework/ort_value_name_idx_map.h +++ b/onnxruntime/core/framework/ort_value_name_idx_map.h @@ -33,11 +33,7 @@ class OrtValueNameIdxMap { common::Status GetIdx(std::string_view name, int& idx) const { idx = -1; -#ifdef DISABLE_ABSEIL auto it = map_.find(std::string(name)); -#else - auto it = map_.find(name); -#endif if (it == map_.end()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Could not find OrtValue with name '", name, "'"); } diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 587d035541c45..902d15e8122b4 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -460,6 +460,12 @@ def convert_arg_line_to_args(self, arg_line): action="store_true", help="Disable memory leak checker from Windows build. By default it is enabled in Windows Debug build. This option is Windows only.", ) + # Dependency search with vcpkg + parser.add_argument( + "--use_vcpkg", + action="store_true", + help="Use vcpkg to search dependencies. Requires CMAKE_TOOLCHAIN_FILE for vcpkg.cmake", + ) # WebAssembly build parser.add_argument("--build_wasm", action="store_true", help="Build for WebAssembly") @@ -999,6 +1005,7 @@ def generate_build_tree( # of them to get the best compatibility. "-DPython_EXECUTABLE=" + sys.executable, "-DPYTHON_EXECUTABLE=" + sys.executable, + "-Donnxruntime_USE_VCPKG=" + ("ON" if args.use_vcpkg else "OFF"), "-Donnxruntime_USE_MIMALLOC=" + ("ON" if args.use_mimalloc else "OFF"), "-Donnxruntime_ENABLE_PYTHON=" + ("ON" if args.enable_pybind else "OFF"), "-Donnxruntime_BUILD_CSHARP=" + ("ON" if args.build_csharp else "OFF"), From 20d94648bbb106c74c43ef4023e142dee0342155 Mon Sep 17 00:00:00 2001 From: Julius Tischbein Date: Wed, 11 Sep 2024 01:51:00 +0200 Subject: [PATCH 07/39] ConvTranpose using CUDNN Frontend with NHWC support (#21752) ### Description Added CUDNN Frontend and used it for NHWC ConvTranspose op including option for bias fusion. Similar to this [Conv PR](https://github.com/microsoft/onnxruntime/pull/19470) ### Backward compatible If ORT is built with cuDNN 8, cuDNN frontend will not be built into binary. Old kernels (using cudnn backend APIs) are used. ### Major Changes For cuDNN 9, we will enable cudnn frontend to fuse data gradient convolution and bias when a provider option fuse_conv_bias=1. ### Potential Issues cuDNN frontend uses TF32 by default. It can be disabled using use_tf32 cuda provider option, but in the case cuDNN frontend encounters issues building an operation graph it will fallback to using TF32. ### Follow ups This is one of the PRs that target to enable NHWC, here the ConvTranspose operation in CUDA EP by default if device supports it. There are other changes will follow up to make it possible. (1) Enable prefer_nhwc by default for device with sm >= 70. (2) Change fuse_conv_bias=1 by default after more testing. (3) Add other NHWC operators (like Resize or UpSample). ### Motivation and Context The new CUDNN Frontend library provides the functionality to fuse operations and provides new heuristics for kernel selection. Here it fuses the convolution data gradient operation (ConvTranspose) with the pointwise bias operation. ### Minor Change In the CUDA convolution operation was a small bug when `GetCudnnConv1dPadToNc1d ` was enabled. --- .../providers/cuda/cuda_execution_provider.cc | 3 +- onnxruntime/core/providers/cuda/nn/conv.cc | 2 +- .../core/providers/cuda/nn/conv_transpose.cc | 626 +++++++++++------- .../core/providers/cuda/nn/conv_transpose.h | 29 + .../core/providers/cuda/nn/conv_transpose_8.h | 266 ++++++++ 5 files changed, 702 insertions(+), 224 deletions(-) create mode 100644 onnxruntime/core/providers/cuda/nn/conv_transpose_8.h diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index b54c572556220..82b29c7b0562e 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -2473,7 +2473,8 @@ static bool RNNNeedFallbackToCPU(const onnxruntime::Node& node, return false; } -static bool ConvTransposeNeedFallbackToCPU(const onnxruntime::Node& node, const logging::Logger& logger, +static bool ConvTransposeNeedFallbackToCPU([[maybe_unused]] const onnxruntime::Node& node, + [[maybe_unused]] const logging::Logger& logger, [[maybe_unused]] const GraphViewer& graph_viewer, [[maybe_unused]] const bool prefer_nhwc) { const auto& node_attributes = node.GetAttributes(); diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index 95ba698b707ac..cc76198dc3ae9 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -385,7 +385,7 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected if (cuda_ep->GetCudnnConv1dPadToNc1d()) { x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1); y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1); - w_dims_cudnn.insert(w_dims.begin() + 2, 1); + w_dims_cudnn.insert(w_dims_cudnn.begin() + 2, 1); pads.insert(pads.begin() + kernel_rank, 0); pads.insert(pads.begin(), 0); kernel_shape.insert(kernel_shape.begin(), 1); diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index bac99d6a81ed2..d4876e1714861 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -7,6 +7,11 @@ #include "conv_transpose.h" #include "core/providers/cuda/tensor/transpose.h" +#if CUDNN_MAJOR < 9 +// if compiled with cuDNN 8 we want to use the legacy cuDNN API +#include "conv_transpose_8.h" +#endif + // To suppress FP static analyzer warnings: // https://msdata.visualstudio.com/Vienna/_workitems/edit/1944928 and // https://msdata.visualstudio.com/Vienna/_workitems/edit/1944950 @@ -38,48 +43,42 @@ REGISTER_KERNEL_TYPED(float, kMSInternalNHWCDomain, true) REGISTER_KERNEL_TYPED(MLFloat16, kMSInternalNHWCDomain, true) #endif -template -Status ConvTranspose::ComputeInternal(OpKernelContext* context) const { - return DoConvTranspose(context, false); -} - +// First input (in this case X) is in case NHWC == true also in NHWC format, the other inputs in NCHW template Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, - PrePackedWeights* prepacked_weights) { + [[maybe_unused]] PrePackedWeights* prepacked_weights) { is_packed = false; // only layout of weight input is adjusted via PrePack - if constexpr (NHWC) { // InputTensors::IN_W - if (input_idx == 1) { + if constexpr (NHWC) { + if (is_nhwc_domain_ && input_idx == 1) { // InputTensors::IN_W + // Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group} auto orig_shape = tensor.Shape(); - const auto rank = orig_shape.NumDimensions(); - - InlinedVector perm; - TensorShapeVector new_dims; - - // Input is { N, C, ...}. Output is { N, M, ...}. 'input channels' is C. 'output channels' is M. - // Transpose the output channels related dimension (M/group) to be last. Leave the input channels as-is. - if (rank == 3) { - // Transpose from {C, M/group, k1} to {C, k1, M/group} - perm = {0, 2, 1}; - new_dims = TensorShapeVector{orig_shape[0], orig_shape[2], orig_shape[1]}; - } else if (rank == 4) { - // Transpose from {C, M/group, kH, kW} to {C, kH, kW, M/group} - perm = {0, 2, 3, 1}; - new_dims = TensorShapeVector{orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[1]}; - } else if (rank == 5) { - // Transpose from {C, M/group, k1, k2, k3} to {C, k1, k2, k3, M/group} - perm = {0, 2, 3, 4, 1}; - new_dims = TensorShapeVector{orig_shape[0], orig_shape[2], orig_shape[3], orig_shape[4], orig_shape[1]}; - } + auto shape_size = orig_shape.GetDims().size(); + + InlinedVector perm; + perm.push_back(0); + for (size_t i = 2; i < shape_size; i++) perm.push_back(i); + perm.push_back(1); + gsl::span permutation(perm.data(), shape_size); - gsl::span permutation(perm.data(), rank); - W_ = Tensor::Create(tensor.DataType(), TensorShape(new_dims), std::move(alloc)); + TensorShapeVector nhwc_dims; + for (size_t i = 0; i < shape_size; i++) { + nhwc_dims.push_back(orig_shape[perm[i]]); + } - ORT_RETURN_IF_ERROR(cuda::Transpose::DoTranspose(GetDeviceProp(), DefaultCudaStream(), DefaultCublasHandle(), - permutation, tensor, *W_)); + W_ = Tensor::Create(tensor.DataType(), TensorShape(nhwc_dims), std::move(alloc)); + auto status = cuda::Transpose::DoTranspose(GetDeviceProp(), + DefaultCudaStream(), + DefaultCublasHandle(), + permutation, tensor, *W_); + if (!status.IsOK()) { + return status; + } CUDA_CALL_THROW(cudaStreamSynchronize(DefaultCudaStream())); is_packed = true; + } else { + W_already_nhwc = true; } } else { ORT_UNUSED_PARAMETER(tensor); @@ -91,236 +90,419 @@ Status ConvTranspose::PrePack(const Tensor& tensor, int input_idx, Allo return Status::OK(); } -template -Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const { - typedef typename ToCudaType::MappedType CudaT; +#if CUDNN_MAJOR >= 9 +#if !defined(__CUDACC__) + +template +Status ConvTranspose::CreateCudnnFeExecutionPlan(const onnxruntime::TensorShapeVector& x_dims, + const onnxruntime::TensorShapeVector& w_dims, + const Tensor* B, + const TensorShapeVector& y_dims, + cudnnContext* handle, + const cudnn_frontend::HeurMode_t heur_mode, + const std::vector& pads, + const std::vector& strides, + const std::vector& dilations, + const bool fuse_bias, + const bool fuse_act, + const bool w_in_nhwc, + const bool use_tf32) const { + s_.bias_fused = fuse_bias; + s_.act_fused = fuse_act; + s_.variant_pack.clear(); // clear variant pack, as stored pointers to tensors change + s_.cudnn_fe_graph = std::make_unique(); + cudnn_frontend::DataType_t data_type = CudnnFeTensor::GetDataType(); + s_.cudnn_fe_graph->set_io_data_type(data_type).set_intermediate_data_type(data_type); + if (data_type == cudnn_frontend::DataType_t::HALF) { + s_.cudnn_fe_graph->set_compute_data_type(cudnn_frontend::DataType_t::FLOAT); + } else { + s_.cudnn_fe_graph->set_compute_data_type(data_type); + } - const Tensor* X = context->Input(0); - const TensorShape& x_shape = X->Shape(); - auto x_dims = x_shape.AsShapeVector(); - auto x_data = reinterpret_cast(X->Data()); - - auto x_dimensions = X->Shape().NumDimensions(); - if (x_dimensions < 3 || x_dimensions > 5) { - // TODO: the error message should tell which operator raises it. - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input X must be 3-, 4- or 5-dimensional.", - " X: ", X->Shape().ToString().c_str()); + s_.cudnn_fe_X = s_.cudnn_fe_graph->tensor(CudnnFeTensor(x_dims, "x", data_type, Layout == LAYOUT_NHWC).Get()); + s_.cudnn_fe_W = s_.cudnn_fe_graph->tensor(CudnnFeTensor(w_dims, "w", data_type, w_in_nhwc).Get()); + + auto conv_options = cudnn_frontend::graph::Conv_dgrad_attributes() + .set_pre_padding(std::vector(pads.begin(), + pads.begin() + pads.size() / 2)) + .set_post_padding(std::vector(pads.begin() + pads.size() / 2, pads.end())) + .set_stride(strides) + .set_dilation(dilations); + s_.cudnn_fe_conv_Y = s_.cudnn_fe_graph->conv_dgrad(s_.cudnn_fe_X, s_.cudnn_fe_W, conv_options); + auto cudnn_fe_y_tensor = CudnnFeTensor(y_dims, "y", data_type, Layout == LAYOUT_NHWC).Get(); + + if (B == nullptr) { + s_.cudnn_fe_Y = s_.cudnn_fe_conv_Y; + } else { + int64_t bias_size; + if (B != nullptr) { + bias_size = B->Shape()[0]; + } else { + bias_size = w_dims[0]; + } + + if (fuse_bias) { + onnxruntime::TensorShapeVector b_dims; + for (size_t i = 0; i < x_dims.size(); i++) { + b_dims.push_back(i == 1 ? bias_size : 1); + } + auto bias_tensor = CudnnFeTensor(b_dims, "b", data_type, Layout == LAYOUT_NHWC).Get(); + auto bias_options = cudnn_frontend::graph::Pointwise_attributes().set_mode(cudnn_frontend::PointwiseMode_t::ADD); + s_.cudnn_fe_B = s_.cudnn_fe_graph->tensor(bias_tensor); + s_.cudnn_fe_Y = s_.cudnn_fe_graph->pointwise(s_.cudnn_fe_conv_Y, s_.cudnn_fe_B, bias_options); + } else { + s_.cudnn_fe_Y = s_.cudnn_fe_conv_Y; + + TensorShapeVector b_dims(y_dims.size(), 1); + TensorShapeVector b_strides(y_dims.size(), 1); + b_dims[1] = bias_size; + b_strides[0] = bias_size; + + ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType(), b_strides)); + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims, CudnnTensor::GetDataType(), cudnn_fe_y_tensor.get_stride())); + + /* Creating an own CUDNN Frontend graph for the bias addition. + s_.cudnn_fe_bias_graph = std::make_unique(); + s_.cudnn_fe_bias_graph->set_io_data_type(data_type) + .set_compute_data_type(data_type == cudnn_frontend::DataType_t::HALF ? + cudnn_frontend::DataType_t::FLOAT : data_type) + .set_intermediate_data_type(data_type); + s_.cudnn_fe_bias_X = s_.cudnn_fe_bias_graph->tensor(CudnnFeTensor(y_dims, "x", data_type).Get()); + + s_.cudnn_fe_B = s_.cudnn_fe_bias_graph->tensor(bias_tensor); + s_.cudnn_fe_bias_Y = s_.cudnn_fe_bias_graph->pointwise(s_.cudnn_fe_bias_X, s_.cudnn_fe_B, bias_options); + s_.cudnn_fe_bias_Y->set_output(true); + + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->validate()); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->build_operation_graph(handle)); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->create_execution_plans({heur_mode})); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->check_support(handle)); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->build_plans(handle));*/ + } + } + if (fuse_act && s_.cudnn_fe_act_attr.has_value()) { + auto& activation_attr = s_.cudnn_fe_act_attr.value(); + s_.cudnn_fe_Y = s_.cudnn_fe_graph->pointwise(s_.cudnn_fe_Y, activation_attr); } - // use pre-packed W if available - const Tensor* W = W_ ? W_.get() : context->Input(1); + s_.cudnn_fe_Y->set_dim(cudnn_fe_y_tensor.get_dim()); + s_.cudnn_fe_Y->set_stride(cudnn_fe_y_tensor.get_stride()); + s_.cudnn_fe_Y->set_output(true); + + try { + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->validate()); + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->build_operation_graph(handle)); + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->create_execution_plans({heur_mode})); + } catch (const std::exception& ex) { + std::string message = MakeString("Failed to initialize CUDNN Frontend", ex.what(), + "with the cudnn frontend json:\n", s_.cudnn_fe_graph->print()); + return Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::EP_FAIL, message); + } - const TensorShape& w_shape = W->Shape(); - TensorShapeVector w_dims = w_shape.AsShapeVector(); - auto w_data = reinterpret_cast(W->Data()); + if (!use_tf32) s_.cudnn_fe_graph->deselect_numeric_notes({cudnn_frontend::NumericalNote_t::TENSOR_CORE}); + + try { + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->check_support(handle)); + CUDNN_FE_CALL_THROW(s_.cudnn_fe_graph->build_plans(handle)); + } catch (const std::exception& ex) { + if (!fuse_bias && !fuse_act && use_tf32) { + std::string message = MakeString("OP not supported by CUDNN Frontend", ex.what(), + "with the cudnn frontend json:\n", s_.cudnn_fe_graph->print()); + return Status(common::StatusCategory::ONNXRUNTIME, common::StatusCode::EP_FAIL, message); + } + + // Try fallback. + return CreateCudnnFeExecutionPlan(x_dims, w_dims, B, y_dims, handle, heur_mode, + pads, strides, dilations, false, false, w_in_nhwc, true); + } + + s_.workspace_bytes = s_.cudnn_fe_graph->get_workspace_size(); + return Status::OK(); +} + +#endif + +template +Status ConvTranspose::UpdateState(OpKernelContext* context, bool dynamic_padding) const { + constexpr bool channels_last = Layout == LAYOUT_NHWC; size_t num_inputs = OpKernel::Node().InputDefs().size(); bool has_bias = dynamic_padding ? num_inputs == 4 : num_inputs == 3; - CudaT* y_data = nullptr; + // set X + const Tensor* X = context->Input(0); + const TensorShape& x_shape = X->Shape(); + // X incl. x_dims is in NHWC Format iff. NHWC == true + const auto x_dims = x_shape.AsShapeVector(); + + s_.x_data = reinterpret_cast(X->Data()); + s_.element_size = X->DataType()->Size(); + + // set W + bool w_in_nhwc; + const Tensor* W; + if (!W_) { + W = context->Input(1); + w_in_nhwc = false; + // Dims and memory layout are in NCHW format + } else { + W = W_.get(); + w_in_nhwc = channels_last; + // W got prepacked, therefore if NHWC == true, then dims and memory layout are in NHWC + } + const TensorShape& w_shape = W->Shape(); + onnxruntime::TensorShapeVector w_dims = w_shape.AsShapeVector(); + s_.w_data = reinterpret_cast(W->Data()); + + // set B + // Always in NCHW format + const Tensor* B = nullptr; + if (has_bias) { + B = context->Input(dynamic_padding ? 3 : 2); + s_.b_data = reinterpret_cast(B->Data()); + } else { + s_.b_data = nullptr; + } - const auto* cuda_ep = static_cast(Info().GetExecutionProvider()); + const Tensor* Pads = dynamic_padding ? context->Input(2) : nullptr; - // convert 1D to 2D - if (x_dimensions == 3) { - // we can either add a fake H or W dimension with a value of 1. to be consistent with the Conv behavior we use - // GetCudnnConv1dPadToNc1d to determine which is added. - // see Conv::UpdateState in /onnxruntime/core/providers/cuda/nn/conv.cc for more details. - if (cuda_ep->GetCudnnConv1dPadToNc1d()) { - // add fake H dimension - const auto insert_at = NHWC ? 1 : 2; + bool input_dims_changed = (s_.last_x_dims != x_dims); + bool w_dims_changed = (s_.last_w_dims != w_dims); + if (input_dims_changed || w_dims_changed) { + if (input_dims_changed) + s_.last_x_dims = gsl::make_span(x_dims); - // NCHW: N, C, d1 -> N, C, 1, d1 - // NHWC: N, d1, C -> N, 1, d1, C - x_dims.insert(x_dims.begin() + insert_at, 1); + if (w_dims_changed) { + s_.last_w_dims = gsl::make_span(w_dims); + } - // 'M' is channels dim in CUDA implementation - // NCHW: C, M/g, k1 -> C, M/g, 1, k1 - // NHWC: C, k1, M/g -> C, 1, k1, M/g - w_dims.insert(w_dims.begin() + insert_at, 1); - } else { - // add fake W dimension - const auto insert_at = NHWC ? 2 : 3; + // The following code is from ConvTransposeAttributes::PrepareForCompute - // NCHW: N, C, d1 -> N, C, d1, 1 - // NHWC: N, d1, C -> N, d1, 1, C - x_dims.insert(x_dims.begin() + insert_at, 1); + const int rank = static_cast(X->Shape().NumDimensions()); + TensorShape input_shape = X->Shape().Slice(channels_last ? 1 : 2, channels_last ? rank - 1 : rank); + const int64_t num_input_channels = channels_last ? X->Shape()[rank - 1] : X->Shape()[1]; + const int64_t N = X->Shape()[0]; + const int64_t num_output_channels_multiplier = w_in_nhwc ? w_shape[rank - 1] : w_shape[1]; + const int64_t num_output_channels = num_output_channels_multiplier * conv_transpose_attrs_.group; - // NCHW: C, M/g, k1 -> C, M/g, k1, 1 - // NHWC: C, k1, M/g -> C, k1, 1, M/g - w_dims.insert(w_dims.begin() + insert_at, 1); + if (conv_transpose_attrs_.group <= 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "group count is <= 0", + " group: ", conv_transpose_attrs_.group); } - } - { - std::lock_guard lock(s_.mutex); - // CUDNN_CONFIG_RETURN_IF_ERROR(cudnnSetStream(CudnnHandle(), Stream(context))); - // TODO: add a global cache if need to handle cases for multiple frames running simultaneously with different batch_size - bool input_dims_changed = (s_.last_x_dims.AsShapeVector() != x_dims); - bool w_dims_changed = (s_.last_w_dims.AsShapeVector() != w_dims); - if (input_dims_changed || w_dims_changed) { - if (input_dims_changed) { - s_.last_x_dims = gsl::make_span(x_dims); - } + if (X->Shape().NumDimensions() != w_shape.NumDimensions()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "X num_dims does not match W num_dims.", + " X: ", X->Shape().ToString().c_str(), + " W: ", w_shape.ToString().c_str()); + } - if (w_dims_changed) { - s_.last_w_dims = gsl::make_span(w_dims); - s_.cached_benchmark_results.clear(); - } + if (w_shape[0] != num_input_channels) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "filter number not equal to input channel number.", + " filter_number: ", w_shape[0], + " num_input_channels: ", num_input_channels); + } - ConvTransposeAttributes::Prepare p; - // PrePack moves the M/group dimension of W to the end, with 'M' being interpreted as 'output channels' - const bool transposed_input_channels = false; - ORT_RETURN_IF_ERROR( - conv_transpose_attrs_.PrepareForCompute(context, has_bias, p, dynamic_padding, &w_shape, NHWC, transposed_input_channels)); - - auto y_dims = p.Y->Shape().AsShapeVector(); - if (x_dimensions == 3) { - if (cuda_ep->GetCudnnConv1dPadToNc1d()) { - // add fake H dimension of 1 - // NCHW: N, M, d1 -> N, M, 1, d1 or - // NHWC: N, d1, M -> N, 1, d1, M - y_dims.insert(y_dims.begin() + (NHWC ? 1 : 2), 1); - p.kernel_shape.insert(p.kernel_shape.begin(), 1); - p.pads.insert(p.pads.begin(), 0); - p.pads.insert(p.pads.begin() + 2, 0); - p.strides.insert(p.strides.begin(), 1); - p.dilations.insert(p.dilations.begin(), 1); - } else { - // add fake W dimension of 1 - // NCHW: N, M, d1 -> N, M, d1, 1 or - // NHWC: N, d1, M -> N, d1, 1, M - y_dims.insert(y_dims.begin() + (NHWC ? 2 : 3), 1); - p.kernel_shape.push_back(1); - p.pads.insert(p.pads.begin() + 1, 0); - p.pads.push_back(0); - p.strides.push_back(1); - p.dilations.push_back(1); - } - } + // it looks like num_output_channels is really k*group similar to how in the conv case + // num_input_channels is k*group. hence removing the check for num_output_channels here. - s_.y_dims = gsl::make_span(y_dims); + if (num_input_channels % conv_transpose_attrs_.group != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input channels is not divisible by group.", + " num_input_channels: ", num_input_channels, + " group: ", conv_transpose_attrs_.group); + } - if (w_dims_changed) { - if constexpr (NHWC) { - ORT_RETURN_IF_ERROR(s_.w_desc.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), - static_cast(w_dims[0]), static_cast(w_dims[3]), - static_cast(w_dims[1]), static_cast(w_dims[2]))); - } else { - ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType())); - } - } + TensorShapeVector kernel_shape; + ORT_RETURN_IF_ERROR(conv_transpose_attrs_.ComputeKernelShape(w_shape, kernel_shape, w_in_nhwc)); - // Special case when there is a dim value of 0 in the shape. - // Return only after we have cached the following for subsequent runs : - // 1) `w_dims` in the `w_desc` - // 2) `y_dims` in s_.y_dims - if (p.Y->Shape().Size() == 0) { - return Status::OK(); - } + const size_t kernel_rank = kernel_shape.size(); - if constexpr (NHWC) { - ORT_RETURN_IF_ERROR(s_.x_tensor.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), - static_cast(x_dims[0]), static_cast(x_dims[3]), - static_cast(x_dims[1]), static_cast(x_dims[2]))); - ORT_RETURN_IF_ERROR(s_.y_tensor.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), - static_cast(y_dims[0]), static_cast(y_dims[3]), - static_cast(y_dims[1]), static_cast(y_dims[2]))); - } else { - ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims, CudnnTensor::GetDataType())); + TensorShapeVector local_output_padding(conv_transpose_attrs_.output_padding); + if (local_output_padding.empty()) { + local_output_padding.resize(kernel_shape.size(), 0); + } + ConvPadVector pads; + pads.reserve(2 * (input_shape.NumDimensions())); + if (dynamic_padding) { + for (int64_t i = 0; i < Pads->Shape().SizeFromDimension(0); ++i) { + pads.push_back(Pads->Data()[i]); } + } else { + pads.assign(conv_transpose_attrs_.pads.begin(), conv_transpose_attrs_.pads.end()); + } + if (pads.empty()) { + pads.resize(kernel_shape.size() * 2, 0); + } + TensorShapeVector dilations(conv_transpose_attrs_.dilations); + if (dilations.empty()) { + dilations.resize(kernel_shape.size(), 1); + } + TensorShapeVector strides(conv_transpose_attrs_.strides); + if (strides.empty()) { + strides.resize(kernel_shape.size(), 1); + } - cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION; - ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations, - gsl::narrow_cast(conv_transpose_attrs_.group), mode, - CudnnTensor::GetDataType(), - UseTF32())); - - if (has_bias) { - const auto& b_shape = p.B->Shape(); - ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D"); - TensorShapeVector b_dims(2 + p.kernel_shape.size()); - b_dims[0] = 1; // N - b_dims[NHWC ? 3 : 1] = b_shape[0]; // C - for (size_t i = 0; i < p.kernel_shape.size(); i++) { - b_dims[(NHWC ? 1 : 2) + i] = 1; - } - - ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType(), NHWC)); - } + TensorShapeVector y_dims; - y_data = reinterpret_cast(p.Y->MutableData()); - - if (!s_.cached_benchmark_results.contains(x_dims)) { - IAllocatorUniquePtr algo_search_workspace = - GetScratchBuffer(AlgoSearchWorkspaceSize, context->GetComputeStream()); - - // set math type to tensor core before algorithm search - if constexpr (std::is_same::value) { - CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); - } else if constexpr (std::is_same::value) { - if (!UseTF32()) { - CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH)); - } - } - - cudnnConvolutionBwdDataAlgoPerf_t perf; - int algo_count = 1; - CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardDataAlgorithmEx( - GetCudnnHandle(context), s_.w_desc, w_data, s_.x_tensor, x_data, s_.conv_desc, s_.y_tensor, y_data, 1, - &algo_count, &perf, algo_search_workspace.get(), AlgoSearchWorkspaceSize)); - s_.cached_benchmark_results.insert(x_dims, {perf.algo, perf.memory, perf.mathType}); - } + conv_transpose_attrs_.ComputePadsAndOutputShape(input_shape, num_output_channels, kernel_shape, + strides, dilations, local_output_padding, N, &pads, &y_dims, channels_last); - const auto& perf = s_.cached_benchmark_results.at(x_dims); - CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, perf.mathType)); - s_.algo = perf.algo; - s_.workspace_bytes = perf.memory; - } + s_.y_dims = gsl::make_span(y_dims); + s_.Y = context->Output(0, s_.y_dims); - // The following block will be executed in case there has been no change in the shapes of the - // input and the filter compared to the previous run - if (!y_data) { - auto y_dims = s_.y_dims.AsShapeVector(); - if (x_dimensions == 3) { - if (cuda_ep->GetCudnnConv1dPadToNc1d()) { - // erase the fake H dimension - y_dims.erase(y_dims.begin() + (NHWC ? 1 : 2)); - } else { - // erase the fake W dimension - y_dims.erase(y_dims.begin() + (NHWC ? 2 : 3)); - } - } + s_.y_data = reinterpret_cast(s_.Y->MutableData()); + const CUDAExecutionProvider* cuda_ep = + static_cast(this->Info().GetExecutionProvider()); - Tensor* Y = context->Output(0, TensorShape(y_dims)); - y_data = reinterpret_cast(Y->MutableData()); + TensorShapeVector x_dims_cudnn{x_dims.begin(), x_dims.end()}; + TensorShapeVector y_dims_cudnn{y_dims.begin(), y_dims.end()}; + TensorShapeVector w_dims_cudnn{w_dims.begin(), w_dims.end()}; - // Bail out early if one of the output dimensions is zero. - if (Y->Shape().Size() == 0) { - return Status::OK(); + if constexpr (channels_last) { + x_dims_cudnn.insert(x_dims_cudnn.begin() + 1, *(x_dims_cudnn.end() - 1)); + y_dims_cudnn.insert(y_dims_cudnn.begin() + 1, *(y_dims_cudnn.end() - 1)); + x_dims_cudnn.erase(x_dims_cudnn.end() - 1); + y_dims_cudnn.erase(y_dims_cudnn.end() - 1); + + if (w_in_nhwc) { + w_dims_cudnn.insert(w_dims_cudnn.begin() + 1, *(w_dims_cudnn.end() - 1)); + w_dims_cudnn.erase(w_dims_cudnn.end() - 1); } } - const auto alpha = Consts::One; - const auto beta = Consts::Zero; + if (kernel_rank < 2) { + // TODO: Explore padding the provided input shape [N, C, D] to [N, C, 1, D] + // especially for EXHAUSTIVE algo search which may result in a better algo selection. + // ORTModule uses different algo search options (HEURISTIC, and use max workspace size) compared to + // inference build (EXHAUSTIVE, 32M workspace size). We observed better perf when we pad input shape + // [N,C,D] to [N,C,1,D], expecially on A100, and especially for ConvGrad. + // PyTorch also pads to [N,C,1,D]. For inference build, we still pad it to [N, C, D, 1] as this seems + // to be the sweet spot for all algo search options: EXHAUSTIVE, HEURISTIC, and DEFAULT. + // See PR #7348 and #7702 for more context. + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1); + y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1); + w_dims_cudnn.insert(w_dims_cudnn.begin() + 2, 1); + pads.insert(pads.begin() + kernel_rank, 0); + pads.insert(pads.begin(), 0); + kernel_shape.insert(kernel_shape.begin(), 1); + strides.insert(strides.begin(), 1); + dilations.insert(dilations.begin(), 1); + } else { + x_dims_cudnn.push_back(1); + y_dims_cudnn.push_back(1); + w_dims_cudnn.push_back(1); + pads.insert(pads.begin() + kernel_rank, 0); + pads.insert(pads.end(), 0); + kernel_shape.push_back(1); + strides.push_back(1); + dilations.push_back(1); + } + } + + // We must delay returning early until here so that the weight dims have been cached properly + if (s_.Y->Shape().Size() == 0) { + return Status::OK(); + } - IAllocatorUniquePtr workspace = GetScratchBuffer(s_.workspace_bytes, context->GetComputeStream()); + auto handle = GetCudnnHandle(context); + + int cudnn_conv_algo = cuda_ep->GetCudnnConvAlgo(); +#if !defined(__CUDACC__) + cudnn_frontend::HeurMode_t heur_mode; + switch (cudnn_conv_algo) { + case 0: + heur_mode = cudnn_frontend::HeurMode_t::B; + break; + case 1: + heur_mode = cudnn_frontend::HeurMode_t::A; + break; + case 2: + heur_mode = cudnn_frontend::HeurMode_t::FALLBACK; + break; + default: + heur_mode = cudnn_frontend::HeurMode_t::A; + break; + } - CUDNN_RETURN_IF_ERROR(cudnnConvolutionBackwardData(GetCudnnHandle(context), &alpha, s_.w_desc, w_data, s_.x_tensor, - x_data, s_.conv_desc, s_.algo, workspace.get(), - s_.workspace_bytes, &beta, s_.y_tensor, y_data)); + auto use_tf32 = cuda_ep->UseTF32(); + const auto fuse_bias = cuda_ep->IsFuseConvBias() || is_fused_node_; + const auto fuse_act = is_fused_node_; + + ORT_RETURN_IF_ERROR(CreateCudnnFeExecutionPlan(x_dims_cudnn, w_dims_cudnn, B, y_dims_cudnn, handle, heur_mode, + std::vector(pads.begin(), + pads.end()), + std::vector(strides.begin(), + strides.end()), + std::vector(dilations.begin(), + dilations.end()), + fuse_bias, fuse_act, w_in_nhwc, use_tf32)); +#endif + } else { + // set Y + s_.Y = context->Output(0, s_.y_dims); + if (s_.Y->Shape().Size() == 0) { + return Status::OK(); + } + s_.y_data = reinterpret_cast(s_.Y->MutableData()); + } + return Status::OK(); +} - if (has_bias) { - const Tensor* B = dynamic_padding ? context->Input(3) : context->Input(2); - auto b_data = reinterpret_cast(B->Data()); - CUDNN_RETURN_IF_ERROR( - cudnnAddTensor(GetCudnnHandle(context), &alpha, s_.b_tensor, b_data, &alpha, s_.y_tensor, y_data)); +template +Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const { + std::lock_guard lock(s_.mutex); + ORT_RETURN_IF_ERROR(UpdateState(context, dynamic_padding)); + if (s_.Y->Shape().Size() == 0) { + return Status::OK(); + } + const auto alpha = onnxruntime::cuda::Consts::One; + auto cudnn_handle = GetCudnnHandle(context); +#if !defined(__CUDACC__) + s_.variant_pack.insert_or_assign(s_.cudnn_fe_X, const_cast(s_.x_data)); + s_.variant_pack.insert_or_assign(s_.cudnn_fe_W, const_cast(s_.w_data)); + s_.variant_pack.insert_or_assign(s_.cudnn_fe_Y, s_.y_data); + if (s_.bias_fused && s_.b_data != nullptr) { + s_.variant_pack.insert_or_assign(s_.cudnn_fe_B, const_cast(s_.b_data)); + } + if (s_.bias_fused && s_.z_data != nullptr) { + s_.variant_pack.insert_or_assign(s_.cudnn_fe_Z, const_cast(s_.z_data)); + if (Layout == LAYOUT_NCHW && s_.z_data == s_.y_data) { + // memset Z if it's required for a succesful fusion + CUDA_RETURN_IF_ERROR(cudaMemset(s_.y_data, 0, s_.Y->SizeInBytes())); } } + auto ws = GetWorkSpace(context->GetComputeStream()); + + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_graph->execute(cudnn_handle, + s_.variant_pack, + ws.get())); + + if (!s_.bias_fused && s_.z_data != nullptr) { + CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnn_handle, &alpha, s_.z_tensor, s_.z_data, + &alpha, s_.y_tensor, s_.y_data)); + } + if (!s_.bias_fused && s_.b_data != nullptr) { + CUDNN_RETURN_IF_ERROR(cudnnAddTensor(cudnn_handle, &alpha, s_.b_tensor, s_.b_data, + &alpha, s_.y_tensor, s_.y_data)); + + /* For the standalone bias addition graph. + s_.variant_pack_bias.insert_or_assign(s_.cudnn_fe_bias_X, s_.y_data); + s_.variant_pack_bias.insert_or_assign(s_.cudnn_fe_bias_Y, s_.y_data); + s_.variant_pack_bias.insert_or_assign(s_.cudnn_fe_B, const_cast(s_.b_data)); + CUDNN_FE_RETURN_IF_ERROR(s_.cudnn_fe_bias_graph->execute(cudnn_handle, + s_.variant_pack_bias, + GetWorkSpace(context->GetComputeStream()).get()));*/ + } +#endif return Status::OK(); } +#endif + +template +Status ConvTranspose::ComputeInternal(OpKernelContext* context) const { + return DoConvTranspose(context, false); +} } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.h b/onnxruntime/core/providers/cuda/nn/conv_transpose.h index 71ad3ee6e2147..1a6957164d22f 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.h +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.h @@ -18,6 +18,8 @@ namespace cuda { template class ConvTranspose : public CudaKernel { public: + using CudaT = typename ToCudaType::MappedType; + ConvTranspose(const OpKernelInfo& info) : CudaKernel(info), conv_transpose_attrs_(info) {}; Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) override; @@ -29,6 +31,33 @@ class ConvTranspose : public CudaKernel { mutable CudnnConvState s_; std::unique_ptr W_; + + bool is_nhwc_domain_; // prepack is only needed for the Conv in kMSInternalNHWCDomain + bool is_fused_node_ = false; // ensures the node is fused although the session option is not set + bool W_already_nhwc = false; // In case NHWC == true and Conv is not in kMSInternalNHWCDomain + + protected: + inline IAllocatorUniquePtr GetWorkSpace(onnxruntime::Stream* stream) const { + return GetScratchBuffer(s_.workspace_bytes, stream); + } + + Status UpdateState(OpKernelContext* context, bool bias_expected) const; + +#if !defined(__CUDACC__) && CUDNN_MAJOR >= 9 + Status CreateCudnnFeExecutionPlan(const onnxruntime::TensorShapeVector& x_dims, + const onnxruntime::TensorShapeVector& w_dims, + const Tensor* B, + const TensorShapeVector& y_dims, + cudnnContext* handle, + const cudnn_frontend::HeurMode_t heur_mode, + const std::vector& pads, + const std::vector& strides, + const std::vector& dilations, + const bool fuse_bias, + const bool fuse_act, + const bool w_in_nhwc, + const bool use_tf32) const; +#endif }; } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h b/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h new file mode 100644 index 0000000000000..b46d41b887e41 --- /dev/null +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose_8.h @@ -0,0 +1,266 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) 2023 NVIDIA Corporation. +// Licensed under the MIT License. + +#include + +#include "conv_transpose.h" +#include + +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/cudnn_common.h" +#include "core/providers/cuda/nn/conv.h" +#include "core/providers/cpu/nn/conv_transpose_attributes.h" + +#include "core/providers/cuda/tensor/transpose.h" + +// To suppress FP static analyzer warnings: +// https://msdata.visualstudio.com/Vienna/_workitems/edit/1944928 and +// https://msdata.visualstudio.com/Vienna/_workitems/edit/1944950 +#ifdef _WIN32 +#pragma warning(push) +#pragma warning(disable : 26110) +#pragma warning(disable : 26117) +#endif + +namespace onnxruntime { +namespace cuda { + +template +Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_padding) const { + const Tensor* X = context->Input(0); + const TensorShape& x_shape = X->Shape(); + auto x_dims = x_shape.AsShapeVector(); + auto x_data = reinterpret_cast(X->Data()); + + auto x_dimensions = X->Shape().NumDimensions(); + if (x_dimensions < 3 || x_dimensions > 5) { + // TODO: the error message should tell which operator raises it. + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input X must be 3-, 4- or 5-dimensional.", + " X: ", X->Shape().ToString().c_str()); + } + + // use pre-packed W if available + const Tensor* W = W_ ? W_.get() : context->Input(1); + + const TensorShape& w_shape = W->Shape(); + TensorShapeVector w_dims = w_shape.AsShapeVector(); + auto w_data = reinterpret_cast(W->Data()); + + size_t num_inputs = OpKernel::Node().InputDefs().size(); + bool has_bias = dynamic_padding ? num_inputs == 4 : num_inputs == 3; + + CudaT* y_data = nullptr; + + const auto* cuda_ep = static_cast(Info().GetExecutionProvider()); + + // convert 1D to 2D + if (x_dimensions == 3) { + // we can either add a fake H or W dimension with a value of 1. to be consistent with the Conv behavior we use + // GetCudnnConv1dPadToNc1d to determine which is added. + // see Conv::UpdateState in /onnxruntime/core/providers/cuda/nn/conv.cc for more details. + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + // add fake H dimension + const auto insert_at = NHWC ? 1 : 2; + + // NCHW: N, C, d1 -> N, C, 1, d1 + // NHWC: N, d1, C -> N, 1, d1, C + x_dims.insert(x_dims.begin() + insert_at, 1); + + // 'M' is channels dim in CUDA implementation + // NCHW: C, M/g, k1 -> C, M/g, 1, k1 + // NHWC: C, k1, M/g -> C, 1, k1, M/g + w_dims.insert(w_dims.begin() + insert_at, 1); + } else { + // add fake W dimension + const auto insert_at = NHWC ? 2 : 3; + + // NCHW: N, C, d1 -> N, C, d1, 1 + // NHWC: N, d1, C -> N, d1, 1, C + x_dims.insert(x_dims.begin() + insert_at, 1); + + // NCHW: C, M/g, k1 -> C, M/g, k1, 1 + // NHWC: C, k1, M/g -> C, k1, 1, M/g + w_dims.insert(w_dims.begin() + insert_at, 1); + } + } + + { + std::lock_guard lock(s_.mutex); + // CUDNN_CONFIG_RETURN_IF_ERROR(cudnnSetStream(CudnnHandle(), Stream(context))); + // TODO: add a global cache if need to handle cases for multiple frames running simultaneously with + // different batch_size + bool input_dims_changed = (s_.last_x_dims.AsShapeVector() != x_dims); + bool w_dims_changed = (s_.last_w_dims.AsShapeVector() != w_dims); + if (input_dims_changed || w_dims_changed) { + if (input_dims_changed) { + s_.last_x_dims = gsl::make_span(x_dims); + } + + if (w_dims_changed) { + s_.last_w_dims = gsl::make_span(w_dims); + s_.cached_benchmark_results.clear(); + } + + ConvTransposeAttributes::Prepare p; + // PrePack moves the M/group dimension of W to the end, with 'M' being interpreted as 'output channels' + const bool transposed_input_channels = false; + ORT_RETURN_IF_ERROR( + conv_transpose_attrs_.PrepareForCompute(context, has_bias, p, dynamic_padding, + &w_shape, NHWC, transposed_input_channels)); + + auto y_dims = p.Y->Shape().AsShapeVector(); + if (x_dimensions == 3) { + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + // add fake H dimension of 1 + // NCHW: N, M, d1 -> N, M, 1, d1 or + // NHWC: N, d1, M -> N, 1, d1, M + y_dims.insert(y_dims.begin() + (NHWC ? 1 : 2), 1); + p.kernel_shape.insert(p.kernel_shape.begin(), 1); + p.pads.insert(p.pads.begin(), 0); + p.pads.insert(p.pads.begin() + 2, 0); + p.strides.insert(p.strides.begin(), 1); + p.dilations.insert(p.dilations.begin(), 1); + } else { + // add fake W dimension of 1 + // NCHW: N, M, d1 -> N, M, d1, 1 or + // NHWC: N, d1, M -> N, d1, 1, M + y_dims.insert(y_dims.begin() + (NHWC ? 2 : 3), 1); + p.kernel_shape.push_back(1); + p.pads.insert(p.pads.begin() + 1, 0); + p.pads.push_back(0); + p.strides.push_back(1); + p.dilations.push_back(1); + } + } + + s_.y_dims = gsl::make_span(y_dims); + + if (w_dims_changed) { + if constexpr (NHWC) { + ORT_RETURN_IF_ERROR(s_.w_desc.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), + static_cast(w_dims[0]), static_cast(w_dims[3]), + static_cast(w_dims[1]), static_cast(w_dims[2]))); + } else { + ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType())); + } + } + + // Special case when there is a dim value of 0 in the shape. + // Return only after we have cached the following for subsequent runs : + // 1) `w_dims` in the `w_desc` + // 2) `y_dims` in s_.y_dims + if (p.Y->Shape().Size() == 0) { + return Status::OK(); + } + + if constexpr (NHWC) { + ORT_RETURN_IF_ERROR(s_.x_tensor.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), + static_cast(x_dims[0]), static_cast(x_dims[3]), + static_cast(x_dims[1]), static_cast(x_dims[2]))); + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(CUDNN_TENSOR_NHWC, CudnnTensor::GetDataType(), + static_cast(y_dims[0]), static_cast(y_dims[3]), + static_cast(y_dims[1]), static_cast(y_dims[2]))); + } else { + ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims, CudnnTensor::GetDataType())); + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims, CudnnTensor::GetDataType())); + } + + cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION; + ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations, + gsl::narrow_cast(conv_transpose_attrs_.group), mode, + CudnnTensor::GetDataType(), + UseTF32())); + + if (has_bias) { + const auto& b_shape = p.B->Shape(); + ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D"); + TensorShapeVector b_dims(2 + p.kernel_shape.size()); + b_dims[0] = 1; // N + b_dims[NHWC ? 3 : 1] = b_shape[0]; // C + for (size_t i = 0; i < p.kernel_shape.size(); i++) { + b_dims[(NHWC ? 1 : 2) + i] = 1; + } + + ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType(), NHWC)); + } + + y_data = reinterpret_cast(p.Y->MutableData()); + + if (!s_.cached_benchmark_results.contains(x_dims)) { + IAllocatorUniquePtr algo_search_workspace = + GetScratchBuffer(AlgoSearchWorkspaceSize, context->GetComputeStream()); + + // set math type to tensor core before algorithm search + if constexpr (std::is_same::value) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); + } else if constexpr (std::is_same::value) { + if (!UseTF32()) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH)); + } + } + + cudnnConvolutionBwdDataAlgoPerf_t perf; + int algo_count = 1; + CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardDataAlgorithmEx( + GetCudnnHandle(context), s_.w_desc, w_data, s_.x_tensor, x_data, s_.conv_desc, s_.y_tensor, y_data, 1, + &algo_count, &perf, algo_search_workspace.get(), AlgoSearchWorkspaceSize)); + s_.cached_benchmark_results.insert(x_dims, {perf.algo, perf.memory, perf.mathType}); + } + + const auto& perf = s_.cached_benchmark_results.at(x_dims); + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, perf.mathType)); + s_.algo = perf.algo; + s_.workspace_bytes = perf.memory; + } + + // The following block will be executed in case there has been no change in the shapes of the + // input and the filter compared to the previous run + if (!y_data) { + auto y_dims = s_.y_dims.AsShapeVector(); + if (x_dimensions == 3) { + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + // erase the fake H dimension + y_dims.erase(y_dims.begin() + (NHWC ? 1 : 2)); + } else { + // erase the fake W dimension + y_dims.erase(y_dims.begin() + (NHWC ? 2 : 3)); + } + } + + Tensor* Y = context->Output(0, TensorShape(y_dims)); + y_data = reinterpret_cast(Y->MutableData()); + + // Bail out early if one of the output dimensions is zero. + if (Y->Shape().Size() == 0) { + return Status::OK(); + } + } + + const auto alpha = Consts::One; + const auto beta = Consts::Zero; + + IAllocatorUniquePtr workspace = GetScratchBuffer(s_.workspace_bytes, context->GetComputeStream()); + + CUDNN_RETURN_IF_ERROR(cudnnConvolutionBackwardData(GetCudnnHandle(context), &alpha, s_.w_desc, w_data, + s_.x_tensor, x_data, s_.conv_desc, s_.algo, workspace.get(), + s_.workspace_bytes, &beta, s_.y_tensor, y_data)); + + if (has_bias) { + const Tensor* B = dynamic_padding ? context->Input(3) : context->Input(2); + auto b_data = reinterpret_cast(B->Data()); + CUDNN_RETURN_IF_ERROR( + cudnnAddTensor(GetCudnnHandle(context), &alpha, s_.b_tensor, b_data, &alpha, s_.y_tensor, y_data)); + } + } + + return Status::OK(); +} + +} // namespace cuda +} // namespace onnxruntime + +#ifdef _WIN32 +#pragma warning(pop) +#endif From e91ff9438bd368300b7d6d95aabfffe8cb9547b6 Mon Sep 17 00:00:00 2001 From: Yi-Hong Lyu Date: Wed, 11 Sep 2024 09:54:15 -0700 Subject: [PATCH 08/39] Enable Pad->Conv(no pads) fusion (#22001) ### Description ### Motivation and Context For some model has pattern Pad -> Conv. If the Conv doesn't have pads attributes, the Pad can be fused into Conv. --- onnxruntime/core/optimizer/pad_fusion.cc | 12 ++--- .../test/optimizer/graph_transform_test.cc | 48 ++++++++++++++++++ .../transform/fusion/fuse-pad-nopadsconv.onnx | Bin 0 -> 397 bytes 3 files changed, 54 insertions(+), 6 deletions(-) create mode 100644 onnxruntime/test/testdata/transform/fusion/fuse-pad-nopadsconv.onnx diff --git a/onnxruntime/core/optimizer/pad_fusion.cc b/onnxruntime/core/optimizer/pad_fusion.cc index 3391e20cf0bb7..25afed52403c4 100644 --- a/onnxruntime/core/optimizer/pad_fusion.cc +++ b/onnxruntime/core/optimizer/pad_fusion.cc @@ -31,15 +31,15 @@ bool VerifyNotCastChild(const Node& child_node) { return false; } - // This pass currently assumed that this attribute already exists on the child node - if (child_node.GetAttributes().find("pads") == child_node.GetAttributes().end()) { - return false; - } - return true; } void UpdatePaddingAttribute(Node& child_node, const std::vector& pads_values, const uint32_t pads_size) { + if (child_node.GetAttributes().find("pads") == child_node.GetAttributes().end()) { + std::vector pads(pads_size - 4, 0); + child_node.AddAttribute("pads", pads); + } + auto child_pads = child_node.GetMutableAttributes()["pads"].mutable_ints(); uint32_t child_pads_size = static_cast(child_pads->size()); @@ -162,4 +162,4 @@ Status PadFusion::Apply(Graph& graph, Node& pad_node, RewriteRuleEffect& rule_ef rule_effect = RewriteRuleEffect::kRemovedCurrentNode; return Status::OK(); } -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 6ae66e35e7853..3aec0d5a67e94 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1469,6 +1469,54 @@ TEST_F(GraphTransformationTests, FusePadWithConv) { } } +TEST_F(GraphTransformationTests, FusePadWithNoPadsConv) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-pad-nopadsconv.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::vector expected_pads; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "Pad") { + const auto* pads_proto = graph_utils::GetConstantInitializer(graph, node.InputDefs()[1]->Name()); + Initializer pads{*pads_proto, graph.ModelPath()}; + gsl::span pads_values = pads.DataAsSpan(); + expected_pads.resize(pads_values.size() - 4); + + for (uint32_t pads_index = 2, index = 0; pads_index < pads_values.size() / 2; pads_index++, index++) { + expected_pads[index] = pads_values[pads_index]; + expected_pads[index + (expected_pads.size() / 2)] = pads_values[pads_index + (pads_values.size() / 2)]; + } + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Pad"], 0); + ASSERT_EQ(op_to_count["Conv"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Conv") { + auto child_pads = node.GetMutableAttributes()["pads"].mutable_ints(); + ASSERT_EQ(child_pads->size(), static_cast(expected_pads.size())) + << "fusion should produce the same size of pads integer as the Conv node"; + for (uint32_t index = 0; index < expected_pads.size(); index++) { + ASSERT_EQ(expected_pads[index], child_pads->Get(index)) + << "fusion does not produce correct padding value"; + } + } + } +} + TEST_F(GraphTransformationTests, FusePadWithMaxPool) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-pad-maxpool.onnx"; diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-pad-nopadsconv.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-pad-nopadsconv.onnx new file mode 100644 index 0000000000000000000000000000000000000000..145847cdc47fadd7a2feaaefadf5eeef5c1be270 GIT binary patch literal 397 zcmduN`%;;npvEIflD0$fZSj8M!9q*Q9XgFj2WnmIZ1&F>OCw5u>)0b LI Date: Wed, 11 Sep 2024 19:41:04 +0200 Subject: [PATCH 09/39] Improve hash_function used by TreeEnsemble (#22043) ### Description unordered_map are implemented in a different way on VisualStudio and gcc. It seems that inserting consecutive keys has a poor performance on Windows. ### Motivation and Context Improve the performance of onnxruntime when initializing trees. --- onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h index b9f3050e59c5b..34c6db61982b5 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h @@ -23,7 +23,9 @@ struct TreeNodeElementId { } struct hash_fn { std::size_t operator()(const TreeNodeElementId& key) const { - return static_cast(static_cast(key.tree_id) << 32 | static_cast(key.node_id)); + // unordered_map has poor performance on Windows when inserting consecutive keys. + // keys are usually inserted with key.node_id being incremented at each iteration. + return static_cast(static_cast(key.tree_id) | static_cast(key.node_id) << 32); } }; }; From 4d824045444756ba70223c32ae11693a252adde6 Mon Sep 17 00:00:00 2001 From: Bin Miao Date: Thu, 12 Sep 2024 05:16:36 +0800 Subject: [PATCH 10/39] [WebNN EP] Support GRU operator (#20405) This PR support Gru operator for WebNN EP. @Honry , @fdwr thanks! --- js/web/docs/webnn-operators.md | 1 + js/web/test/suite-test-list.jsonc | 6 +- .../core/providers/shared/utils/utils.cc | 9 + .../core/providers/shared/utils/utils.h | 1 + .../core/providers/webnn/builders/helper.h | 1 + .../webnn/builders/impl/gru_op_builder.cc | 250 ++++++++++++++++++ .../webnn/builders/op_builder_factory.cc | 4 + .../webnn/builders/op_builder_factory.h | 1 + 8 files changed, 270 insertions(+), 3 deletions(-) create mode 100644 onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index 48b06b780dfc7..164096b4fda9a 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -41,6 +41,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | GlobalLpPool| ai.onnx(7+) | l2Pool2d | ✗ | ✓ | Only supports 4-D input, 'p' value is 2 | | Greater | ai.onnx(7-8, 9-12, 13+) | greater | ✓ | ✓ | | | GreaterOrEqual | ai.onnx(12-15, 16+) | greaterOrEqual | ✓ | ✓ | | +| GRU | ai.onnx(7-13, 14-21, 22+) | gru | ✓ | ✓ | Only supports 'layout' == 0. 'clip' is not supported. The activation functions in 'activations' must be one of 'Relu', 'Tanh', 'Sigmoid'. Forward and backward activations must be the same if bidirectional. 'sequence_lens' if present should be constant with values equal to the first dimension length of input 'X' | | HardSigmoid | ai.onnx(7+) | hardSigmoid | ✓ | ✓ | | | HardSwish | ai.onnx(14+) | hardSwish | ✓ | ✓ | | | Identity | ai.onnx(7-13, 14-15, 16-18, 19-20, 21+) | identity | ✓ | ✓ | | diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 7f0c1cc3e420c..5c1e2e27a6eff 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1812,9 +1812,9 @@ // // "test_gridsample_zeros_padding", // // "test_gridsample", // // "test_gru_batchwise", - // // "test_gru_defaults", - // // "test_gru_seq_length", - // // "test_gru_with_initial_bias", + "test_gru_defaults", + "test_gru_seq_length", + "test_gru_with_initial_bias", // // "test_hammingwindow_expanded", // // "test_hammingwindow_symmetric_expanded", // // "test_hammingwindow_symmetric", diff --git a/onnxruntime/core/providers/shared/utils/utils.cc b/onnxruntime/core/providers/shared/utils/utils.cc index 2088618538de5..5b2f2c1fa1b2e 100644 --- a/onnxruntime/core/providers/shared/utils/utils.cc +++ b/onnxruntime/core/providers/shared/utils/utils.cc @@ -192,6 +192,15 @@ std::vector NodeAttrHelper::Get(const std::string& key, const std::vect return def_val; } +std::vector NodeAttrHelper::Get(const std::string& key, const std::vector& def_val) const { + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + const auto& values = entry->second.strings(); + return std::vector{values.cbegin(), values.cend()}; + } + + return def_val; +} + std::vector NodeAttrHelper::Get(const std::string& key, const std::vector& def_val) const { if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { const auto& values = entry->second.floats(); diff --git a/onnxruntime/core/providers/shared/utils/utils.h b/onnxruntime/core/providers/shared/utils/utils.h index 5813dcc48d72b..ddbae42534711 100644 --- a/onnxruntime/core/providers/shared/utils/utils.h +++ b/onnxruntime/core/providers/shared/utils/utils.h @@ -57,6 +57,7 @@ class NodeAttrHelper { std::vector Get(const std::string& key, const std::vector& def_val) const; const std::string& Get(const std::string& key, const std::string& def_val) const; + std::vector Get(const std::string& key, const std::vector& def_val) const; // Convert the i() or ints() of the attribute from int64_t to int32_t int32_t Get(const std::string& key, int32_t def_val) const; diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 4d723a3c59ee2..b51092619db22 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -183,6 +183,7 @@ static const InlinedHashMap op_map = { {"GlobalLpPool", "l2Pool2d"}, {"Greater", "greater"}, {"GreaterOrEqual", "greaterOrEqual"}, + {"Gru", "gru"}, {"HardSigmoid", "hardSigmoid"}, {"HardSwish", "hardSwish"}, {"Identity", "identity"}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc new file mode 100644 index 0000000000000..23cc7f1b11459 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc @@ -0,0 +1,250 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class GruOpBuilder : public BaseOpBuilder { + // Add operator related. + public: + void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; + + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + const logging::Logger& logger) const override; +}; + +void GruOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + if (node.InputDefs().size() > 4 && node.InputDefs()[4]->Exists()) { + model_builder.AddInitializerToSkip(node.InputDefs()[4]->Name()); // sequence_lens + model_builder.AddInputToSkip(node.InputDefs()[4]->Name()); + } +} + +Status GruOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const { + NodeAttrHelper helper(node); + uint32_t hidden_size = helper.Get("hidden_size", 1); + + const auto& input_defs = node.InputDefs(); + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input's shape"); + uint32_t steps = static_cast(input_shape[0]); + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val weight = model_builder.GetOperand(input_defs[1]->Name()); + emscripten::val recurrent_weight = model_builder.GetOperand(input_defs[2]->Name()); + + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + options.set("layout", emscripten::val("zrn")); + + if (input_defs.size() > 3 && input_defs[3]->Exists()) { + emscripten::val bias = model_builder.GetOperand(input_defs[3]->Name()); + emscripten::val split_options = emscripten::val::object(); + split_options.set("label", node.Name() + "_split"); + split_options.set("axis", 1); + // Split it to bias and recurrentBias. + emscripten::val splitted_biases = + model_builder.GetBuilder().call("split", bias, /*splits*/ 2, split_options); + options.set("bias", splitted_biases[0]); + options.set("recurrentBias", splitted_biases[1]); + } + + if (input_defs.size() > 5 && input_defs[5]->Exists()) { + options.set("initialHiddenState", model_builder.GetOperand(input_defs[5]->Name())); + } + + bool linear_before_reset = !!helper.Get("linear_before_reset ", 0); + options.set("resetAfter", linear_before_reset); + + const auto& output_defs = node.OutputDefs(); + bool has_Y = output_defs.size() > 0 && output_defs[0]->Exists(); + bool has_Y_h = output_defs.size() > 1 && output_defs[1]->Exists(); + options.set("returnSequence", has_Y); + + std::string direction = helper.Get("direction", "forward"); + if (direction == "forward") { + options.set("direction", emscripten::val("forward")); + } else if (direction == "reverse") { + options.set("direction", emscripten::val("backward")); + } else if (direction == "bidirectional") { + options.set("direction", emscripten::val("both")); + } + + if (helper.HasAttr("activations")) { + const auto activations = helper.Get("activations", std::vector{"Sigmoid", "Tanh"}); + emscripten::val recurrent_network_activations = emscripten::val::array(); + for (size_t i = 0; i < 2; ++i) { + const std::string& activation = activations[i]; + if (activation == "Relu") { + recurrent_network_activations.call("push", emscripten::val("relu")); + } else if (activation == "Sigmoid") { + recurrent_network_activations.call("push", emscripten::val("sigmoid")); + } else if (activation == "Tanh") { + recurrent_network_activations.call("push", emscripten::val("tanh")); + } + } + + options.set("activations", recurrent_network_activations); + } + + emscripten::val outputs = model_builder.GetBuilder().call("gru", input, weight, recurrent_weight, + steps, hidden_size, options); + + if (has_Y) { + model_builder.AddOperand(output_defs[0]->Name(), outputs[1]); + } + if (has_Y_h) { + model_builder.AddOperand(output_defs[1]->Name(), outputs[0]); + } + + return Status::OK(); +} + +bool GruOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + if (input_defs.size() < 3) { + LOGS(logger, ERROR) << "GRU: input size must greater than or equal to 3"; + return false; + } + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger) || input_shape.empty()) { + LOGS(logger, ERROR) << "Cannot get input's shape"; + return false; + } + int32_t steps = static_cast(input_shape[0]); + + if (input_defs.size() > 4 && input_defs[4]->Exists()) { + if (!Contains(initializers, input_defs[4]->Name())) { + LOGS(logger, ERROR) << "GRU: sequence_lens must be constant"; + return false; + } + + const auto& sequence_lens_tensor = *initializers.at(input_defs[4]->Name()); + std::vector sequence_lens; + if (!ReadIntArrayFrom1DTensor(sequence_lens_tensor, sequence_lens, logger)) { + LOGS(logger, ERROR) << "Cannot read sequence lens tensor"; + return false; + } + if (!std::all_of(sequence_lens.begin(), sequence_lens.end(), + [steps](int32_t lens) -> bool { return steps == lens; })) { + LOGS(logger, ERROR) << "GRU: every sequence length must be equal to input shape[0]"; + return false; + } + } + + NodeAttrHelper helper(node); + if (helper.HasAttr("activations")) { + const auto activations = helper.Get("activations", std::vector{"Sigmoid", "Tanh"}); + + if (activations.size() >= 4) { + if (activations[0] != activations[2] || activations[1] != activations[3]) { + LOGS(logger, ERROR) << "GRU: forward and reverse directions must have the same activations"; + return false; + } + } + + const InlinedHashSet supported_activations = {"Relu", "Tanh", "Sigmoid"}; + if (!std::all_of(activations.begin(), activations.end(), + [&supported_activations](const std::string& activation) -> bool { + return supported_activations.contains(activation); + })) { + LOGS(logger, ERROR) << "GRU: activations must be one of Relu, Tanh, Sigmoid"; + return false; + } + } + + if (helper.Get("clip", std::numeric_limits::max()) != std::numeric_limits::max()) { + LOGS(logger, ERROR) << "GRU: clip is not supported"; + return false; + } + + if (helper.Get("layout", 0) != 0) { + LOGS(logger, ERROR) << "GRU: batchwise (layout == 1) is not supported"; + return false; + } + + return true; +} + +bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + int32_t input0_type = 0; // input data type + int32_t input1_type = 0; // weight data type + int32_t input2_type = 0; // recurrentWeight data type + int32_t input3_type = 0; // bias data type + int32_t input4_type = 0; // recurrentBias data type + int32_t input5_type = 0; // initialHiddenState data type + bool has_input3 = input_defs.size() > 3 && input_defs[3]->Exists(); + bool has_input4 = input_defs.size() > 4 && input_defs[4]->Exists(); + bool has_input5 = input_defs.size() > 5 && input_defs[5]->Exists(); + + if (!GetType(*input_defs[0], input0_type, logger) || + !GetType(*input_defs[1], input1_type, logger) || + !GetType(*input_defs[2], input2_type, logger) || + (has_input3 && !GetType(*input_defs[3], input3_type, logger)) || + (has_input4 && !GetType(*input_defs[4], input4_type, logger)) || + (has_input5 && !GetType(*input_defs[5], input5_type, logger))) { + return false; + } + + std::unordered_set supported_data_types; + if (device_type == WebnnDeviceType::CPU) { + // WebNN CPU backend only support float32 input data type. + supported_data_types = { + ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + }; + } else if (device_type == WebnnDeviceType::GPU) { + supported_data_types = { + ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, + }; + } + + if (!IsSupportedDataType(input0_type, supported_data_types)) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input type: [" << input0_type + << "] is not supported for now"; + return false; + } + + if (input0_type != input1_type || + input0_type != input2_type || + (has_input3 && input0_type != input3_type) || + (has_input4 && input0_type != input4_type) || + (has_input5 && input0_type != input5_type)) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input data types should be the same."; + return false; + } + + return true; +} + +void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.builders.push_back(std::make_unique()); + op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 862cf5ded15bc..01761290f07e3 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -108,6 +108,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateGemmOpBuilder("MatMulInteger", op_registrations); } + { // GRU + CreateGruOpBuilder("GRU", op_registrations); + } + { // Logical CreateLogicalOpBuilder("Equal", op_registrations); CreateLogicalOpBuilder("Greater", op_registrations); diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h index e11938d8fa406..b66218cc9a902 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h @@ -33,6 +33,7 @@ void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& o void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); From b80032862800f516b9810ac60a4a30f0a565b8e4 Mon Sep 17 00:00:00 2001 From: Jagadish Krishnamoorthy Date: Wed, 11 Sep 2024 14:52:18 -0700 Subject: [PATCH 11/39] [ROCm EP/ MIGraphx EP] matmul_nbits: Use GPU_WARP_SIZE_HOST for host side code (#22045) ### Description For ROCm device, the host side code needs to call GPU_WARP_SIZE_HOST to query warpSize of the underlying GPU device. ### Motivation and Context Fixes MatMulNBits tests on gfx1100/01 which has warpSize of 32. Signed-off-by: Jagadish Krishnamoorthy --- onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu index af9e87eaf225d..ce6c07fbed2bc 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu @@ -289,7 +289,7 @@ bool TryMatMul4Bits( return false; } dim3 blocks((n + kColsPerThreadBlock - 1) / kColsPerThreadBlock, m); - dim3 threads(kWarpSize, kColsPerThreadBlock); + dim3 threads(GPU_WARP_SIZE_HOST, kColsPerThreadBlock); int blocks_per_K = (k + block_size - 1) / block_size; int shared_mem_size = sizeof(T) * blocks_per_K * kColsPerThreadBlock + (zero_points != nullptr ? (blocks_per_K + 1) / 2 * kColsPerThreadBlock * 2 : 0); From 0309c5f02fa8ffc11c3ab74c582e88c3997969e0 Mon Sep 17 00:00:00 2001 From: sfatimar Date: Thu, 12 Sep 2024 03:25:40 +0530 Subject: [PATCH 12/39] Ovep release lnl 1.2.1 (#22027) Error Codes are added to catch compilation error and signal recompile. Remote Tensors are added to ensure direct memory access for NPU inferencing. UMD Bypass cache enabled with 2024.4 will eliminate need to disk caching ### Motivation and Context The changes are needed to ensure backward compatibility UMD Bypass caching eliminates driver caching Remote Tensors lead to performance improvement with inferencing on NPU --------- Co-authored-by: Preetha Veeramalai Co-authored-by: Srirammaswamy Co-authored-by: saurabh Co-authored-by: Javier E. Martinez Co-authored-by: Eric Crawford Co-authored-by: jatinwadhwa921 --- cmake/onnxruntime_providers_openvino.cmake | 4 + .../onnxruntime/core/framework/allocator.h | 2 + onnxruntime/core/framework/allocator.cc | 4 + .../providers/openvino/backend_manager.cc | 38 ++++- .../openvino/backends/basic_backend.cc | 161 +++++++++++++++--- .../openvino/backends/basic_backend.h | 9 + .../openvino/openvino_execution_provider.cc | 17 ++ .../openvino/openvino_execution_provider.h | 4 +- .../core/providers/openvino/ov_allocator.cc | 55 ++++++ .../core/providers/openvino/ov_allocator.h | 24 +++ .../core/providers/openvino/ov_interface.h | 1 + onnxruntime/test/perftest/ort_test_session.cc | 59 ++++++- onnxruntime/test/perftest/ort_test_session.h | 3 + 13 files changed, 338 insertions(+), 43 deletions(-) create mode 100644 onnxruntime/core/providers/openvino/ov_allocator.cc create mode 100644 onnxruntime/core/providers/openvino/ov_allocator.h diff --git a/cmake/onnxruntime_providers_openvino.cmake b/cmake/onnxruntime_providers_openvino.cmake index e559583fae8f5..2eb3611bae902 100644 --- a/cmake/onnxruntime_providers_openvino.cmake +++ b/cmake/onnxruntime_providers_openvino.cmake @@ -21,6 +21,10 @@ message(FATAL_ERROR "OpenVINO 2024.0 and newer are supported. Please, use latest OpenVINO release") endif() + if(OpenVINO_VERSION VERSION_GREATER_EQUAL 2024.4) + add_definitions(-DUSE_OVEP_NPU_MEMORY=1) + endif() + if (WIN32) unset(CMAKE_MAP_IMPORTED_CONFIG_RELWITHDEBINFO) endif() diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index 097873c5e3653..abab118efd04f 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -50,6 +50,8 @@ constexpr const char* HIP = "Hip"; constexpr const char* HIP_PINNED = "HipPinned"; constexpr const char* OpenVINO_CPU = "OpenVINO_CPU"; constexpr const char* OpenVINO_GPU = "OpenVINO_GPU"; +constexpr const char* OpenVINO_RT = "OpenVINO_RT"; +constexpr const char* OpenVINO_RT_NPU = "OpenVINO_RT_NPU"; constexpr const char* WEBGPU_BUFFER = "WebGPU_Buffer"; constexpr size_t kAllocAlignment = 256; diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index c3e96e450c59b..5e66f2b99fded 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -145,6 +145,10 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA *out = new OrtMemoryInfo( name1, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(id1)), id1, mem_type1); + } else if (strcmp(name1, onnxruntime::OpenVINO_RT_NPU) == 0) { + *out = new OrtMemoryInfo( + name1, type, OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, static_cast(id1)), id1, + mem_type1); } else if (strcmp(name1, onnxruntime::CUDA_PINNED) == 0) { *out = new OrtMemoryInfo( onnxruntime::CUDA_PINNED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast(id1)), diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index be41b125e4440..4fca4037301fb 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -107,12 +108,15 @@ BackendManager::BackendManager(const GlobalContext& global_context, subgraph_context_, ep_ctx_handle_); } catch (const OnnxRuntimeException& ex) { + std::string exception_str = ex.what(); + bool eligible_for_cpu_fallback = device_type.find("NPU") != std::string::npos && + !GetGlobalContext().disable_cpu_fallback && + !ep_ctx_handle_.IsValidOVEPCtxGraph(); #if defined(OPENVINO_DISABLE_NPU_FALLBACK) - ORT_THROW(ex.what()); + eligible_for_cpu_fallback = false; #else - if (device_type.find("NPU") != std::string::npos && - !GetGlobalContext().disable_cpu_fallback) { - LOGS_DEFAULT(WARNING) << ex.what(); + if (eligible_for_cpu_fallback) { + LOGS_DEFAULT(VERBOSE) << exception_str; LOGS_DEFAULT(WARNING) << "Model compilation failed at OV NPU." << "Falling back to OV CPU for execution"; GetGlobalContext().device_type = "CPU"; @@ -125,10 +129,32 @@ BackendManager::BackendManager(const GlobalContext& global_context, } catch (std::string const& msg) { ORT_THROW(msg); } - } else { - ORT_THROW(ex.what()); } #endif + if (!eligible_for_cpu_fallback) { + if (device_type.find("NPU") != std::string::npos && + exception_str.find("intel_npu") != std::string::npos) { + // Handle NPU device related errors +#ifndef NDEBUG + ORT_THROW(exception_str + "\nModel needs to be recompiled\n"); +#else + std::string error_message = "UNKNOWN NPU ERROR"; + std::string error_code = "code 0x0"; + std::regex error_message_pattern(R"(\bZE_\w*\b)"); + std::regex error_code_pattern("code 0x[0-9a-fA-F]+"); + std::smatch matches; + if (std::regex_search(exception_str, matches, error_message_pattern)) { + error_message = matches[0]; + } + if (std::regex_search(exception_str, matches, error_code_pattern)) { + error_code = matches[0]; + } + throw std::runtime_error(error_message + ", " + error_code + "\nModel needs to be recompiled\n"); +#endif + } else { + ORT_THROW(exception_str); + } + } } } if (global_context_.export_ep_ctx_blob && !ep_ctx_handle_.IsValidOVEPCtxGraph()) { diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 8d340e2daf4b5..1f9c61780f27a 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -48,14 +48,6 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr // Set the inference_num_threads property of the CPU SetNumThreads(device_config); -#ifndef NDEBUG - if (IsDebugEnabled()) { - std::string file_name = subgraph_context.subgraph_name + "_static.onnx"; - std::fstream outfile(file_name, std::ios::out | std::ios::trunc | std::ios::binary); - model_proto->SerializeToOstream(outfile); - } -#endif - try { std::string dev_prec = global_context.device_type + "_" + global_context_.precision_str; @@ -180,6 +172,11 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { device_property = std::make_pair("NPU_COMPILER_TYPE", env_npu_compiler_type); } device_config.emplace(ov::device::properties("NPU", device_property)); +#if (OPENVINO_VERSION_MAJOR >= 2024) && (OPENVINO_VERSION_MINOR > 3) + if (global_context_.export_ep_ctx_blob) { + global_context_.ie_core.Get().set_property("NPU", ov::intel_npu::bypass_umd_caching(true)); + } +#endif } } @@ -295,16 +292,104 @@ void BasicBackend::StartAsyncInference(Ort::KernelContext& context, OVInferReque ORT_THROW(msg); } } else { - OVTensorPtr graph_input_blob; - try { - graph_input_blob = infer_request->GetTensor(input_name); - } catch (const char* msg) { - ORT_THROW(msg); + if ((global_context_.device_type.find("CPU") != std::string::npos || + global_context_.device_type.find("GPU") != std::string::npos)) { + OVTensorPtr graph_input_blob; + try { + graph_input_blob = infer_request->GetTensor(input_name); + } catch (const char* msg) { + ORT_THROW(msg); + } + FillInputBlob(std::move(graph_input_blob), batch_slice_idx, std::move(input_name), context, subgraph_context_); + } else { + auto tensor = context.GetInput(subgraph_context_.input_names.at(input_name)); + auto allocator_name = tensor.GetTensorMemoryInfo().GetAllocatorName(); + ov_tensor_data_t ov_tensor_key; + ort_tensor_key_t ort_tensor_key{tensor.GetTensorRawData(), allocator_name}; + if (const auto& it = ort_ov_tensor_map.find(ort_tensor_key); it != ort_ov_tensor_map.end()) { + ov_tensor_key = it->second; + } else { + // Does this make sense for both types of allocators? + auto input = graph_input_info.at(input_idx); + if (allocator_name == OpenVINO_RT_NPU) { + ov_tensor_key.copy_needed = false; + ov_tensor_key.tensor_ptr = std::make_shared(input.get_element_type(), input.get_shape(), + (void*)tensor.GetTensorRawData()); + } else { + ov_tensor_key.copy_needed = true; + ov_tensor_key.tensor_ptr = std::make_shared(input.get_element_type(), input.get_shape()); + } + ort_ov_tensor_map.emplace(ort_tensor_key, ov_tensor_key); + + if (ov_tensor_key.copy_needed) { + const char* ort_tensor_data = tensor.GetTensorData(); + size_t tensor_data_size = ov_tensor_key.tensor_ptr->get_byte_size(); + auto ort_batch_memory_offset = ort_tensor_data + tensor_data_size * batch_slice_idx; + std::memcpy(ov_tensor_key.tensor_ptr->data(), ort_batch_memory_offset, tensor_data_size); + } + + try { + infer_request->SetTensor(input_name, ov_tensor_key.tensor_ptr); + } catch (const char* msg) { + ORT_THROW(msg); + } + } } - FillInputBlob(std::move(graph_input_blob), batch_slice_idx, std::move(input_name), context, subgraph_context_); } input_idx++; } + if (global_context_.device_type.find("NPU") != std::string::npos) { + // Set the output blob as remote blob + auto graph_output_info = exe_network_.Get().outputs(); + auto output_idx = 0; + for (auto output_info_iter = graph_output_info.begin(); + output_info_iter != graph_output_info.end(); ++output_info_iter) { + auto output_names = output_info_iter->get_names(); + std::string onnx_output_name; + std::string output_name; + // using the output name retrieved from ONNX original to match with the output names returned by OV tensors + for (auto it = subgraph_context_.output_names.begin(); it != subgraph_context_.output_names.end(); ++it) { + onnx_output_name = it->first; + if (output_names.find(onnx_output_name) != output_names.end()) { + // Assigning the output_name + output_name = it->first; + break; + } + } + size_t batch_size = 1; + Ort::UnownedValue tensor = GetOutputTensor(context, + batch_size, + infer_request, + output_name, + subgraph_context_.output_names); + auto allocator_name = tensor.GetTensorMemoryInfo().GetAllocatorName(); + + ov_tensor_data_t ov_tensor_data; + ort_tensor_key_t ort_tensor_key{tensor.GetTensorRawData(), allocator_name}; + if (const auto& it = ort_ov_tensor_map.find(ort_tensor_key); it != ort_ov_tensor_map.end()) { + ov_tensor_data = it->second; + } else { + auto output = graph_output_info.at(output_idx); + if (allocator_name == OpenVINO_RT_NPU) { + ov_tensor_data.copy_needed = false; + ov_tensor_data.tensor_ptr = std::make_shared(output.get_element_type(), output.get_shape(), + (void*)tensor.GetTensorRawData()); + } else { + ov_tensor_data.copy_needed = true; + ov_tensor_data.tensor_ptr = std::make_shared(output.get_element_type(), output.get_shape()); + } + ort_ov_tensor_map.emplace(ort_tensor_key, ov_tensor_data); + + try { + infer_request->SetTensor(output_name, ov_tensor_data.tensor_ptr); + } catch (const char* msg) { + ORT_THROW(msg); + } + } + output_idx++; + } + } + // Start Async inference infer_request->StartAsync(); } catch (const char* msg) { @@ -454,20 +539,42 @@ void BasicBackend::CompleteAsyncInference(Ort::KernelContext& context, OVInferRe " doesn't exist in the " "list of OpenVINO output tensor names"); } - try { - graph_output_blob = infer_request->GetTensor(output_name); - } catch (const char* msg) { - ORT_THROW(msg); - } - size_t batch_size = 1; - Ort::UnownedValue output_tensor = - GetOutputTensor(context, batch_size, infer_request, std::move(output_name), subgraph_context_.output_names); - auto mem_info = output_tensor.GetTensorMemoryInfo(); - if (mem_info.GetAllocatorName() == OpenVINO_GPU) { - return; + if ((global_context_.device_type.find("CPU") != std::string::npos || + global_context_.device_type.find("GPU") != std::string::npos)) { + try { + graph_output_blob = infer_request->GetTensor(output_name); + } catch (const char* msg) { + ORT_THROW(msg); + } + size_t batch_size = 1; + Ort::UnownedValue output_tensor = + GetOutputTensor(context, batch_size, infer_request, std::move(output_name), subgraph_context_.output_names); + auto mem_info = output_tensor.GetTensorMemoryInfo(); + if (mem_info.GetAllocatorName() == OpenVINO_GPU) { + return; + } else { + size_t batch_slice = 0; + FillOutputBlob(std::move(graph_output_blob), output_tensor, batch_slice); + } } else { - size_t batch_slice = 0; - FillOutputBlob(std::move(graph_output_blob), output_tensor, batch_slice); + size_t batch_size = 1; + Ort::UnownedValue output_tensor = + GetOutputTensor(context, batch_size, infer_request, std::move(output_name), subgraph_context_.output_names); + auto allocator_name = output_tensor.GetTensorMemoryInfo().GetAllocatorName(); + ov_tensor_data_t ov_tensor_data; + ort_tensor_key_t ort_tensor_key{output_tensor.GetTensorRawData(), allocator_name}; + if (const auto& it = ort_ov_tensor_map.find(ort_tensor_key); it != ort_ov_tensor_map.end()) { + ov_tensor_data = it->second; + } else { + ORT_THROW(log_tag + "Expected all outputs to have associated OV::Tensor's"); + } + + if (ov_tensor_data.copy_needed) { + auto ort_tensor_data = output_tensor.GetTensorMutableData(); + size_t tensor_data_size = ov_tensor_data.tensor_ptr->get_byte_size(); + auto ort_batch_memory_offset = ort_tensor_data /*+ tensor_data_size * batch_size*/; + std::memcpy(ort_batch_memory_offset, ov_tensor_data.tensor_ptr->data(), tensor_data_size); + } } } diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h index cd242a06b27d4..cd69e88f994b9 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.h +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h @@ -11,6 +11,7 @@ #include #include #include +#include #include "core/session/onnxruntime_cxx_api.h" #include "core/providers/openvino/contexts.h" @@ -20,6 +21,11 @@ namespace onnxruntime { namespace openvino_ep { +struct ov_tensor_data_t { + OVTensorPtr tensor_ptr; + bool copy_needed; +}; + class InferRequestsQueue; class BasicBackend : public IBackend { public: @@ -60,6 +66,9 @@ class BasicBackend : public IBackend { #if defined IO_BUFFER_ENABLED OVRemoteContextPtr remote_context_; #endif + + using ort_tensor_key_t = std::pair; + std::map ort_ov_tensor_map; }; class InferRequestsQueue { diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 29c45916795d3..08144651319cf 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -10,6 +10,9 @@ #include "core/providers/openvino/onnx_ctx_model_helper.h" #include "core/providers/openvino/ov_versions/capability.h" #include "openvino/core/version.hpp" +#ifdef USE_OVEP_NPU_MEMORY +#include "core/providers/openvino/ov_allocator.h" +#endif #define MEMCPY_S(dest, src, destsz, srcsz) memcpy(dest, src, std::min(destsz, srcsz)) @@ -180,4 +183,18 @@ common::Status OpenVINOExecutionProvider::Compile( return Status::OK(); } +#ifdef USE_OVEP_NPU_MEMORY +std::vector OpenVINOExecutionProvider::CreatePreferredAllocators() { + AllocatorCreationInfo npu_allocator_info{ + [this](OrtDevice::DeviceId device_id) { + return std::make_unique(global_context_->ie_core.Get(), OrtDevice::NPU, device_id, OpenVINO_RT_NPU); + }, + 0, + }; + + // fill in allocator + return std::vector{CreateAllocator(npu_allocator_info)}; +} +#endif + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.h b/onnxruntime/core/providers/openvino/openvino_execution_provider.h index 030e5bba71b67..8b1c62c607f6e 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.h +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.h @@ -189,7 +189,9 @@ class OpenVINOExecutionProvider : public IExecutionProvider { const void* GetExecutionHandle() const noexcept override { return nullptr; } - +#ifdef USE_OVEP_NPU_MEMORY + std::vector CreatePreferredAllocators() override; +#endif private: std::unique_ptr global_context_; openvino_ep::EPCtxHandler ep_ctx_handle_{}; diff --git a/onnxruntime/core/providers/openvino/ov_allocator.cc b/onnxruntime/core/providers/openvino/ov_allocator.cc new file mode 100644 index 0000000000000..6700244b754d8 --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_allocator.cc @@ -0,0 +1,55 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License +#ifdef USE_OVEP_NPU_MEMORY +#include "core/providers/openvino/ov_allocator.h" +#include "core/providers/openvino/ov_interface.h" +#include "openvino/runtime/intel_npu/level_zero/level_zero.hpp" +#include "openvino/runtime/intel_npu/properties.hpp" + +namespace onnxruntime { + +using namespace openvino_ep; + +constexpr size_t default_alignment = 4096; + +static inline size_t align_up(size_t size, size_t pow2_alignment) { + return (size + pow2_alignment - 1) & ~(pow2_alignment - 1); +} + +OVRTAllocator::OVRTAllocator(ov::Core& core, OrtDevice::DeviceType device_type, OrtDevice::DeviceId device_id, const char* name) : IAllocator(OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(device_type, OrtDevice::MemType::DEFAULT, device_id), device_id, OrtMemTypeCPUInput)), core_(core) { + if (device_type == OrtDevice::NPU) { + remote_ctx_ = core_.get_default_context("NPU").as(); + } else { + ORT_THROW("Invalid device type"); + } +} + +void* OVRTAllocator::Alloc(size_t size) { + try { + size_t alloc_size = align_up(size + sizeof(ov::Tensor*) + default_alignment, default_alignment); + ov::Tensor* tensor = new ov::Tensor(remote_ctx_.create_host_tensor(ov::element::Type_t::u8, + {alloc_size})); + uintptr_t data_ptr = reinterpret_cast(tensor->data()); + + ov::Tensor** ptr = reinterpret_cast(align_up(data_ptr + sizeof(ov::Tensor*), default_alignment)); + ptr[-1] = tensor; + + return reinterpret_cast(ptr); + + } catch (const ov::Exception& e) { + ORT_THROW(std::string("Alloc failed: ") + e.what()); + } + return nullptr; +} + +void OVRTAllocator::Free(void* p) { + try { + ov::Tensor** ptr = reinterpret_cast(p); + delete ptr[-1]; + } catch (const ov::Exception& e) { + ORT_THROW(std::string("Free failed: ") + e.what()); + } +} + +} // namespace onnxruntime +#endif diff --git a/onnxruntime/core/providers/openvino/ov_allocator.h b/onnxruntime/core/providers/openvino/ov_allocator.h new file mode 100644 index 0000000000000..083cfc4d5aed3 --- /dev/null +++ b/onnxruntime/core/providers/openvino/ov_allocator.h @@ -0,0 +1,24 @@ +// Copyright (C) Intel Corporation +// Licensed under the MIT License +#ifdef USE_OVEP_NPU_MEMORY +#pragma once + +#include "core/common/inlined_containers.h" +#include "core/framework/allocator.h" +#include "openvino/runtime/remote_context.hpp" + +namespace onnxruntime { + +class OVRTAllocator : public IAllocator { + public: + OVRTAllocator(ov::Core& core, OrtDevice::DeviceType device_type, OrtDevice::DeviceId device_id, const char* name); + void* Alloc(size_t size) override; + void Free(void* p) override; + + private: + ov::Core& core_; + ov::RemoteContext remote_ctx_; +}; + +} // namespace onnxruntime +#endif diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index fa22e0f3cb03d..f4da4ea3e3244 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -10,6 +10,7 @@ #include #include "openvino/openvino.hpp" +#include "openvino/runtime/intel_npu/properties.hpp" #include "openvino/pass/convert_fp32_to_fp16.hpp" #include "openvino/frontend/manager.hpp" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 837aeb3c37acd..ae7680571ced1 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -34,10 +34,18 @@ std::chrono::duration OnnxRuntimeTestSession::Run() { // Randomly pick one OrtValueArray from test_inputs_. (NOT ThreadSafe) const std::uniform_int_distribution::param_type p(0, static_cast(test_inputs_.size() - 1)); const size_t id = static_cast(dist_(rand_engine_, p)); + auto& input = test_inputs_.at(id); auto start = std::chrono::high_resolution_clock::now(); - auto output_values = session_.Run(Ort::RunOptions{nullptr}, input_names_.data(), input.data(), input_names_.size(), - output_names_raw_ptr.data(), output_names_raw_ptr.size()); + + if (!use_device_mem) { + auto output_values = session_.Run(Ort::RunOptions{nullptr}, input_names_.data(), input.data(), input_names_.size(), + output_names_raw_ptr.data(), output_names_raw_ptr.size()); + } else { + session_.Run(Ort::RunOptions{nullptr}, input_names_.data(), input.data(), input_names_.size(), + output_names_raw_ptr.data(), outputs_.data(), output_names_raw_ptr.size()); + } + auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration duration_seconds = end - start; return duration_seconds; @@ -815,6 +823,10 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); "[ERROR] [OpenVINO] The value for the key 'export_ep_ctx_blob' " "should be a boolean i.e. true or false. Default value is false.\n"); } + } else if (key == "use_device_mem") { + if (value == "true" || value == "True") { + use_device_mem = true; + } } else { ORT_THROW("[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO. ['device_type', 'device_id', 'enable_npu_fast_compile', 'num_of_threads', 'cache_dir', 'num_streams', 'enable_opencl_throttling', 'disable_dynamic_shapes'] \n"); } @@ -858,6 +870,27 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); input_names_str_[i] = m.GetInputName(i); input_names_[i] = input_names_str_[i].c_str(); } + + if (use_device_mem) { + Ort::MemoryInfo memory_info = Ort::MemoryInfo("OpenVINO_RT_NPU", OrtArenaAllocator, 0, OrtMemTypeCPUOutput); + custom_allocator_ = std::make_unique(session_, memory_info); + for (size_t i = 0; i < output_names_raw_ptr.size(); i++) { + Ort::TypeInfo type_info = session_.GetOutputTypeInfo(i); + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + + std::vector output_shape = tensor_info.GetShape(); + + // free dimensions are treated as 1 if not overridden + for (int64_t& dim : output_shape) { + if (dim == -1) { + dim = 1; + } + } + + outputs_.push_back(Ort::Value::CreateTensor(*custom_allocator_, (const int64_t*)output_shape.data(), + output_shape.size(), tensor_info.GetElementType())); + } + } } template @@ -944,9 +977,11 @@ bool OnnxRuntimeTestSession::PopulateGeneratedInputTestData(int32_t seed) { // iterate over all input nodes for (size_t i = 0; i < static_cast(input_length_); i++) { Ort::TypeInfo type_info = session_.GetInputTypeInfo(i); - Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); if (type_info.GetONNXType() == ONNX_TYPE_TENSOR) { auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + if (!use_device_mem) { + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + } std::vector input_node_dim = tensor_info.GetShape(); // free dimensions are treated as 1 if not overridden @@ -955,12 +990,18 @@ bool OnnxRuntimeTestSession::PopulateGeneratedInputTestData(int32_t seed) { dim = 1; } } - - auto allocator = Ort::AllocatorWithDefaultOptions(); - Ort::Value input_tensor = Ort::Value::CreateTensor(allocator, (const int64_t*)input_node_dim.data(), - input_node_dim.size(), tensor_info.GetElementType()); - InitializeTensorWithSeed(seed, input_tensor); - PreLoadTestData(0, i, std::move(input_tensor)); + if (use_device_mem) { + Ort::Value input_tensor = Ort::Value::CreateTensor(*custom_allocator_, (const int64_t*)input_node_dim.data(), + input_node_dim.size(), tensor_info.GetElementType()); + InitializeTensorWithSeed(seed, input_tensor); + PreLoadTestData(0, i, std::move(input_tensor)); + } else { + auto allocator = Ort::AllocatorWithDefaultOptions(); + Ort::Value input_tensor = Ort::Value::CreateTensor(allocator, (const int64_t*)input_node_dim.data(), + input_node_dim.size(), tensor_info.GetElementType()); + InitializeTensorWithSeed(seed, input_tensor); + PreLoadTestData(0, i, std::move(input_tensor)); + } } } return true; diff --git a/onnxruntime/test/perftest/ort_test_session.h b/onnxruntime/test/perftest/ort_test_session.h index f1a4220ab325e..e33041a2a0958 100644 --- a/onnxruntime/test/perftest/ort_test_session.h +++ b/onnxruntime/test/perftest/ort_test_session.h @@ -38,6 +38,8 @@ class OnnxRuntimeTestSession : public TestSession { std::mt19937 rand_engine_; std::uniform_int_distribution dist_; std::vector> test_inputs_; + std::unique_ptr custom_allocator_; + std::vector outputs_; std::vector output_names_; // The same size with output_names_. // TODO: implement a customized allocator, then we can remove output_names_ to simplify this code @@ -46,6 +48,7 @@ class OnnxRuntimeTestSession : public TestSession { std::vector input_names_str_; const int input_length_; std::string provider_name_; + bool use_device_mem = false; }; } // namespace perftest From d8e64bb529c1d0f18efd47710d179205c96ffbca Mon Sep 17 00:00:00 2001 From: Lennart Hannink Date: Thu, 12 Sep 2024 01:05:37 +0200 Subject: [PATCH 13/39] Refactor CoreMLExecution to C++ bridge class (#21857) Refactor Objective-C++ class `CoreMLExecution` into existing C++ bridge class `onnxruntime::coreml::Execution`. --- .../core/providers/coreml/model/model.mm | 398 ++++++++---------- 1 file changed, 171 insertions(+), 227 deletions(-) diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index 4d20061820e71..68460ff7c9b31 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -30,8 +30,8 @@ // to manually do this asm(".linker_option \"-framework\", \"CoreML\""); -using namespace onnxruntime; -using namespace onnxruntime::coreml; +namespace onnxruntime { +namespace coreml { namespace { /** @@ -247,213 +247,6 @@ Status CopyMLMultiArrayBuffer(const void* mlmultiarray_buffer, void* tensor_buff } } // namespace -NS_ASSUME_NONNULL_BEGIN - -// Execution for a CoreML model, it performs -// 1. Compile the model by given path for execution -// 2. Predict using given OnnxTensorFeatureProvider input and copy the output data back ORT -// 3. The compiled model will be removed in dealloc or removed using cleanup function -@interface CoreMLExecution : NSObject { - NSString* coreml_model_path_; - NSString* compiled_model_path_; - const logging::Logger* logger_; - uint32_t coreml_flags_; -} - -- (instancetype)initWithPath:(const std::string&)path - logger:(const logging::Logger&)logger - coreml_flags:(uint32_t)coreml_flags; -- (void)cleanup; -- (void)dealloc; -- (Status)loadModel API_AVAILABLE_COREML3; -- (Status)predict:(const std::unordered_map&)inputs - outputs:(const std::unordered_map&)outputs - getOutputTensorDataFn:(const GetOutputTensorMutableRawDataFn&)get_output_tensor_mutable_raw_data_fn - API_AVAILABLE_COREML3; - -@property(nullable) MLModel* model API_AVAILABLE_COREML3; - -@end - -@implementation CoreMLExecution - -- (instancetype)initWithPath:(const std::string&)path - logger:(const logging::Logger&)logger - coreml_flags:(uint32_t)coreml_flags { - if (self = [super init]) { - coreml_model_path_ = util::Utf8StringToNSString(path.c_str()); - logger_ = &logger; - coreml_flags_ = coreml_flags; - } - return self; -} - -- (void)cleanup { - NSError* error = nil; - if (compiled_model_path_ != nil) { - [[NSFileManager defaultManager] removeItemAtPath:compiled_model_path_ error:&error]; - if (error != nil) { - LOGS(*logger_, ERROR) << "Failed cleaning up the compiled model: " << [compiled_model_path_ UTF8String] - << ", error message: " << [[error localizedDescription] UTF8String]; - } - compiled_model_path_ = nil; - } - -#if !defined(NDEBUG) - std::string path_override = Env::Default().GetEnvironmentVar(util::kOverrideModelOutputDirectoryEnvVar); - if (!path_override.empty()) { - // don't cleanup - coreml_model_path_ = nil; - } -#endif - - if (coreml_model_path_ != nil) { - error = nil; - [[NSFileManager defaultManager] removeItemAtPath:coreml_model_path_ error:&error]; - if (error != nil) { - LOGS(*logger_, ERROR) << "Failed cleaning up the coreml model: " << [coreml_model_path_ UTF8String] - << ", error message: " << [[error localizedDescription] UTF8String]; - } - coreml_model_path_ = nil; - } -} - -- (void)dealloc { - [self cleanup]; -} - -- (Status)loadModel { - NSURL* modelUrl = [NSURL URLWithString:coreml_model_path_]; - if (modelUrl == nil) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create model URL from path"); - } - - // TODO: Update this to version with callback handler as the API used here is deprecated. - // https://developer.apple.com/documentation/coreml/mlmodel/3929553-compilemodelaturl - // As we call loadModel during EP Compile there shouldn't be an issue letting the actual compile run in the - // background. We will have to check for completion in `predict` and block until it is done. - NSError* error = nil; - NSURL* compileUrl = [MLModel compileModelAtURL:modelUrl error:&error]; - - if (error != nil) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Error compiling model: ", - [[error localizedDescription] UTF8String]); - } - - compiled_model_path_ = [compileUrl path]; - - MLModelConfiguration* config = [MLModelConfiguration alloc]; - config.computeUnits = (coreml_flags_ & COREML_FLAG_USE_CPU_ONLY) - ? MLComputeUnitsCPUOnly - : MLComputeUnitsAll; - _model = [MLModel modelWithContentsOfURL:compileUrl configuration:config error:&error]; - - if (error != nil || _model == nil) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create MLModel", - (error != nil) ? MakeString(", error: ", [[error localizedDescription] UTF8String]) : ""); - } - - return Status::OK(); -} - -- (Status)predict:(const std::unordered_map&)inputs - outputs:(const std::unordered_map&)outputs - getOutputTensorDataFn:(const GetOutputTensorMutableRawDataFn&)get_output_tensor_mutable_raw_data_fn { - Status status = Status::OK(); - ORT_TRY { - if (_model == nil) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Model is not loaded"); - } - - id input_features; - InlinedVector> conversion_buffers; - ORT_RETURN_IF_ERROR(CreateInputFeatureProvider(inputs, *logger_, &input_features, conversion_buffers)); - - MLPredictionOptions* options = [[MLPredictionOptions alloc] init]; - NSError* error = nil; - id output_features = [_model predictionFromFeatures:input_features - options:options - error:&error]; - - if (error != nil) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Error executing model: ", - [[error localizedDescription] UTF8String]); - } - - for (const auto& [output_name, output_tensor_info] : outputs) { - MLFeatureValue* output_value = - [output_features featureValueForName:util::Utf8StringToNSString(output_name.c_str())]; - - if (output_value == nil) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "output_features has no value for ", output_name); - } - - MLMultiArray* data = [output_value multiArrayValue]; - - const auto coreml_static_output_shape = [data]() { - InlinedVector result; - result.reserve(data.shape.count); - for (NSNumber* dim in data.shape) { - const auto dim_value = dim.longLongValue; - result.push_back(dim_value); - } - return result; - }(); - - const auto static_output_shape = GetStaticOutputShape(output_tensor_info.shape, coreml_static_output_shape, - *logger_); - - void* output_buffer = get_output_tensor_mutable_raw_data_fn(output_name, output_tensor_info.data_type, - static_output_shape); - - if (const size_t num_elements = data.count; num_elements > 0) { - if (const auto shape_size = ShapeSize(static_output_shape); - shape_size < 0 || num_elements != static_cast(shape_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "CoreML MLMultiArray count (", num_elements, ") and shape size (", shape_size, - ") do not match"); - } - - // support a non-contiguous array, provided only one dimension is not contiguous - int64_t num_blocks = 0; - int64_t block_size = 0; - int64_t stride = 0; - - ORT_RETURN_IF_ERROR(GetMLMultiArrayCopyInfo(data, num_blocks, block_size, stride)); - - __block Status copy_status; - const auto* tensor_info = &output_tensor_info; - // `getBytesWithHandler` replaces deprecated `.dataPointer` on new versions - if (@available(macOS 12.3, iOS 15.4, *)) { - [data getBytesWithHandler:^(const void* bytes, NSInteger size) { - copy_status = CopyMLMultiArrayBuffer(bytes, output_buffer, data, - num_blocks, block_size, stride, tensor_info); - }]; - } else { - copy_status = CopyMLMultiArrayBuffer(data.dataPointer, output_buffer, data, - num_blocks, block_size, stride, tensor_info); - } - - ORT_RETURN_IF_ERROR(copy_status); - } - } - } - ORT_CATCH(const std::exception& e) { - ORT_HANDLE_EXCEPTION([&]() { - status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Exception: ", e.what()); - }); - } - - return status; -} - -@end - -NS_ASSUME_NONNULL_END - -namespace onnxruntime { -namespace coreml { - Status GetMLMultiArrayCopyInfo(const MLMultiArray* _Nonnull array, int64_t& num_blocks, int64_t& block_size, int64_t& stride) { const auto* shape = array.shape; @@ -498,11 +291,14 @@ Status GetMLMultiArrayCopyInfo(const MLMultiArray* _Nonnull array, } // Internal Execution class -// This class will bridge Model (c++) with CoreMLExecution (objective c++) +// This class is part of the model class and handles the calls into CoreML. Specifically, it performs +// 1. Compile the model by given path for execution +// 2. Predict using given OnnxTensorFeatureProvider input and copy the output data back ORT +// 3. The compiled model will be removed in dealloc or removed using cleanup function class Execution { public: Execution(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags); - ~Execution() {}; + ~Execution(); Status LoadModel(); Status Predict(const std::unordered_map& inputs, @@ -510,30 +306,97 @@ Status Predict(const std::unordered_map& inputs, const GetOutputTensorMutableRawDataFn& get_output_tensor_mutable_raw_data_fn); private: - bool model_loaded{false}; - CoreMLExecution* execution_; + void cleanup(); + NSString* coreml_model_path_{nil}; + NSString* compiled_model_path_{nil}; + const logging::Logger& logger_; + uint32_t coreml_flags_{0}; + MLModel* model_{nil}; }; -Execution::Execution(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags) { +Execution::Execution(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags) + : logger_(logger), + coreml_flags_(coreml_flags) { @autoreleasepool { - execution_ = [[CoreMLExecution alloc] initWithPath:path - logger:logger - coreml_flags:coreml_flags]; + coreml_model_path_ = util::Utf8StringToNSString(path.c_str()); + } +} + +Execution::~Execution() { + @autoreleasepool { + cleanup(); + } +} + +void Execution::cleanup() { + NSError* error = nil; + if (compiled_model_path_ != nil) { + [[NSFileManager defaultManager] removeItemAtPath:compiled_model_path_ error:&error]; + if (error != nil) { + LOGS(logger_, ERROR) << "Failed cleaning up the compiled model: " << [compiled_model_path_ UTF8String] + << ", error message: " << [[error localizedDescription] UTF8String]; + } + compiled_model_path_ = nil; + } + +#if !defined(NDEBUG) + std::string path_override = Env::Default().GetEnvironmentVar(util::kOverrideModelOutputDirectoryEnvVar); + if (!path_override.empty()) { + // don't cleanup + coreml_model_path_ = nil; + } +#endif + + if (coreml_model_path_ != nil) { + error = nil; + [[NSFileManager defaultManager] removeItemAtPath:coreml_model_path_ error:&error]; + if (error != nil) { + LOGS(logger_, ERROR) << "Failed cleaning up the coreml model: " << [coreml_model_path_ UTF8String] + << ", error message: " << [[error localizedDescription] UTF8String]; + } + coreml_model_path_ = nil; } } Status Execution::LoadModel() { - if (model_loaded) { + if (model_ != nil) { return Status::OK(); } if (HAS_COREML3_OR_LATER) { - Status status{}; @autoreleasepool { - status = [execution_ loadModel]; + NSError* error = nil; + + NSURL* modelUrl = [NSURL URLWithString:coreml_model_path_]; + if (modelUrl == nil) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create model URL from path"); + } + + // TODO: Update this to version with callback handler as the API used here is deprecated. + // https://developer.apple.com/documentation/coreml/mlmodel/3929553-compilemodelaturl + // As we call loadModel during EP Compile there shouldn't be an issue letting the actual compile run in the + // background. We will have to check for completion in `predict` and block until it is done. + NSURL* compileUrl = [MLModel compileModelAtURL:modelUrl error:&error]; + if (error != nil) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Error compiling model: ", + [[error localizedDescription] UTF8String]); + } + + compiled_model_path_ = [compileUrl path]; + + MLModelConfiguration* config = [MLModelConfiguration alloc]; + config.computeUnits = (coreml_flags_ & COREML_FLAG_USE_CPU_ONLY) + ? MLComputeUnitsCPUOnly + : MLComputeUnitsAll; + model_ = [MLModel modelWithContentsOfURL:compileUrl configuration:config error:&error]; + + if (error != nil || model_ == nil) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create MLModel", + (error != nil) ? MakeString(", error: ", [[error localizedDescription] UTF8String]) : ""); + } + + return Status::OK(); } - model_loaded = status.IsOK(); - return status; } return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Execution::LoadModel requires macos 10.15+ or ios 13+"); @@ -542,13 +405,94 @@ Status Predict(const std::unordered_map& inputs, Status Execution::Predict(const std::unordered_map& inputs, const std::unordered_map& outputs, const GetOutputTensorMutableRawDataFn& get_output_tensor_mutable_raw_data_fn) { - ORT_RETURN_IF_NOT(model_loaded, "Execution::Predict requires Execution::LoadModel"); - if (HAS_COREML3_OR_LATER) { @autoreleasepool { - return [execution_ predict:inputs - outputs:outputs - getOutputTensorDataFn:get_output_tensor_mutable_raw_data_fn]; + Status status = Status::OK(); + ORT_TRY { + if (model_ == nil) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Model is not loaded"); + } + + id input_features; + InlinedVector> conversion_buffers; + ORT_RETURN_IF_ERROR(CreateInputFeatureProvider(inputs, logger_, &input_features, conversion_buffers)); + + MLPredictionOptions* options = [[MLPredictionOptions alloc] init]; + NSError* error = nil; + id output_features = [model_ predictionFromFeatures:input_features + options:options + error:&error]; + + if (error != nil) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Error executing model: ", + [[error localizedDescription] UTF8String]); + } + + for (const auto& [output_name, output_tensor_info] : outputs) { + MLFeatureValue* output_value = + [output_features featureValueForName:util::Utf8StringToNSString(output_name.c_str())]; + + if (output_value == nil) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "output_features has no value for ", output_name); + } + + MLMultiArray* data = [output_value multiArrayValue]; + + const auto coreml_static_output_shape = [data]() { + InlinedVector result; + result.reserve(data.shape.count); + for (NSNumber* dim in data.shape) { + const auto dim_value = dim.longLongValue; + result.push_back(dim_value); + } + return result; + }(); + + const auto static_output_shape = GetStaticOutputShape(output_tensor_info.shape, coreml_static_output_shape, + logger_); + + void* output_buffer = get_output_tensor_mutable_raw_data_fn(output_name, output_tensor_info.data_type, + static_output_shape); + + if (const size_t num_elements = data.count; num_elements > 0) { + if (const auto shape_size = ShapeSize(static_output_shape); + shape_size < 0 || num_elements != static_cast(shape_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "CoreML MLMultiArray count (", num_elements, ") and shape size (", shape_size, + ") do not match"); + } + + // support a non-contiguous array, provided only one dimension is not contiguous + int64_t num_blocks = 0; + int64_t block_size = 0; + int64_t stride = 0; + + ORT_RETURN_IF_ERROR(GetMLMultiArrayCopyInfo(data, num_blocks, block_size, stride)); + + __block Status copy_status; + const auto* tensor_info = &output_tensor_info; + // `getBytesWithHandler` replaces deprecated `.dataPointer` on new versions + if (@available(macOS 12.3, iOS 15.4, *)) { + [data getBytesWithHandler:^(const void* bytes, NSInteger size) { + copy_status = CopyMLMultiArrayBuffer(bytes, output_buffer, data, + num_blocks, block_size, stride, tensor_info); + }]; + } else { + copy_status = CopyMLMultiArrayBuffer(data.dataPointer, output_buffer, data, + num_blocks, block_size, stride, tensor_info); + } + + ORT_RETURN_IF_ERROR(copy_status); + } + } + } + ORT_CATCH(const std::exception& e) { + ORT_HANDLE_EXCEPTION([&]() { + status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Exception: ", e.what()); + }); + } + + return status; } } From d495e6cf1c477098255511c4136bb7ea43a7c0dc Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Wed, 11 Sep 2024 22:02:30 -0700 Subject: [PATCH 14/39] adds support for Uint8ClampedArray (#21985) Fixes https://github.com/microsoft/onnxruntime/issues/21753 --- js/common/lib/tensor-impl.ts | 19 ++++++++++++++++--- js/common/lib/tensor.ts | 17 +++++++++++++++++ .../type-tests/tensor/create-new-uint8.ts | 19 +++++++++++++++++++ .../unit-tests/tensor/constructor-type.ts | 8 ++++++++ 4 files changed, 60 insertions(+), 3 deletions(-) create mode 100644 js/common/test/type-tests/tensor/create-new-uint8.ts diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts index 4e0ef821dde57..342f5e3a467eb 100644 --- a/js/common/lib/tensor-impl.ts +++ b/js/common/lib/tensor-impl.ts @@ -51,13 +51,16 @@ export class Tensor implements TensorInterface { */ constructor( type: TensorType, - data: TensorDataType | readonly string[] | readonly number[] | readonly boolean[], + data: TensorDataType | Uint8ClampedArray | readonly string[] | readonly number[] | readonly boolean[], dims?: readonly number[], ); /** * Construct a new CPU tensor object from the given data and dims. Type is inferred from data. */ - constructor(data: TensorDataType | readonly string[] | readonly boolean[], dims?: readonly number[]); + constructor( + data: TensorDataType | Uint8ClampedArray | readonly string[] | readonly boolean[], + dims?: readonly number[], + ); /** * Construct a new tensor object from the pinned CPU data with the given type and dims. * @@ -90,12 +93,13 @@ export class Tensor implements TensorInterface { arg0: | TensorType | TensorDataType + | Uint8ClampedArray | readonly string[] | readonly boolean[] | CpuPinnedConstructorParameters | TextureConstructorParameters | GpuBufferConstructorParameters, - arg1?: TensorDataType | readonly number[] | readonly string[] | readonly boolean[], + arg1?: TensorDataType | Uint8ClampedArray | readonly number[] | readonly string[] | readonly boolean[], arg2?: readonly number[], ) { // perform one-time check for BigInt/Float16Array support @@ -216,6 +220,12 @@ export class Tensor implements TensorInterface { } } else if (arg1 instanceof typedArrayConstructor) { data = arg1; + } else if (arg1 instanceof Uint8ClampedArray) { + if (arg0 === 'uint8') { + data = Uint8Array.from(arg1); + } else { + throw new TypeError(`A Uint8ClampedArray tensor's data must be type of uint8`); + } } else { throw new TypeError(`A ${type} tensor's data must be type of ${typedArrayConstructor}`); } @@ -243,6 +253,9 @@ export class Tensor implements TensorInterface { } else { throw new TypeError(`Invalid element type of data array: ${firstElementType}.`); } + } else if (arg0 instanceof Uint8ClampedArray) { + type = 'uint8'; + data = Uint8Array.from(arg0); } else { // get tensor type from TypedArray const mappedType = NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.get( diff --git a/js/common/lib/tensor.ts b/js/common/lib/tensor.ts index 70396bbe1e9a3..8a1197994393b 100644 --- a/js/common/lib/tensor.ts +++ b/js/common/lib/tensor.ts @@ -192,6 +192,15 @@ export interface TensorConstructor extends TensorFactory { dims?: readonly number[], ): TypedTensor<'bool'>; + /** + * Construct a new uint8 tensor object from a Uint8ClampedArray, data and dims. + * + * @param type - Specify the element type. + * @param data - Specify the CPU tensor data. + * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. + */ + new (type: 'uint8', data: Uint8ClampedArray, dims?: readonly number[]): TypedTensor<'uint8'>; + /** * Construct a new 64-bit integer typed tensor object from the given type, data and dims. * @@ -245,6 +254,14 @@ export interface TensorConstructor extends TensorFactory { */ new (data: Uint8Array, dims?: readonly number[]): TypedTensor<'uint8'>; + /** + * Construct a new uint8 tensor object from the given data and dims. + * + * @param data - Specify the CPU tensor data. + * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. + */ + new (data: Uint8ClampedArray, dims?: readonly number[]): TypedTensor<'uint8'>; + /** * Construct a new uint16 tensor object from the given data and dims. * diff --git a/js/common/test/type-tests/tensor/create-new-uint8.ts b/js/common/test/type-tests/tensor/create-new-uint8.ts new file mode 100644 index 0000000000000..46438f97ca2e7 --- /dev/null +++ b/js/common/test/type-tests/tensor/create-new-uint8.ts @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import * as ort from 'onnxruntime-common'; + +// construct from Uint8Array +// +// {type-tests}|pass +new ort.Tensor(new Uint8Array(1)); + +// construct from Uint8ClampedArray +// +// {type-tests}|pass +new ort.Tensor(new Uint8ClampedArray(1)); + +// construct from type (bool), data (Uint8ClampedArray) and shape (number array) +// +// {type-tests}|fail|1|2769 +new ort.Tensor('bool', new Uint8ClampedArray([255, 256]), [2]); diff --git a/js/common/test/unit-tests/tensor/constructor-type.ts b/js/common/test/unit-tests/tensor/constructor-type.ts index def711684d7f5..02390800e8611 100644 --- a/js/common/test/unit-tests/tensor/constructor-type.ts +++ b/js/common/test/unit-tests/tensor/constructor-type.ts @@ -82,6 +82,14 @@ describe('Tensor Constructor Tests - check types', () => { assert.equal(tensor.type, 'bool', "tensor.type should be 'bool'"); }); + it('[uint8] new Tensor(uint8ClampedArray, dims): uint8 tensor can be constructed from Uint8ClampedArray', () => { + const uint8ClampedArray = new Uint8ClampedArray(2); + uint8ClampedArray[0] = 0; + uint8ClampedArray[1] = 256; // clamped + const tensor = new Tensor('uint8', uint8ClampedArray, [2]); + assert.equal(tensor.type, 'uint8', "tensor.type should be 'uint8'"); + }); + it("[bool] new Tensor('bool', uint8Array, dims): tensor can be constructed from Uint8Array", () => { const tensor = new Tensor('bool', new Uint8Array([1, 0, 1, 0]), [2, 2]); assert.equal(tensor.type, 'bool', "tensor.type should be 'bool'"); From ae39c40e5b65874735cd07aca692287aa1cf1b62 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Thu, 12 Sep 2024 19:07:42 +0800 Subject: [PATCH 15/39] fix typo in iOS pipeline (#22067) ### Description ### Motivation and Context The parameter isn't correct. Maybe it hasn't negative impact by chance so far. https://github.com/microsoft/onnxruntime/blob/d8e64bb529c1d0f18efd47710d179205c96ffbca/cmake/CMakeLists.txt#L1712-L1717 --- tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml index 48d48156fe913..74211bc5dbd7c 100644 --- a/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml @@ -53,7 +53,7 @@ jobs: python3 $(Build.SourcesDirectory)/tools/ci_build/build.py \ --skip_submodule_sync \ --build_dir $(Build.BinariesDirectory)/iOS \ - --build_shared \ + --build_shared_lib \ --use_coreml \ --use_xnnpack \ --ios \ From 951b1b7160b0efc21d97ead1051777410e2ca775 Mon Sep 17 00:00:00 2001 From: mindest <30493312+mindest@users.noreply.github.com> Date: Fri, 13 Sep 2024 01:54:32 +0900 Subject: [PATCH 16/39] [CI] Linux ROCm CI Pipeline: fix error, set trigger rules. (#22069) ### Description * Correct the wrong EP name for ROCm, fix CI error. * Update `set-trigger-rules.py`. * Modify the .yml via `set-trigger-rules.py` --- onnxruntime/test/python/onnxruntime_test_python.py | 2 +- .../ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml | 1 + tools/ci_build/set-trigger-rules.py | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index feabd648f8385..24151932a6681 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -1694,7 +1694,7 @@ def test_register_custom_e_ps_library(self): available_eps = C.get_available_providers() # skip amd gpu build - if "RocmExecutionProvider" in available_eps: + if "ROCMExecutionProvider" in available_eps: return if sys.platform.startswith("win"): diff --git a/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml index 7b77281b0efe2..50f3862761320 100644 --- a/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-rocm-ci-pipeline.yml @@ -1,4 +1,5 @@ ##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### +### please do rerun set-trigger-rules.py ### trigger: branches: include: diff --git a/tools/ci_build/set-trigger-rules.py b/tools/ci_build/set-trigger-rules.py index 583e5b05ed6d8..fb6aa44cdf31a 100644 --- a/tools/ci_build/set-trigger-rules.py +++ b/tools/ci_build/set-trigger-rules.py @@ -24,6 +24,7 @@ "linux-migraphx-ci-pipeline.yml", "linux-openvino-ci-pipeline.yml", "linux-qnn-ci-pipeline.yml", + "linux-rocm-ci-pipeline.yml", "mac-ci-pipeline.yml", "mac-coreml-ci-pipeline.yml", "mac-ios-ci-pipeline.yml", From 84f73327f55b3dadbf20b69bc1a12cc2811986ed Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 12 Sep 2024 10:33:37 -0700 Subject: [PATCH 17/39] allow scalar axes for Unsqueeze for WebGPU (#22054) ### Description Align with CPU behavior. https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/cpu/tensor/unsqueeze.cc#L60-L62 --- onnxruntime/core/providers/js/operators/unsqueeze.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/js/operators/unsqueeze.h b/onnxruntime/core/providers/js/operators/unsqueeze.h index 7cbfdc38b742d..f15a3008895aa 100644 --- a/onnxruntime/core/providers/js/operators/unsqueeze.h +++ b/onnxruntime/core/providers/js/operators/unsqueeze.h @@ -26,8 +26,9 @@ class Unsqueeze final : public JsKernel, public UnsqueezeBase { if (num_inputs == 2) { // axes is an input const Tensor* axes_tensor = context->Input(1); ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); - ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, - "An axes tensor must be a vector tensor."); + ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 0 || + axes_tensor->Shape().NumDimensions() == 1, + "An axes tensor must be a scalar or a vector tensor."); auto nDims = static_cast(axes_tensor->Shape()[0]); const auto* data = axes_tensor->Data(); axes.assign(data, data + nDims); From 10883d7997ed4b53f989a49bd4387c5769fbd12f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20P=C3=A9ron?= Date: Thu, 12 Sep 2024 18:46:27 +0100 Subject: [PATCH 18/39] Suppress GCC warning in TreeEnsembleAggregator (#22062) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description When building with GCC 14.2.1, I got the following warning: onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h:329:59: error: template-id not allowed for constructor in C++20 [-Werror=template-id-cdtor] Remove template parameters from the constructor: The constructor TreeAggregatorMax has been simplified to TreeAggregatorMax, because the compiler already knows the template parameters from the class definition. ### Motivation and Context Fix the build issue Signed-off-by: Clément Péron --- .../core/providers/cpu/ml/tree_ensemble_aggregator.h | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h index 34c6db61982b5..b031a6f0cefa3 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h @@ -328,11 +328,10 @@ class TreeAggregatorMin : public TreeAggregator class TreeAggregatorMax : public TreeAggregator { public: - TreeAggregatorMax(size_t n_trees, - const int64_t& n_targets_or_classes, - POST_EVAL_TRANSFORM post_transform, - const std::vector& base_values) : TreeAggregator(n_trees, n_targets_or_classes, - post_transform, base_values) {} + TreeAggregatorMax(size_t n_trees, + const int64_t& n_targets_or_classes, + POST_EVAL_TRANSFORM post_transform, + const std::vector& base_values) : TreeAggregator(n_trees, n_targets_or_classes, post_transform, base_values) {} // 1 output From d539c27de82b9d1631b743b941f9c3ade49e7a05 Mon Sep 17 00:00:00 2001 From: wangshuai09 <391746016@qq.com> Date: Fri, 13 Sep 2024 02:42:17 +0800 Subject: [PATCH 19/39] Fix version check for using -mavxvnni (#21616) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Change the `CMAKE_CXX_COMPILER_VERSION` greater than `11` for using '-mavxvnni'. ### Motivation and Context `CMakeFiles/onnxruntime_mlas.dir/root/Git.d/onnxruntime/onnxruntime/core/mlas/lib/x86_64/QgemmU8S8KernelAvx2.S.o cc: error: unrecognized command-line option ‘-mavxvnni’; did you mean ‘-mavx512vnni’?` using `gcc (GCC) 10.3.1`. `-mavxnni` is supported since [GCC 11 Release](https://gcc.gnu.org/gcc-11/changes.html), this PR change the version check. --- cmake/onnxruntime_mlas.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index cf23416943c1f..b612b3ead4658 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -578,7 +578,7 @@ else() message(STATUS "CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}") -if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "10") +if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "11") message(STATUS "Using -mavx2 -mfma -mavxvnni flags") set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mavxvnni") else() From 5c361106e61b94213784c7a6953e8a099235c7e4 Mon Sep 17 00:00:00 2001 From: 0xdr3dd Date: Fri, 13 Sep 2024 00:20:34 +0530 Subject: [PATCH 20/39] [Fuzzer] Add two new ORT libfuzzer (Linux clang support for now) (#22055) ### Description This PR adds two new libfuzzer in fuzzer project. 1. Binary libfuzzer 2. libprotobuf-fuzzer To compile run below cmd on linux: ``` LLVM_PROFILE_FILE="%p.profraw" CFLAGS="-g -fsanitize=address,fuzzer-no-link -shared-libasan -fprofile-instr-generate -fcoverage-mapping" CXXFLAGS="-g -shared-libasan -fsanitize=address,fuzzer-no-link -fprofile-instr-generate -fcoverage-mapping" CC=clang CXX=clang++ ./build.sh --update --build --config Debug --compile_no_warning_as_error --build_shared_lib --skip_submodule_sync --use_full_protobuf --parallel --fuzz_testing --build_dir build/ ``` Run fuzzer: ``` LD_PRELOAD=$(clang -print-file-name=libclang_rt.asan-x86_64.so) build/Debug/onnxruntime_libfuzzer_fuzz testinput -rss_limit_mb=8196 -max_total_time=472800 -fork=2 -jobs=4 -workers=4 -ignore_crashes=1 -max_len=2097152 2>&1 | grep -v "\[libprotobuf ERROR" ``` ### Motivation and Context The existing custom fuzzer is not coverage guided and it's slow and it will work on one model mutation at a time. The new fuzzers are coverage guided, and we can use more models' files as a corpus to increase the coverage. --- cmake/onnxruntime_fuzz_test.cmake | 145 ++++++++++++------ .../fuzzing/ort_libfuzzer/OrtLibfuzzer.cpp | 42 +++++ .../ort_libfuzzer/OrtProtoLibfuzzer.cpp | 94 ++++++++++++ 3 files changed, 236 insertions(+), 45 deletions(-) create mode 100644 onnxruntime/test/fuzzing/ort_libfuzzer/OrtLibfuzzer.cpp create mode 100644 onnxruntime/test/fuzzing/ort_libfuzzer/OrtProtoLibfuzzer.cpp diff --git a/cmake/onnxruntime_fuzz_test.cmake b/cmake/onnxruntime_fuzz_test.cmake index 26d41e98687d4..eea411d938176 100644 --- a/cmake/onnxruntime_fuzz_test.cmake +++ b/cmake/onnxruntime_fuzz_test.cmake @@ -4,23 +4,24 @@ # Check that the options are properly set for # the fuzzing project if (onnxruntime_FUZZ_ENABLED) - message(STATUS "Building dependency protobuf-mutator and libfuzzer") - - # set the options used to control the protobuf-mutator build - set(PROTOBUF_LIBRARIES ${PROTOBUF_LIB}) - set(LIB_PROTO_MUTATOR_TESTING OFF) - - # include the protobuf-mutator CMakeLists.txt rather than the projects CMakeLists.txt to avoid target clashes - # with google test - add_subdirectory("external/libprotobuf-mutator/src") - - # add the appropriate include directory and compilation flags - # needed by the protobuf-mutator target and the libfuzzer - set(PROTOBUF_MUT_INCLUDE_DIRS "external/libprotobuf-mutator") - onnxruntime_add_include_to_target(protobuf-mutator ${PROTOBUF_LIB}) - onnxruntime_add_include_to_target(protobuf-mutator-libfuzzer ${PROTOBUF_LIB}) - target_include_directories(protobuf-mutator PRIVATE ${INCLUDE_DIRECTORIES} ${PROTOBUF_MUT_INCLUDE_DIRS}) - target_include_directories(protobuf-mutator-libfuzzer PRIVATE ${INCLUDE_DIRECTORIES} ${PROTOBUF_MUT_INCLUDE_DIRS}) + message(STATUS "Building dependency protobuf-mutator and libfuzzer") + + # set the options used to control the protobuf-mutator build + set(PROTOBUF_LIBRARIES ${PROTOBUF_LIB}) + set(LIB_PROTO_MUTATOR_TESTING OFF) + + # include the protobuf-mutator CMakeLists.txt rather than the projects CMakeLists.txt to avoid target clashes + # with google test + add_subdirectory("external/libprotobuf-mutator/src") + + # add the appropriate include directory and compilation flags + # needed by the protobuf-mutator target and the libfuzzer + set(PROTOBUF_MUT_INCLUDE_DIRS "external/libprotobuf-mutator") + onnxruntime_add_include_to_target(protobuf-mutator ${PROTOBUF_LIB}) + onnxruntime_add_include_to_target(protobuf-mutator-libfuzzer ${PROTOBUF_LIB}) + target_include_directories(protobuf-mutator PRIVATE ${INCLUDE_DIRECTORIES} ${PROTOBUF_MUT_INCLUDE_DIRS}) + target_include_directories(protobuf-mutator-libfuzzer PRIVATE ${INCLUDE_DIRECTORIES} ${PROTOBUF_MUT_INCLUDE_DIRS}) + if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") # MSVC-specific compiler options target_compile_options(protobuf-mutator PRIVATE "/wd4244" "/wd4245" "/wd4267" "/wd4100" "/wd4456") @@ -44,42 +45,96 @@ if (onnxruntime_FUZZ_ENABLED) ) endif() - # add Fuzzing Engine Build Configuration - message(STATUS "Building Fuzzing engine") + # add Fuzzing Engine Build Configuration + message(STATUS "Building Fuzzing engine") + + # set Fuzz root directory + set(SEC_FUZZ_ROOT ${TEST_SRC_DIR}/fuzzing) + + # Security fuzzing engine src file reference + set(SEC_FUZ_SRC "${SEC_FUZZ_ROOT}/src/BetaDistribution.cpp" + "${SEC_FUZZ_ROOT}/src/OnnxPrediction.cpp" + "${SEC_FUZZ_ROOT}/src/testlog.cpp" + "${SEC_FUZZ_ROOT}/src/test.cpp") + + # compile the executables + onnxruntime_add_executable(onnxruntime_security_fuzz ${SEC_FUZ_SRC}) + + # compile with c++17 + target_compile_features(onnxruntime_security_fuzz PUBLIC cxx_std_17) - # set Fuzz root directory - set(SEC_FUZZ_ROOT ${TEST_SRC_DIR}/fuzzing) + # Security fuzzing engine header file reference + onnxruntime_add_include_to_target(onnxruntime_security_fuzz onnx onnxruntime) - # Security fuzzing engine src file reference - set(SEC_FUZ_SRC "${SEC_FUZZ_ROOT}/src/BetaDistribution.cpp" - "${SEC_FUZZ_ROOT}/src/OnnxPrediction.cpp" - "${SEC_FUZZ_ROOT}/src/testlog.cpp" - "${SEC_FUZZ_ROOT}/src/test.cpp") + # Assign all include to one variable + set(SEC_FUZ_INC "${SEC_FUZZ_ROOT}/include") + set(INCLUDE_FILES ${SEC_FUZ_INC} "$") - # compile the executables - onnxruntime_add_executable(onnxruntime_security_fuzz ${SEC_FUZ_SRC}) + # add all these include directory to the Fuzzing engine + target_include_directories(onnxruntime_security_fuzz PRIVATE ${INCLUDE_FILES}) - # compile with c++17 - target_compile_features(onnxruntime_security_fuzz PUBLIC cxx_std_17) + # add link libraries to the project + target_link_libraries(onnxruntime_security_fuzz onnx_proto onnxruntime protobuf-mutator ${PROTOBUF_LIB}) - # Security fuzzing engine header file reference - onnxruntime_add_include_to_target(onnxruntime_security_fuzz onnx onnxruntime) + # add the dependencies + add_dependencies(onnxruntime_security_fuzz onnx_proto onnxruntime protobuf-mutator ${PROTOBUF_LIB}) - # Assign all include to one variable - set(SEC_FUZ_INC "${SEC_FUZZ_ROOT}/include") - set(INCLUDE_FILES ${SEC_FUZ_INC} "$") + # copy the shared libraries (DLLs on Windows, SOs on Linux) to the execution directory + add_custom_command(TARGET onnxruntime_security_fuzz POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ $ + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ $) - # add all these include directory to the Fuzzing engine - target_include_directories(onnxruntime_security_fuzz PRIVATE ${INCLUDE_FILES}) + if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + # Add a second fuzzer that uses libFuzzer in fuzzer/libfuzzer + message(STATUS "Building libProtoBufFuzzer-based fuzzer") - # add link libraries the project - target_link_libraries(onnxruntime_security_fuzz onnx_proto onnxruntime protobuf-mutator ${PROTOBUF_LIB}) + # Set source files for the libFuzzer + set(LIBFUZZER_SRC "${SEC_FUZZ_ROOT}/src/OnnxPrediction.cpp" + "${SEC_FUZZ_ROOT}/src/testlog.cpp" + "${SEC_FUZZ_ROOT}/ort_libfuzzer/OrtProtoLibfuzzer.cpp") - # add the dependencies - add_dependencies(onnxruntime_security_fuzz onnx_proto onnxruntime protobuf-mutator ${PROTOBUF_LIB}) + # Compile the libFuzzer-based fuzzer + onnxruntime_add_executable(onnxruntime_proto_libfuzzer ${LIBFUZZER_SRC}) + # Security fuzzing engine header file reference + onnxruntime_add_include_to_target(onnxruntime_proto_libfuzzer onnx onnxruntime) + # Set include directories for libFuzzer + target_include_directories(onnxruntime_proto_libfuzzer PRIVATE ${INCLUDE_FILES}) - # copy the dlls to the execution directory - add_custom_command(TARGET onnxruntime_security_fuzz POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy_if_different $ $ - COMMAND ${CMAKE_COMMAND} -E copy_if_different $ $) + # Add link libraries for libFuzzer + target_link_libraries(onnxruntime_proto_libfuzzer onnx_proto onnxruntime protobuf-mutator protobuf-mutator-libfuzzer -fsanitize=fuzzer,address ${PROTOBUF_LIB}) + + # Add the dependencies for libFuzzer + add_dependencies(onnxruntime_proto_libfuzzer onnx_proto onnxruntime protobuf-mutator protobuf-mutator-libfuzzer ${PROTOBUF_LIB}) + + # Copy shared libraries for libFuzzer + add_custom_command(TARGET onnxruntime_proto_libfuzzer POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ $ + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ $) + # Add a second fuzzer that uses libFuzzer in fuzzer/libfuzzer + message(STATUS "Building libBufFuzzer-based fuzzer") + + # Set source files for the libFuzzer + set(LIBFUZZER_SRC "${SEC_FUZZ_ROOT}/src/OnnxPrediction.cpp" + "${SEC_FUZZ_ROOT}/src/testlog.cpp" + "${SEC_FUZZ_ROOT}/ort_libfuzzer/OrtLibfuzzer.cpp") + + # Compile the libFuzzer-based fuzzer + onnxruntime_add_executable(onnxruntime_libfuzzer_fuzz ${LIBFUZZER_SRC}) + # Security fuzzing engine header file reference + onnxruntime_add_include_to_target(onnxruntime_libfuzzer_fuzz onnx onnxruntime) + # Set include directories for libFuzzer + target_compile_definitions(onnxruntime_libfuzzer_fuzz PRIVATE GOOGLE_PROTOBUF_NO_LOGGING=1) + target_include_directories(onnxruntime_libfuzzer_fuzz PRIVATE ${INCLUDE_FILES}) + + # Add link libraries for libFuzzer + target_link_libraries(onnxruntime_libfuzzer_fuzz onnx_proto onnxruntime protobuf-mutator protobuf-mutator-libfuzzer -fsanitize=fuzzer,address ${PROTOBUF_LIB}) + + # Add the dependencies for libFuzzer + add_dependencies(onnxruntime_libfuzzer_fuzz onnx_proto onnxruntime protobuf-mutator protobuf-mutator-libfuzzer ${PROTOBUF_LIB}) + + # Copy shared libraries for libFuzzer + add_custom_command(TARGET onnxruntime_libfuzzer_fuzz POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ $ + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ $) + endif() endif() diff --git a/onnxruntime/test/fuzzing/ort_libfuzzer/OrtLibfuzzer.cpp b/onnxruntime/test/fuzzing/ort_libfuzzer/OrtLibfuzzer.cpp new file mode 100644 index 0000000000000..406aca722bb67 --- /dev/null +++ b/onnxruntime/test/fuzzing/ort_libfuzzer/OrtLibfuzzer.cpp @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "OnnxPrediction.h" +#include "onnxruntime_session_options_config_keys.h" +#include "src/libfuzzer/libfuzzer_macro.h" +#include "fuzzer/FuzzedDataProvider.h" + +Ort::Env env; + +void predict(onnx::ModelProto& msg, unsigned int seed, Ort::Env& env) { + // Create object for prediction + // + OnnxPrediction predict(msg, env); + + // Give predict a function to generate the data + // to run prediction on. + // + predict.SetupInput(GenerateDataForInputTypeTensor, seed); + + // Run the prediction on the data + // + predict.RunInference(); +} + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + FuzzedDataProvider data_provider(data, size); + onnx::ModelProto msg; + try { + if (!msg.ParseFromArray(data, static_cast(size))) { + return 0; // Ignore invalid inputs + } + predict(msg, data_provider.ConsumeIntegral(), env); + } catch (const std::exception& e) { + // Optionally log or suppress the exception + // std::cerr << "Caught exception: " << e.what() << std::endl; + } catch (...) { + // Handle any other exceptions + // std::cerr << "Caught unknown exception." << std::endl; + } + return 0; +} diff --git a/onnxruntime/test/fuzzing/ort_libfuzzer/OrtProtoLibfuzzer.cpp b/onnxruntime/test/fuzzing/ort_libfuzzer/OrtProtoLibfuzzer.cpp new file mode 100644 index 0000000000000..607d9cfd9c755 --- /dev/null +++ b/onnxruntime/test/fuzzing/ort_libfuzzer/OrtProtoLibfuzzer.cpp @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "src/mutator.h" +#include "OnnxPrediction.h" +#include "onnxruntime_session_options_config_keys.h" +#include "src/libfuzzer/libfuzzer_macro.h" +#include "onnx/onnx_pb.h" + +#include + +Ort::Env env; + +std::string wstring_to_string(const std::wstring& wstr) { + std::wstring_convert> converter; + return converter.to_bytes(wstr); +} + +void predict(onnx::ModelProto& msg, unsigned int seed, Ort::Env& env) { + // Create object for prediction + // + OnnxPrediction predict(msg, env); + + // Give predict a function to generate the data + // to run prediction on. + // + predict.SetupInput(GenerateDataForInputTypeTensor, seed); + + // Run the prediction on the data + // + predict.RunInference(); + + // View the output + // + predict.PrintOutputValues(); +} + +template +using PostProcessor = + protobuf_mutator::libfuzzer::PostProcessorRegistration; + +// Helper function to generate random strings +std::string generate_random_string(size_t length, std::mt19937& rng) { + const std::string characters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; + std::uniform_int_distribution<> dist(0, characters.size() - 1); + std::string result; + for (size_t i = 0; i < length; ++i) { + result += characters[dist(rng)]; + } + return result; +} + +// Helper function to generate random float +float generate_random_float(std::mt19937& rng) { + std::uniform_real_distribution dist(0.0f, 1.0f); + return dist(rng); +} + +// PostProcessor for ONNX ModelProto with random values +static PostProcessor reg1 = { + [](onnx::ModelProto* model_proto, unsigned int seed) { + std::mt19937 rng(seed); + + // Set model's IR version + model_proto->set_ir_version(7); + + model_proto->set_producer_name("onnx"); + model_proto->set_producer_version("7.0"); + model_proto->set_domain("example.com"); + + // Add a dummy opset import + auto* opset_import = model_proto->add_opset_import(); + opset_import->set_version(10); + + // Access the graph from the model + auto* graph = model_proto->mutable_graph(); + + // Set a random name for the graph + graph->set_name(generate_random_string(10, rng)); + }}; + +DEFINE_PROTO_FUZZER(const onnx::ModelProto& msg) { + try { + auto seed = static_cast(std::chrono::system_clock::now().time_since_epoch().count()); + onnx::ModelProto msg_proto = msg; + predict(msg_proto, seed, env); + } catch (const std::exception& e) { + // Optionally log or suppress the exception + // std::cerr << "Caught exception: " << e.what() << std::endl; + } catch (...) { + // Handle any other exceptions + // std::cerr << "Caught unknown exception." << std::endl; + } +} From 55ab13e7ca8c5147ff5d7e82da5b6bde01720f7d Mon Sep 17 00:00:00 2001 From: mingyueliuh <131847423+mingyueliuh@users.noreply.github.com> Date: Thu, 12 Sep 2024 19:23:09 -0400 Subject: [PATCH 21/39] [VitisAI] support memory buffer contains the TensorProto external data (#22042) ### Description Extend VitisAI EP `tensor_proto_as_raw` API to support memory buffer containing the TensorProto external data ### Motivation and Context For reduce peak memory usage, VitisAI EP need support ORT format model and setting session option `session.use_ort_model_bytes_for_initializers` for enable directly use the model bytes for initializers. Co-authored-by: mingyue --- .../providers/vitisai/imp/tensor_proto.cc | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc index 4b2b7610cf7ea..872d022e85264 100644 --- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc +++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc @@ -9,9 +9,44 @@ #include "core/providers/shared_library/provider_api.h" namespace vaip { using namespace onnxruntime; + +static gsl::span process_ext_address(const ONNX_NAMESPACE::TensorProto& tensor) { + auto tensor_proto = const_cast(&tensor); + auto file = std::string(); + uintptr_t offset = 0; + size_t size = 0; + if (tensor_proto->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL) { + auto external_data = tensor_proto->mutable_external_data(); + auto external_data_size = external_data->size(); + for (auto i = 0; i < external_data_size; ++i) { + auto& data = external_data->at(i); + char* end = nullptr; + if (*data.mutable_key() == "location") { + file = *data.mutable_value(); + } else if (*data.mutable_key() == "offset") { + offset = (uintptr_t)std::strtoull(data.mutable_value()->data(), &end, 10); + } else if (*data.mutable_key() == "length") { + size = (size_t)std::strtoull(data.mutable_value()->data(), &end, 10); + } else if (*data.mutable_key() == "checksum") { + // checksum = (size_t)std::strtoull(data.mutable_value()->data(), &end, 10); + } + } + if (file == "*/_ORT_MEM_ADDR_/*") { + auto addr = reinterpret_cast(offset); + return {addr, size}; + } + } + return {}; +} + gsl::span tensor_proto_as_raw(const onnxruntime::Graph& graph, const ONNX_NAMESPACE::TensorProto& tensor) { auto& mut_tensor = const_cast(tensor); if (!tensor.has_raw_data()) { + auto maybe_external_memory_address = process_ext_address(tensor); + if (!maybe_external_memory_address.empty()) { + return maybe_external_memory_address; + } + std::vector unpacked_tensor; auto path = graph.ModelPath(); auto s = onnxruntime::utils::UnpackInitializerData(tensor, path, unpacked_tensor); From f7bf5a19baf0a7caa9cca7dc08bf192e392a14e4 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Thu, 12 Sep 2024 17:18:50 -0700 Subject: [PATCH 22/39] [QNN EP] Ensure QNN EP rejects nodes with I/O of dynamic shape (#22066) ### Description Updates QNN EP to properly reject nodes that have inputs or outputs with dynamic shapes. ### Motivation and Context Currently, QNN EP does not properly offload subgraphs with dynamic shapes to the CPU EP. This PR ensures that QNN EP rejects nodes that consume or generate I/O with dynamic shapes. --- .../qnn/builder/qnn_model_wrapper.cc | 4 +- .../test/providers/qnn/qnn_basic_test.cc | 57 +++++++++++++++++++ .../test/providers/qnn/qnn_test_utils.cc | 4 +- .../test/providers/qnn/qnn_test_utils.h | 6 +- 4 files changed, 68 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index 3c029fda9cd52..2c7f3c8b22ddd 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -308,8 +308,10 @@ bool QnnModelWrapper::GetOnnxShape(const NodeArg& node_arg, std::vectordim()) { + if (!dim.has_dim_value()) { + return false; // Do not support dynamic shapes. + } shape.push_back(SafeInt(dim.dim_value())); } diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 9d19c36dc94b2..c4367aeb52edc 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -948,6 +948,63 @@ TEST_F(QnnHTPBackendTests, Float32ModelWithFP16PrecisionTest) { 0.008f); } +// Test that QNN EP only handles nodes with static shapes and rejects nodes with dynamic shape I/O. +TEST_F(QnnHTPBackendTests, EPRejectsDynamicShapesF32) { + // Local function that builds a model in which the last two nodes use dynamic shapes. + auto model_build_fn = [](ModelTestBuilder& builder) { + NodeArg* input1 = builder.MakeInput(std::vector{1, 2, 8, 8}, + GetFloatDataInRange(0.0f, 1.0f, 128)); + NodeArg* input2 = builder.MakeInput(std::vector{3}, std::vector{1, 2, 49}); + + // Add a Conv with known shapes. QNN EP should support it. + NodeArg* weight = builder.MakeInitializer(std::vector{2, 2, 2, 2}, + GetFloatDataInRange(-0.3f, 0.3f, 16)); + NodeArg* bias = builder.MakeInitializer(std::vector{2}, {0.0f, 1.0f}); + + auto* conv_output = builder.MakeIntermediate(); + builder.AddNode("Conv", {input1, weight, bias}, {conv_output}); + + // Add a Reshape to a dynamic shape. QNN EP should reject this node. + auto* reshape_output = builder.MakeIntermediate(); + builder.AddNode("Reshape", {conv_output, input2}, {reshape_output}); + + // Add a Softmax. QNN EP should reject this node because its input has a dynamic shape. + NodeArg* output = builder.MakeOutput(); + builder.AddNode("Softmax", {reshape_output}, {output}); + }; + + // Local function that checks that the nodes with dynamic shape I/O were assigned to CPU EP. + std::function ep_graph_checker = [](const Graph& graph) { + for (const Node& node : graph.Nodes()) { + const std::string& ep_name = node.GetExecutionProviderType(); + const std::string& op_type = node.OpType(); + if (op_type == "Reshape" || op_type == "Softmax") { + EXPECT_EQ(ep_name, kCpuExecutionProvider); + } else { + EXPECT_EQ(ep_name, kQnnExecutionProvider); + } + } + }; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + provider_options["enable_htp_fp16_precision"] = "1"; // QNN EP will use fp16 precision. + // CPU EP will use fp32, so we can relax accuracy requirements. + + RunQnnModelTest(model_build_fn, + provider_options, + /*opset*/ 19, + ExpectedEPNodeAssignment::Some, + /*abs_err*/ 1e-4f, + logging::Severity::kERROR, + /*verify_output*/ true, + &ep_graph_checker); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index afaa5a341d5e9..8a4f7f2a1f6b5 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -98,10 +98,12 @@ void TryEnableQNNSaver(ProviderOptions& qnn_options) { void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions provider_options, int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, - float fp32_abs_err, logging::Severity log_severity, bool verify_outputs) { + float fp32_abs_err, logging::Severity log_severity, bool verify_outputs, + std::function* ep_graph_checker) { EPVerificationParams verification_params; verification_params.ep_node_assignment = expected_ep_assignment; verification_params.fp32_abs_err = fp32_abs_err; + verification_params.graph_verifier = ep_graph_checker; // Add kMSDomain to cover contrib op like Gelu const std::unordered_map domain_to_version = {{"", opset_version}, {kMSDomain, 1}}; diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index 3a6753e9b6131..bb77c92668853 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -1033,12 +1033,16 @@ inline GetTestQDQModelFn BuildQDQOpTestCase( * \param expected_ep_assignment How many nodes are expected to be assigned to QNN (All, Some, or None). * \param fp32_abs_err The acceptable error between CPU EP and QNN EP. * \param log_severity The logger's minimum severity level. + * \param verify_outputs True to verify that the outputs match (within tolerance). + * \param ep_graph_checker Function called on the Graph generated for the EP's session. Used to check node + * EP assignment. */ void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions provider_options, int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, float fp32_abs_err = 1e-5f, logging::Severity log_severity = logging::Severity::kERROR, - bool verify_outputs = true); + bool verify_outputs = true, + std::function* ep_graph_checker = nullptr); enum class BackendSupport { SUPPORT_UNKNOWN, From 22437b581b8559702fe9f5a5fe2309a495bd9e15 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Thu, 12 Sep 2024 22:38:17 -0400 Subject: [PATCH 23/39] [java] Fix for OnnxTensor creation when passing in a ByteBuffer containing elements of a different type (#21774) ### Description Fixes a bug where the buffer offset and position was incorrectly computed if the user supplied a `ByteBuffer` to `createTensor` but set the type of the tensor to something other than `INT8`. This would be more common if the user was trying to load the initializers from a serialized representation and didn't want to bother with the type information (which is the case in #21321). ### Motivation and Context Partial fix for #21321. The remainder of the fix is to add a helper which allows users to load initializers out of an `onnx_data` file, but that will require adding protobuf as a dependency for the Java API to allow the parsing of an ONNX file separately from the native code. It might be nicer to put that functionality into ORT's C API so it can return the lengths & offsets of the initializers when provided with an ONNX file containing external initializers. We hit this kind of thing in Java more often than other languages as in Java models can be supplied as classpath resources which we can easily read, but not materialize on disk for the ORT native library to read. --- .../src/main/java/ai/onnxruntime/OrtUtil.java | 13 ++++++---- .../java/ai/onnxruntime/OnnxTensorTest.java | 26 ++++++++++++++++++- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/OrtUtil.java b/java/src/main/java/ai/onnxruntime/OrtUtil.java index 5b2e9b2efac4c..4f3dee3c00b91 100644 --- a/java/src/main/java/ai/onnxruntime/OrtUtil.java +++ b/java/src/main/java/ai/onnxruntime/OrtUtil.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved. * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. */ @@ -483,9 +483,12 @@ static BufferTuple prepareBuffer(Buffer data, OnnxJavaType type) { if (type == OnnxJavaType.STRING || type == OnnxJavaType.UNKNOWN) { throw new IllegalStateException("Cannot create a " + type + " tensor from a buffer"); } + // This buffer could be a ByteBuffer which is being used to carry data of another type, if so, + // it's type.size should be 1 to compute the correct buffer size and offset. + int elementSize = data instanceof ByteBuffer ? 1 : type.size; int bufferPos; - long bufferSizeLong = data.remaining() * (long) type.size; - if (bufferSizeLong > (Integer.MAX_VALUE - (8 * type.size))) { + long bufferSizeLong = data.remaining() * (long) elementSize; + if (bufferSizeLong > (Integer.MAX_VALUE - (8L * elementSize))) { // The maximum direct byte buffer size is a little below Integer.MAX_VALUE depending // on the JVM, so we check for something 8 elements below the maximum size which // should be allocatable (assuming there is enough memory) on all 64-bit JVMs. @@ -496,11 +499,11 @@ static BufferTuple prepareBuffer(Buffer data, OnnxJavaType type) { + type); } // Now we know we're in range - int bufferSize = data.remaining() * type.size; + int bufferSize = data.remaining() * elementSize; Buffer tmp; if (data.isDirect()) { tmp = data; - bufferPos = data.position() * type.size; + bufferPos = data.position() * elementSize; } else { // Copy the data to a new direct buffer, then restore the state of the input. int origPosition = data.position(); diff --git a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java index c060cf73ecf14..ea210d96c1507 100644 --- a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java +++ b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -218,6 +218,30 @@ public void testUint8Creation() throws OrtException { } } + @Test + public void testByteBufferCreation() throws OrtException { + OrtEnvironment env = OrtEnvironment.getEnvironment(); + ByteBuffer byteBuf = ByteBuffer.allocateDirect(Float.BYTES * 5).order(ByteOrder.nativeOrder()); + FloatBuffer floatBuf = byteBuf.asFloatBuffer(); + floatBuf.put(1.0f); + floatBuf.put(2.0f); + floatBuf.put(3.0f); + floatBuf.put(4.0f); + floatBuf.put(5.0f); + floatBuf.position(1); + float[] expected = new float[floatBuf.remaining()]; + floatBuf.get(expected); + floatBuf.position(1); + byteBuf.position(4); + try (OnnxTensor t = + OnnxTensor.createTensor( + env, byteBuf, new long[] {floatBuf.remaining()}, OnnxJavaType.FLOAT)) { + Assertions.assertNotNull(t); + float[] actual = (float[]) t.getValue(); + Assertions.assertArrayEquals(expected, actual); + } + } + @Test public void testEmptyTensor() throws OrtException { OrtEnvironment env = OrtEnvironment.getEnvironment(); From 904b850b445ccfb3dc935e39b96cfb3dbfb52673 Mon Sep 17 00:00:00 2001 From: Michael Tyler <67695629+MichaelTylerArm@users.noreply.github.com> Date: Fri, 13 Sep 2024 04:51:59 +0100 Subject: [PATCH 24/39] Update Arm Compute Library Execution Provider (#22032) ### Description This PR makes the following updates to the Arm Compute Library execution provider: - Target Arm Compute Library 24.07 - Add support for the following operators: - Conv (FP16) - NhwcConv - QLinearConv - MatMul - FusedMatMul - MatMulIntegerToFloat - Optimize memory usage and performance - Expose the enable_fast_math setting - Use the main runtime thread pool ### Motivation and Context These updates improve performance and memory usage, and enable use of a more recent version of Arm Compute Library. @microsoft-github-policy-service agree company="Arm Ltd" --------- Signed-off-by: Michael Tyler --- cmake/CMakeLists.txt | 39 +- .../core/providers/acl/acl_provider_factory.h | 4 +- .../main/java/ai/onnxruntime/OrtSession.java | 10 +- ...ai_onnxruntime_OrtSession_SessionOptions.c | 8 +- .../core/optimizer/graph_transformer_utils.cc | 33 +- .../core/optimizer/nhwc_transformer.cc | 4 +- .../qdq_selector_action_transformer.cc | 5 +- onnxruntime/core/providers/acl/acl_common.cc | 155 ++++- onnxruntime/core/providers/acl/acl_common.h | 27 +- .../providers/acl/acl_execution_provider.cc | 110 +++- .../providers/acl/acl_execution_provider.h | 30 +- .../providers/acl/acl_provider_factory.cc | 16 +- .../acl/acl_provider_factory_creator.h | 3 +- onnxruntime/core/providers/acl/math/gemm.h | 33 +- onnxruntime/core/providers/acl/math/matmul.cc | 404 ++++++++++++ onnxruntime/core/providers/acl/math/matmul.h | 64 ++ .../core/providers/acl/nn/batch_norm.cc | 5 +- onnxruntime/core/providers/acl/nn/conv.cc | 605 +++++++++++------- onnxruntime/core/providers/acl/nn/conv.h | 57 +- .../core/providers/acl/nn/fused_conv.cc | 9 +- onnxruntime/core/providers/acl/nn/pool.cc | 21 +- onnxruntime/core/providers/acl/scheduler.cc | 44 ++ onnxruntime/core/providers/acl/scheduler.h | 33 + .../core/providers/acl/tensor/concat.cc | 9 +- .../python/onnxruntime_pybind_schema.cc | 3 +- .../python/onnxruntime_pybind_state.cc | 22 +- .../python/onnxruntime_pybind_state_common.h | 3 +- onnxruntime/test/onnx/main.cc | 3 +- .../test/perftest/command_args_parser.cc | 2 + onnxruntime/test/perftest/ort_test_session.cc | 38 +- onnxruntime/test/providers/cpu/model_tests.cc | 3 +- onnxruntime/test/util/default_providers.cc | 7 +- .../test/util/include/default_providers.h | 3 +- tools/ci_build/build.py | 10 +- 34 files changed, 1396 insertions(+), 426 deletions(-) create mode 100644 onnxruntime/core/providers/acl/math/matmul.cc create mode 100644 onnxruntime/core/providers/acl/math/matmul.h create mode 100644 onnxruntime/core/providers/acl/scheduler.cc create mode 100644 onnxruntime/core/providers/acl/scheduler.h diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index fb3b75fda4eaf..3d4f055bb6f53 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1,4 +1,5 @@ # Copyright (c) Microsoft Corporation. All rights reserved. +# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates # Licensed under the MIT License. # Minimum CMake required @@ -132,11 +133,6 @@ option(onnxruntime_USE_DML "Build with DirectML support" OFF) option(onnxruntime_USE_MIGRAPHX "Build with AMDMIGraphX support" OFF) option(onnxruntime_USE_WINML "Build with WinML support" OFF) option(onnxruntime_USE_ACL "Build with ACL support" OFF) -option(onnxruntime_USE_ACL_1902 "Build with ACL version 1902 support" OFF) -option(onnxruntime_USE_ACL_1905 "Build with ACL version 1905 support" OFF) -option(onnxruntime_USE_ACL_1908 "Build with ACL version 1908 support" OFF) -option(onnxruntime_USE_ACL_2002 "Build with ACL version 2002 support" OFF) -option(onnxruntime_USE_ACL_2308 "Build with ACL version 2308 support" OFF) option(onnxruntime_USE_ARMNN "Build with ArmNN support" OFF) option(onnxruntime_ARMNN_RELU_USE_CPU "Use the CPU implementation for the Relu operator for the ArmNN EP" ON) option(onnxruntime_ARMNN_BN_USE_CPU "Use the CPU implementation for the Batch Normalization operator for the ArmNN EP" ON) @@ -1207,25 +1203,8 @@ function(onnxruntime_add_include_to_target dst_target) endfunction() # ACL -if (onnxruntime_USE_ACL OR onnxruntime_USE_ACL_1902 OR onnxruntime_USE_ACL_1905 OR onnxruntime_USE_ACL_1908 OR onnxruntime_USE_ACL_2002 OR onnxruntime_USE_ACL_2308) +if (onnxruntime_USE_ACL) set(onnxruntime_USE_ACL ON) - if (onnxruntime_USE_ACL_1902) - add_definitions(-DACL_1902=1) - else() - if (onnxruntime_USE_ACL_1908) - add_definitions(-DACL_1908=1) - else() - if (onnxruntime_USE_ACL_2002) - add_definitions(-DACL_2002=1) - else() - if (onnxruntime_USE_ACL_2308) - add_definitions(-DACL_2308=1) - else() - add_definitions(-DACL_1905=1) - endif() - endif() - endif() - endif() if (NOT ${onnxruntime_ACL_LIBS} STREQUAL "") add_library(arm_compute SHARED IMPORTED) @@ -1233,18 +1212,13 @@ if (onnxruntime_USE_ACL OR onnxruntime_USE_ACL_1902 OR onnxruntime_USE_ACL_1905 IMPORTED_NO_SONAME 1 IMPORTED_LOCATION "${onnxruntime_ACL_LIBS}/libarm_compute.so") - add_library(arm_compute_core SHARED IMPORTED) - set_target_properties(arm_compute_core PROPERTIES - IMPORTED_NO_SONAME 1 - IMPORTED_LOCATION "${onnxruntime_ACL_LIBS}/libarm_compute_core.so") - add_library(arm_compute_graph SHARED IMPORTED) set_target_properties(arm_compute_graph PROPERTIES IMPORTED_NO_SONAME 1 IMPORTED_LOCATION "${onnxruntime_ACL_LIBS}/libarm_compute_graph.so") endif() - list(APPEND onnxruntime_EXTERNAL_LIBRARIES arm_compute arm_compute_core arm_compute_graph) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES arm_compute arm_compute_graph) endif() @@ -1263,11 +1237,6 @@ if (onnxruntime_USE_ARMNN) IMPORTED_NO_SONAME 1 IMPORTED_LOCATION "${onnxruntime_ACL_LIBS}/libarm_compute.so") - add_library(arm_compute_core SHARED IMPORTED) - set_target_properties(arm_compute_core PROPERTIES - IMPORTED_NO_SONAME 1 - IMPORTED_LOCATION "${onnxruntime_ACL_LIBS}/libarm_compute_core.so") - add_library(arm_compute_graph SHARED IMPORTED) set_target_properties(arm_compute_graph PROPERTIES IMPORTED_NO_SONAME 1 @@ -1281,7 +1250,7 @@ if (onnxruntime_USE_ARMNN) IMPORTED_LOCATION "${onnxruntime_ARMNN_LIBS}/libarmnn.so") endif() - list(APPEND onnxruntime_EXTERNAL_LIBRARIES armnn arm_compute arm_compute_core arm_compute_graph) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES armnn arm_compute arm_compute_graph) endif() if (onnxruntime_USE_DNNL) diff --git a/include/onnxruntime/core/providers/acl/acl_provider_factory.h b/include/onnxruntime/core/providers/acl/acl_provider_factory.h index 0dc0ec27ff345..8875a83a39f54 100644 --- a/include/onnxruntime/core/providers/acl/acl_provider_factory.h +++ b/include/onnxruntime/core/providers/acl/acl_provider_factory.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "onnxruntime_c_api.h" @@ -10,7 +11,8 @@ extern "C" { /** * \param use_arena zero: false. non-zero: true. */ -ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_ACL, _In_ OrtSessionOptions* options, int use_arena) +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_ACL, _In_ OrtSessionOptions* options, + bool enable_fast_math) ORT_ALL_ARGS_NONNULL; #ifdef __cplusplus diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index 8ab4a1cb26bb1..8fe73ff69e169 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -1,5 +1,6 @@ /* * Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved. + * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates * Licensed under the MIT License. */ package ai.onnxruntime; @@ -1181,12 +1182,12 @@ public void addDirectML(int deviceId) throws OrtException { /** * Adds the ARM Compute Library as an execution backend. * - * @param useArena If true use the arena memory allocator. + * @param enableFastMath Enable fast math mode in ACL. * @throws OrtException If there was an error in native code. */ - public void addACL(boolean useArena) throws OrtException { + public void addACL(boolean enableFastMath) throws OrtException { checkClosed(); - addACL(OnnxRuntime.ortApiHandle, nativeHandle, useArena ? 1 : 0); + addACL(OnnxRuntime.ortApiHandle, nativeHandle, enableFastMath); } /** @@ -1354,7 +1355,8 @@ private native void addTvm(long apiHandle, long nativeHandle, String settings) private native void addDirectML(long apiHandle, long nativeHandle, int deviceId) throws OrtException; - private native void addACL(long apiHandle, long nativeHandle, int useArena) throws OrtException; + private native void addACL(long apiHandle, long nativeHandle, boolean enableFastMath) + throws OrtException; private native void addArmNN(long apiHandle, long nativeHandle, int useArena) throws OrtException; diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c index 337f4c1921c6e..ff9348c299e90 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c @@ -1,5 +1,6 @@ /* * Copyright (c) 2019, 2023 Oracle and/or its affiliates. All rights reserved. + * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates * Licensed under the MIT License. */ #include @@ -644,12 +645,13 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addDir * Signature: (JJI)V */ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addACL - (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint useArena) { + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jboolean enableFastMath) { (void)jobj; #ifdef USE_ACL - checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,OrtSessionOptionsAppendExecutionProvider_ACL((OrtSessionOptions*) handle,useArena)); + checkOrtStatus(jniEnv,(const OrtApi*)apiHandle, + OrtSessionOptionsAppendExecutionProvider_ACL((OrtSessionOptions*) handle, enableFastMath)); #else - (void)apiHandle;(void)handle;(void)useArena; // Parameters used when ACL is defined. + (void)apiHandle;(void)handle;(void)enableFastMath; // Parameters used when ACL is defined. throwOrtException(jniEnv,convertErrorCode(ORT_INVALID_ARGUMENT),"This binary was not compiled with ACL support."); #endif } diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 0530ab771e0be..997d99441d36d 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "core/optimizer/graph_transformer_utils.h" @@ -196,6 +197,8 @@ InlinedVector> GenerateTransformers( session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsDisableQuantQDQ, "0") == "1"; #ifndef DISABLE_CONTRIB_OPS const InlinedHashSet cpu_ep = {onnxruntime::kCpuExecutionProvider}; + const InlinedHashSet cpu_acl_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kAclExecutionProvider}; #endif const InlinedHashSet dml_ep = {onnxruntime::kDmlExecutionProvider}; AllocatorPtr cpu_allocator = std::make_shared(); @@ -285,6 +288,11 @@ InlinedVector> GenerateTransformers( onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider, onnxruntime::kDmlExecutionProvider}; + const InlinedHashSet cpu_acl_cuda_dml_rocm_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kAclExecutionProvider, + onnxruntime::kCudaExecutionProvider, + onnxruntime::kRocmExecutionProvider, + onnxruntime::kDmlExecutionProvider}; const InlinedHashSet cpu_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider, onnxruntime::kRocmExecutionProvider, onnxruntime::kAclExecutionProvider, @@ -296,8 +304,9 @@ InlinedVector> GenerateTransformers( onnxruntime::kAclExecutionProvider, onnxruntime::kArmNNExecutionProvider, onnxruntime::kJsExecutionProvider}; - const InlinedHashSet cpu_dml_eps = {onnxruntime::kCpuExecutionProvider, - onnxruntime::kDmlExecutionProvider}; + const InlinedHashSet cpu_dml_acl_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kDmlExecutionProvider, + onnxruntime::kAclExecutionProvider}; const int64_t qdq_matmulnbits_accuracy_level = ParseStringWithClassicLocale( session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsQDQMatMulNBitsAccuracyLevel, @@ -323,26 +332,26 @@ InlinedVector> GenerateTransformers( } transformers.emplace_back(std::make_unique(cpu_ep)); - transformers.emplace_back(std::make_unique(cpu_dml_eps)); - transformers.emplace_back(std::make_unique(cpu_ep)); + transformers.emplace_back(std::make_unique(cpu_dml_acl_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_eps)); transformers.emplace_back(std::make_unique(cpu_rocm_acl_armnn_js_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps, level)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps, level)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps, level)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps, level)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); // GeluApproximation has side effects which may change results. It needs to be manually enabled, // or alternatively the model can be updated offline using a model conversion script @@ -367,7 +376,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); #endif - transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(dml_ep)); #ifdef MLAS_TARGET_AMD64_IX86 diff --git a/onnxruntime/core/optimizer/nhwc_transformer.cc b/onnxruntime/core/optimizer/nhwc_transformer.cc index e67557dcf9391..ee79fa620374e 100644 --- a/onnxruntime/core/optimizer/nhwc_transformer.cc +++ b/onnxruntime/core/optimizer/nhwc_transformer.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include @@ -183,7 +184,8 @@ Status NhwcTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, modified = false; for (std::unique_ptr& node : api_graph->Nodes()) { // If the node is not supported in the CPU EP, skip it - if (node->GetExecutionProviderType() != kCpuExecutionProvider) { + const auto ep = node->GetExecutionProviderType(); + if ((ep != kCpuExecutionProvider) && (ep != kAclExecutionProvider)) { continue; } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc index adfa680878945..1c506bafd1d14 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include @@ -381,9 +382,9 @@ QDQSelectorActionTransformer::QDQSelectorActionTransformer( CreateSelectorActionRegistry(is_int8_allowed, qdq_matmulnbits_accuracy_level, intra_op_thread_pool, p_buffered_tensors), apply_context, - // this transformer is compatible with CPU, DML and CUDA EP. + // this transformer is compatible with CPU, DML, ACL and CUDA EP. // There is further EP control on the rule level. - {kCpuExecutionProvider, kDmlExecutionProvider, kCudaExecutionProvider}} { + {kCpuExecutionProvider, kDmlExecutionProvider, kAclExecutionProvider, kCudaExecutionProvider}} { } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/acl_common.cc b/onnxruntime/core/providers/acl/acl_common.cc index f1ab6682a8259..c8d878a81bd1a 100644 --- a/onnxruntime/core/providers/acl/acl_common.cc +++ b/onnxruntime/core/providers/acl/acl_common.cc @@ -1,5 +1,6 @@ // Copyright(C) 2018 Intel Corporation // Copyright (c) 2019-2020, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License #ifdef _WIN32 @@ -8,14 +9,45 @@ #include "core/providers/acl/acl_common.h" -#include "arm_compute/runtime/PoolManager.h" -#include "arm_compute/runtime/BlobLifetimeManager.h" - -#undef ACL_1902 - namespace onnxruntime { namespace acl { +void PopulateWorkspace(const arm_compute::experimental::MemoryRequirements& reqs, + Workspace& workspace, arm_compute::MemoryGroup& memory_group, + arm_compute::ITensorPack& run_pack, arm_compute::ITensorPack& prep_pack) { + for (const arm_compute::experimental::MemoryInfo& req : reqs) { + if (req.size == 0) { + continue; + } + + arm_compute::Tensor* aux_tensor; + if (req.lifetime == arm_compute::experimental::MemoryLifetime::Temporary) { + workspace.temporary_tensors.emplace_back(std::make_unique()); + aux_tensor = workspace.temporary_tensors.back().get(); + + memory_group.manage(aux_tensor); + } else if (req.lifetime == arm_compute::experimental::MemoryLifetime::Prepare) { + workspace.prepare_tensors.emplace_back(std::make_unique()); + aux_tensor = workspace.prepare_tensors.back().get(); + + prep_pack.add_tensor(req.slot, aux_tensor); + } else { + workspace.persistent_tensors.emplace_back(std::make_unique()); + aux_tensor = workspace.persistent_tensors.back().get(); + + prep_pack.add_tensor(req.slot, aux_tensor); + } + run_pack.add_tensor(req.slot, aux_tensor); + + const auto aux_info = arm_compute::TensorInfo{arm_compute::TensorShape(req.size), 1, arm_compute::DataType::U8}; + aux_tensor->allocator()->init(aux_info, req.alignment); + } + + for (const std::unique_ptr& tensor : workspace.temporary_tensors) { + tensor->allocator()->allocate(); + } +} + arm_compute::TensorShape ACLTensorShape(const TensorShape& tensorShape, unsigned int extDim) { arm_compute::TensorShape shape; unsigned int inDim = tensorShape.NumDimensions(); @@ -36,27 +68,112 @@ arm_compute::TensorShape ACLTensorShape(const TensorShape& tensorShape, unsigned return shape; } +Status GetArgShape(const NodeArg* tensor, TensorShape& outShape) { + const auto& inShape = tensor->Shape(); + TensorShapeVector shapeVec; + + for (int i = 0; i < inShape->dim_size(); i++) { + const auto& dim = inShape->dim(i); + ORT_RETURN_IF_NOT(dim.has_dim_value(), "ACL does not support unknown tensor shapes: ", tensor->Name()); + shapeVec.push_back(dim.dim_value()); + } + + outShape = TensorShape(shapeVec); + return Status::OK(); +} + void ACLPrintTensorShape(const char* s, arm_compute::Tensor& t) { for (unsigned int i = 0; i < t.info()->tensor_shape().num_dimensions(); i++) LOGS_DEFAULT(VERBOSE) << "ACL " << s << " " << t.info()->tensor_shape()[i]; LOGS_DEFAULT(VERBOSE) << std::endl; } -std::shared_ptr ACLCreateMemoryManager() { - auto lifetime_mgr = std::make_shared(); - auto pool_mgr = std::make_shared(); - auto mm = std::make_shared(lifetime_mgr, pool_mgr); +arm_compute::DataType ACLDataType(const std::string& dtype) { + if (dtype == "tensor(float)") { + return arm_compute::DataType::F32; + } + if (dtype == "tensor(float16)") { + return arm_compute::DataType::F16; + } + if (dtype == "tensor(bfloat16)") { + return arm_compute::DataType::BFLOAT16; + } + if (dtype == "tensor(uint8)") { + return arm_compute::DataType::QASYMM8; + } + if (dtype == "tensor(int8)") { + return arm_compute::DataType::QASYMM8_SIGNED; + } + if (dtype == "tensor(int32)") { + return arm_compute::DataType::S32; + } + ORT_THROW("ACL execution provider does not support data type ", dtype); +} - return mm; +int GetIntScalar(const Tensor* tensor) { + ORT_ENFORCE(tensor->Shape().Size() == 1, "Tensor is not a scalar"); + if (tensor->IsDataType()) { + return *tensor->Data(); + } + if (tensor->IsDataType()) { + return *tensor->Data(); + } + ORT_THROW("Unsupported int type: ", DataTypeImpl::ToString(tensor->DataType())); } -arm_compute::Status ACLImportMemory(arm_compute::TensorAllocator* allocator, void* memory, size_t size) { -#ifdef ACL_1902 - return allocator->import_memory(memory, size); -#else +Status LoadQuantizationInfo(const OpKernelInfo& info, arm_compute::Tensor* tensor, + const int scaleIdx, const int zpIdx, bool flipZeroPoint) { + const Tensor* scaleTensor = nullptr; + ORT_RETURN_IF_NOT(info.TryGetConstantInput(scaleIdx, &scaleTensor), "Scale must be constant"); + + const Tensor* zeroPointTensor = nullptr; + ORT_RETURN_IF_NOT(info.TryGetConstantInput(zpIdx, &zeroPointTensor), "Zero point must be constant"); + + const float* scale = scaleTensor->Data(); + const int zeroPoint = GetIntScalar(zeroPointTensor); + tensor->info()->set_quantization_info(arm_compute::QuantizationInfo(*scale, flipZeroPoint ? -zeroPoint : zeroPoint)); + + return Status::OK(); +} + +void GetPackingInfo(std::vector>& state, size_t& packedSize, size_t& alignment) { + alignment = 0; + for (auto& tensor : state) { + alignment = std::max(alignment, tensor->allocator()->alignment()); + } + + packedSize = 0; + for (auto& tensor : state) { + const size_t size = tensor->info()->total_size(); + packedSize += ((size - 1) / alignment + 1) * alignment; + } +} + +Status LoadPackedTensors(std::vector>& state, void* packed, + const size_t packedSize, const size_t alignment) { + auto buffSize = packedSize + alignment; + uint8_t* alignedPtr = (uint8_t*)(alignment == 0 ? packed : std::align(alignment, packedSize, packed, buffSize)); + + uint8_t* currentPtr = alignedPtr; + for (auto& tensor : state) { + ORT_RETURN_IF_ERROR(ACLImportMemory(tensor->allocator(), currentPtr, 0)); + + const size_t size = tensor->info()->total_size(); + currentPtr += ((size - 1) / alignment + 1) * alignment; + } + + return Status::OK(); +} + +Status ACLImportMemory(arm_compute::TensorAllocator* allocator, void* memory, size_t size) { ORT_UNUSED_PARAMETER(size); - return allocator->import_memory(memory); -#endif + arm_compute::Status status = allocator->import_memory(memory); + + if (status) { + return Status::OK(); + } else { + return Status(common::ONNXRUNTIME, common::FAIL, status.error_description()); + } } template @@ -71,12 +188,13 @@ void importDataToTensor(arm_compute::Tensor* tensor, const T* data) { arm_compute::execute_window_loop( aclInpuWindow, [&](const arm_compute::Coordinates& co) { - *reinterpret_cast(aclInputIt.ptr()) = data[index]; + *reinterpret_cast(aclInputIt.ptr()) = data[index]; index++; }, aclInputIt); } template void importDataToTensor(arm_compute::Tensor*, const float*); +template void importDataToTensor(arm_compute::Tensor*, const MLFloat16*); template void importDataFromTensor(arm_compute::Tensor* tensor, T* data) { @@ -89,12 +207,13 @@ void importDataFromTensor(arm_compute::Tensor* tensor, T* data) { arm_compute::execute_window_loop( aclInpuWindow, [&](const arm_compute::Coordinates& co) { - data[index] = *reinterpret_cast(aclInputIt.ptr()); + data[index] = *reinterpret_cast(aclInputIt.ptr()); index++; }, aclInputIt); } template void importDataFromTensor(arm_compute::Tensor*, float*); +template void importDataFromTensor(arm_compute::Tensor*, MLFloat16*); } // namespace acl } // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/acl_common.h b/onnxruntime/core/providers/acl/acl_common.h index 899736c477165..f2e89de15efd9 100644 --- a/onnxruntime/core/providers/acl/acl_common.h +++ b/onnxruntime/core/providers/acl/acl_common.h @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2019-2020, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #pragma once @@ -7,17 +8,37 @@ #include "core/framework/op_kernel.h" // ACL +#include "arm_compute/core/experimental/Types.h" #include "arm_compute/runtime/Tensor.h" #include "arm_compute/runtime/TensorAllocator.h" -#include "arm_compute/runtime/MemoryManagerOnDemand.h" namespace onnxruntime { namespace acl { +struct Workspace { + std::vector> temporary_tensors; + std::vector> prepare_tensors; + std::vector> persistent_tensors; +}; + +void PopulateWorkspace(const arm_compute::experimental::MemoryRequirements& reqs, + Workspace& workspace, arm_compute::MemoryGroup& memory_group, + arm_compute::ITensorPack& run_pack, arm_compute::ITensorPack& prep_pack); + arm_compute::TensorShape ACLTensorShape(const TensorShape& tensorShape, unsigned int extDim = 0); +Status GetArgShape(const NodeArg* tensor, TensorShape& outShape); void ACLPrintTensorShape(const char*, arm_compute::Tensor& t); -std::shared_ptr ACLCreateMemoryManager(); -arm_compute::Status ACLImportMemory(arm_compute::TensorAllocator* allocator, void* memory, size_t size); +arm_compute::DataType ACLDataType(const std::string& dtype); + +int GetIntScalar(const Tensor* tensor); +Status LoadQuantizationInfo(const OpKernelInfo& info, arm_compute::Tensor* tensor, + const int scaleIdx, const int zpIdx, bool flipZeroPoint); + +void GetPackingInfo(std::vector>& state, size_t& packedSize, size_t& alignment); +Status LoadPackedTensors(std::vector>& state, void* packed, + const size_t packedSize, const size_t alignment); + +Status ACLImportMemory(arm_compute::TensorAllocator* allocator, void* memory, size_t size); template void importDataToTensor(arm_compute::Tensor* tensor, const T* data); template diff --git a/onnxruntime/core/providers/acl/acl_execution_provider.cc b/onnxruntime/core/providers/acl/acl_execution_provider.cc index d19dc15e17f6d..8d34e36fe7cd6 100644 --- a/onnxruntime/core/providers/acl/acl_execution_provider.cc +++ b/onnxruntime/core/providers/acl/acl_execution_provider.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2019-2020, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "acl_execution_provider.h" @@ -7,13 +8,19 @@ #include "core/framework/op_kernel.h" #include "core/framework/kernel_registry.h" #include "core/framework/compute_capability.h" +#include "core/providers/acl/math/matmul.h" +#include "core/providers/acl/nn/conv.h" +#include "core/session/inference_session.h" #include "contrib_ops/cpu/cpu_contrib_kernels.h" #include "acl_fwd.h" +#include "scheduler.h" -namespace onnxruntime { +#include "arm_compute/runtime/Scheduler.h" +#include "arm_compute/runtime/PoolManager.h" +#include "arm_compute/runtime/BlobLifetimeManager.h" +#include "arm_compute/runtime/Allocator.h" -constexpr const char* ACL = "Acl"; -constexpr const char* ACL_CPU = "AclCpu"; +namespace onnxruntime { namespace acl { @@ -22,7 +29,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 6, Rel class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 7, 8, Gemm); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 9, 10, Gemm); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 11, Gemm); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 1, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 11, Conv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 11, MLFloat16, Conv); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 7, 9, float, AveragePool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 1, 7, float, MaxPool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 8, 11, float, MaxPool); @@ -39,6 +47,22 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDoma class ONNX_OPERATOR_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 11, Concat); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kMSDomain, 1, float, FusedConv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kAclExecutionProvider, kMSDomain, 1, NhwcConv); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 13, MatMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 13, MLFloat16, MatMul); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kAclExecutionProvider, kMSDomain, 1, FusedMatMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kMSDomain, 1, uint8_t, MatMulIntegerToFloat); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kMSDomain, 1, int8_t, MatMulIntegerToFloat); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 10, uint8_t, QLinearConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kOnnxDomain, 10, int8_t, QLinearConv); + +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kMSDomain, 1, uint8_t, QLinearConv); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kAclExecutionProvider, kMSDomain, 1, int8_t, QLinearConv); Status RegisterACLKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { @@ -48,7 +72,8 @@ Status RegisterACLKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -67,6 +92,22 @@ Status RegisterACLKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { @@ -85,10 +126,22 @@ std::shared_ptr GetAclKernelRegistry() { return kernel_registry; } +std::shared_ptr ACLCreateMemoryManager() { + auto lifetime_mgr = std::make_shared(); + auto pool_mgr = std::make_shared(); + auto mm = std::make_shared(lifetime_mgr, pool_mgr); + + return mm; +} + } // namespace acl -ACLExecutionProvider::ACLExecutionProvider(const ACLExecutionProviderInfo&) - : IExecutionProvider{onnxruntime::kAclExecutionProvider} {} +ACLExecutionProvider::ACLExecutionProvider(const ACLExecutionProviderInfo& info) + : IExecutionProvider{onnxruntime::kAclExecutionProvider}, + info(info), + memory_manager(onnxruntime::acl::ACLCreateMemoryManager()) { + arm_compute::Scheduler::set(std::make_shared(this)); +} ACLExecutionProvider::~ACLExecutionProvider() {} @@ -97,4 +150,47 @@ std::shared_ptr ACLExecutionProvider::GetKernelRegistry() const return kernel_registry; } +std::vector> +ACLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, + const IKernelLookup& kernel_lookup) const { + std::vector> result; + for (const auto& node : graph.Nodes()) { + if (const KernelCreateInfo* kernel_create_info = kernel_lookup.LookUpKernel(node); + kernel_create_info != nullptr) { + Status support_status = Status::OK(); + const std::string op_name = kernel_create_info->kernel_def->OpName(); + + if (op_name == "Conv" || op_name == "NhwcConv" || op_name == "QLinearConv") { + support_status = onnxruntime::acl::ValidateConv(node); + } + if (op_name == "MatMul" || op_name == "FusedMatMul" || op_name == "MatMulIntegerToFloat") { + support_status = onnxruntime::acl::ValidateMatMul(node); + } + + if (support_status.IsOK()) { + std::unique_ptr sub_graph = std::make_unique(); + sub_graph->nodes.push_back(node.Index()); + result.push_back(std::make_unique(std::move(sub_graph))); + } else { + LOGS_DEFAULT(WARNING) << "ACL supports operator " << op_name + << ", but not with these parameters. Using fallback for node: " << node.Name() + << " Reason: " << support_status.ErrorMessage(); + } + } + } + + return result; +} + +Status ACLExecutionProvider::OnRunStart(const onnxruntime::RunOptions&) { + arm_compute::Allocator alloc{}; + memory_manager->populate(alloc, 1); + return Status::OK(); +}; + +Status ACLExecutionProvider::OnRunEnd(bool, const onnxruntime::RunOptions&) { + memory_manager->clear(); + return Status::OK(); +}; + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/acl_execution_provider.h b/onnxruntime/core/providers/acl/acl_execution_provider.h index 126656e0956bb..1c267d8713673 100755 --- a/onnxruntime/core/providers/acl/acl_execution_provider.h +++ b/onnxruntime/core/providers/acl/acl_execution_provider.h @@ -1,20 +1,24 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2019, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #pragma once #include "core/framework/execution_provider.h" #include "core/graph/constants.h" +#include "core/platform/threadpool.h" + +#include "arm_compute/runtime/MemoryManagerOnDemand.h" namespace onnxruntime { // Information needed to construct ACL execution providers. struct ACLExecutionProviderInfo { - bool create_arena{true}; + bool enable_fast_math{false}; - explicit ACLExecutionProviderInfo(bool use_arena) - : create_arena(use_arena) {} + explicit ACLExecutionProviderInfo(bool enable_fast_math) + : enable_fast_math(enable_fast_math) {} ACLExecutionProviderInfo() = default; }; @@ -31,6 +35,26 @@ class ACLExecutionProvider : public IExecutionProvider { } std::shared_ptr GetKernelRegistry() const override; + + std::vector> GetCapability( + const onnxruntime::GraphViewer& graph, + const IKernelLookup& kernel_lookup) const override; + + Status OnRunStart(const onnxruntime::RunOptions&) override; + + Status OnRunEnd(bool, const onnxruntime::RunOptions&) override; + + void SetThreadPool(concurrency::ThreadPool* thread_pool) { + thread_pool_ = thread_pool; + } + + concurrency::ThreadPool* GetThreadPool() const { + return thread_pool_; + } + + const ACLExecutionProviderInfo info; + const std::shared_ptr memory_manager; + concurrency::ThreadPool* thread_pool_ = nullptr; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/acl_provider_factory.cc b/onnxruntime/core/providers/acl/acl_provider_factory.cc index 4eb11b222e576..26a41afeeee36 100755 --- a/onnxruntime/core/providers/acl/acl_provider_factory.cc +++ b/onnxruntime/core/providers/acl/acl_provider_factory.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2019, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "core/providers/acl/acl_provider_factory.h" @@ -11,27 +12,28 @@ namespace onnxruntime { struct ACLProviderFactory : IExecutionProviderFactory { - ACLProviderFactory(bool create_arena) : create_arena_(create_arena) {} + ACLProviderFactory(bool enable_fast_math) : enable_fast_math_(enable_fast_math) {} ~ACLProviderFactory() override {} std::unique_ptr CreateProvider() override; private: - bool create_arena_; + bool enable_fast_math_; }; std::unique_ptr ACLProviderFactory::CreateProvider() { ACLExecutionProviderInfo info; - info.create_arena = create_arena_; + info.enable_fast_math = enable_fast_math_; return std::make_unique(info); } -std::shared_ptr ACLProviderFactoryCreator::Create(int use_arena) { - return std::make_shared(use_arena != 0); +std::shared_ptr ACLProviderFactoryCreator::Create(bool enable_fast_math) { + return std::make_shared(enable_fast_math); } } // namespace onnxruntime -ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_ACL, _In_ OrtSessionOptions* options, int use_arena) { - options->provider_factories.push_back(onnxruntime::ACLProviderFactoryCreator::Create(use_arena)); +ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_ACL, _In_ OrtSessionOptions* options, + bool enable_fast_math) { + options->provider_factories.push_back(onnxruntime::ACLProviderFactoryCreator::Create(enable_fast_math)); return nullptr; } diff --git a/onnxruntime/core/providers/acl/acl_provider_factory_creator.h b/onnxruntime/core/providers/acl/acl_provider_factory_creator.h index 2eee50ee710da..31a596f2d4bbc 100644 --- a/onnxruntime/core/providers/acl/acl_provider_factory_creator.h +++ b/onnxruntime/core/providers/acl/acl_provider_factory_creator.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #pragma once @@ -10,7 +11,7 @@ namespace onnxruntime { struct ACLProviderFactoryCreator { - static std::shared_ptr Create(int use_arena); + static std::shared_ptr Create(bool enable_fast_math); }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/math/gemm.h b/onnxruntime/core/providers/acl/math/gemm.h index f5288d7f231b0..5db2372705184 100644 --- a/onnxruntime/core/providers/acl/math/gemm.h +++ b/onnxruntime/core/providers/acl/math/gemm.h @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2019-2020, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #pragma once @@ -28,7 +29,6 @@ namespace acl { typedef struct { std::shared_ptr layer; std::shared_ptr a, b, c, d; - std::shared_ptr mm_layer; } ACLNEGEMM; typedef std::map::iterator GEMMLayersIterator; @@ -37,6 +37,9 @@ template class Gemm : public onnxruntime::Gemm { public: Gemm(const OpKernelInfo& info) : onnxruntime::Gemm(info) { + provider_ = (const_cast( + static_cast(info.GetExecutionProvider()))); + int64_t temp; ORT_ENFORCE(info.GetAttr("transA", &temp).IsOK()); @@ -49,12 +52,11 @@ class Gemm : public onnxruntime::Gemm { } Status Compute(OpKernelContext* context) const override { -#ifdef ACL_2308 if (this->packed_b_) { // Prepacked RHS not supported, defaulting to cpu execution provider return onnxruntime::Gemm::Compute(context); } -#endif + const auto A = context->Input(0); const auto B = context->Input(1); const auto C = context->Input(2); @@ -96,19 +98,20 @@ class Gemm : public onnxruntime::Gemm { (cShape[1] == 1 && cShape[0] != (long unsigned int)N)) { return onnxruntime::Gemm::Compute(context); } -#ifdef ACL_2308 cShape = arm_compute::TensorShape(N); LOGS_DEFAULT(VERBOSE) << "Bias reshaped to: {" << N << "}"; -#else - cShape = arm_compute::TensorShape(1, N); - LOGS_DEFAULT(VERBOSE) << "Bias reshaped to: {1," << N << "}"; -#endif } int64_t K = helper.K(); - if (A) LOGS_DEFAULT(VERBOSE) << "A " << A->Shape().ToString().c_str(); - if (B) LOGS_DEFAULT(VERBOSE) << "B " << B->Shape().ToString().c_str(); - if (C) LOGS_DEFAULT(VERBOSE) << "C " << C->Shape().ToString().c_str(); + if (A) { + LOGS_DEFAULT(VERBOSE) << "A " << A->Shape().ToString().c_str(); + } + if (B) { + LOGS_DEFAULT(VERBOSE) << "B " << B->Shape().ToString().c_str(); + } + if (C) { + LOGS_DEFAULT(VERBOSE) << "C " << C->Shape().ToString().c_str(); + } LOGS_DEFAULT(VERBOSE) << "D " << D->Shape().ToString().c_str(); LOGS_DEFAULT(VERBOSE) << "M " << (int)M << ", N " << (int)N << ", K " << (int)K; LOGS_DEFAULT(VERBOSE) << "Alfa " << alpha_ << ", Beta " << beta_; @@ -131,10 +134,8 @@ class Gemm : public onnxruntime::Gemm { // dimensions are stored in the opposite order to ACL's tGEMM.d->allocator()->init(arm_compute::TensorInfo(arm_compute::TensorShape(N, M), arm_compute::Format::F32)); - tGEMM.mm_layer = ACLCreateMemoryManager(); - if (FC) { - auto layer = std::make_shared(tGEMM.mm_layer); + auto layer = std::make_shared(provider_->memory_manager); arm_compute::FullyConnectedLayerInfo fc_info; fc_info.transpose_weights = trans_B_ == CblasTrans; layer->configure(tGEMM.a.get(), tGEMM.b.get(), useC ? tGEMM.c.get() : nullptr, tGEMM.d.get(), fc_info); @@ -173,10 +174,7 @@ class Gemm : public onnxruntime::Gemm { ACLPrintTensorShape("c", *pGEMM->c); ACLPrintTensorShape("d", *pGEMM->d); - arm_compute::Allocator alloc_mm{}; - pGEMM->mm_layer->populate(alloc_mm, 1); pGEMM->layer->run(); - pGEMM->mm_layer->clear(); if (D->Shape().Size() != 0 && pGEMM->d->info()->has_padding()) { importDataFromTensor(pGEMM->d.get(), d_data); @@ -195,6 +193,7 @@ class Gemm : public onnxruntime::Gemm { } private: + ACLExecutionProvider* provider_; static thread_local std::map gemmLayers; CBLAS_TRANSPOSE trans_A_; diff --git a/onnxruntime/core/providers/acl/math/matmul.cc b/onnxruntime/core/providers/acl/math/matmul.cc new file mode 100644 index 0000000000000..468b394471c13 --- /dev/null +++ b/onnxruntime/core/providers/acl/math/matmul.cc @@ -0,0 +1,404 @@ +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#include +#include +#include "core/common/status.h" +#include "core/framework/op_kernel_info.h" +#include "core/framework/op_node_proto_helper.h" +#include "core/framework/tensor_shape.h" +#ifdef _WIN32 +#pragma warning(disable : 4244) +#endif +#include +#include + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/util/math.h" +#include "core/util/math_cpuonly.h" + +#include "core/providers/acl/math/matmul.h" +#include "core/providers/acl/acl_common.h" +#include "core/providers/acl/acl_fwd.h" + +// ACL +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/runtime/NEON/functions/NEMatMul.h" +#include "src/cpu/operators/CpuGemm.h" +#include "src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h" +#include "src/cpu/operators/CpuMatMul.h" + +namespace onnxruntime { + +namespace acl { + +TensorShape BroadcastInput(const TensorShape& shape, bool prependDim) { + const auto nd = shape.NumDimensions(); + if (nd == 0) { + ORT_THROW("MatMul by scalar not allowed"); + } + + int64_t batchSize = 1; + if (nd == 1) { + if (prependDim) { + return {1, 1, shape[0]}; + } else { + return {1, shape[0], 1}; + } + } + + for (size_t i = 0; i < nd - 2; i++) { + batchSize *= shape[i]; + } + + return {batchSize, shape[nd - 2], shape[nd - 1]}; +} + +struct MatMulConfig { + bool isQuantized; + float alpha; + bool transA; + bool transB; + TensorShape aShapeBroadcast; + TensorShape bShapeBroadcast; +}; + +Status ParseMatMul(const onnxruntime::Node& node, MatMulConfig& config) { + onnxruntime::ProtoHelperNodeContext ctx(node); + onnxruntime::OpNodeProtoHelper attrs(&ctx); + const auto inputDefs = node.InputDefs(); + + config.isQuantized = node.OpType() == "MatMulIntegerToFloat"; + + config.alpha = 1; + attrs.GetAttr("alpha", &config.alpha); + + int64_t transA = 0; + attrs.GetAttr("transA", &transA); + int64_t transB = 0; + attrs.GetAttr("transB", &transB); + + config.transA = transA; + config.transB = transB; + + const int64_t transBatchA = attrs.GetAttrOrDefault("transBatchA", 0); + const int64_t transBatchB = attrs.GetAttrOrDefault("transBatchB", 0); + + ORT_RETURN_IF(transBatchA, "transBatchA not supported by ACL"); + ORT_RETURN_IF(transBatchB, "transBatchB not supported by ACL"); + + ORT_RETURN_IF(config.isQuantized && inputDefs.size() >= 7, "ACL MatMulIntegerToFloat does not support bias"); + + TensorShape aShapeIn; + ORT_RETURN_IF_ERROR(GetArgShape(inputDefs[0], aShapeIn)); + + TensorShape bShapeIn; + ORT_RETURN_IF_ERROR(GetArgShape(inputDefs[1], bShapeIn)); + + config.aShapeBroadcast = BroadcastInput(aShapeIn, !config.transA); + config.bShapeBroadcast = BroadcastInput(bShapeIn, config.transB); + + ORT_RETURN_IF(!(config.bShapeBroadcast[0] == 1 || (config.aShapeBroadcast[0] == config.bShapeBroadcast[0])), + "ACL does not support broadcasting"); + + ORT_RETURN_IF(config.alpha != 1 && config.bShapeBroadcast[0] > 1, + "ACL does not support alpha scaling with batched B"); + + return Status::OK(); +} + +Status ValidateMatMul(const onnxruntime::Node& node) { + MatMulConfig config; + return ParseMatMul(node, config); +} + +MatMul::MatMul(const OpKernelInfo& info) : onnxruntime::OpKernel(info) { + provider_ = (const_cast( + static_cast(info.GetExecutionProvider()))); + + const auto inputDefs = OpKernel::Node().InputDefs(); + const auto outputDefs = OpKernel::Node().OutputDefs(); + + const Tensor* tmp = nullptr; + const bool aIsConst = info.TryGetConstantInput(0, &tmp); + const bool bIsConst = info.TryGetConstantInput(1, &tmp); + + MatMulConfig config; + ORT_THROW_IF_ERROR(ParseMatMul(OpKernel::Node(), config)); + + ORT_THROW_IF_ERROR(GetArgShape(outputDefs[0], outShape)); + if (outShape.Size() == 0) { + return; + } + + const TensorShape aShape{ + config.aShapeBroadcast[0], + config.aShapeBroadcast[config.transA ? 2 : 1], + config.aShapeBroadcast[config.transA ? 1 : 2]}; + + const TensorShape bShape{ + config.bShapeBroadcast[0], + config.bShapeBroadcast[config.transB ? 2 : 1], + config.bShapeBroadcast[config.transB ? 1 : 2]}; + + const TensorShape outShapeBroadcast{aShape[0], aShape[1], bShape[2]}; + + ORT_ENFORCE(outShape.Size() == outShapeBroadcast.Size(), "Output sizes do not match"); + + arm_compute::DataType aType = ACLDataType(*inputDefs[0]->Type()); + arm_compute::DataType bType = ACLDataType(*inputDefs[1]->Type()); + arm_compute::DataType outType = ACLDataType(*outputDefs[0]->Type()); + + arm_compute::GEMMInfo gemmInfo(false, false, bIsConst); + gemmInfo.set_fast_math(provider_->info.enable_fast_math); + + a = std::make_shared(); + b = std::make_shared(); + out = std::make_shared(); + + a->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(config.aShapeBroadcast), 1, aType)); + b->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(config.bShapeBroadcast), 1, bType)); + out->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(outShapeBroadcast), 1, outType)); + + if (config.isQuantized) { + ORT_THROW_IF_ERROR(LoadQuantizationInfo(info, a.get(), 2, 4, true)); + ORT_THROW_IF_ERROR(LoadQuantizationInfo(info, b.get(), 3, 5, true)); + } + + arm_compute::ITensor* a_to_use = a.get(); + if (config.transA) { + a_transposed = std::make_shared(); + a_transposed->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(aShape), 1, aType)); + a_to_use = a_transposed.get(); + + a_permute = std::make_shared(); + a_permute->configure(a.get(), a_transposed.get(), {1, 0, 2}); + } + + arm_compute::ITensor* b_to_use = b.get(); + if (config.transB) { + if (bIsConst) { + workspace.persistent_tensors.emplace_back(std::make_unique()); + b_transposed = workspace.persistent_tensors.back().get(); + } else { + workspace.temporary_tensors.emplace_back(std::make_unique()); + b_transposed = workspace.temporary_tensors.back().get(); + } + + b_transposed->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(bShape), 1, bType), 128); + b_to_use = b_transposed; + + b_permute = std::make_shared(); + b_permute->configure(b.get(), b_transposed, {1, 0, 2}); + } + + a_to_use->info()->set_are_values_constant(aIsConst); + b_to_use->info()->set_are_values_constant(bIsConst); + + if (config.bShapeBroadcast[0] > 1) { + arm_compute::CpuMatMulSettings settings; + settings.fast_math(provider_->info.enable_fast_math); + + a_to_use->info()->set_are_values_constant(false); + b_to_use->info()->set_are_values_constant(false); + + const auto matmul = std::make_shared(); + matmul->configure(a_to_use->info(), b_to_use->info(), out->info(), {}, settings, {}); + layer = std::move(matmul); + } else if (config.isQuantized) { + const auto gemm = std::make_shared(); + gemm->configure(a_to_use->info(), b_to_use->info(), nullptr, out->info(), gemmInfo); + layer = std::move(gemm); + } else { + const auto gemm = std::make_shared(); + gemm->configure(a_to_use->info(), b_to_use->info(), nullptr, out->info(), config.alpha, 0.f, gemmInfo); + layer = std::move(gemm); + } + + memory_group = arm_compute::MemoryGroup(provider_->memory_manager); + run_pack = {{arm_compute::ACL_SRC_0, a_to_use}, {arm_compute::ACL_SRC_1, b_to_use}, {arm_compute::ACL_DST, out.get()}}; + prep_pack = {{arm_compute::ACL_SRC_1, b_to_use}}; + + PopulateWorkspace(layer->workspace(), workspace, memory_group, run_pack, prep_pack); +} + +Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { + is_packed = false; + if (input_idx != 1 || outShape.Size() == 0) { + return Status::OK(); + } + + const uint8_t* data = (uint8_t*)tensor.DataRaw(); + + ORT_RETURN_IF_ERROR(ACLImportMemory(b->allocator(), (void*)data, 0)); + + if (!workspace.persistent_tensors.empty()) { + size_t packedSize = 0; + size_t alignment = 0; + GetPackingInfo(workspace.persistent_tensors, packedSize, alignment); + auto buffSize = packedSize + alignment; + + pbRaw = IAllocator::MakeUniquePtr(alloc, buffSize, true); + ORT_RETURN_IF_ERROR(LoadPackedTensors(workspace.persistent_tensors, pbRaw.get(), packedSize, alignment)); + + if (prepacked_weights != nullptr) { + prepacked_weights->buffers_.push_back(std::move(pbRaw)); + prepacked_weights->buffer_sizes_.push_back(buffSize); + } + + is_packed = true; + } + + if (b_transposed) { + b_permute->run(); + } + + for (std::unique_ptr& prep_tensor : workspace.prepare_tensors) { + prep_tensor->allocator()->allocate(); + } + + layer->prepare(prep_pack); + + for (std::unique_ptr& prep_tensor : workspace.prepare_tensors) { + prep_tensor->allocator()->free(); + } + + return Status::OK(); +} + +Status MatMul::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + int input_idx, /*out*/ bool& used_shared_buffers) { + used_shared_buffers = false; + if (input_idx != 1) { + return Status::OK(); + } + + if (!workspace.persistent_tensors.empty()) { + size_t packedSize = 0; + size_t alignment = 0; + GetPackingInfo(workspace.persistent_tensors, packedSize, alignment); + + ORT_RETURN_IF_ERROR(LoadPackedTensors(workspace.persistent_tensors, prepacked_buffers[0].get(), packedSize, alignment)); + + used_shared_buffers = true; + } + + return Status::OK(); +} + +Status MatMul::Compute(OpKernelContext* context) const { + provider_->SetThreadPool(context->GetOperatorThreadPool()); + + const Tensor* A = context->Input(0); + const Tensor* B = pbRaw ? nullptr : context->Input(1); + + Tensor* outOrt = context->Output(0, outShape); + + if (outShape.Size() == 0) { + return Status::OK(); + } + + const void* a_data = A->DataRaw(); + const void* b_data = B == nullptr ? nullptr : B->DataRaw(); + void* out_data = outOrt->MutableDataRaw(); + + ORT_RETURN_IF(A->Shape().Size() != 0 && a->info()->has_padding(), "Padded ACL input tensor not supported"); + ORT_RETURN_IF_ERROR(ACLImportMemory(a->allocator(), (void*)a_data, 0)); + + if (b_data != nullptr) { + ORT_RETURN_IF_ERROR(ACLImportMemory(b->allocator(), (void*)b_data, 0)); + } + + ORT_RETURN_IF(outOrt->Shape().Size() != 0 && out->info()->has_padding(), "Padded ACL output tensor not supported"); + ORT_RETURN_IF_ERROR(ACLImportMemory(out->allocator(), (void*)out_data, 0)); + + ORT_RETURN_IF(B != nullptr && workspace.persistent_tensors.size(), "Persistent state requires pre-packing"); + + if (a_transposed) { + a_transposed->allocator()->allocate(); + a_permute->run(); + } + + { + arm_compute::MemoryGroupResourceScope scope_mg(const_cast(memory_group)); + if (b_transposed && B) { + b_permute->run(); + } + + layer->run(const_cast(run_pack)); + } + + a->allocator()->free(); + if (B != nullptr) + b->allocator()->free(); + out->allocator()->free(); + + if (a_transposed) { + a_transposed->allocator()->free(); + } + + return Status::OK(); +} + +ONNX_OPERATOR_KERNEL_EX( + MatMul, + kOnnxDomain, + 13, + kAclExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMul, + kOnnxDomain, + 13, + MLFloat16, + kAclExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_OPERATOR_KERNEL_EX( + FusedMatMul, + kMSDomain, + 1, + kAclExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + FusedMatMul, + kMSDomain, + 1, + MLFloat16, + kAclExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulIntegerToFloat, + kMSDomain, + 1, + uint8_t, + kAclExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + MatMul); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulIntegerToFloat, + kMSDomain, + 1, + int8_t, + kAclExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + MatMul); + +} // namespace acl +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/math/matmul.h b/onnxruntime/core/providers/acl/math/matmul.h new file mode 100644 index 0000000000000..b137e33833de9 --- /dev/null +++ b/onnxruntime/core/providers/acl/math/matmul.h @@ -0,0 +1,64 @@ +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#pragma once +#include "core/framework/op_kernel.h" +#include "core/providers/acl/acl_common.h" +#include "core/providers/acl/acl_execution_provider.h" + +// ACL +#include "arm_compute/runtime/Tensor.h" +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/runtime/IOperator.h" +#include "arm_compute/runtime/Tensor.h" +#include "arm_compute/runtime/TensorAllocator.h" +#include "arm_compute/runtime/Allocator.h" +#include "arm_compute/runtime/PoolManager.h" +#include "arm_compute/runtime/BlobLifetimeManager.h" +#include "arm_compute/runtime/MemoryManagerOnDemand.h" + +// NEON +#include "arm_compute/runtime/NEON/functions/NEGEMM.h" +#include "arm_compute/runtime/NEON/functions/NEPermute.h" + +namespace onnxruntime { +namespace acl { + +Status ValidateMatMul(const onnxruntime::Node& node); + +class MatMul : public OpKernel { + public: + explicit MatMul(const OpKernelInfo& info); + + Status PrePack(const Tensor&, int, AllocatorPtr, + bool& is_packed, PrePackedWeights*) override; + + Status UseSharedPrePackedBuffers(std::vector&, + int, bool&) override; + + Status Compute(OpKernelContext* context) const override; + + protected: + ACLExecutionProvider* provider_; + std::shared_ptr a_permute; + std::shared_ptr b_permute; + std::shared_ptr layer; + + arm_compute::MemoryGroup memory_group; + arm_compute::ITensorPack run_pack; + arm_compute::ITensorPack prep_pack; + + Workspace workspace; + + std::shared_ptr a; + std::shared_ptr b; + std::shared_ptr a_transposed; + arm_compute::Tensor* b_transposed = nullptr; + std::shared_ptr out; + arm_compute::Tensor* pb; + + IAllocatorUniquePtr pbRaw; + TensorShape outShape; +}; +} // namespace acl +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/nn/batch_norm.cc b/onnxruntime/core/providers/acl/nn/batch_norm.cc index be0e57c5c0543..192bc34556eef 100755 --- a/onnxruntime/core/providers/acl/nn/batch_norm.cc +++ b/onnxruntime/core/providers/acl/nn/batch_norm.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2020, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "core/common/common.h" @@ -80,7 +81,6 @@ Status BatchNorm::Compute(OpKernelContext* context) const { auto layer = std::make_shared(); -#ifdef ACL_2308 arm_compute::TensorShape in_x_shape; const TensorShape& x_shape = X->Shape(); const auto& dims_vec = x_shape.GetDims(); @@ -94,9 +94,6 @@ Status BatchNorm::Compute(OpKernelContext* context) const { in_x_shape.set(2, onnxruntime::narrow(dims_vec[1])); // C tbatch_norm.in->allocator()->init(arm_compute::TensorInfo(in_x_shape, arm_compute::Format::F32)); -#else - tbatch_norm.in->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(X->Shape()), arm_compute::Format::F32)); -#endif tbatch_norm.out->allocator()->init(arm_compute::TensorInfo(tbatch_norm.in->info()->tensor_shape(), arm_compute::Format::F32)); tbatch_norm.scale->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(S->Shape()), arm_compute::Format::F32)); diff --git a/onnxruntime/core/providers/acl/nn/conv.cc b/onnxruntime/core/providers/acl/nn/conv.cc index 85bd0cfe96279..a62158f1c26ee 100644 --- a/onnxruntime/core/providers/acl/nn/conv.cc +++ b/onnxruntime/core/providers/acl/nn/conv.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2019-2020, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #ifdef _WIN32 @@ -19,31 +20,81 @@ // ACL #include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/ITensorPack.h" +#include "src/cpu/operators/CpuConv2d.h" // NEON -#include "arm_compute/runtime/NEON/functions/NEConvolutionLayer.h" #include "arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h" -#ifdef ACL_1902 -#include "arm_compute/core/NEON/kernels/NEDepthwiseConvolutionLayer3x3Kernel.h" -#endif -#if defined(ACL_1905) || defined(ACL_1908) -#include "arm_compute/runtime/NEON/functions/assembly/NEDepthwiseConvolutionAssemblyDispatch.h" -#endif - #define CONV_ACL #undef DEPTHWISE_CPU #define PREF_DIM 4 namespace onnxruntime { + namespace acl { -template -thread_local std::map Conv::convLayers; +struct ConvConfig { + bool isQuantized; + bool is_channels_last; + bool isDepthwise; + TensorShape inShapeIn; + TensorShape kShapeIn; + const std::string* inType; + const std::string* kType; +}; + +Status ParseConv(const onnxruntime::Node& node, ConvConfig& config) { + onnxruntime::ProtoHelperNodeContext ctx(node); + onnxruntime::OpNodeProtoHelper attrs(&ctx); + const auto inputDefs = node.InputDefs(); + + config.isQuantized = node.OpType() == "QLinearConv"; + + if (config.isQuantized) { + TensorShape scaleShape; + ORT_RETURN_IF_ERROR(GetArgShape(inputDefs[4], scaleShape)); + ORT_RETURN_IF(scaleShape.Size() > 1, "ACL execution provider does not support per-channel quantization"); + } + + config.is_channels_last = node.OpType() == "NhwcConv"; + if (!config.is_channels_last) { + int64_t cl_ret = 0; + attrs.GetAttr("channels_last", &cl_ret); + config.is_channels_last = (bool)cl_ret; + } -template -arm_compute::TensorShape Conv::ACLReshapeWeightsDepthwise(arm_compute::Tensor* kernel) const { + int64_t group = 1; + attrs.GetAttr("group", &group); + + const NodeArg* kDef = inputDefs[config.isQuantized ? 3 : 1]; + + ORT_RETURN_IF_ERROR(GetArgShape(inputDefs[0], config.inShapeIn)); + ORT_RETURN_IF_ERROR(GetArgShape(kDef, config.kShapeIn)); + + ORT_RETURN_IF(config.kShapeIn.NumDimensions() > 4, "ACL execution provider supports 1D and 2D Conv only"); + + config.inType = inputDefs[0]->Type(); + config.kType = kDef->Type(); + const bool mixedType = config.inType != config.kType; + + config.isDepthwise = group > 1; + if (config.isDepthwise) { + const size_t channels = config.inShapeIn[config.is_channels_last ? config.inShapeIn.NumDimensions() - 1 : 1]; + ORT_RETURN_IF(group != channels, "ACL does not support grouping unless group == channels"); + ORT_RETURN_IF(mixedType, "ACL does not support mixed input types for depthwise Conv"); + } + + return Status::OK(); +} + +Status ValidateConv(const onnxruntime::Node& node) { + ConvConfig config; + return ParseConv(node, config); +} + +arm_compute::TensorShape Conv::ACLReshapeWeightsDepthwise(arm_compute::Tensor* kernel) const { arm_compute::TensorShape shape = arm_compute::TensorShape(kernel->info()->tensor_shape()); shape[2] = shape[2] * shape[3]; shape[3] = 1; @@ -51,43 +102,89 @@ arm_compute::TensorShape Conv::ACLReshapeWeightsDepthwise(arm_compute::Tensor return shape; } -#ifdef CONV_ACL -template -Status Conv::Compute(OpKernelContext* context) const { +Conv::Conv(const OpKernelInfo& info) : onnxruntime::OpKernel(info), conv_attrs_(info) { + provider_ = (const_cast( + static_cast(info.GetExecutionProvider()))); + + ConvConfig config; + ORT_THROW_IF_ERROR(ParseConv(OpKernel::Node(), config)); + isQuantized = config.isQuantized; + is_channels_last = config.is_channels_last; + size_t num_inputs = OpKernel::Node().InputDefs().size(); + has_bias = isQuantized ? (num_inputs == 9) : (num_inputs == 3); - ACLNEConv* pConv; - ConvLayersIterator it = Conv::convLayers.find((OpKernel*)this); - if (it != Conv::convLayers.end()) { - pConv = &it->second; - if (pConv->isDepthwiseCPU == true) { - Status s = onnxruntime::Conv::Compute(context); - return s; - } + const Tensor* tmp = nullptr; + const bool kIsConst = info.TryGetConstantInput(1, &tmp); + ORT_ENFORCE(kIsConst, "ACL does not support Conv with mutable weights"); + + in = std::make_shared(); + k = std::make_shared(); + if (has_bias) + b = std::make_shared(); + out = std::make_shared(); + + const arm_compute::DataLayout data_layout = is_channels_last ? arm_compute::DataLayout::NHWC : arm_compute::DataLayout::NCHW; + + TensorShape inShape = config.inShapeIn; + if (is_channels_last && config.inShapeIn.NumDimensions() < 4) { + inShape = TensorShape({config.inShapeIn[0], config.inShapeIn[1], 1, config.inShapeIn[2]}); } - const Tensor* X = context->Input(0); - const Tensor* W = context->Input(1); - const Tensor* B = num_inputs == 3 ? context->Input(2) : nullptr; + arm_compute::DataType inType = ACLDataType(*config.inType); + in->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(inShape, PREF_DIM), 1, inType, data_layout)); + + arm_compute::DataType kType = ACLDataType(*config.kType); + + TensorShapeVector kShapeVec = config.kShapeIn.AsShapeVector(); + while (kShapeVec.size() < 4) { + kShapeVec.push_back(1); + } + + const TensorShape kShape = is_channels_last ? TensorShape({kShapeVec[0], kShapeVec[2], kShapeVec[3], kShapeVec[1]}) : TensorShape(kShapeVec); + + k->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(kShape), 1, kType, data_layout)); + + TensorShape bShape; + if (has_bias) { + const Tensor* bias = nullptr; + const bool biasIsConst = info.TryGetConstantInput(isQuantized ? 8 : 2, &bias); + ORT_ENFORCE(biasIsConst, "ACL does not support Conv with mutable bias"); + + const auto bDef = OpKernel::Node().InputDefs()[isQuantized ? 8 : 2]; + ORT_THROW_IF_ERROR(GetArgShape(bDef, bShape)); + arm_compute::DataType bType = ACLDataType(*bDef->Type()); + b->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(bShape), 1, bType, data_layout)); + + const void* b_data = bias->DataRaw(); + ORT_THROW_IF_ERROR(ACLImportMemory(b->allocator(), (void*)b_data, 0)); + } + + ORT_THROW_IF_ERROR(GetArgShape(OpKernel::Node().OutputDefs()[0], outShape)); + TensorShape outShapeACL = outShape; + if (is_channels_last && outShape.NumDimensions() < 4) { + outShapeACL = TensorShape({outShape[0], outShape[1], 1, outShape[2]}); + } + + out->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(outShapeACL, PREF_DIM), 1, inType, data_layout)); - const int64_t N = X->Shape()[0]; - const int64_t M = W->Shape()[0]; + if (isQuantized) { + ORT_THROW_IF_ERROR(LoadQuantizationInfo(info, in.get(), 1, 2, false)); + ORT_THROW_IF_ERROR(LoadQuantizationInfo(info, k.get(), 4, 5, false)); + ORT_THROW_IF_ERROR(LoadQuantizationInfo(info, out.get(), 6, 7, false)); + } LOGS_DEFAULT(VERBOSE) << "Conv ACL:"; - LOGS_DEFAULT(VERBOSE) << "X " << X->Shape().ToString().c_str(); - LOGS_DEFAULT(VERBOSE) << "W " << W->Shape().ToString().c_str(); - if (B != nullptr) LOGS_DEFAULT(VERBOSE) << "B " << B->Shape().ToString().c_str(); - - if (X->Shape().NumDimensions() != PREF_DIM) { - LOGS_DEFAULT(WARNING) << "ACL does not have support for tensors with 4 or more dimensions; defaulting to cpu implementation"; - Status s = onnxruntime::Conv::Compute(context); - return s; + LOGS_DEFAULT(VERBOSE) << "X " << inShape.ToString().c_str(); + LOGS_DEFAULT(VERBOSE) << "W " << config.kShapeIn.ToString().c_str(); + if (has_bias) { + LOGS_DEFAULT(VERBOSE) << "B " << bShape.ToString().c_str(); } - ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X, W)); + ORT_THROW_IF_ERROR(conv_attrs_.ValidateInputShape(config.inShapeIn, config.kShapeIn, config.is_channels_last)); TensorShapeVector kernel_shape; - ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape)); + ORT_THROW_IF_ERROR(conv_attrs_.ComputeKernelShape(config.kShapeIn, kernel_shape)); ConvAttributes::ConvPadVector pads(conv_attrs_.pads); if (pads.empty()) { @@ -102,16 +199,13 @@ Status Conv::Compute(OpKernelContext* context) const { strides.resize(kernel_shape.size(), 1); } - TensorShapeVector Y_dims; - Y_dims.insert(Y_dims.begin(), {N, M}); - TensorShape input_shape = X->Shape().Slice(2); -#ifdef ACL_2308 - ORT_RETURN_IF_ERROR(conv_attrs_.InferPadsAndOutputShape(input_shape, kernel_shape, strides, dilations, pads, Y_dims)); -#else - ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShape(input_shape, kernel_shape, strides, dilations, pads, Y_dims)); -#endif - Tensor* Y = context->Output(0, TensorShape(Y_dims)); - LOGS_DEFAULT(VERBOSE) << "Y " << Y->Shape().ToString().c_str(); + TensorShape input_shape = config.inShapeIn.Slice(2); + TensorShapeVector out_shape; + ORT_THROW_IF_ERROR(conv_attrs_.InferPadsAndOutputShape( + input_shape, kernel_shape, strides, dilations, + pads, out_shape)); + + LOGS_DEFAULT(VERBOSE) << "Y " << outShape.ToString().c_str(); arm_compute::ActivationLayerInfo::ActivationFunction acl_activ_func; bool acl_activ_enabled = false; @@ -136,243 +230,274 @@ Status Conv::Compute(OpKernelContext* context) const { ORT_NOT_IMPLEMENTED("Not implemented fused activation: ", activation_type); } - if (it == Conv::convLayers.end()) { - auto mm_layer = ACLCreateMemoryManager(); - - ACLNEConv tconv; - tconv.mm_layer = std::move(mm_layer); + const size_t idx_channel = arm_compute::get_data_layout_dimension_index(data_layout, arm_compute::DataLayoutDimension::CHANNEL); + isDepthwiseCPU = config.isDepthwise; - tconv.in = std::make_shared(); - tconv.k = std::make_shared(); - if (B != nullptr) - tconv.b = std::make_shared(); - tconv.out = std::make_shared(); + std::vector aclStrides(2); + aclStrides[0] = (strides.size() == 2) ? strides[1] : 1; + aclStrides[1] = strides[0]; - tconv.in->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(X->Shape(), PREF_DIM), arm_compute::Format::F32)); - tconv.k->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(W->Shape()), arm_compute::Format::F32)); - if (B != nullptr) { - tconv.b->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(B->Shape()), arm_compute::Format::F32)); - } - tconv.out->allocator()->init(arm_compute::TensorInfo(ACLTensorShape(Y->Shape(), PREF_DIM), arm_compute::Format::F32)); - - const arm_compute::DataLayout data_layout = tconv.in->info()->data_layout(); - const int idx_channel = arm_compute::get_data_layout_dimension_index(data_layout, arm_compute::DataLayoutDimension::CHANNEL); - bool isDepthwise = (conv_attrs_.group > 1 && conv_attrs_.group == tconv.in->info()->tensor_shape()[idx_channel]); - tconv.isDepthwiseCPU = isDepthwise; - - std::vector aclStrides(2); - aclStrides[0] = (strides.size() == 2) ? strides[1] : 1; - aclStrides[1] = strides[0]; - - std::vector aclPads(4); - // The pad order in acl is: pad_left, pad_right, pad_top, pad_bottom - if (pads.size() == 2) { - if (strides.size() == 1) { - aclPads[0] = 0; - aclPads[1] = 0; - aclPads[2] = pads[1]; - aclPads[3] = pads[0]; - } else { - aclPads[0] = pads[1]; - aclPads[1] = pads[0]; - aclPads[2] = pads[1]; - aclPads[3] = pads[0]; - } + std::vector aclPads(4); + // The pad order in acl is: pad_left, pad_right, pad_top, pad_bottom + if (pads.size() == 2) { + if (strides.size() == 1) { + aclPads[0] = 0; + aclPads[1] = 0; + aclPads[2] = pads[0]; + aclPads[3] = pads[1]; } else { aclPads[0] = pads[1]; - aclPads[1] = pads[3]; - aclPads[2] = pads[0]; - aclPads[3] = pads[2]; + aclPads[1] = pads[0]; + aclPads[2] = pads[1]; + aclPads[3] = pads[0]; } + } else { + aclPads[0] = pads[1]; + aclPads[1] = pads[3]; + aclPads[2] = pads[0]; + aclPads[3] = pads[2]; + } - arm_compute::PadStrideInfo aclPadStride = arm_compute::PadStrideInfo(aclStrides[0], aclStrides[1], - aclPads[0], aclPads[1], aclPads[2], aclPads[3], arm_compute::DimensionRoundingType::FLOOR); - unsigned int aclDilation0 = (dilations.size() == 2) ? dilations[1] : 1; - - LOGS_DEFAULT(VERBOSE) << "padding: {" << aclPads[0] << "," << aclPads[1] << "," << aclPads[2] << "," << aclPads[3] << "}"; - LOGS_DEFAULT(VERBOSE) << "strides: {" << aclStrides[0] << "," << aclStrides[1] << "}"; - - if (isDepthwise) { - LOGS_DEFAULT(VERBOSE) << "Depthwise convolution"; -#ifdef DEPTHWISE_CPU - Status s = onnxruntime::Conv::Compute(context); - std::pair ret; - ret = Conv::convLayers.insert(std::pair((OpKernel*)this, tconv)); - return s; -#else - tconv.k->info()->set_tensor_shape(ACLReshapeWeightsDepthwise(tconv.k.get())); - - // in the configure function for NEDepthwiseConvolutionLayer3x3, there is a separation based on the optimization -#ifdef ACL_1902 - bool optimizable = - arm_compute::NEDepthwiseConvolutionLayer3x3Kernel::is_optimized_execution_possible(tconv.in->info()->tensor_shape(), - aclPadStride, - tconv.in->info()->data_type(), - 1 /* depth multiplier */, - tconv.in->info()->data_layout()); -#elif defined(ACL_1905) || defined(ACL_1908) - bool optimizable = - arm_compute::NEDepthwiseConvolutionAssemblyDispatch::is_optimized_supported(tconv.in->info(), - tconv.k->info(), - aclPadStride, - 1 /* depth multiplier */, - arm_compute::Size2D(aclDilation0, dilations[0])); -#elif defined(ACL_2002) - bool optimizable = bool(arm_compute::NEDepthwiseConvolutionLayerOptimized::validate(tconv.in->info(), - tconv.k->info(), - (B != nullptr) ? tconv.b->info() : nullptr, - tconv.out->info(), - aclPadStride, - 1 /* depth multiplier */, - acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(), - arm_compute::Size2D(aclDilation0, dilations[0]))); -#elif defined(ACL_2308) - bool optimizable = bool(arm_compute::NEDepthwiseConvolutionLayer::validate(tconv.in->info(), - tconv.k->info(), - (B != nullptr) ? tconv.b->info() : nullptr, - tconv.out->info(), - aclPadStride, - 1 /* depth multiplier */, - acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(), - arm_compute::Size2D(aclDilation0, dilations[0]))); -#endif + arm_compute::PadStrideInfo aclPadStride = arm_compute::PadStrideInfo( + (unsigned int)aclStrides[0], (unsigned int)aclStrides[1], + (unsigned int)aclPads[0], (unsigned int)aclPads[1], + (unsigned int)aclPads[2], (unsigned int)aclPads[3], arm_compute::DimensionRoundingType::FLOOR); + size_t aclDilation0 = (dilations.size() == 2) ? dilations[1] : 1; + + LOGS_DEFAULT(VERBOSE) << "padding: {" << aclPads[0] << "," << aclPads[1] << "," << aclPads[2] << "," << aclPads[3] << "}"; + LOGS_DEFAULT(VERBOSE) << "strides: {" << aclStrides[0] << "," << aclStrides[1] << "}"; + + if (config.isDepthwise) { + LOGS_DEFAULT(VERBOSE) << "Depthwise convolution"; + k->info()->set_tensor_shape(ACLReshapeWeightsDepthwise(k.get())); + auto dl = std::make_shared(); + dl->configure(in.get(), k.get(), (has_bias) ? b.get() : nullptr, out.get(), + aclPadStride, 1 /* depth multiplier */, + acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(), + arm_compute::Size2D(aclDilation0, dilations[0])); + depthwise_layer = std::move(dl); + isDepthwiseCPU = false; + } else { + LOGS_DEFAULT(VERBOSE) << "ACL 2D convolution"; + auto cl = std::make_shared(); + cl->configure(in->info(), k->info(), (has_bias) ? b->info() : nullptr, out->info(), + aclPadStride, + arm_compute::WeightsInfo(), arm_compute::Size2D(aclDilation0, dilations[0]), + acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(), + provider_->info.enable_fast_math, (unsigned int)conv_attrs_.group); + conv_layer = std::move(cl); + + memory_group = arm_compute::MemoryGroup(provider_->memory_manager); + run_pack = {{arm_compute::ACL_SRC_0, in.get()}, {arm_compute::ACL_SRC_1, k.get()}, {arm_compute::ACL_SRC_2, b.get()}, {arm_compute::ACL_DST, out.get()}}; + prep_pack = {{arm_compute::ACL_SRC_1, k.get()}, {arm_compute::ACL_SRC_2, b.get()}}; + + PopulateWorkspace(conv_layer->workspace(), workspace, memory_group, run_pack, prep_pack); + } - if (optimizable) { - LOGS_DEFAULT(VERBOSE) << "ACL optimized depthwise convolution"; -#if defined(ACL_1902) || defined(ACL_1905) - auto layer = std::make_shared(); -#elif defined(ACL_1908) - auto layer = std::make_shared(); -#elif defined(ACL_2002) || defined(ACL_2308) - auto layer = std::make_shared(); -#endif + ACLPrintTensorShape("X", *in.get()); + ACLPrintTensorShape("Y", *out.get()); +} -#ifdef ACL_1902 - layer->configure(tconv.in.get(), tconv.k.get(), (B != nullptr) ? tconv.b.get() : nullptr, tconv.out.get(), - aclPadStride, 1 /* depth multiplier */, - acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo()); -#elif defined(ACL_1905) || defined(ACL_1908) || defined(ACL_2002) || defined(ACL_2308) - layer->configure(tconv.in.get(), tconv.k.get(), (B != nullptr) ? tconv.b.get() : nullptr, tconv.out.get(), - aclPadStride, 1 /* depth multiplier */, - acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(), - arm_compute::Size2D(aclDilation0, dilations[0])); -#endif - tconv.layer = std::move(layer); - tconv.isDepthwiseCPU = false; - } else { - LOGS_DEFAULT(VERBOSE) << "CPU depthwise convolution"; - Status s = onnxruntime::Conv::Compute(context); - std::pair ret; - ret = Conv::convLayers.insert(std::pair((OpKernel*)this, tconv)); - return s; - } -#endif // DEPTHWISE_CPU - } else { - if (tconv.k->info()->tensor_shape()[0] == 1 && tconv.k->info()->tensor_shape()[1] == 1) { - LOGS_DEFAULT(VERBOSE) << "CPU pointwise convolution"; - Status s = onnxruntime::Conv::Compute(context); - return s; - } else { - if (tconv.k->info()->tensor_shape()[0] == 9 && tconv.k->info()->tensor_shape()[1] == 9) { - LOGS_DEFAULT(WARNING) << "9x9 DirectConvolution does not have an implementation in NCHW layout; defaulting to cpu implementation"; - Status s = onnxruntime::Conv::Compute(context); - return s; - } - LOGS_DEFAULT(VERBOSE) << "ACL 2D convolution"; - auto layer = std::make_shared(mm_layer); - layer->configure(tconv.in.get(), tconv.k.get(), (B != nullptr) ? tconv.b.get() : nullptr, tconv.out.get(), - aclPadStride, - arm_compute::WeightsInfo(), arm_compute::Size2D(aclDilation0, dilations[0]), - acl_activ_enabled ? arm_compute::ActivationLayerInfo(acl_activ_func, conv_attrs_.alpha) : arm_compute::ActivationLayerInfo(), - false, conv_attrs_.group); - tconv.layer = std::move(layer); - } +#ifdef CONV_ACL +Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { + is_packed = false; + if (isQuantized ? (input_idx != 3) : (input_idx != 1)) { + return Status::OK(); + } + + if (!workspace.persistent_tensors.empty()) { + size_t packedSize = 0; + size_t alignment = 0; + GetPackingInfo(workspace.persistent_tensors, packedSize, alignment); + auto buffSize = packedSize + alignment; + + pkRaw = IAllocator::MakeUniquePtr(alloc, buffSize, true); + ORT_RETURN_IF_ERROR(LoadPackedTensors(workspace.persistent_tensors, pkRaw.get(), packedSize, alignment)); + + if (prepacked_weights != nullptr) { + prepacked_weights->buffers_.push_back(std::move(pkRaw)); + prepacked_weights->buffer_sizes_.push_back(buffSize); } - tconv.out->info()->set_format(tconv.in->info()->format()); + is_packed = true; + } - std::pair ret; - ret = Conv::convLayers.insert(std::pair((OpKernel*)this, tconv)); - pConv = &ret.first->second; + bool free_k = false; + const void* k_data = tensor.DataRaw(); + if (is_channels_last) { + TensorShape shape = tensor.Shape(); + if (shape.NumDimensions() < 4) { + shape = TensorShape({shape[0], shape[1], shape[2], 1}); + } - ACLPrintTensorShape("X", *tconv.in.get()); - ACLPrintTensorShape("Y", *tconv.out.get()); + arm_compute::Tensor kIn; + kIn.allocator()->init(arm_compute::TensorInfo(ACLTensorShape(shape), 1, + k->info()->data_type(), arm_compute::DataLayout::NCHW)); + kIn.info()->set_quantization_info(k->info()->quantization_info()); + ORT_RETURN_IF_ERROR(ACLImportMemory(kIn.allocator(), (void*)k_data, 0)); + k->allocator()->allocate(); + free_k = is_packed; + is_packed = true; + + arm_compute::NEPermute perm_layer; + perm_layer.configure(&kIn, k.get(), {2, 0, 1, 3}); + perm_layer.run(); } else { - // TODO: valildate shapes - pConv = &it->second; + ORT_RETURN_IF_ERROR(ACLImportMemory(k->allocator(), (void*)k_data, 0)); } - const T* x_data = X->Data(); - if (X->Shape().Size() != 0 && pConv->in->info()->has_padding()) { - pConv->in->allocator()->allocate(); - importDataToTensor(pConv->in.get(), x_data); - } else { - ACLImportMemory(pConv->in->allocator(), (void*)x_data, X->Shape().Size() * 4); + for (std::unique_ptr& prep_tensor : workspace.prepare_tensors) { + prep_tensor->allocator()->allocate(); } - const T* k_data = W->Data(); - ACLImportMemory(pConv->k->allocator(), (void*)k_data, W->Shape().Size() * 4); + if (conv_layer) { + conv_layer->prepare(prep_pack); + } else { + depthwise_layer->prepare(); + } - if (B != nullptr) { - const T* b_data = B->Data(); - ACLImportMemory(pConv->b->allocator(), (void*)b_data, B->Shape().Size() * 4); + for (std::unique_ptr& prep_tensor : workspace.prepare_tensors) { + prep_tensor->allocator()->free(); } - T* y_data = Y->MutableData(); - if (Y->Shape().Size() != 0 && pConv->out->info()->has_padding()) { - pConv->out->allocator()->allocate(); - } else { - ACLImportMemory(pConv->out->allocator(), (void*)y_data, Y->Shape().Size() * 4); + if (free_k) { + k->allocator()->free(); } - arm_compute::Allocator alloc_mm{}; - pConv->mm_layer->populate(alloc_mm, 1); - pConv->layer->run(); - pConv->mm_layer->clear(); + return Status::OK(); +} - if (Y->Shape().Size() != 0 && pConv->out->info()->has_padding()) { - importDataFromTensor(pConv->out.get(), y_data); +Status Conv::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, + int input_idx, /*out*/ bool& used_shared_buffers) { + used_shared_buffers = false; + if (isQuantized ? (input_idx != 3) : (input_idx != 1)) { + return Status::OK(); } - pConv->in->allocator()->free(); - pConv->k->allocator()->free(); - if (B != nullptr) - pConv->b->allocator()->free(); - pConv->out->allocator()->free(); + if (!workspace.persistent_tensors.empty()) { + size_t packedSize = 0; + size_t alignment = 0; + GetPackingInfo(workspace.persistent_tensors, packedSize, alignment); - LOGS_DEFAULT(VERBOSE) << std::endl; + ORT_RETURN_IF_ERROR(LoadPackedTensors(workspace.persistent_tensors, prepacked_buffers[0].get(), + packedSize, alignment)); + + used_shared_buffers = true; + } return Status::OK(); } -#else -template -Status Conv::Compute(OpKernelContext* context) const { - size_t num_inputs = OpKernel::Node().InputDefs().size(); + +Status Conv::Compute(OpKernelContext* context) const { + provider_->SetThreadPool(context->GetOperatorThreadPool()); const Tensor* X = context->Input(0); - const Tensor* W = context->Input(1); - const Tensor* B = num_inputs == 3 ? context->Input(2) : nullptr; - LOGS_DEFAULT(VERBOSE) << "X " << X->Shape().ToString().c_str(); - LOGS_DEFAULT(VERBOSE) << "W " << W->Shape().ToString().c_str(); - if (B != nullptr) - LOGS_DEFAULT(VERBOSE) << "B " << B->Shape().ToString().c_str(); + Tensor* Y = context->Output(0, outShape); + + const void* x_data = X->DataRaw(); + ORT_RETURN_IF(X->Shape().Size() != 0 && in->info()->has_padding(), "Padded ACL input tensor not supported"); + ORT_RETURN_IF_ERROR(ACLImportMemory(in->allocator(), (void*)x_data, 0)); + + void* y_data = Y->MutableDataRaw(); + ORT_RETURN_IF(Y->Shape().Size() != 0 && out->info()->has_padding(), "Padded ACL output tensor not supported"); + ORT_RETURN_IF_ERROR(ACLImportMemory(out->allocator(), (void*)y_data, 0)); + + if (conv_layer) { + arm_compute::MemoryGroupResourceScope scope_mg(const_cast(memory_group)); + conv_layer->run(const_cast(run_pack)); + } else { + depthwise_layer->run(); + } + + in->allocator()->free(); + k->allocator()->free(); + out->allocator()->free(); LOGS_DEFAULT(VERBOSE) << std::endl; - Status s = onnxruntime::Conv::Compute(context); - return s; + return Status::OK(); } #endif ONNX_OPERATOR_KERNEL_EX( Conv, kOnnxDomain, + 11, + kAclExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Conv); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + Conv, + kOnnxDomain, + 11, + MLFloat16, + kAclExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Conv); + +ONNX_OPERATOR_KERNEL_EX( + NhwcConv, + kMSDomain, 1, kAclExecutionProvider, KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Conv); + Conv); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + QLinearConv, + kOnnxDomain, + 10, + uint8_t, + kAclExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()) + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), + Conv); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + QLinearConv, + kOnnxDomain, + 10, + int8_t, + kAclExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()) + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), + Conv); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + QLinearConv, + kMSDomain, + 1, + uint8_t, + kAclExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()) + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), + Conv); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + QLinearConv, + kMSDomain, + 1, + int8_t, + kAclExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()) + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), + Conv); } // namespace acl } // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/nn/conv.h b/onnxruntime/core/providers/acl/nn/conv.h index 660d47b4172df..b05ba5363542f 100644 --- a/onnxruntime/core/providers/acl/nn/conv.h +++ b/onnxruntime/core/providers/acl/nn/conv.h @@ -1,17 +1,19 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2019-2020, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #pragma once #include "core/framework/op_kernel.h" #include "core/providers/cpu/nn/conv.h" +#include "core/providers/acl/acl_common.h" #include "core/providers/acl/acl_execution_provider.h" // ACL -#ifdef ACL_2308 #include "arm_compute/runtime/Tensor.h" -#endif #include "arm_compute/core/TensorInfo.h" +#include "arm_compute/runtime/IOperator.h" +#include "arm_compute/runtime/Tensor.h" #include "arm_compute/runtime/TensorAllocator.h" #include "arm_compute/runtime/Allocator.h" #include "arm_compute/runtime/PoolManager.h" @@ -19,45 +21,50 @@ #include "arm_compute/runtime/MemoryManagerOnDemand.h" // NEON -#include "arm_compute/runtime/NEON/functions/NEConvolutionLayer.h" #include "arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h" namespace onnxruntime { namespace acl { -typedef struct -{ - std::shared_ptr layer; - std::shared_ptr mm_layer; - std::shared_ptr in; - std::shared_ptr k; - std::shared_ptr b; - std::shared_ptr out; - bool isDepthwiseCPU; -} ACLNEConv; - -typedef std::map::iterator ConvLayersIterator; +Status ValidateConv(const onnxruntime::Node& node); -template -class Conv : public onnxruntime::Conv { +class Conv : public onnxruntime::OpKernel { public: - explicit Conv(const OpKernelInfo& info) : onnxruntime::Conv(info), conv_attrs_(info) { - provider_ = (const_cast( - static_cast(info.GetExecutionProvider()))); - } + explicit Conv(const OpKernelInfo& info); - ~Conv() { - Conv::convLayers.erase(this); - } + Status PrePack(const Tensor&, int, AllocatorPtr, + bool& is_packed, PrePackedWeights*) override; + + Status UseSharedPrePackedBuffers(std::vector&, + int, bool&) override; Status Compute(OpKernelContext* context) const override; protected: - static thread_local std::map convLayers; ConvAttributes conv_attrs_; ACLExecutionProvider* provider_; std::string activation_type; + std::shared_ptr depthwise_layer; + + std::shared_ptr conv_layer; + arm_compute::MemoryGroup memory_group; + arm_compute::ITensorPack run_pack; + arm_compute::ITensorPack prep_pack; + + Workspace workspace; + + std::shared_ptr in; + std::shared_ptr k; + IAllocatorUniquePtr pkRaw; + std::shared_ptr b; + std::shared_ptr out; + TensorShape outShape; + bool is_channels_last; + bool isQuantized; + bool isDepthwiseCPU; + bool has_bias; + arm_compute::TensorShape ACLReshapeWeightsDepthwise(arm_compute::Tensor* kernel) const; }; } // namespace acl diff --git a/onnxruntime/core/providers/acl/nn/fused_conv.cc b/onnxruntime/core/providers/acl/nn/fused_conv.cc index 3cf18394b5c4c..34e50ebdf6921 100644 --- a/onnxruntime/core/providers/acl/nn/fused_conv.cc +++ b/onnxruntime/core/providers/acl/nn/fused_conv.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2020, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #ifdef _WIN32 @@ -17,11 +18,13 @@ namespace onnxruntime { namespace acl { -class FusedConv final : public acl::Conv { +class FusedConv final : public acl::Conv { public: - explicit FusedConv(const OpKernelInfo& info) : acl::Conv(info) { + explicit FusedConv(const OpKernelInfo& info) : acl::Conv(info) { ORT_ENFORCE(info.GetAttr("activation", &(this->activation_type)).IsOK()); - ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK()); + MLAS_ACTIVATION activation; + activation.ActivationKind = MlasIdentityActivation; + ORT_ENFORCE(GetFusedActivationAttr(info, activation).IsOK()); } }; diff --git a/onnxruntime/core/providers/acl/nn/pool.cc b/onnxruntime/core/providers/acl/nn/pool.cc index 01d9bc0302c3a..cbbecef6bbfac 100644 --- a/onnxruntime/core/providers/acl/nn/pool.cc +++ b/onnxruntime/core/providers/acl/nn/pool.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2019, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include @@ -63,12 +64,7 @@ ACLNEPool PoolOperation(onnxruntime::OpKernelContext* context, if (pool_attrs.global_pooling) { layer->configure(tpool.in.get(), tpool.out.get(), - arm_compute::PoolingLayerInfo(pool_type -#ifdef ACL_2308 - , - arm_compute::DataLayout::NCHW -#endif - )); + arm_compute::PoolingLayerInfo(pool_type, arm_compute::DataLayout::NCHW)); } else { TensorShapeVector aclStrides(2); aclStrides[0] = (strides.size() == 2) ? strides[1] : 1; @@ -95,8 +91,11 @@ ACLNEPool PoolOperation(onnxruntime::OpKernelContext* context, aclPads[3] = pads[2]; } - arm_compute::PadStrideInfo aclPadStride = arm_compute::PadStrideInfo(aclStrides[0], aclStrides[1], - aclPads[0], aclPads[1], aclPads[2], aclPads[3], arm_compute::DimensionRoundingType::FLOOR); + arm_compute::PadStrideInfo aclPadStride = arm_compute::PadStrideInfo( + (unsigned int)aclStrides[0], (unsigned int)aclStrides[1], + (unsigned int)aclPads[0], (unsigned int)aclPads[1], + (unsigned int)aclPads[2], (unsigned int)aclPads[3], + arm_compute::DimensionRoundingType::FLOOR); TensorShapeVector aclKernelShape(2); aclKernelShape[0] = (kernel_shape.size() > 1) ? kernel_shape[1] : 1; @@ -113,9 +112,7 @@ ACLNEPool PoolOperation(onnxruntime::OpKernelContext* context, arm_compute::PoolingLayerInfo pool_info(pool_type, aclSize, -#ifdef ACL_2308 arm_compute::DataLayout::NCHW, -#endif aclPadStride, excludePadding); layer->configure(tpool.in.get(), tpool.out.get(), pool_info); @@ -133,8 +130,8 @@ ACLNEPool PoolOperation(onnxruntime::OpKernelContext* context, aclInpuWindow.use_tensor_dimensions(tpool.in->info()->tensor_shape()); arm_compute::Iterator aclInputIt(tpool.in.get(), aclInpuWindow); - const unsigned int aclWidth = tpool.in->info()->dimension(0); - const unsigned int aclHeight = tpool.in->info()->dimension(1); + const size_t aclWidth = tpool.in->info()->dimension(0); + const size_t aclHeight = tpool.in->info()->dimension(1); // copy input tensor into the larger buffer arm_compute::execute_window_loop( diff --git a/onnxruntime/core/providers/acl/scheduler.cc b/onnxruntime/core/providers/acl/scheduler.cc new file mode 100644 index 0000000000000..e1bab6adb5a1f --- /dev/null +++ b/onnxruntime/core/providers/acl/scheduler.cc @@ -0,0 +1,44 @@ +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#include "core/common/common.h" +#include "scheduler.h" + +using onnxruntime::concurrency::ThreadPool; + +namespace onnxruntime { +namespace acl { + +void ORTScheduler::set_num_threads(unsigned int num_threads) { + ORT_THROW("Not supported"); +} + +unsigned int ORTScheduler::num_threads() const { + // We can't check the size of the thread pool during kernel initialization, + // as required by ACL. Therefore we have to choose a fixed thread count and + // let some cores run multiple workloads if there are fewer than 32 cores. + // This doesn't seem to cause performance issues with fewer cores in practice. + return 32; +} + +void ORTScheduler::schedule(arm_compute::ICPPKernel* kernel, const Hints& hints) { + arm_compute::ITensorPack tensors; + schedule_op(kernel, hints, kernel->window(), tensors); +} + +void ORTScheduler::schedule_op(arm_compute::ICPPKernel* kernel, const Hints& hints, + const arm_compute::Window& window, arm_compute::ITensorPack& tensors) { + schedule_common(kernel, hints, window, tensors); +} + +void ORTScheduler::run_workloads(std::vector& workloads) { + ThreadPool::TrySimpleParallelFor(_provider->GetThreadPool(), workloads.size(), + [&](std::ptrdiff_t id) { + const arm_compute::ThreadInfo info{ + (int)id, (int)workloads.size(), &cpu_info()}; + workloads[id](info); + }); +} + +} // namespace acl +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/scheduler.h b/onnxruntime/core/providers/acl/scheduler.h new file mode 100644 index 0000000000000..c66700a48f3d5 --- /dev/null +++ b/onnxruntime/core/providers/acl/scheduler.h @@ -0,0 +1,33 @@ +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// SPDX-License-Identifier: MIT + +#include "acl_execution_provider.h" + +#include "arm_compute/runtime/IScheduler.h" +#include "arm_compute/core/CPP/ICPPKernel.h" + +namespace onnxruntime { +namespace acl { + +class ORTScheduler : public arm_compute::IScheduler { + public: + ORTScheduler(ACLExecutionProvider* provider) : _provider(provider) { + } + + void set_num_threads(unsigned int num_threads) override; + + unsigned int num_threads() const override; + + void schedule(arm_compute::ICPPKernel* kernel, const Hints& hints) override; + + void schedule_op(arm_compute::ICPPKernel* kernel, const Hints& hints, + const arm_compute::Window& window, arm_compute::ITensorPack& tensors) override; + + void run_workloads(std::vector& workloads) override; + + private: + ACLExecutionProvider* _provider = nullptr; +}; + +} // namespace acl +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/acl/tensor/concat.cc b/onnxruntime/core/providers/acl/tensor/concat.cc index 75eedaac80aea..0cf02ab8762b9 100644 --- a/onnxruntime/core/providers/acl/tensor/concat.cc +++ b/onnxruntime/core/providers/acl/tensor/concat.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2020, NXP Semiconductor, Inc. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "core/providers/acl/tensor/concat.h" @@ -76,11 +77,7 @@ Status Concat::Compute(OpKernelContext* ctx) const { LOGS_DEFAULT(VERBOSE) << "Concat ACL:"; arm_compute::Tensor output; -#ifdef ACL_2308 std::vector inputs_vector; -#else - std::vector inputs_vector; -#endif for (int i = 0; i < input_count; i++) { arm_compute::Tensor* input = new arm_compute::Tensor(); auto X = input_tensors[i]; @@ -101,11 +98,7 @@ Status Concat::Compute(OpKernelContext* ctx) const { for (int i = 0; i < input_count; i++) { auto X = input_tensors[i]; const T* x_data = X->Data(); -#ifdef ACL_2308 arm_compute::Tensor* in = const_cast(static_cast(inputs_vector[i])); -#else - arm_compute::Tensor* in = static_cast(inputs_vector[i]); -#endif if (X->Shape().Size() != 0 && in->info()->has_padding()) { in->allocator()->allocate(); diff --git a/onnxruntime/python/onnxruntime_pybind_schema.cc b/onnxruntime/python/onnxruntime_pybind_schema.cc index c5757095e2e1e..1319e8f6fe959 100644 --- a/onnxruntime/python/onnxruntime_pybind_schema.cc +++ b/onnxruntime/python/onnxruntime_pybind_schema.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "python/onnxruntime_pybind_state_common.h" @@ -54,7 +55,7 @@ void addGlobalSchemaFunctions(pybind11::module& m) { onnxruntime::VitisAIProviderFactoryCreator::Create(ProviderOptions{}), #endif #ifdef USE_ACL - onnxruntime::ACLProviderFactoryCreator::Create(0), + onnxruntime::ACLProviderFactoryCreator::Create(false), #endif #ifdef USE_ARMNN onnxruntime::ArmNNProviderFactoryCreator::Create(0), diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 47b8d75f22aea..e8bf61612c89b 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "python/onnxruntime_pybind_exceptions.h" @@ -1141,8 +1142,25 @@ std::unique_ptr CreateExecutionProviderInstance( #endif } else if (type == kAclExecutionProvider) { #ifdef USE_ACL - return onnxruntime::ACLProviderFactoryCreator::Create( - session_options.enable_cpu_mem_arena) + bool enable_fast_math = false; + auto it = provider_options_map.find(type); + if (it != provider_options_map.end()) { + for (auto option : it->second) { + if (option.first == "enable_fast_math") { + std::set supported_values = {"true", "True", "false", "False"}; + if (supported_values.find(option.second) != supported_values.end()) { + enable_fast_math = (option.second == "true") || (option.second == "True"); + } else { + ORT_THROW( + "Invalid value for enable_fast_math. " + "Select from 'true' or 'false'\n"); + } + } else { + ORT_THROW("Unrecognized option: ", option.first); + } + } + } + return onnxruntime::ACLProviderFactoryCreator::Create(enable_fast_math) ->CreateProvider(); #endif } else if (type == kArmNNExecutionProvider) { diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index 4d6e411defae3..08e5e4f7b18fa 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #pragma once @@ -439,7 +440,7 @@ std::shared_ptr CreateExecutionProviderFactory_Dnnl(c std::shared_ptr CreateExecutionProviderFactory_Tvm(const tvm::TvmEPOptions& info); std::shared_ptr CreateExecutionProviderFactory_Tvm(const char* params); #endif -std::shared_ptr CreateExecutionProviderFactory_ACL(int use_arena); +std::shared_ptr CreateExecutionProviderFactory_ACL(bool enable_fast_math); std::shared_ptr CreateExecutionProviderFactory_ArmNN(int use_arena); std::shared_ptr CreateExecutionProviderFactory_DML(int device_id); std::shared_ptr CreateExecutionProviderFactory_Nnapi( diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 0397bba90438b..924616f49ab25 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include @@ -655,7 +656,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); } if (enable_acl) { #ifdef USE_ACL - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ACL(sf, enable_cpu_mem_arena ? 1 : 0)); + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ACL(sf, false)); #else fprintf(stderr, "ACL is not supported in this build"); return -1; diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 7d06bbadbd645..c1c48d4945a4d 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2023 NVIDIA Corporation. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "command_args_parser.h" @@ -66,6 +67,7 @@ namespace perftest { "\t-i: Specify EP specific runtime options as key value pairs. Different runtime options available are: \n" "\t [Usage]: -e -i '| |'\n" "\n" + "\t [ACL only] [enable_fast_math]: Options: 'true', 'false', default: 'false', \n" "\t [DML only] [performance_preference]: DML device performance preference, options: 'default', 'minimum_power', 'high_performance', \n" "\t [DML only] [device_filter]: DML device filter, options: 'any', 'gpu', 'npu', \n" "\t [DML only] [disable_metacommands]: Options: 'true', 'false', \n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index ae7680571ced1..3ed5eaee5a5f7 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Copyright (c) 2023 NVIDIA Corporation. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include "ort_test_session.h" @@ -519,9 +520,42 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); #endif } else if (provider_name_ == onnxruntime::kAclExecutionProvider) { #ifdef USE_ACL +#if defined(_MSC_VER) + std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); +#else + std::string ov_string = performance_test_config.run_config.ep_runtime_config_string; +#endif // defined(_MSC_VER) + std::istringstream ss(ov_string); + std::string token; + bool enable_fast_math = false; + while (ss >> token) { + if (token == "") { + continue; + } + auto pos = token.find("|"); + if (pos == std::string::npos || pos == 0 || pos == token.length()) { + ORT_THROW("[ERROR] [ACL] Use a '|' to separate the key and value for the run-time option you are trying to use.\n"); + } + + auto key = token.substr(0, pos); + auto value = token.substr(pos + 1); + + if (key == "enable_fast_math") { + std::set ov_supported_values = {"true", "True", "false", "False"}; + if (ov_supported_values.find(value) != ov_supported_values.end()) { + enable_fast_math = (value == "true") || (value == "True"); + } else { + ORT_THROW( + "[ERROR] [ACL] You have selcted an invalid value for the key 'enable_fast_math'. " + "Select from 'true' or 'false' \n"); + } + } else { + ORT_THROW( + "[ERROR] [ACL] Unrecognized option: ", key); + } + } Ort::ThrowOnError( - OrtSessionOptionsAppendExecutionProvider_ACL(session_options, - performance_test_config.run_config.enable_cpu_mem_arena ? 1 : 0)); + OrtSessionOptionsAppendExecutionProvider_ACL(session_options, enable_fast_math)); #else ORT_THROW("Acl is not supported in this build\n"); #endif diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index a5015c18cee63..177647ab5be6b 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include @@ -254,7 +255,7 @@ TEST_P(ModelTest, Run) { #endif #ifdef USE_ACL else if (provider_name == "acl") { - ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_ACL(ortso, 0)); + ASSERT_ORT_STATUS_OK(OrtSessionOptionsAppendExecutionProvider_ACL(ortso, false)); } #endif #ifdef USE_ARMNN diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 1feba20e32bbb..6451f8ec6dce8 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include @@ -207,11 +208,11 @@ std::unique_ptr DefaultRknpuExecutionProvider() { #endif } -std::unique_ptr DefaultAclExecutionProvider(bool enable_arena) { +std::unique_ptr DefaultAclExecutionProvider(bool enable_fast_math) { #ifdef USE_ACL - return ACLProviderFactoryCreator::Create(enable_arena)->CreateProvider(); + return ACLProviderFactoryCreator::Create(enable_fast_math)->CreateProvider(); #else - ORT_UNUSED_PARAMETER(enable_arena); + ORT_UNUSED_PARAMETER(enable_fast_math); return nullptr; #endif } diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h index 606dfc068d399..b3a619022f79b 100644 --- a/onnxruntime/test/util/include/default_providers.h +++ b/onnxruntime/test/util/include/default_providers.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #pragma once #include "core/common/optional.h" @@ -53,7 +54,7 @@ std::unique_ptr DefaultOpenVINOExecutionProvider(); std::unique_ptr DefaultNnapiExecutionProvider(); std::unique_ptr DefaultVSINPUExecutionProvider(); std::unique_ptr DefaultRknpuExecutionProvider(); -std::unique_ptr DefaultAclExecutionProvider(bool enable_arena = true); +std::unique_ptr DefaultAclExecutionProvider(bool enable_fast_math = false); std::unique_ptr DefaultArmNNExecutionProvider(bool enable_arena = true); std::unique_ptr DefaultRocmExecutionProvider(bool test_tunable_op = false); std::unique_ptr DefaultCoreMLExecutionProvider(bool use_mlprogram = false); diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 902d15e8122b4..8535f1e8c85a0 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 # Copyright (c) Microsoft Corporation. All rights reserved. +# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates # Licensed under the MIT License. import argparse @@ -651,9 +652,7 @@ def convert_arg_line_to_args(self, arg_line): parser.add_argument("--enable_transformers_tool_test", action="store_true", help="Enable transformers tool test") parser.add_argument( "--use_acl", - nargs="?", - const="ACL_1905", - choices=["ACL_1902", "ACL_1905", "ACL_1908", "ACL_2002", "ACL_2308"], + action="store_true", help="Build with ACL for ARM architectures.", ) parser.add_argument("--acl_home", help="Path to ACL home dir") @@ -1052,11 +1051,6 @@ def generate_build_tree( "-Donnxruntime_USE_TELEMETRY=" + ("ON" if args.use_telemetry else "OFF"), "-Donnxruntime_ENABLE_LTO=" + ("ON" if args.enable_lto else "OFF"), "-Donnxruntime_USE_ACL=" + ("ON" if args.use_acl else "OFF"), - "-Donnxruntime_USE_ACL_1902=" + ("ON" if args.use_acl == "ACL_1902" else "OFF"), - "-Donnxruntime_USE_ACL_1905=" + ("ON" if args.use_acl == "ACL_1905" else "OFF"), - "-Donnxruntime_USE_ACL_1908=" + ("ON" if args.use_acl == "ACL_1908" else "OFF"), - "-Donnxruntime_USE_ACL_2002=" + ("ON" if args.use_acl == "ACL_2002" else "OFF"), - "-Donnxruntime_USE_ACL_2308=" + ("ON" if args.use_acl == "ACL_2308" else "OFF"), "-Donnxruntime_USE_ARMNN=" + ("ON" if args.use_armnn else "OFF"), "-Donnxruntime_ARMNN_RELU_USE_CPU=" + ("OFF" if args.armnn_relu else "ON"), "-Donnxruntime_ARMNN_BN_USE_CPU=" + ("OFF" if args.armnn_bn else "ON"), From 59b7b6bb7cbb7bcc86dab590f1b4d5ed50d53dec Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 13 Sep 2024 16:52:49 +0000 Subject: [PATCH 25/39] Remove training from web ci pipeline (#22082) ### Description Remove training from web ci pipeline ### Motivation and Context --- .../templates/linux-wasm-ci.yml | 21 ------------------- .../azure-pipelines/templates/win-web-ci.yml | 6 +----- 2 files changed, 1 insertion(+), 26 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index a56eb37faef84..2ab432e94fcbd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -31,10 +31,6 @@ parameters: type: boolean default: false -- name: BuildTraining - type: boolean - default: true - - name: WithCache type: boolean default: false @@ -116,19 +112,6 @@ jobs: DisplayName: 'Build and test (browser) (simd + threads)' WithCache: ${{ parameters.WithCache }} - - ${{ if eq(parameters.BuildTraining, true) }}: - - template: build-linux-wasm-step.yml - parameters: - Today: $(Today) - ${{ if eq(parameters.BuildStaticLib, true)}}: - AdditionalKey: wasm_training | ${{ parameters.BuildConfig }} | static - ${{ else }}: - AdditionalKey: wasm_training | ${{ parameters.BuildConfig }} - CacheDir: $(ORT_CACHE_DIR)/wasm_training - Arguments: '$(CommonBuildArgs) --build_dir $(Build.BinariesDirectory)/wasm_training --enable_training_apis --target onnxruntime_webassembly --skip_tests' - DisplayName: 'Build (training + simd + threads)' - WithCache: ${{ parameters.WithCache }} - - ${{ if eq(parameters.BuildJsep, true) }}: - template: build-linux-wasm-step.yml parameters: @@ -150,10 +133,6 @@ jobs: cp $(Build.BinariesDirectory)/wasm_inferencing_jsep/${{ parameters.BuildConfig }}/ort-wasm-simd-threaded.jsep.wasm $(Build.ArtifactStagingDirectory) cp $(Build.BinariesDirectory)/wasm_inferencing_jsep/${{ parameters.BuildConfig }}/ort-wasm-simd-threaded.jsep.mjs $(Build.ArtifactStagingDirectory) fi - if [ -d $(Build.BinariesDirectory)/wasm_training ]; then - cp $(Build.BinariesDirectory)/wasm_training/${{ parameters.BuildConfig }}/ort-training-wasm-simd-threaded.wasm $(Build.ArtifactStagingDirectory) - cp $(Build.BinariesDirectory)/wasm_training/${{ parameters.BuildConfig }}/ort-training-wasm-simd-threaded.mjs $(Build.ArtifactStagingDirectory) - fi displayName: 'Create Artifacts' - ${{ if eq(parameters.SkipPublish, false) }}: - task: PublishPipelineArtifact@0 diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index c1fde93d8e640..0e8a7eb94379b 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -214,11 +214,7 @@ jobs: workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'E2E package consuming test' condition: and(succeeded(), eq('${{ parameters.BuildConfig }}', 'Release')) - - script: | - npm run test:training:e2e - workingDirectory: '$(Build.SourcesDirectory)\js\web' - displayName: 'E2E training package test' - condition: and(succeeded(), eq('${{ parameters.BuildConfig }}', 'Release')) + - task: CopyFiles@2 inputs: sourceFolder: $(Build.SourcesDirectory)\js\common From 7e2c722459a7a7015a238379acc8705d9ce5b8dc Mon Sep 17 00:00:00 2001 From: aciddelgado <139922440+aciddelgado@users.noreply.github.com> Date: Fri, 13 Sep 2024 13:21:11 -0700 Subject: [PATCH 26/39] Add Continuous Decoding support in GQA (#21523) ### Description This PR will add support for Continuous Decoding for batch_size = 1 input. From now on, GQA can take arbitrary length input using seqlens_k as total_sequence_length - 1 and the sequence length of qkv as new_sequence_length. **This change will not affect the default behavior of GQA** ### Motivation and Context Prior to this change it was impossible to support sequence_length > 1 inputs when past context was given. This use case is essential to making continuous decoding work, which is one of our current efforts in ORT-GenAI. --- docs/ContribOperators.md | 6 +- .../contrib_ops/cpu/bert/attention_common.h | 3 +- .../contrib_ops/cpu/bert/attention_helper.h | 11 +- .../contrib_ops/cpu/bert/gqa_attention_base.h | 177 +++++------ .../cpu/bert/group_query_attention.cc | 31 +- .../cpu/bert/group_query_attention_helper.h | 36 ++- .../cpu/sparse/sparse_attention_base.h | 4 +- .../bert/cutlass_fmha/fmha_launch_template.h | 1 - .../cuda/bert/group_query_attention.cc | 5 +- .../cuda/bert/group_query_attention_helper.h | 298 ------------------ .../cuda/bert/group_query_attention_impl.cu | 149 ++++++--- .../cuda/bert/group_query_attention_impl.h | 4 +- .../rocm/bert/group_query_attention.cu | 10 +- .../core/graph/contrib_ops/bert_defs.cc | 8 +- .../transformers/test_flash_attn_cuda.py | 171 +++++++++- .../test/python/transformers/test_gqa_cpu.py | 79 ++++- .../transformers/test_sparse_attention.py | 7 +- 17 files changed, 498 insertions(+), 502 deletions(-) delete mode 100644 onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index aadf4ebe2f488..09a7e47fc9913 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2521,6 +2521,8 @@ This version of the operator has been available since version 1 of the 'com.micr Only supports causal and local attention. Supports rotary position embedding for CPU and CUDA. Supports packed input for CPU and CUDA. + Supports continuous decoding for batch_size == 1 for CPU and CUDA. + #### Version @@ -2561,9 +2563,9 @@ This version of the operator has been available since version 1 of the 'com.micr
past_value (optional) : T
past state value with support for format BNSH. When past_value uses same tensor as present_value(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
seqlens_k : M
-
1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.
+
1D Tensor of shape (batch_size). Equivalent to (total_sequence_lengths - 1).
total_sequence_length : M
-
Scalar tensor of total sequence length (past + new).
+
Scalar tensor equivalent to the maximum total sequence length (past + new) of the batch. Used for checking inputs and determining prompt vs token generation case.
cos_cache (optional) : T
2D tensor with shape (max_sequence_length, head_size / 2).
sin_cache (optional) : T
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 45acb90ba68b0..e0fa581c8071d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -114,7 +114,8 @@ struct GroupQueryAttentionParameters { int local_window_size; bool kv_share_buffer; bool is_packed_qkv; - bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor + bool is_subsequent_prompt; // indicates whether we have past context and seqlen > 1 + bool is_first_prompt; // indicates whether this is first decoding step bool do_rotary; bool rotary_interleaved; bool use_smooth_softmax; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index e6c948acb0d6c..4d435f71cc195 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -236,19 +236,16 @@ T* ConcatStateChunkGQA(const T* past, size_t past_buff_chunk_length, size_t past_chunk_length, size_t new_chunk_length, - bool is_prompt, bool past_present_share_buffer, std::ptrdiff_t i) { T* start = present + i * present_buff_chunk_length; T* p = start; - if (!is_prompt) { - if (!past_present_share_buffer) { - const T* src_past = past + i * past_buff_chunk_length; - memcpy(p, src_past, past_chunk_length * sizeof(T)); - } - p += past_chunk_length; + if (!past_present_share_buffer && past_chunk_length > 0) { + const T* src_past = past + i * past_buff_chunk_length; + memcpy(p, src_past, past_chunk_length * sizeof(T)); } + p += past_chunk_length; memcpy(p, chunk, new_chunk_length * sizeof(T)); return start; diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 2bf0aa0915c2d..bfec9aef56727 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -59,6 +59,7 @@ class GQAAttentionBase { GroupQueryAttentionParameters& parameters, // attention parameters AllocatorPtr allocator, // allocator for temporary tensors OpKernelContext* context) const { + const bool is_prompt = parameters.is_first_prompt; const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int head_size = parameters.head_size; @@ -88,14 +89,14 @@ class GQAAttentionBase { const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, - present_key_data, past_present_share_buffer, packed_qkv, tp); + present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v, seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, - tp); + is_prompt, tp); return Status::OK(); } @@ -105,35 +106,35 @@ class GQAAttentionBase { // attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T) // attention_probs(B, N, S, T) = Softmax(attention_probs) template - void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT - const T* Q, // Q data. Its size is BxNxSxH - const T* K, // k data. Its size is BxNxLxH - const int32_t* seqlens_k, // past sequence lengths tensor - int batch_size, // batch size of self-attention - int sequence_length, // sequence length of self-attention (S) - int past_buffer_sequence_length, // sequence length of past state - int present_buffer_sequence_length, // sequence length of present state - int head_size, // head size of self-attention - const T* past_key, // past key only - T* present_key, // present key only - bool past_present_share_buffer, // whether present key and value share the same buffer - bool packed_qkv, // whether Q, K, V are packed - ThreadPool* tp) const { // thread pool - const bool is_prompt = sequence_length != 1; + void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT + const T* Q, // Q data. Its size is BxNxSxH + const T* K, // k data. Its size is BxNxLxH + const int32_t* seqlens_k, // total - 1 sequence lengths tensor + const size_t batch_size, // batch size of self-attention + const size_t sequence_length, // sequence length of self-attention (S) + const size_t past_buffer_sequence_length, // sequence length of past state + const size_t present_buffer_sequence_length, // sequence length of present state + const size_t head_size, // head size of self-attention + const T* past_key, // past key only + T* present_key, // present key only + const bool past_present_share_buffer, // whether present key and value share the same buffer + const bool packed_qkv, // whether Q, K, V are packed + const bool is_prompt, // whether it is prompt + ThreadPool* tp) const { // thread pool const ptrdiff_t packed_batch_stride = packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : SafeInt(0); - const int kv_num_heads_factor = num_heads_ / kv_num_heads_; - const size_t q_input_chunk_length = static_cast(sequence_length) * head_size; // S x H - const size_t kv_input_chunk_length = static_cast(sequence_length) * head_size; // L x H - const size_t past_buff_chunk_length = static_cast(past_buffer_sequence_length) * head_size; // L x H - const size_t present_buff_chunk_length = static_cast(present_buffer_sequence_length) * head_size; // T x H + const size_t kv_num_heads_factor = num_heads_ / kv_num_heads_; + const size_t q_input_chunk_length = sequence_length * head_size; // S x H + const size_t kv_input_chunk_length = sequence_length * head_size; // L x H + const size_t past_buff_chunk_length = past_buffer_sequence_length * head_size; // L x H + const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H if (!past_present_share_buffer) { memset(present_key, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); } - const int loop_len = batch_size * num_heads_; + const size_t loop_len = batch_size * num_heads_; const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; TensorOpCost unit_cost; @@ -156,12 +157,11 @@ class GQAAttentionBase { ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t i = begin; i != end; ++i) { - const int batch_index = static_cast(i) / num_heads_; - const int head_index = static_cast(i) % num_heads_; - const int past_seqlen = - sequence_length == 1 ? static_cast(seqlens_k[batch_index]) : past_buffer_sequence_length; - const size_t past_chunk_length = static_cast(past_seqlen) * head_size; - const int total_seqlen = seqlens_k[batch_index] + 1; + const size_t batch_index = i / num_heads_; + const size_t head_index = i % num_heads_; + const size_t total_seqlen = static_cast(seqlens_k[batch_index]) + 1; + const size_t past_seqlen = is_prompt ? 0 : total_seqlen - sequence_length; // Assume no padding sequence length + const size_t past_chunk_length = past_seqlen * head_size; const ptrdiff_t output_offset = SafeInt(i) * sequence_length * present_buffer_sequence_length; T* output = attention_probs + output_offset; @@ -174,7 +174,7 @@ class GQAAttentionBase { } if (nullptr != present_key) { k = ConcatStateChunkGQA(past_key, k, present_key, present_buff_chunk_length, past_buff_chunk_length, - past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer, + past_chunk_length, kv_input_chunk_length, past_present_share_buffer, i / kv_num_heads_factor); } @@ -189,16 +189,17 @@ class GQAAttentionBase { } else { q = Q + q_input_chunk_length * i; } + math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q, - head_size, k, head_size, 0.0f /*bata*/, output, present_buffer_sequence_length, - nullptr); + static_cast(head_size), k, static_cast(head_size), 0.0f /*bata*/, output, + static_cast(present_buffer_sequence_length), nullptr); // compute Softmax T* output_softmax = output; - for (int seq = 0; seq < sequence_length; seq++) { - int seq_causal_length = sequence_length == 1 ? total_seqlen : seq + 1; - if (local_window_size_ > 0 && seq_causal_length > local_window_size_ + 1) { - for (int total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) { + for (size_t seq = 0; seq < sequence_length; seq++) { + size_t seq_causal_length = past_seqlen + seq + 1; + if (local_window_size_ > 0 && seq_causal_length > static_cast(local_window_size_) + 1) { + for (size_t total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) { output_softmax[total_seq_id] = 0.f; } if (softcap_ > 0.f) { @@ -214,17 +215,17 @@ class GQAAttentionBase { } } else { if (softcap_ > 0.f) { - ComputeAttentionSoftcapInplace(output_softmax, seq_causal_length, softcap_); + ComputeAttentionSoftcapInplace(output_softmax, static_cast(seq_causal_length), softcap_); } if (use_smooth_softmax_) { - ComputeSmoothSoftmaxInplace(output_softmax, 1, seq_causal_length, nullptr); + ComputeSmoothSoftmaxInplace(output_softmax, 1, static_cast(seq_causal_length), nullptr); } else { - ComputeAttentionSoftmaxInplace(output_softmax, 1, seq_causal_length, nullptr); + ComputeAttentionSoftmaxInplace(output_softmax, 1, static_cast(seq_causal_length), nullptr); } } // set causal [seq_causal_length, total_seqlen) to 0.f - for (int total_seq_id = seq_causal_length; total_seq_id < total_seqlen; total_seq_id++) { + for (size_t total_seq_id = seq_causal_length; total_seq_id < total_seqlen; total_seq_id++) { output_softmax[total_seq_id] = 0.f; } @@ -235,34 +236,36 @@ class GQAAttentionBase { } template - void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH - const T* attention_probs, // Attention probs with size BxNxSxT - const T* V, // V value with size BxN_kvxSxH - const int32_t* seqlens_k, // past sequence lengths tensor - int batch_size, // batch size - int sequence_length, // sequence length - int past_buffer_sequence_length, // sequence length in past state - int present_buffer_sequence_length, // sequence length in past state - int head_size, // head size of Q, K, V - int hidden_size, // hidden size of Output - const T* past_value, // past value only - T* present_value, // present value only - bool past_present_share_buffer, // whether present key and value share the same buffer - bool packed_qkv, // whether Q, K, V are packed + void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH + const T* attention_probs, // Attention probs with size BxNxSxT + const T* V, // V value with size BxN_kvxSxH + const int32_t* seqlens_k, // total - 1 sequence lengths tensor + const size_t batch_size, // batch size + const size_t sequence_length, // sequence length + const size_t past_buffer_sequence_length, // sequence length in past state + const size_t present_buffer_sequence_length, // sequence length in past state + const size_t head_size, // head size of Q, K, V + const size_t hidden_size, // hidden size of Output + const T* past_value, // past value only + T* present_value, // present value only + const bool past_present_share_buffer, // whether present key and value share the same buffer + const bool packed_qkv, // whether Q, K, V are packed + const bool is_prompt, // whether it is prompt ThreadPool* tp) const { - const bool is_prompt = sequence_length != 1; const ptrdiff_t packed_batch_stride = packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : SafeInt(0); - const int kv_num_heads_factor = num_heads_ / kv_num_heads_; - const int kv_input_chunk_length = sequence_length * head_size; // L x H - const size_t past_buff_chunk_length = static_cast(past_buffer_sequence_length) * head_size; // L x H - const size_t present_buff_chunk_length = static_cast(present_buffer_sequence_length) * head_size; // T x H + const size_t kv_num_heads_factor = num_heads_ / kv_num_heads_; + const size_t kv_input_chunk_length = sequence_length * head_size; // L x H + const size_t past_buff_chunk_length = past_buffer_sequence_length * head_size; // L x H + const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H if (!past_present_share_buffer) { memset(present_value, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); } + const size_t loop_len = batch_size * num_heads_; + // The cost of Gemm TensorOpCost unit_cost; unit_cost.compute_cycles = @@ -282,37 +285,35 @@ class GQAAttentionBase { unit_cost.bytes_loaded += bytes_to_copy_trans_all; unit_cost.bytes_stored += bytes_to_copy_trans_all; - ThreadPool::TryParallelFor( - tp, SafeInt(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { - for (std::ptrdiff_t i = begin; i != end; ++i) { - const int batch_index = static_cast(i / num_heads_); - const int head_index = static_cast(i % num_heads_); - const int past_seqlen = - sequence_length == 1 ? static_cast(seqlens_k[batch_index]) : past_buffer_sequence_length; - const size_t past_chunk_length = static_cast(past_seqlen) * head_size; - const int total_seqlen = seqlens_k[batch_index] + 1; - - const T* v; - if (packed_qkv) { - v = V + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor); - } else { - v = V + kv_input_chunk_length * (i / kv_num_heads_factor); - } - if (nullptr != present_value) { - v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length, - past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer, - i / kv_num_heads_factor); - } + ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t i = begin; i != end; ++i) { + const size_t batch_index = i / num_heads_; + const size_t head_index = i % num_heads_; + const size_t total_seqlen = static_cast(seqlens_k[batch_index]) + 1; + const size_t past_seqlen = is_prompt ? 0 : total_seqlen - sequence_length; // Assume no padding sequence length + const size_t past_chunk_length = past_seqlen * head_size; + + const T* v; + if (packed_qkv) { + v = V + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor); + } else { + v = V + kv_input_chunk_length * (i / kv_num_heads_factor); + } + if (nullptr != present_value) { + v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length, + past_chunk_length, kv_input_chunk_length, past_present_share_buffer, + i / kv_num_heads_factor); + } - T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size; - ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * present_buffer_sequence_length * i; + T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size; + ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * present_buffer_sequence_length * i; - math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, - 1.f, /*alpha*/ - attention_probs + attention_probs_offset, present_buffer_sequence_length, v, - head_size, 0.0f /*beta*/, output_current, hidden_size, nullptr); - } - }); + math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, 1.f, /*alpha*/ + attention_probs + attention_probs_offset, + static_cast(present_buffer_sequence_length), v, static_cast(head_size), + 0.0f /*beta*/, output_current, static_cast(hidden_size), nullptr); + } + }); } }; diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 87675255f5ba4..2a38e4a1ac636 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -45,7 +45,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { const Tensor* past_key = context->Input(3); const Tensor* past_value = context->Input(4); const Tensor* seqlens_k = context->Input(5); - const Tensor* total_seqlen = context->Input(6); + const Tensor* total_seqlen_tensor = context->Input(6); const Tensor* cos_cache = context->Input(7); const Tensor* sin_cache = context->Input(8); @@ -61,7 +61,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { num_heads_, kv_num_heads_, seqlens_k, - total_seqlen, + total_seqlen_tensor, scale_, softcap_)); @@ -103,6 +103,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { } if (do_rotary_) { + // Initialize rotary parameters rotary_embedding_helper::RotaryParameters rotary_params = {}; rotary_params.batch_size = batch_size; rotary_params.sequence_length = sequence_length; @@ -114,17 +115,29 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { rotary_params.seq_stride = head_size; rotary_params.head_stride = sequence_length * rotary_params.seq_stride; rotary_params.batch_stride = (packed_qkv ? (num_heads_ + 2 * kv_num_heads_) : num_heads_) * rotary_params.head_stride; - rotary_params.position_ids_format = sequence_length == 1 ? 1 : 0; + rotary_params.position_ids_format = !parameters.is_first_prompt ? 1 : 0; rotary_params.transposed = true; auto* tp = context->GetOperatorThreadPool(); - std::vector pos_ids(sequence_length == 1 ? batch_size : 1); - if (sequence_length == 1) { + // Generate position ids + const int pos_ids_size = parameters.is_first_prompt ? 1 : batch_size * sequence_length; + std::vector pos_ids(pos_ids_size); + if (parameters.is_first_prompt) { + pos_ids[0] = static_cast(0); + } else { + // Note: As of now, interactive decoding supports only batch size 1 and token generation supports only sequence length 1. for (int b = 0; b < batch_size; b++) { - pos_ids[b] = static_cast(seqlens_k->Data()[b]); + const int total_seqlen = seqlens_k->Data()[b] + 1; + const int past_seqlen = total_seqlen - sequence_length; + for (int s = 0; s < sequence_length; s++) { + if (past_seqlen + s < total_seqlen) { + pos_ids[b * sequence_length + s] = static_cast(past_seqlen) + s; + } else { + pos_ids[b * sequence_length + s] = static_cast(1); + } + } } - } else { - pos_ids[0] = static_cast(0); } + // Initialize separate buffers for rotary embeddings const T* q_input; const T* k_input; T* q_rotary; @@ -149,6 +162,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { Q = RotaryQ; K = RotaryK; } + // Run rotary embedding for Q and K ORT_RETURN_IF_ERROR(RunRotaryEmbedding(tp, rotary_params, q_input, pos_ids.data(), cos_cache->Data(), sin_cache->Data(), q_rotary, rotary_interleaved_)); @@ -161,6 +175,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(RunRotaryEmbedding(tp, rotary_params, k_input, pos_ids.data(), cos_cache->Data(), sin_cache->Data(), k_rotary, rotary_interleaved_)); + // Pack V into rotary QKV buffer if (packed_qkv) { const T* v_input = k_input + kv_num_heads_ * sequence_length * head_size; T* v_rotary = k_rotary + kv_num_heads_ * sequence_length * head_size; diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index 3342052260ff9..0bdee151d2173 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -168,14 +168,13 @@ Status CheckInputs(const Tensor* query, "Input 'past_key' and 'past_value' shall be both present or both absent."); } - // Check seqlens_k tensor (holding past seqlen for token gen) - const auto& seqlens_dim = seqlens_k->Shape().GetDims(); - if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) { + const auto& seqlens_k_dim = seqlens_k->Shape().GetDims(); + if (seqlens_k_dim.size() != 1 && seqlens_k_dim[0] != batch_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "seqlens_k must be shape (batch_size)."); } - // Set present sequence length and kv_share_buffer from input total_seqlen tensor + // Set present sequence length from input total_seqlen tensor if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "total_sequence_length tensor must be of one element."); @@ -195,11 +194,11 @@ Status CheckInputs(const Tensor* query, } if (cos_dims[0] < total_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache dimension 0 should be not be less than total_sequence_length."); + "cos_cache dimension 0 shall not be less than total_sequence_length."); } if (sin_dims[0] < total_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sin_cache dimension 0 should be not be less than total_sequence_length."); + "sin_cache dimension 0 shall not be less than total_sequence_length."); } if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -219,7 +218,26 @@ Status CheckInputs(const Tensor* query, "Input 'cos_cache' and 'sin_cache' shall be both present or both absent."); } - bool is_prompt = sequence_length != 1; + bool is_subsequent_prompt = false; + if (sequence_length > 1 && sequence_length != total_sequence_length) { + if (batch_size != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "batch_size must be 1 when sequence_length > 1 and past context is given."); + } + is_subsequent_prompt = true; + } + + bool is_first_prompt; + if (is_subsequent_prompt) { + is_first_prompt = false; // irrelevant for interactive decoding + } else { + // If not interactive, sequence_length is 1 for token gen and arbitrarily large for prompt + is_first_prompt = (sequence_length == total_sequence_length); + if (!is_first_prompt && sequence_length != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "sequence_length shall be 1 when it is not prompt."); + } + } if (parameters != nullptr) { GroupQueryAttentionParameters* output_parameters = reinterpret_cast(parameters); @@ -227,6 +245,7 @@ Status CheckInputs(const Tensor* query, output_parameters->sequence_length = sequence_length; // sequence length of Q output_parameters->seqlen_past_kv_cache = past_sequence_length; // max sequence length of past kv tensors output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors + output_parameters->total_sequence_length = total_sequence_length; // total sequence length output_parameters->hidden_size = q_hidden_size; output_parameters->num_heads = num_heads; output_parameters->head_size = head_size; @@ -235,7 +254,8 @@ Status CheckInputs(const Tensor* query, output_parameters->rotary_dim = rotary_dim; output_parameters->is_packed_qkv = is_packed_qkv; output_parameters->is_unidirectional = true; - output_parameters->is_prompt = is_prompt; + output_parameters->is_subsequent_prompt = is_subsequent_prompt; + output_parameters->is_first_prompt = is_first_prompt; output_parameters->scale = scale; output_parameters->softcap = softcap; output_parameters->qkv_format = qkv_format; diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h index cf66bd8407126..37172074e5d86 100644 --- a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h @@ -184,7 +184,7 @@ class SparseAttentionBase { // Concatenate past_k + k -> present_k // TODO: avoid copying mutiple times for a group. k = ConcatStateChunkGQA(past_key, k, present_key, present_buff_chunk_length, past_buff_chunk_length, - past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer, + is_prompt ? 0 : past_chunk_length, kv_input_chunk_length, past_present_share_buffer, i / kv_num_heads_factor); // Compute Q*K' + AttentionMask @@ -365,7 +365,7 @@ class SparseAttentionBase { // Concatenate past_v + v -> present_v v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length, - past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer, + is_prompt ? 0 : past_chunk_length, kv_input_chunk_length, past_present_share_buffer, i / kv_num_heads_factor); DUMP_CPU_TENSOR("present_value", v, total_seq_len, head_size); diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index a10d2548fa7b8..7f1c3786858c8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -42,7 +42,6 @@ struct RightPaddingBatchHook { auto lse_dim = ceil_div((int32_t)(p.num_queries), kAlignLSE) * kAlignLSE; - // Advance to current batch - in case of different sequence lengths if (p.seqlen_k_ptr) { p.num_keys = p.seqlen_k_ptr[batch_id]; } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index d0ae812bb4fa2..6eff584cec5da 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -5,7 +5,7 @@ #include "core/platform/env_var_utils.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cuda/bert/group_query_attention.h" -#include "contrib_ops/cuda/bert/group_query_attention_helper.h" +#include "contrib_ops/cpu/bert/group_query_attention_helper.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" @@ -95,7 +95,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { kv_num_heads_, seqlens_k, total_seqlen, - is_past_bsnh_, scale_, softcap_, device_prop.maxThreadsPerBlock)); @@ -253,7 +252,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { data.out_accum = reinterpret_cast(out_accum_buffer.get()); } if (seqlens_k_buffer != nullptr) { - data.seqlens_k_total = reinterpret_cast(seqlens_k_buffer.get()); + data.seqlens_k_buff = reinterpret_cast(seqlens_k_buffer.get()); } // Memory Efficient Buffers if (k_buffer != nullptr) { diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h deleted file mode 100644 index e65827e4ccdd5..0000000000000 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h +++ /dev/null @@ -1,298 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/providers/common.h" -#include "contrib_ops/cpu/bert/attention_common.h" - -namespace onnxruntime { -namespace contrib { -namespace group_query_attention_helper { - -Status CheckInputs(const Tensor* query, - const Tensor* key, - const Tensor* value, - const Tensor* past_key, - const Tensor* past_value, - const Tensor* cos_cache, - const Tensor* sin_cache, - void* parameters, - int num_heads, - int kv_num_heads, - const Tensor* seqlens_k, - const Tensor* total_seqlen, - bool is_past_bsnh, - float scale, - float softcap) { - // Note: Here S* is past_cache_sequence_length, S- is past_sequence_length, S+ is sequence_length - // past_key : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr - // past_value : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr - // no packing for q/k/v: - // query (Q) : (B, S, D) or (B, S, (D_q + 2 D_kv)) - // key (K) : (B, S, D_kv) or nullptr - // value (V) : (B, S, D_kv) or nullptr - AttentionQkvFormat qkv_format = Q_K_V_BSNH; - AttentionQkvFormat past_kv_format = is_past_bsnh ? Q_K_V_BSNH : Q_K_V_BNSH; - const bool is_packed_qkv = key == nullptr; - const auto& query_dims = query->Shape().GetDims(); - - if (query_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", - query_dims.size()); - } - - int batch_size = static_cast(query_dims[0]); - int sequence_length = static_cast(query_dims[1]); - int q_hidden_size = static_cast(query_dims[2]); - int head_size = 0; - - if (num_heads % kv_num_heads != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", - num_heads % kv_num_heads); - } - - int kv_hidden_size = 0; - // Check key and value when not packed - if (!is_packed_qkv) { - head_size = static_cast(q_hidden_size) / num_heads; - if (head_size % 8 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "head_size must be a multiple of 8. Got head_size % 8 == ", - head_size % 8); - } - if (value == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); - } - const auto& key_dims = key->Shape().GetDims(); - if (key_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", - key_dims.size()); - } else if (query_dims[0] != key_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 0 (batch size)"); - } else if (query_dims[1] != key_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 1 (sequence length)"); - } - kv_hidden_size = static_cast(key_dims[2]); - const auto& value_dims = value->Shape().GetDims(); - if (value_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", - value_dims.size()); - } else if (query_dims[0] != value_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 0 (batch size)"); - } else if (query_dims[1] != value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 1 (sequence length)"); - } else if (value_dims[2] != kv_hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); - } - } else { - // Check packed qkv - head_size = static_cast(q_hidden_size) / (num_heads + 2 * kv_num_heads); - if (head_size % 8 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "head_size must be a multiple of 8. Got head_size % 8 == ", - head_size % 8); - } - if (value != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); - } - q_hidden_size = head_size * num_heads; - kv_hidden_size = head_size * kv_num_heads; - } - - // Check past-present KV - int32_t past_sequence_length = 0; - if (past_key != nullptr && past_value != nullptr) { - const auto& past_key_dims = past_key->Shape().GetDims(); - const auto& past_value_dims = past_value->Shape().GetDims(); - - if (past_key_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' is expected to have 4 dimensions, got ", - past_key_dims.size()); - } - if (past_value_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' is expected to have 4 dimensions, got ", - past_value_dims.size()); - } - - if (past_key_dims[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 0 should be batch_size, got ", - past_key_dims[0]); - } - if (past_value_dims[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' dimension 0 should be batch_size, got ", - past_value_dims[0]); - } - - // BNSH - if (!is_past_bsnh) { - if (past_key_dims[2] != past_value_dims[2]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "BNSH Input 'past_key' and 'past_value' should have same dimension 2 (max sequence" - "length or past sequence length), got ", - past_key_dims[1]); - } - if (past_key_dims[1] != kv_num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' shall have kv_num_heads"); - } - if (past_value_dims[1] != kv_num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' shall have kv_num_heads"); - } - // We assume all sequence in past kv are right-padded to max or past sequence length - past_sequence_length = static_cast(past_key_dims[2]); - // BSNH - } else { - if (past_key_dims[1] != past_value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "BNSH Input 'past_key' and 'past_value' should have same dimension 1 (max sequence" - "length or past sequence length), got ", - past_key_dims[1]); - } - if (past_key_dims[2] != kv_num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' shall have kv_num_heads"); - } - if (past_value_dims[2] != kv_num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' shall have kv_num_heads"); - } - // We assume all sequence in past kv are right-padded to max or past sequence length - past_sequence_length = static_cast(past_key_dims[1]); - } - - if (past_key_dims[3] != head_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 3 should be same as head_size, got ", - past_key_dims[3]); - } - if (past_value_dims[3] != head_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' dimension 3 should be same as head_size, got ", - past_value_dims[3]); - } - } else if (past_key != nullptr || past_value != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' and 'past_value' shall be both present or both absent."); - } - - // Check seqlens_k tensor (holding past seqlen for token gen) - const auto& seqlens_dim = seqlens_k->Shape().GetDims(); - if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "seqlens_k must be shape (batch_size)."); - } - - // Set present sequence length and kv_share_buffer from input total_seqlen tensor - if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "total_sequence_length tensor must be of one element."); - } - int total_sequence_length = *((*total_seqlen).template Data()); - int present_sequence_length = std::max(total_sequence_length, past_sequence_length); - - int rotary_dim = 0; - if (cos_cache != nullptr && sin_cache != nullptr) { - const auto& cos_dims = cos_cache->Shape().GetDims(); - const auto& sin_dims = sin_cache->Shape().GetDims(); - - if (head_size % 16 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "head_size shall be a multiple of 16. Got head_size % 16 == ", - head_size % 16); - } - if (cos_dims[0] < total_sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache dimension 0 should be not be less than total_sequence_length."); - } - if (sin_dims[0] < total_sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sin_cache dimension 0 should be not be less than total_sequence_length."); - } - if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); - } - if (sin_dims[1] > (head_size / 16) * 8 || sin_dims[1] % 8 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); - } - if (cos_dims[1] != sin_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache and sin_cache dimension 1 must be the same."); - } - rotary_dim = static_cast(cos_dims[1] * 2); - } else if (cos_cache != nullptr || sin_cache != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'cos_cache' and 'sin_cache' shall be both present or both absent."); - } - - bool is_prompt = (sequence_length == total_sequence_length); - if (!is_prompt && sequence_length != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sequence_length shall be 1 when it is not prompt."); - } - - if (parameters != nullptr) { - GroupQueryAttentionParameters* output_parameters = reinterpret_cast(parameters); - output_parameters->batch_size = batch_size; - output_parameters->sequence_length = sequence_length; // sequence length of Q - output_parameters->seqlen_past_kv_cache = past_sequence_length; // max sequence length of past kv tensors - output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors - output_parameters->total_sequence_length = total_sequence_length; // total sequence length - output_parameters->hidden_size = q_hidden_size; - output_parameters->num_heads = num_heads; - output_parameters->head_size = head_size; - output_parameters->kv_hidden_size = kv_hidden_size; - output_parameters->kv_num_heads = kv_num_heads; - output_parameters->rotary_dim = rotary_dim; - output_parameters->is_packed_qkv = is_packed_qkv; - output_parameters->is_prompt = is_prompt; - output_parameters->scale = scale; - output_parameters->softcap = softcap; - output_parameters->qkv_format = qkv_format; - output_parameters->past_kv_format = past_kv_format; - } - - return Status::OK(); -} - -Status CheckInputs(const Tensor* query, - const Tensor* key, - const Tensor* value, - const Tensor* past_key, - const Tensor* past_value, - const Tensor* cos_cache, - const Tensor* sin_cache, - void* parameters, - int num_heads, - int kv_num_heads, - const Tensor* seqlens_k, - const Tensor* total_seqlen, - bool is_past_bsnh, - float scale, - float softcap, - int max_threads_per_block) { - if (max_threads_per_block > 0 && num_heads > max_threads_per_block) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); - } - - return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale, softcap); -} - -} // namespace group_query_attention_helper -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index be94f26ec298f..8bf9848245ec7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -71,6 +71,8 @@ __global__ void ConcatNewToPastKV(const int new_seqlen, const T* new_kv, T* present_kv, const int* seqlens_k, + const bool past_only, + // const int* seqlens_q, const bool is_bsnh) { // refers to past; otherwise bnsh const int h = threadIdx.x; const int n = threadIdx.y; @@ -88,7 +90,9 @@ __global__ void ConcatNewToPastKV(const int new_seqlen, // past_kv: BPNH or BNPH // new_kv: BLNH // present_kv: BTNH or BNTH, where T = P + L - const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b]; + + // prompt, token, and interactive decoding cases + const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b] + 1 - new_seqlen; int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; if (s < past_seqlen) { @@ -96,7 +100,7 @@ __global__ void ConcatNewToPastKV(const int new_seqlen, const int past_head_stride = is_bsnh ? H : past_buffer_seqlen * H; const int in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; present_kv[out_offset] = past_kv[in_offset]; - } else if (s < past_seqlen + new_seqlen) { + } else if (!past_only && s < past_seqlen + new_seqlen) { // Note: new KV always BSNH const int new_batch_stride = new_seqlen * num_heads * H; const int new_row_stride = num_heads * H; @@ -116,6 +120,7 @@ __global__ void ConcatNewToPastKVLarge(const int new_seqlen, const T* new_kv, T* present_kv, const int* seqlens_k, + const bool past_only, const bool is_bsnh) { int i = threadIdx.x + (blockDim.x * blockIdx.x); if (i < H * num_heads) { @@ -132,7 +137,9 @@ __global__ void ConcatNewToPastKVLarge(const int new_seqlen, // past_kv: BPNH or BNPH // new_kv: BLNH // present_kv: BTNH or BNTH, where T = P + L - const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b]; + + // prompt, token, and interactive decoding cases + const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b] + 1 - new_seqlen; int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; if (s < past_seqlen) { @@ -140,7 +147,7 @@ __global__ void ConcatNewToPastKVLarge(const int new_seqlen, const int past_head_stride = is_bsnh ? H : past_buffer_seqlen * H; const int in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; present_kv[out_offset] = past_kv[in_offset]; - } else if (s < past_seqlen + new_seqlen) { + } else if (!past_only && s < past_seqlen + new_seqlen) { const int new_batch_stride = new_seqlen * num_heads * H; const int new_row_stride = num_heads * H; const int new_head_stride = H; @@ -160,13 +167,12 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter const int max_threads_per_block, const bool past_only = false) { const int batch_size = parameters.batch_size; - const int kv_sequence_length = past_only ? 0 : parameters.sequence_length; + const int kv_sequence_length = parameters.sequence_length; const int past_sequence_length = parameters.seqlen_past_kv_cache; const int present_sequence_length = parameters.seqlen_present_kv_cache; const int kv_num_heads = parameters.kv_num_heads; const int head_size = parameters.head_size; - const int* seqlens_k = parameters.is_prompt ? nullptr : reinterpret_cast(data.seqlens_k); - + const int* seqlens_k = parameters.is_first_prompt ? nullptr : reinterpret_cast(data.seqlens_k); AttentionQkvFormat past_kv_format = parameters.past_kv_format; assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); @@ -180,6 +186,7 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter reinterpret_cast(new_key), reinterpret_cast(data.present_key), seqlens_k, + past_only, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); ConcatNewToPastKV<<>>(kv_sequence_length, past_sequence_length, @@ -187,6 +194,7 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter reinterpret_cast(new_value), reinterpret_cast(data.present_value), seqlens_k, + past_only, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); } else { int steps = (H * kv_num_heads + 255) / 256; @@ -200,6 +208,7 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter reinterpret_cast(new_key), reinterpret_cast(data.present_key), seqlens_k, + past_only, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); ConcatNewToPastKVLarge<<>>(kv_sequence_length, past_sequence_length, @@ -209,6 +218,7 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter reinterpret_cast(new_value), reinterpret_cast(data.present_value), seqlens_k, + past_only, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); } return CUDA_CALL(cudaGetLastError()); @@ -219,7 +229,7 @@ template __global__ void ConcatKVInPlace(const int max_seqlen, T* kv_buff, const T* new_kv, - const int* past_seqlens_k, + const int* seqlens_k, const int* total_seqlens_k, const bool is_past_kv_bnsh_format, const bool is_new_kv_bnsh_format) { @@ -234,7 +244,7 @@ __global__ void ConcatKVInPlace(const int max_seqlen, const int past_seq_len = (total_seqlens_k != nullptr) ? (total_seqlens_k[b] - new_seqlen) - : (past_seqlens_k == nullptr ? 0 : past_seqlens_k[b]); + : (seqlens_k == nullptr ? 0 : (seqlens_k[b] + 1 - new_seqlen)); int out_offset = is_past_kv_bnsh_format ? INDEX_4D(kv_num_heads, max_seqlen, H, b, n, s + past_seq_len, h) @@ -253,7 +263,7 @@ __global__ void ConcatKVInPlaceLarge(const int max_seqlen, const int kv_num_heads, T* kv_buff, const T* new_kv, - const int* past_seqlens_k, + const int* seqlens_k, const int* total_seqlens_k, const bool is_past_kv_bnsh_format, const bool is_new_kv_bnsh_format) { // refers to kv buff; otherwise bnsh @@ -264,9 +274,10 @@ __global__ void ConcatKVInPlaceLarge(const int max_seqlen, const int s = blockIdx.y; const int b = blockIdx.z; const int new_seqlen = gridDim.y; + const int past_seq_len = (total_seqlens_k != nullptr) ? (total_seqlens_k[b] - new_seqlen) - : (past_seqlens_k == nullptr ? 0 : past_seqlens_k[b]); + : (seqlens_k == nullptr ? 0 : (seqlens_k[b] + 1 - new_seqlen)); int out_offset = is_past_kv_bnsh_format ? INDEX_4D(kv_num_heads, max_seqlen, H, b, n, s + past_seq_len, h) @@ -286,15 +297,15 @@ Status LaunchConcatKVInPlace(int batch_size, int kv_num_heads, int head_size, int max_sequence_length, - const int* past_seqlens_k, + const int* seqlens_k, const int* total_seqlens_k, int new_seq_len, const T* new_key, const T* new_value, T* present_key, T* present_value, - bool is_past_kv_bnsh_format, - bool is_new_kv_bnsh_format, + const bool is_past_kv_bnsh_format, + const bool is_new_kv_bnsh_format, cudaStream_t stream, const int max_threads_per_block) { static_assert(sizeof(T) == 2); @@ -307,14 +318,14 @@ Status LaunchConcatKVInPlace(int batch_size, ConcatKVInPlace<<>>(max_sequence_length, reinterpret_cast(present_key), reinterpret_cast(new_key), - past_seqlens_k, + seqlens_k, total_seqlens_k, is_past_kv_bnsh_format, is_new_kv_bnsh_format); ConcatKVInPlace<<>>(max_sequence_length, reinterpret_cast(present_value), reinterpret_cast(new_value), - past_seqlens_k, + seqlens_k, total_seqlens_k, is_past_kv_bnsh_format, is_new_kv_bnsh_format); @@ -327,7 +338,7 @@ Status LaunchConcatKVInPlace(int batch_size, kv_num_heads, reinterpret_cast(present_key), reinterpret_cast(new_key), - past_seqlens_k, + seqlens_k, total_seqlens_k, is_past_kv_bnsh_format, is_new_kv_bnsh_format); @@ -336,7 +347,7 @@ Status LaunchConcatKVInPlace(int batch_size, kv_num_heads, reinterpret_cast(present_value), reinterpret_cast(new_value), - past_seqlens_k, + seqlens_k, total_seqlens_k, is_past_kv_bnsh_format, is_new_kv_bnsh_format); @@ -354,7 +365,8 @@ Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, cudaStream_t stream, const int max_threads_per_block) { const int max_sequence_length = parameters.seqlen_present_kv_cache; - const int* past_seqlens_k = parameters.is_prompt ? nullptr : reinterpret_cast(data.seqlens_k); + const int* seqlens_k = (parameters.is_first_prompt && !parameters.is_subsequent_prompt) ? nullptr + : reinterpret_cast(data.seqlens_k); assert(parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); @@ -364,8 +376,8 @@ Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, parameters.kv_num_heads, parameters.head_size, max_sequence_length, - past_seqlens_k, - nullptr, // total_seqlens_k is not available + seqlens_k, + nullptr, // total_seqlens_k would be wrong to use here parameters.sequence_length, reinterpret_cast(new_key), reinterpret_cast(new_value), @@ -495,23 +507,33 @@ __global__ void PastToTotalSeqlen(int32_t* seqlens_k, seqlens_k_buff[threadIdx.x] = seqlens_k[threadIdx.x] + add_seqlen; } -// Convert Past to Total sequence length tensor -Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, - int32_t* seqlens_k_buff, bool is_total, cudaStream_t stream, - const int /*threads_per_block*/) { - if (parameters.is_prompt) { - return Status::OK(); - } - const int batch_size = parameters.batch_size; - const int add_seqlen = is_total ? parameters.sequence_length : 0; - +// Calculate total sequence length from seqlens_k +Status LaunchGetSeqlensTotal(int32_t* seqlens_k, int32_t* seqlens_k_buff, const int batch_size, cudaStream_t stream, + const int /*threads_per_block*/) { const dim3 grid(1, 1, 1); // TODO(aciddelgado): unlikely but could have a bigger batch_size than max_threads const dim3 block(batch_size, 1, 1); + PastToTotalSeqlen<<>>(seqlens_k, seqlens_k_buff, 1); + return CUDA_CALL(cudaGetLastError()); +} - // TODO(aciddelgado): small version - PastToTotalSeqlen<<>>(seqlens_k, seqlens_k_buff, add_seqlen); +// Currently, interactive decoding only works for batch_size 1 +__global__ void GetSeqlensInteractive(const int32_t* seqlens_k, int32_t* seqlens_k_buff, + const int batch_size, const int sequence_length) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid < batch_size) { + seqlens_k_buff[tid] = seqlens_k[tid] + 1 - sequence_length; + } +} +// Calculate past sequence length for each batch entry for flash attention kernel +Status LaunchGetSeqlensInteractive(const int32_t* seqlens_k, int32_t* seqlens_k_buff, + const int batch_size, const int sequence_length, cudaStream_t stream, + const int max_threads_per_block) { + const int threads = std::min(batch_size, max_threads_per_block); + const int blocks = (threads / max_threads_per_block) + 1; + GetSeqlensInteractive<<>>(seqlens_k, seqlens_k_buff, batch_size, + sequence_length); return CUDA_CALL(cudaGetLastError()); } @@ -576,7 +598,22 @@ Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unp return CUDA_CALL(cudaGetLastError()); } -// Kernel to convert seqlens_k to position_ids +__global__ void SeqlensToPosIdsInteractive(const int32_t* seqlens_k, int64_t* position_ids, + const int seqlen, const int batch_size) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + int b = tid / seqlen; + int s = tid % seqlen; + if (b < batch_size) { + const int total_seqlen = seqlens_k[b] + 1; + const int past_seqlen = total_seqlen - seqlen; + if (past_seqlen + s < total_seqlen) { + position_ids[tid] = past_seqlen + s; + } else { + position_ids[tid] = 1; + } + } +} + __global__ void SeqlensToPosIdsPrompt(const int32_t* seqlens_k, int64_t* position_ids, const int seqlen, const int batch_size) { int tid = blockDim.x * blockIdx.x + threadIdx.x; @@ -591,7 +628,6 @@ __global__ void SeqlensToPosIdsPrompt(const int32_t* seqlens_k, int64_t* positio } } -// Kernel to convert seqlens_k to position_ids __global__ void SeqlensToPosIdsToken(const int32_t* seqlens_k, int64_t* position_ids, const int batch_size) { int tid = blockDim.x * blockIdx.x + threadIdx.x; if (tid < batch_size) { @@ -601,12 +637,15 @@ __global__ void SeqlensToPosIdsToken(const int32_t* seqlens_k, int64_t* position // Convert seqlens_k to position_ids Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, const int32_t* seqlens_k, - int64_t* position_ids, cudaStream_t stream, const int max_threads_per_block) { + int64_t* position_ids, cudaStream_t stream, + const int max_threads_per_block) { const int seqlen = parameters.sequence_length; const int batch_size = parameters.batch_size; const int threads = max_threads_per_block; const int blocks = (batch_size * seqlen + threads - 1) / threads; - if (parameters.is_prompt) { + if (parameters.is_subsequent_prompt) { + SeqlensToPosIdsInteractive<<>>(seqlens_k, position_ids, seqlen, batch_size); + } else if (parameters.is_first_prompt) { SeqlensToPosIdsPrompt<<>>(seqlens_k, position_ids, seqlen, batch_size); } else { SeqlensToPosIdsToken<<>>(seqlens_k, position_ids, batch_size); @@ -650,7 +689,12 @@ Status FlashAttention( } void* seqlens_k = reinterpret_cast(data.seqlens_k); - if (parameters.is_prompt) { + if (parameters.is_subsequent_prompt) { + ORT_RETURN_IF_ERROR(LaunchGetSeqlensInteractive(reinterpret_cast(data.seqlens_k), + reinterpret_cast(data.seqlens_k_buff), batch_size, + sequence_length, stream, max_threads_per_block)); + seqlens_k = reinterpret_cast(data.seqlens_k_buff); + } else if (parameters.is_first_prompt) { // set seqlens_k to zeros... flash api uses seqlens_k to indicate where to append key and value // user should use seqlens_k to index into output to get new tokens if (batch_size <= parameters.zeros_count) { @@ -659,10 +703,12 @@ Status FlashAttention( // Launch kernel to create larger seqlen tensor when batch_size > 256 constexpr int thr_per_blk = 256; int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk; - repeat_seqlen<<>>(data.seqlens_k_total, 0, batch_size); - seqlens_k = data.seqlens_k_total; + repeat_seqlen<<>>(data.seqlens_k_buff, 0, batch_size); + seqlens_k = reinterpret_cast(data.seqlens_k_buff); } - } else if (!parameters.kv_share_buffer) { // copy past kv to present kv + } + + if (!parameters.kv_share_buffer || parameters.is_first_prompt) { // copy past kv to present kv ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, nullptr, nullptr, stream, max_threads_per_block, true)); } @@ -682,7 +728,7 @@ Status FlashAttention( reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved, parameters.is_packed_qkv)); - // if (parameters.left_padding && parameters.is_prompt) { + // if (parameters.left_padding && parameters.is_first_prompt) { // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); // } @@ -766,15 +812,16 @@ Status EfficientAttention( key = reinterpret_cast(k_buffer); } - if (parameters.is_prompt) { + if (parameters.is_subsequent_prompt || !parameters.is_first_prompt) { + ORT_RETURN_IF_ERROR(LaunchGetSeqlensTotal(data.seqlens_k, data.seqlens_k_buff, batch_size, stream, 256)); + } else { // Launch kernel to copy seqlen constexpr int thr_per_blk = 256; int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk; - repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, + repeat_seqlen<<>>(data.seqlens_k_buff, parameters.sequence_length, batch_size); - } else { - ORT_RETURN_IF_ERROR(LaunchGetSeqlenBuff(parameters, data.seqlens_k, data.seqlens_k_total, true, stream, 256)); } + int* seqlens_k = data.seqlens_k_buff; if (parameters.kv_share_buffer) { // Share buffer case @@ -815,7 +862,7 @@ Status EfficientAttention( } DUMP_TENSOR_INIT(); - DUMP_TENSOR("seqlens_k", data.seqlens_k_total, batch_size, 1); + DUMP_TENSOR("seqlens_k", seqlens_k, batch_size, 1); MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; @@ -823,14 +870,14 @@ Status EfficientAttention( p.batch_size = batch_size; p.num_heads = num_heads; p.sequence_length = sequence_length; - p.kv_sequence_length = present_sequence_length; // TOTALLY UNNECESSARY IF WE HAVE SEQLENS_K, maybe remove + p.kv_sequence_length = present_sequence_length; // maybe remove p.max_sequence_length = present_sequence_length; p.qk_head_size = head_size; p.v_head_size = head_size; p.causal = true; p.scale = scale; p.softcap = parameters.softcap; - p.seqlen_k_ptr = data.seqlens_k_total; // Note: seqlens_k is total sequence length for efficient + p.seqlen_k_ptr = seqlens_k; // Note: seqlens_k is total sequence length for efficient p.seqstart_q_ptr = nullptr; p.seqstart_k_ptr = nullptr; p.query = query; @@ -912,7 +959,7 @@ template Status LaunchConcatKVInPlace(int batch_size, int kv_num_heads, int head_size, int max_sequence_length, - const int* past_seqlens_k, + const int* seqlens_k, const int* total_seqlens_k, int new_seq_len, const half* new_key, @@ -928,7 +975,7 @@ template Status LaunchConcatKVInPlace(int batch_size, int kv_num_heads, int head_size, int max_sequence_length, - const int* past_seqlens_k, + const int* seqlens_k, const int* total_seqlens_k, int new_seq_len, const BFloat16* new_key, diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index e8dc69188b95f..8593ecede2bab 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -27,7 +27,7 @@ struct GroupQueryAttentionData { T* softmax_lse = nullptr; T* softmax_lse_accum = nullptr; T* out_accum = nullptr; - int* seqlens_k_total = nullptr; + int* seqlens_k_buff = nullptr; // Memory Efficient buffers T* fmha_buffer = nullptr; T* unpacked_qkv_buffer = nullptr; @@ -61,7 +61,7 @@ Status LaunchConcatKVInPlace(int batch_size, int kv_num_heads, int head_size, int max_sequence_length, // max sequence length of present_key or present_value. - const int* past_seqlens_k, // it is not used when total_seqlens_k is available. + const int* seqlens_k, // it is not used when total_seqlens_k is available. const int* total_seqlens_k, // optional, nullptr means it is not available. int new_seq_len, const T* new_key, diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu index 7a16eb38181aa..e644b7e903138 100644 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -5,7 +5,7 @@ #include "core/providers/rocm/rocm_common.h" #include "core/platform/env_var_utils.h" #include "contrib_ops/rocm/bert/group_query_attention.h" -#include "contrib_ops/rocm/bert/group_query_attention_helper.h" +#include "contrib_ops/cpu/bert/group_query_attention_helper.h" #include "contrib_ops/rocm/bert/rotary_embedding_impl.h" #include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" @@ -115,7 +115,7 @@ Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, const int batch_size = parameters.batch_size; const int threads = max_threads_per_block; const int blocks = (batch_size * seqlen + threads - 1) / threads; - if (parameters.is_prompt) { + if (parameters.is_first_prompt) { SeqlensToPosIdsPrompt<<>>(seqlens_k, position_ids, seqlen, batch_size); } else { SeqlensToPosIdsToken<<>>(seqlens_k, position_ids, batch_size); @@ -325,7 +325,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { // build present kv cache auto* present_key_ptr = reinterpret_cast(present_key->MutableDataRaw()); auto* present_value_ptr = reinterpret_cast(present_value->MutableDataRaw()); - if (parameters.is_prompt) { + if (parameters.is_first_prompt) { // copy prompt kv to present kv ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, key_ptr, kv_shape, key_strides.ForBNSHCoord(), present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); @@ -383,7 +383,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { return ret; } - if (parameters.is_prompt && is_unidirectional_) { + if (parameters.is_first_prompt && is_unidirectional_) { return mask_info::decode("t", sequence_length, kv_sequence_length); } @@ -496,7 +496,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { parameters.head_size, parameters.head_size, // v head size GetCkFmhaDataTypeString(), - !parameters.is_prompt, // true, // is_group_mode + !parameters.is_first_prompt, // true, // is_group_mode true, // is_v_rowmajor ? dim is fastest : seq is fastest mask.type, bias_type, diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 5185205f1dde1..c706c6fc5ff5f 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1049,6 +1049,8 @@ Supports different number of heads for q and kv for CPU and CUDA. Only supports causal and local attention. Supports rotary position embedding for CPU and CUDA. Supports packed input for CPU and CUDA. +Supports continuous decoding for batch_size == 1 for CPU and CUDA. + )DOC"; ONNX_MS_OPERATOR_SET_SCHEMA( @@ -1110,12 +1112,12 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Input(5, "seqlens_k", - // For prompt, the value is number of tokens (excluding padding) - 1. - "1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.", + "1D Tensor of shape (batch_size). Equivalent to (total_sequence_lengths - 1).", "M") .Input(6, "total_sequence_length", - "Scalar tensor of total sequence length (past + new).", + "Scalar tensor equivalent to the maximum total sequence length (past + new) of the batch. Used for " + "checking inputs and determining prompt vs token generation case.", "M") .Input(7, "cos_cache", diff --git a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py index c04929a3b603e..46ab905977f48 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py +++ b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py @@ -223,6 +223,7 @@ def create_group_query_attention_graph_prompt( rotary=False, rotary_interleaved=False, packed=False, + interactive=False, softcap=0.0, use_smooth_softmax=False, ): @@ -1224,7 +1225,7 @@ def parity_check_gqa_prompt( config, causal=True, local=False, - past_format=Formats.BSNH, + past_format=Formats.BNSH, rotary=False, rotary_interleaved=False, packed=False, @@ -1422,7 +1423,7 @@ def parity_check_gqa_prompt_no_buff( config, causal=True, local=False, - past_format=Formats.BSNH, + past_format=Formats.BNSH, rotary=False, rotary_interleaved=False, packed=False, @@ -1597,7 +1598,7 @@ def parity_check_gqa_past( config, causal=True, local=False, - past_format=Formats.BSNH, + past_format=Formats.BNSH, rotary=False, rotary_interleaved=False, packed=False, @@ -1667,7 +1668,6 @@ def parity_check_gqa_past( if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) - # cache_seqlens = torch.tensor([config.past_sequence_length], device="cuda").repeat(config.batch_size) cache_seqlens = torch.randint( 0, config.kv_sequence_length - config.sequence_length + 1, @@ -1696,7 +1696,6 @@ def parity_check_gqa_past( "b 1 (s h) d -> b s h d", s=config.sequence_length, ) - # q_ro = q k_ro = rotary_embedding(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) else: cos, sin = None, None @@ -1730,6 +1729,8 @@ def parity_check_gqa_past( k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) + cache_seqlens += config.sequence_length - 1 + # Flash function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) @@ -1783,15 +1784,14 @@ def parity_check_gqa_past( numpy.testing.assert_allclose( present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg ) - numpy.testing.assert_allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg) def parity_check_gqa_past_no_buff( config, - causal=False, + causal=True, local=False, - past_format=Formats.BSNH, + past_format=Formats.BNSH, rotary=False, rotary_interleaved=False, packed=False, @@ -1864,7 +1864,6 @@ def parity_check_gqa_past_no_buff( v_cache_ref = v_cache_ref.transpose(1, 2) k_cache_ref = torch.cat((k_cache_ref, new_k), 1) v_cache_ref = torch.cat((v_cache_ref, new_v), 1) - # cache_seqlens = torch.tensor([config.past_sequence_length], device="cuda").repeat(config.batch_size) cache_seqlens = torch.randint( 0, config.kv_sequence_length, @@ -1896,7 +1895,6 @@ def parity_check_gqa_past_no_buff( "b 1 (s h) d -> b s h d", s=config.sequence_length, ) - # q_ro = q k_ro = rotary_embedding(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) else: cos, sin = None, None @@ -1930,6 +1928,8 @@ def parity_check_gqa_past_no_buff( k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) + cache_seqlens += config.sequence_length - 1 + # Flash function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) @@ -1976,6 +1976,23 @@ def parity_check_gqa_past_no_buff( f" with {config}, causal={causal}, local={local}, past_format={past_format}," f" rotary={rotary}, rotary_interleaved={rotary_interleaved}, packed={packed}, softcap={softcap}" ) + for b in range(config.batch_size): + numpy.testing.assert_allclose( + present_k[b, :, : (cache_seqlens + 1)[b]], + k_cache_ref[b, :, : (cache_seqlens + 1)[b]].detach().cpu().numpy(), + rtol=rtol, + atol=atol, + equal_nan=True, + err_msg=err_msg, + ) + numpy.testing.assert_allclose( + present_v[b, :, : (cache_seqlens + 1)[b]], + v_cache_ref[b, :, : (cache_seqlens + 1)[b]].detach().cpu().numpy(), + rtol=rtol, + atol=atol, + equal_nan=True, + err_msg=err_msg, + ) numpy.testing.assert_allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg) @@ -2229,6 +2246,86 @@ def gqa_past_flash_attention_test_cases(): ) +def gqa_interactive_one_batch_flash_attention_test_cases(): + batches = [1] + seqs = ( + [(2, 128), (128, 129), (32, 128), (256, 2048)] + if pipeline_mode + else [ + (1, 128), + (32, 128), + (128, 2048), + (1235, 5000), + (40, 800), + (1, 256), + (2, 799), + (41, 2048), + # (1, 128 * 512), + # (16, 128 * 512), + # (128, 128), + ] + ) + num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + random.seed(69) + + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for local in [False, True]: + for rotary, rotary_interleaved in rotary_options_for_current_os(): + for packed in [False, True]: + config = Config(b, s, s2, -1, n, n2, h) + yield ( + str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}", + config, + local, + rotary, + rotary_interleaved, + packed, + ) + + +def gqa_interactive_one_batch_memory_efficient_attention_test_cases(): + batches = [1] + seqs = ( + [(2, 128), (128, 129), (32, 128), (256, 2048)] + if pipeline_mode + else [ + (1, 128), + (32, 128), + (128, 2048), + (1235, 5000), + (40, 800), + (1, 256), + (2, 799), + (41, 2048), + # (1, 128 * 512), + # (16, 128 * 512), + # (128, 128), + ] + ) + num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + random.seed(69) + + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for rotary, rotary_interleaved in rotary_options_for_current_os(): + for packed in [False, True]: + config = Config(b, s, s2, -1, n, n2, h) + yield ( + str(config) + f"{rotary}_{rotary_interleaved}_{packed}", + config, + rotary, + rotary_interleaved, + packed, + ) + + class TestGQA(unittest.TestCase): @parameterized.expand(gqa_no_past_memory_efficient_test_cases()) def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap): @@ -2350,6 +2447,60 @@ def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interle use_smooth_softmax=True, ) + @parameterized.expand(gqa_interactive_one_batch_flash_attention_test_cases()) + def test_gqa_interactive_one_batch_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): + if not has_flash_attention(): + return + print("------- FLASH ATTENTION (INTERACTIVE) -------") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + + parity_check_gqa_past( + config, + local=local, + past_format=Formats.BNSH, + rtol=5e-3, + atol=5e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + parity_check_gqa_past_no_buff( + config, + local=local, + past_format=Formats.BNSH, + rtol=5e-3, + atol=5e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + + @parameterized.expand(gqa_interactive_one_batch_memory_efficient_attention_test_cases()) + def test_gqa_interactive_one_batch_memory_efficient_attention(self, _, config, rotary, rotary_interleaved, packed): + if not has_memory_efficient(): + return + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + print("-------- MEMORY EFFICIENT (INTERACTIVE) --------") + + parity_check_gqa_past( + config, + past_format=Formats.BNSH, + rtol=5e-3, + atol=5e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + parity_check_gqa_past_no_buff( + config, + past_format=Formats.BNSH, + rtol=5e-3, + atol=5e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index cc9d7ff51a5c6..dc21d4e4a5890 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -121,8 +121,12 @@ def rotate_tensor( else: x_rot = torch.cat((real, imag), dim=-1) else: - cos_x = cos[:, 0:seq_len, :, :] - sin_x = sin[:, 0:seq_len, :, :] + batch_size = x.shape[0] + cos_x = torch.zeros((batch_size, seq_len, 1, cos.shape[3]), device=x.device) + sin_x = torch.zeros((batch_size, seq_len, 1, sin.shape[3]), device=x.device) + for b in range(x.shape[0]): + cos_x[b] = cos[0, pos[b] : pos[b] + seq_len, :, :] + sin_x[b] = sin[0, pos[b] : pos[b] + seq_len, :, :] real = cos_x * x1 - sin_x * x2 imag = sin_x * x1 + cos_x * x2 if interleaved: @@ -716,7 +720,6 @@ def gqa_prompt_func( ort_inputs["sin_cache"] = sin.detach().cpu().numpy() io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) - # TODO: do we need io binding for cpu input? io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_input( "past_key", "cpu", 0, numpy.float32, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() @@ -788,6 +791,7 @@ def gqa_past_func( softcap=0.0, use_smooth_softmax=False, ): + assert seqlens_k is not None onnx_model_str = create_group_query_attention_graph_past( config, past_kv_format, @@ -819,12 +823,12 @@ def gqa_past_func( sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) io_binding = ort_session.io_binding() - if new_k is not None: + if new_k is not None and new_v is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() io_binding.bind_cpu_input("key", ort_inputs["key"]) io_binding.bind_cpu_input("value", ort_inputs["value"]) - if cos is not None: + if cos is not None and sin is not None: ort_inputs["cos_cache"] = cos.detach().cpu().numpy() ort_inputs["sin_cache"] = sin.detach().cpu().numpy() io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) @@ -867,12 +871,12 @@ def gqa_past_func( sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) io_binding = ort_session.io_binding() - if new_k is not None: + if new_k is not None and new_v is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() io_binding.bind_cpu_input("key", ort_inputs["key"]) io_binding.bind_cpu_input("value", ort_inputs["value"]) - if cos is not None: + if cos is not None and sin is not None: ort_inputs["cos_cache"] = cos.detach().cpu().numpy() ort_inputs["sin_cache"] = sin.detach().cpu().numpy() io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) @@ -1518,7 +1522,6 @@ def parity_check_gqa_past( if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) - # cache_seqlens = torch.tensor([config.past_sequence_length], device="cpu").repeat(config.batch_size) cache_seqlens = torch.randint( 0, config.kv_sequence_length - config.sequence_length + 1, @@ -1576,6 +1579,8 @@ def parity_check_gqa_past( k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) + cache_seqlens += config.sequence_length - 1 + # ORT function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) @@ -1739,7 +1744,6 @@ def parity_check_gqa_past_no_buff( v_cache_ref = v_cache_ref.transpose(1, 2) k_cache_ref = torch.cat((k_cache_ref, new_k), 1) v_cache_ref = torch.cat((v_cache_ref, new_v), 1) - # cache_seqlens = torch.tensor([config.past_sequence_length], device="cpu").repeat(config.batch_size) cache_seqlens = torch.randint( 0, config.kv_sequence_length, @@ -1800,6 +1804,8 @@ def parity_check_gqa_past_no_buff( k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) + cache_seqlens += config.sequence_length - 1 + # Flash function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) @@ -2000,6 +2006,61 @@ def test_gqa_past(self): ) self.assertTrue(all_close) + def test_gqa_interactive_one_batch(self): + print("-------- TEST GQA INTERACTIVE ---------") + batches = [1] + seqs = ( + [(2, 128), (128, 129), (32, 128), (256, 2048)] + if pipeline_mode + else [ + (1, 128), + (1, 339), + (1, 1024), + (1, 5000), + (1, 800), + (1, 256), + (1, 799), + (1, 2048), + # (1, 128 * 512), + # (16, 128 * 512), + # (128, 128), + ] + ) + num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 64, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + random.seed(69) + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for local in [False, True]: + for rotary, rotary_interleaved in [(False, False), (True, False), (True, True)]: + for packed in [False, True]: + config = Config(b, s, s2, -1, n, n2, h) + past_kv_format = Formats.BNSH + all_close = parity_check_gqa_past( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + self.assertTrue(all_close) + all_close = parity_check_gqa_past_no_buff( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + self.assertTrue(all_close) + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_sparse_attention.py b/onnxruntime/test/python/transformers/test_sparse_attention.py index 6a08d2101b100..5dbb9a277e45a 100644 --- a/onnxruntime/test/python/transformers/test_sparse_attention.py +++ b/onnxruntime/test/python/transformers/test_sparse_attention.py @@ -890,7 +890,7 @@ def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rot dtype=dtype, is_packed_qkv=packed_qkv, do_rotary=do_rotary, - rotary_interleaved=sequence_length <= 128, + rotary_interleaved=do_rotary and sequence_length <= 128, max_cache_sequence_length=None if sequence_length >= 128 else 128, ) yield config @@ -929,7 +929,7 @@ def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rot dtype=dtype, is_packed_qkv=packed_qkv, do_rotary=do_rotary, - rotary_interleaved=sequence_length <= 128, + rotary_interleaved=do_rotary and sequence_length <= 128, max_cache_sequence_length=None if sequence_length >= 128 else 128, # test smaller kv cache buffer. ) yield config @@ -940,7 +940,6 @@ def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rot class TestSparseAttention(unittest.TestCase): - @unittest.skipUnless(has_cuda_support(), "cuda not available") def test_sparse_attention_cuda(self): major, minor = torch.cuda.get_device_capability() @@ -1056,7 +1055,7 @@ def run_relevance_past(self, sm: int, device, do_rotary: bool): vert_stride=4, softmax_scale=None, do_rotary=do_rotary, - rotary_interleaved=(past_seq_len % 2 == 1), + rotary_interleaved=do_rotary and (past_seq_len % 2 == 1), device=device, is_packed_qkv=packed_qkv, max_rotary_sequence_length=None if past_seq_len >= 128 else 128, # test smaller rotary buffer. From a89bddd5c224c045510d09537a95d32602e021cc Mon Sep 17 00:00:00 2001 From: liqun Fu Date: Fri, 13 Sep 2024 14:55:08 -0700 Subject: [PATCH 27/39] Matmul_nbits kernel for mlas sqnbits to support Fp16 inputs (#21807) --- cmake/onnxruntime_mlas.cmake | 4 +- docs/OperatorKernels.md | 2 +- .../cpu/quantization/matmul_nbits.cc | 246 +++++++++++++----- .../cpu/quantization/matmul_nbits_impl.cc | 11 +- onnxruntime/core/mlas/inc/mlas.h | 36 ++- onnxruntime/core/mlas/lib/cast.cpp | 42 ++- onnxruntime/core/mlas/lib/mlasi.h | 11 +- onnxruntime/core/mlas/lib/platform.cpp | 4 + .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 45 ++++ .../core/providers/cpu/tensor/cast_op.cc | 2 +- .../test/contrib_ops/matmul_4bits_test.cc | 54 +++- 11 files changed, 341 insertions(+), 116 deletions(-) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index b612b3ead4658..e35c83ba45952 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -580,10 +580,10 @@ message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}") if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "11") message(STATUS "Using -mavx2 -mfma -mavxvnni flags") - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mavxvnni") + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c -mavxvnni") else() message(STATUS "Using -mavx2 -mfma flags") - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma") + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c") endif() set(mlas_platform_srcs_avx512f ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index d57394b3e7b97..121240e6e18f9 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -488,7 +488,7 @@ Do not modify directly.* |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| |MatMulInteger16|*in* A:**T1**
*in* B:**T2**
*out* Y:**T3**|1+|**T1** = tensor(int16)
**T2** = tensor(int16)
**T3** = tensor(int32)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)| -|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(float), tensor(uint8)
**T4** = tensor(int32)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)
**T3** = tensor(float), tensor(float16), tensor(uint8)
**T4** = tensor(int32)| |MaxpoolWithMask|*in* X:**T**
*in* M:**tensor(int32)**
*out* Y:**T**|1+|**T** = tensor(float)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float)| |MurmurHash3|*in* X:**T1**
*out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(uint32)| diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index bf43aca73ef3a..ccb779721d006 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -146,8 +146,15 @@ class MatMulNBits final : public OpKernel { bool all_constant_{false}; #endif // defined(ORT_NEURAL_SPEED) + + template + Status ComputeTyped(OpKernelContext* ctx) const; }; +bool IsATypeFloat16(const Tensor& tensor) { + return tensor.GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; +} + Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { @@ -211,10 +218,10 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat #else // defined(ORT_NEURAL_SPEED) ORT_UNUSED_PARAMETER(prepacked_weights); const auto compute_type = static_cast(accuracy_level_); + if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { + return Status::OK(); + } if (input_idx == InputIndex::B) { - if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { - return Status::OK(); - } packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type); if (packed_b_size_ == 0) { return Status::OK(); @@ -226,8 +233,15 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat } else if (compute_type == CompInt8) { #ifdef MLAS_TARGET_AMD64_IX86 if (input_idx == InputIndex::scales && packed_b_ != nullptr) { - auto sptr = tensor.Data(); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), sptr, has_zp_input_, nullptr, nullptr); + if (IsATypeFloat16(tensor)) { + auto sptr = tensor.Data(); + std::vector scales_v(static_cast(tensor.Shape().Size())); + MlasConvertHalfToFloatBuffer(sptr, &scales_v[0], scales_v.size()); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), &scales_v[0], has_zp_input_, nullptr, nullptr); + } else { + auto sptr = tensor.Data(); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), sptr, has_zp_input_, nullptr, nullptr); + } is_packed = false; } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { auto zptr = tensor.Data(); @@ -274,9 +288,20 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep } Status MatMulNBits::Compute(OpKernelContext* ctx) const { + const Tensor* a = ctx->Input(InputIndex::A); + + if (IsATypeFloat16(*a)) { + return ComputeTyped(ctx); + } else { + return ComputeTyped(ctx); + } +} + +template +Status MatMulNBits::ComputeTyped(OpKernelContext* ctx) const { concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); const Tensor* a = ctx->Input(InputIndex::A); - const auto* a_data = a->Data(); + const auto* a_data = a->Data(); TensorShape b_shape({static_cast(N_), static_cast(K_)}); MatMulComputeHelper helper; @@ -289,7 +314,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { return Status::OK(); } - auto* y_data = y->MutableData(); + auto* y_data = y->MutableData(); const size_t batch_count = helper.OutputOffsets().size(); const size_t M = static_cast(helper.M()); @@ -297,9 +322,12 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const size_t K = static_cast(helper.K()); const size_t lda = helper.Lda(false); - const bool has_single_b_matrix = std::all_of(helper.RightOffsets().begin(), - helper.RightOffsets().end(), - [](size_t offset) { return offset == 0; }); + // clang-format off + const bool has_single_b_matrix = std::all_of( + helper.RightOffsets().begin(), + helper.RightOffsets().end(), + [](size_t offset) { return offset == 0; }); + // clang-format on #if defined(ORT_NEURAL_SPEED) @@ -336,9 +364,9 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const Tensor* zero_points = ctx->Input(InputIndex::zero_points); const Tensor* bias = ctx->Input(InputIndex::bias); - const auto* scales_data = scales->Data(); + const auto* scales_data = scales->Data(); const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); - const auto* bias_data = bias == nullptr ? nullptr : bias->Data(); + const auto* bias_data = bias == nullptr ? nullptr : bias->Data(); IAllocatorUniquePtr workspace{}; const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize( @@ -349,26 +377,64 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); } - InlinedVector data(batch_count); - for (size_t i = 0; i < batch_count; ++i) { - data[i].A = a_data + helper.LeftOffsets()[i]; - data[i].lda = lda; + if constexpr (std::is_same::value) { + InlinedVector data(batch_count); + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); + + auto tmp_a_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(a->Shape().Size())); + MlasConvertHalfToFloatBuffer(a_data, tmp_a_data_ptr.get(), static_cast(a->Shape().Size())); + + auto tmp_scales_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(scales->Shape().Size())); + MlasConvertHalfToFloatBuffer(scales_data, tmp_scales_data_ptr.get(), static_cast(scales->Shape().Size())); + + std::vector bias_data_v; + if (bias_data != nullptr) { + bias_data_v.resize((const unsigned int)(bias->Shape().Size())); + MlasConvertHalfToFloatBuffer(bias_data, &bias_data_v[0], bias_data_v.size()); + } + std::vector C_v((const unsigned int)(y->Shape().Size())); + for (size_t i = 0; i < batch_count; ++i) { + data[i].A = tmp_a_data_ptr.get() + helper.LeftOffsets()[i]; + data[i].lda = lda; #ifdef MLAS_TARGET_AMD64_IX86 - if (compute_type == CompInt8) { - data[i].QuantBDataWorkspace = packed_b_.get(); + if (compute_type == CompInt8) { + data[i].QuantBDataWorkspace = packed_b_.get(); + } +#endif + data[i].PackedQuantBData = static_cast(packed_b_.get()); + data[i].QuantBScale = tmp_scales_data_ptr.get(); + data[i].QuantBZeroPoint = zero_points_data; + data[i].Bias = bias_data != nullptr ? &bias_data_v[0] : nullptr; + data[i].C = &C_v[0] + helper.OutputOffsets()[i]; + data[i].ldc = N; } + MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), + thread_pool); + MlasConvertFloatToHalfBuffer(&C_v[0], y_data, C_v.size()); + return Status::OK(); + } else { + InlinedVector data(batch_count); + for (size_t i = 0; i < batch_count; ++i) { + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; +#ifdef MLAS_TARGET_AMD64_IX86 + if (compute_type == CompInt8) { + data[i].QuantBDataWorkspace = packed_b_.get(); + } #endif - data[i].PackedQuantBData = static_cast(packed_b_.get()); - data[i].QuantBScale = scales_data; - data[i].QuantBZeroPoint = zero_points_data; - data[i].Bias = bias_data; - data[i].C = y_data + helper.OutputOffsets()[i]; - data[i].ldc = N; + data[i].PackedQuantBData = static_cast(packed_b_.get()); + data[i].QuantBScale = scales_data; + data[i].QuantBZeroPoint = zero_points_data; + data[i].Bias = bias_data; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + } + MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), + thread_pool); + return Status::OK(); } - MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), - thread_pool); - - return Status::OK(); } } @@ -380,7 +446,17 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const Tensor* zero_points = ctx->Input(InputIndex::zero_points); const Tensor* reorder_idx = ctx->Input(InputIndex::g_idx); - const auto* scales_data = scales->Data(); + const auto* scales_data = scales->Data(); + const float* scales_data_; + std::vector scales_data_v; + if constexpr (std::is_same::value) { + scales_data_v.resize((const unsigned int)scales->Shape().Size()); + MlasConvertHalfToFloatBuffer(scales_data, &scales_data_v[0], scales_data_v.size()); + scales_data_ = &scales_data_v[0]; + } else { + scales_data_ = scales_data; + } + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); @@ -391,12 +467,12 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { AllocatorPtr allocator; ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); - if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { + if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { // dequantize b, only 4b quantization is supported for now MlasDequantizeBlockwise( tmp_b_data_ptr.get(), // dequantized output b_data, // quantized input - scales_data, // quantization scales + scales_data_, // quantization scales static_cast(zero_points_data), // quantization zero points static_cast(block_size_), // quantization block size column_wise_quant_, // columnwise quantization or row-wise @@ -406,12 +482,12 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { } else { ORT_ENFORCE(column_wise_quant_, "Row-wise quantization is not supported for now"); // !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!! - if ((zero_points && zero_points->IsDataType())) { - DequantizeBlockwise( + if ((zero_points && zero_points->IsDataType())) { + DequantizeBlockwise( tmp_b_data_ptr.get(), // dequantized output b_data, // quantized input - scales_data, // quantization scales - static_cast(zero_points_data), // quantization zero points + scales_data_, // quantization scales + static_cast(zero_points_data), // quantization zero points reorder_idx_data, static_cast(block_size_), // quantization block size column_wise_quant_, // columnwise quantization or row-wise @@ -422,7 +498,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { DequantizeBlockwise( tmp_b_data_ptr.get(), // dequantized output b_data, // quantized input - scales_data, // quantization scales + scales_data_, // quantization scales static_cast(zero_points_data), // quantization zero points reorder_idx_data, static_cast(block_size_), // quantization block size @@ -436,40 +512,80 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { auto tm_b_data_ptr_trans = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); MlasTranspose(tmp_b_data_ptr.get(), tm_b_data_ptr_trans.get(), N_, K_); #endif + if constexpr (std::is_same::value) { + std::vector data(batch_count); - std::vector data(batch_count); - for (size_t i = 0; i < batch_count; i++) { - data[i].BIsPacked = false; - data[i].A = a_data + helper.LeftOffsets()[i]; - data[i].lda = lda; - data[i].B = tmp_b_data_ptr.get() + helper.RightOffsets()[i]; - data[i].ldb = ldb; - data[i].C = y_data + helper.OutputOffsets()[i]; - data[i].ldc = N; - data[i].alpha = 1.f; - data[i].beta = 0.0f; - } + auto tmp_a_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(a->Shape().Size())); + MlasConvertHalfToFloatBuffer(a_data, tmp_a_data_ptr.get(), static_cast(a->Shape().Size())); - // if there is a bias input, copy bias values into C and set beta to 1.0f - if (const Tensor* bias = ctx->Input(InputIndex::bias); - bias != nullptr) { - gsl::span bias_span = bias->DataAsSpan(); - for (size_t i = 0; i < batch_count; ++i) { - float* C_row = data[i].C; - const size_t ldc = data[i].ldc; - for (size_t m = 0; m < M; ++m) { - memcpy(C_row, bias_span.data(), bias_span.size_bytes()); - C_row += ldc; + auto tmp_c_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(y->Shape().Size())); + for (size_t i = 0; i < batch_count; i++) { + data[i].BIsPacked = false; + data[i].A = tmp_a_data_ptr.get() + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = tmp_b_data_ptr.get() + helper.RightOffsets()[i]; + data[i].ldb = ldb; + data[i].C = tmp_c_ptr.get() + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].alpha = 1.f; + data[i].beta = 0.0f; + } + + // if there is a bias input, copy bias values into C and set beta to 1.0f + if (const Tensor* bias = ctx->Input(InputIndex::bias); + bias != nullptr) { + auto tmp_bias_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(bias->Shape().Size())); + MlasConvertHalfToFloatBuffer(bias->Data(), tmp_bias_data_ptr.get(), static_cast(bias->Shape().Size())); + for (size_t i = 0; i < batch_count; ++i) { + float* C_row = data[i].C; + const size_t ldc = data[i].ldc; + for (size_t m = 0; m < M; ++m) { + std::copy(tmp_bias_data_ptr.get(), tmp_bias_data_ptr.get() + bias->Shape().Size(), C_row); + C_row += ldc; + } + data[i].beta = 1.0f; } + } - data[i].beta = 1.0f; + MlasGemmBatch(CblasNoTrans, CblasTrans, + M, N, K, data.data(), batch_count, thread_pool); + MlasConvertFloatToHalfBuffer(tmp_c_ptr.get(), y_data, static_cast(y->Shape().Size())); + return Status::OK(); + } else { + std::vector data(batch_count); + for (size_t i = 0; i < batch_count; i++) { + data[i].BIsPacked = false; + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = tmp_b_data_ptr.get() + helper.RightOffsets()[i]; + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].alpha = 1.f; + data[i].beta = 0.0f; } - } - MlasGemmBatch(CblasNoTrans, CblasTrans, - M, N, K, data.data(), batch_count, thread_pool); + // if there is a bias input, copy bias values into C and set beta to 1.0f + if (const Tensor* bias = ctx->Input(InputIndex::bias); + bias != nullptr) { + gsl::span bias_span = bias->DataAsSpan(); + for (size_t i = 0; i < batch_count; ++i) { + float* C_row = data[i].C; + const size_t ldc = data[i].ldc; + for (size_t m = 0; m < M; ++m) { + memcpy(C_row, bias_span.data(), bias_span.size_bytes()); + C_row += ldc; + } - return Status::OK(); + data[i].beta = 1.0f; + } + } + + MlasGemmBatch(CblasNoTrans, CblasTrans, + M, N, K, data.data(), batch_count, thread_pool); + + return Status::OK(); + } } ONNX_OPERATOR_KERNEL_EX( @@ -478,9 +594,9 @@ ONNX_OPERATOR_KERNEL_EX( 1, kCpuExecutionProvider, KernelDefBuilder() - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) .TypeConstraint("T2", DataTypeImpl::GetTensorType()) - .TypeConstraint("T3", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T3", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) .TypeConstraint("T4", DataTypeImpl::GetTensorType()), MatMulNBits); diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc index b28f3758f89b5..6a19a741c3028 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc @@ -54,12 +54,12 @@ void Dequantize4BitsKernelReOrder( T scale = *(scale_data + n_idx * scales_shape_x + rid); float zp_f = 8; if (zero_points) { - if constexpr (std::is_same_v) { - zp_f = *(zero_points + n_idx * scales_shape_x + rid); - } else { + if constexpr (std::is_same_v) { uint8_t zp = 8; zp = zero_points[n_idx * zero_point_shape_x + rid / 2]; zp = (rid & 0x01) ? (zp >> 4) : (zp & 0x0f); + } else { + zp_f = *(zero_points + static_cast(n_idx) * static_cast(scales_shape_x) + static_cast(rid)); } } @@ -112,5 +112,10 @@ template void DequantizeBlockwise( const float* zero_points, const int32_t* reorder_idx, int32_t block_size, bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); +template void DequantizeBlockwise( + float* output, const uint8_t* quant_data, const float* scales_data, + const MLFloat16* zero_points, const int32_t* reorder_idx, int32_t block_size, + bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); + } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 8b3156d77e57c..28ae64c4d5b3e 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -20,6 +20,7 @@ Module Name: #include #include #include +#include // // Define the calling convention for Windows targets. @@ -1025,18 +1026,6 @@ MlasComputeTanh( size_t N ); -// -// Half-precision floating-point routines. -// - -void -MLASCALL -MlasConvertHalfToFloatBuffer( - const unsigned short* Source, - float* Destination, - size_t Count -); - // // Transpose routines. // @@ -1426,7 +1415,27 @@ using MLAS_FP16 = onnxruntime::MLFloat16; constexpr size_t FP16_SIZE = sizeof(uint16_t); -/** +// +// Half-precision floating-point routines. +// + +void +MLASCALL +MlasConvertHalfToFloatBuffer( + const MLAS_FP16* Source, + float* Destination, + size_t Count +); + +void +MLASCALL +MlasConvertFloatToHalfBuffer( +const float* Source, +MLAS_FP16* Destination, +size_t Count +); + + /** * @brief Whether current CPU supports FP16 acceleration. */ bool MLASCALL @@ -1787,6 +1796,7 @@ MlasTranspose( M, N); } + #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED /** * @brief Max Pooling for fp16 NHWC diff --git a/onnxruntime/core/mlas/lib/cast.cpp b/onnxruntime/core/mlas/lib/cast.cpp index 24af4064bbd9b..a6138e29bd796 100644 --- a/onnxruntime/core/mlas/lib/cast.cpp +++ b/onnxruntime/core/mlas/lib/cast.cpp @@ -23,37 +23,35 @@ union fp32_bits { void MLASCALL MlasConvertHalfToFloatBuffer( - const unsigned short* Source, + const MLAS_FP16* Source, float* Destination, size_t Count ) { - if (GetMlasPlatform().CastF16ToF32Kernel == nullptr) { - // If there is no kernel use the reference implementation, adapted from mlas_float16.h. - constexpr fp32_bits magic = {113 << 23}; - constexpr uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift + for (size_t i = 0; i < Count; ++i) { + Destination[i] = Source[i].ToFloat(); + } + } else { + // If the kernel is available, use it to perform the conversion. + GetMlasPlatform().CastF16ToF32Kernel(reinterpret_cast(Source), Destination, Count); + } +} +void +MLASCALL +MlasConvertFloatToHalfBuffer( + const float* Source, + MLAS_FP16* Destination, + size_t Count +) +{ + if (GetMlasPlatform().CastF32ToF16Kernel == nullptr) { for (size_t i = 0; i < Count; ++i) { - fp32_bits o; - o.u = (Source[i] & 0x7fff) << 13; // exponent/mantissa bits - uint32_t exp = shifted_exp & o.u; // just the exponent - o.u += (127 - 15) << 23; // exponent adjust - - // handle exponent special cases - if (exp == shifted_exp) { // Inf/NaN? - o.u += (128 - 16) << 23; // extra exp adjust - } else if (exp == 0) { // Zero/Denormal? - o.u += 1 << 23; // extra exp adjust - o.f -= magic.f; // renormalize - } - - o.u |= (Source[i] & 0x8000) << 16; // sign bit - Destination[i] = o.f; + Destination[i] = MLAS_FP16(Source[i]); } - } else { // If the kernel is available, use it to perform the conversion. - GetMlasPlatform().CastF16ToF32Kernel(Source, Destination, Count); + GetMlasPlatform().CastF32ToF16Kernel(Source, reinterpret_cast(Destination), Count); } } diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 6f5db766b7def..8e8f46b8a102e 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -610,13 +610,19 @@ void size_t N ); -typedef +typedef void(MLASCALL MLAS_CAST_F16_TO_F32_KERNEL)( const unsigned short* Source, float* Destination, size_t Count ); +typedef void(MLASCALL MLAS_CAST_F32_TO_F16_KERNEL)( + const float* Source, + unsigned short* Destination, + size_t Count +); + typedef void (MLASCALL MLAS_QLINEAR_BINARY_OP_S8_KERNEL)( @@ -880,6 +886,8 @@ extern "C" { #if defined(MLAS_TARGET_AMD64) MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelSse; MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelAvx; + MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelAvx2; + MLAS_CAST_F32_TO_F16_KERNEL MlasCastF32ToF16KernelAvx2; #endif } @@ -1165,6 +1173,7 @@ struct MLAS_PLATFORM { const MLAS_SQNBIT_GEMM_DISPATCH* SQNBitGemmDispatch{nullptr}; MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel; + MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel; }; inline diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 4cd7faaa9e6ff..2b4d99800c546 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -245,6 +245,7 @@ Return Value: this->ConvDepthwiseS8S8Kernel = MlasConvDepthwiseKernel; this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernel; this->CastF16ToF32Kernel = nullptr; + this->CastF32ToF16Kernel = nullptr; #if defined(MLAS_TARGET_AMD64_IX86) @@ -387,6 +388,9 @@ Return Value: this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernelAvx2; this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelFma3; this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2; + this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx2; + this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelAvx2; + // // Check if the processor supports Hybrid core architecture. diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index 55d86bb9cc18e..baaa4ba1a3b1f 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -29,6 +29,51 @@ Module Name: #include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h" #include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h" +void +MlasCastF16ToF32KernelAvx2(const unsigned short* src_fp16, float* dst_fp32, size_t size) +{ + size_t i = 0; + + // Process 16 elements at a time using AVX2 + for (; i + 15 < size; i += 16) { + // Load 16 FP16 values into an AVX2 register + __m256i fp16_values = _mm256_loadu_si256(reinterpret_cast(src_fp16 + i)); + + // Convert FP16 values to FP32 + __m256 fp32_values1 = _mm256_cvtph_ps(_mm256_castsi256_si128(fp16_values)); + __m256 fp32_values2 = _mm256_cvtph_ps(_mm256_extracti128_si256(fp16_values, 1)); + + // Store the converted FP32 values into the output vector + _mm256_storeu_ps(dst_fp32 + i, fp32_values1); + _mm256_storeu_ps(dst_fp32 + i + 8, fp32_values2); + } + + // Process any remaining elements + const MLAS_FP16* fp16 = reinterpret_cast(src_fp16); + for (; i < size; ++i) { + dst_fp32[i] = fp16[i].ToFloat(); + } +} + +void +MlasCastF32ToF16KernelAvx2(const float* src_fp32, unsigned short* dst_fp16, size_t size) +{ + size_t i = 0; + + // Process 8 elements at a time using AVX2 + for (; i + 8 <= size; i += 8) { + __m256 fp32_chunk = _mm256_loadu_ps(&src_fp32[i]); + __m128i fp16_chunk = _mm256_cvtps_ph(fp32_chunk, _MM_FROUND_TO_NEAREST_INT); + _mm_storeu_si128(reinterpret_cast<__m128i*>(&dst_fp16[i]), fp16_chunk); + } + + // Process any remaining elements + for (; i < size; ++i) { + MLAS_FP16 fp16(src_fp32[i]); + dst_fp16[i] = fp16.val; + } +} + MLAS_FORCEINLINE __m256 load_float_n_avx2(const float* data, int n) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index f2aaa75cadd8d..35f3b12aeba35 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -258,7 +258,7 @@ struct TensorCaster { auto out_data = out.MutableData(); auto in_data = in.Data(); const size_t shape_size = narrow(shape.Size()); - MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size); + MlasConvertHalfToFloatBuffer(in_data, out_data, shape_size); } }; diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 548f24e8ac69e..fa7c6bce7c23e 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -262,8 +262,8 @@ void RunTest(const TestOptions& opts, } // namespace -TEST(MatMulNBits, Float32) { - // onnxruntime::profiling::Profiler::Profiler::Instance().StartProfiling("profile.json"); +template +void TestMatMulNBitsTyped() { for (auto M : {1, 2, 100}) { for (auto N : {/*2560, */ 1, 2, 32, 288}) { for (auto K : {/*2560, */ 16, 32, 64, 128, 256, 1024, 93, 1234}) { @@ -276,30 +276,53 @@ TEST(MatMulNBits, Float32) { if (base_opts.accuracy_level == 4) { base_opts.output_abs_error = 0.1f; + } else { + if constexpr (std::is_same::value) { + base_opts.output_abs_error = 0.01f; + } } { TestOptions opts = base_opts; - RunTest(opts); + RunTest(opts); } { TestOptions opts = base_opts; opts.has_zero_point = true; - RunTest(opts); + RunTest(opts); } #if !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) { TestOptions opts = base_opts; opts.has_g_idx = true; - RunTest(opts); + RunTest(opts); + } + + { + TestOptions opts = base_opts; + opts.has_g_idx = true; + opts.has_bias = true; + if constexpr (std::is_same::value) { + if (opts.accuracy_level == 0 || opts.accuracy_level == 1) { + // CI failure (not able to repro on either local machines): + // M:100, N:288, K:1234, block_size:16, accuracy_level:0, has_zero_point:0, zp_is_4bit:1, has_g_idx:1, has_bias:1 + // The difference between cur_expected[i] and cur_actual[i] is 1.0401010513305664e-05, which exceeds tolerance, + // tolerance evaluates to 1.006456386676291e-05. + opts.output_abs_error = 0.0001f; + } + } + // only enabled for CPU EP for now + std::vector> explicit_eps; + explicit_eps.emplace_back(DefaultCpuExecutionProvider()); + RunTest(opts, std::move(explicit_eps)); } { TestOptions opts = base_opts; opts.has_zero_point = true, opts.zp_is_4bit = false; - RunTest(opts); + RunTest(opts); } #endif // !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) @@ -311,7 +334,7 @@ TEST(MatMulNBits, Float32) { std::vector> explicit_eps; explicit_eps.emplace_back(DefaultCpuExecutionProvider()); - RunTest(opts, std::move(explicit_eps)); + RunTest(opts, std::move(explicit_eps)); } } } @@ -320,6 +343,21 @@ TEST(MatMulNBits, Float32) { } } +TEST(MatMulNBits, Float32) { + // onnxruntime::profiling::Profiler::Profiler::Instance().StartProfiling("profile.json"); + TestMatMulNBitsTyped(); +} + +#ifdef MLAS_TARGET_AMD64_IX86 +#if !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) +// Actual and expected difference is over 0.01 with DmlExecutionProvider. +// Skip the tests instead of raising the tolerance to make is pass. +TEST(MatMulNBits, Float16) { + TestMatMulNBitsTyped(); +} +#endif +#endif + #if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) namespace { @@ -367,7 +405,7 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura } } // namespace -TEST(MatMulNBits, Float16) { +TEST(MatMulNBits, Float16Cuda) { #if defined(USE_CUDA) || defined(USE_ROCM) auto has_gidx_options = {true, false}; #else From c63dd0234b4e0236b24fabdca005bbeb75ff4eb9 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Sat, 14 Sep 2024 12:36:20 +0800 Subject: [PATCH 28/39] [WebNN EP] Use opSupportLimits to dynamically check data type support (#22025) - Remove hard code data type checks and use WebNN's opSupportLimits instead - Add HasSupportedOutputsImpl for output data type validation - Get preferred layout info from opSupportLimits - Move Not op to logical_op_builder.cc because it should be there. This avoid the inconsistent input names in `unary_op_builder.cc`. --- .../core/providers/webnn/builders/helper.cc | 61 ++++++++++++++++--- .../core/providers/webnn/builders/helper.h | 43 +++++++++---- .../builders/impl/activation_op_builder.cc | 40 ------------ .../builders/impl/argmax_min_op_builder.cc | 27 -------- .../webnn/builders/impl/base_op_builder.cc | 52 ++++++++-------- .../webnn/builders/impl/base_op_builder.h | 9 ++- .../webnn/builders/impl/binary_op_builder.cc | 36 +++-------- .../webnn/builders/impl/cast_op_builder.cc | 32 +++++----- .../webnn/builders/impl/clip_op_builder.cc | 29 --------- .../webnn/builders/impl/concat_op_builder.cc | 28 +++++++++ .../webnn/builders/impl/conv_op_builder.cc | 35 +++-------- .../webnn/builders/impl/gather_op_builder.cc | 26 +++----- .../webnn/builders/impl/gemm_op_builder.cc | 35 +++-------- .../webnn/builders/impl/gru_op_builder.cc | 40 ++++-------- .../webnn/builders/impl/logical_op_builder.cc | 42 +++++++------ .../webnn/builders/impl/max_min_op_builder.cc | 29 ++++----- .../builders/impl/normalization_op_builder.cc | 35 ++++------- .../webnn/builders/impl/pad_op_builder.cc | 27 -------- .../builders/impl/reduction_op_builder.cc | 52 ---------------- .../webnn/builders/impl/resize_op_builder.cc | 26 -------- .../webnn/builders/impl/shape_op_builder.cc | 27 -------- .../webnn/builders/impl/slice_op_builder.cc | 26 -------- .../webnn/builders/impl/softmax_op_builder.cc | 26 -------- .../webnn/builders/impl/ternary_op_builder.cc | 23 ++----- .../builders/impl/transpose_op_builder.cc | 27 -------- .../webnn/builders/impl/unary_op_builder.cc | 43 ------------- .../providers/webnn/builders/model_builder.cc | 7 ++- .../providers/webnn/builders/model_builder.h | 5 +- .../providers/webnn/builders/op_builder.h | 3 +- .../webnn/builders/op_builder_factory.cc | 2 +- .../webnn/webnn_execution_provider.cc | 22 ++++--- .../webnn/webnn_execution_provider.h | 1 + 32 files changed, 281 insertions(+), 635 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index d3c1d06818db2..c4a633fcc92bb 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -45,12 +45,12 @@ bool GetShape(const NodeArg& node_arg, std::vector& shape, const loggin return true; } -bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, - const WebnnDeviceType device_type, const logging::Logger& logger) { +bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const WebnnDeviceType device_type, + const emscripten::val& wnn_limits, const logging::Logger& logger) { const auto& op_builders = GetOpBuilders(); if (Contains(op_builders, node.OpType())) { const auto* op_builder = op_builders.at(node.OpType()); - return op_builder->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, device_type, logger); + return op_builder->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, device_type, wnn_limits, logger); } else { return false; } @@ -86,6 +86,7 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, cons std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, const emscripten::val& wnn_builder, const WebnnDeviceType device_type, + const emscripten::val& wnn_limits, const logging::Logger& logger) { std::vector> supported_node_groups; @@ -105,7 +106,7 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v // Firstly check if platform supports the WebNN op. if (CheckSingleOp(node->OpType(), wnn_builder, device_type)) { LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() << "] is supported by browser"; - supported = IsNodeSupported(*node, graph_viewer, device_type, logger); + supported = IsNodeSupported(*node, graph_viewer, device_type, wnn_limits, logger); } LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() @@ -130,10 +131,54 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v return supported_node_groups; } -bool IsSupportedDataType(const int32_t data_type, - const std::unordered_set& supported_data_types) { - return std::find(supported_data_types.begin(), supported_data_types.end(), data_type) != - supported_data_types.end(); +bool AreInputDataTypesSame(const std::string& op_type, + gsl::span input_types, + const logging::Logger& logger) { + for (size_t i = 1; i < input_types.size(); i++) { + if (input_types[0] != input_types[i]) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input data types should be the same, but [" + << input_types[0] << "] does not match " + << input_types[i] << "]."; + return false; + } + } + return true; +} + +bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types) { + auto it = onnx_to_webnn_data_type_map.find(static_cast(onnx_data_type)); + if (it == onnx_to_webnn_data_type_map.end()) + return false; + + std::string webnn_data_type = it->second; + + // Check if WebNN supports the data type. + emscripten::val is_supported = webnn_supported_data_types.call("includes", + emscripten::val(webnn_data_type)); + return is_supported.as(); +} + +// Check if the input or output data type of ONNX node is supported by the WebNN operator. +bool IsDataTypeSupportedByOp(const std::string& onnx_op_type, + const int32_t onnx_data_type, + const emscripten::val& wnn_limits, + const std::string& webnn_input_output_name, + const std::string& onnx_input_output_name, + const logging::Logger& logger) { + std::string webnn_op_type; + if (!GetWebNNOpType(onnx_op_type, webnn_op_type)) + return false; + + if (!IsSupportedDataType(onnx_data_type, wnn_limits[webnn_op_type][webnn_input_output_name]["dataTypes"])) { + LOGS(logger, VERBOSE) << "[" << onnx_op_type + << "] " << onnx_input_output_name + << " type: [" << onnx_data_type + << "] is not supported for now"; + return false; + } + + return true; } bool GetBidirectionalBroadcastShape(std::vector& shape_a, diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index b51092619db22..257fcff9ef50c 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -148,6 +148,7 @@ bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, c std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, const emscripten::val& wnn_builder, const WebnnDeviceType device_type, + const emscripten::val& wnn_limits, const logging::Logger& logger); static const InlinedHashMap op_map = { {"Abs", "abs"}, @@ -250,20 +251,38 @@ inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn return true; } -static const std::unordered_set webnn_supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_BOOL, - ONNX_NAMESPACE::TensorProto_DataType_INT8, - ONNX_NAMESPACE::TensorProto_DataType_UINT8, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_INT32, - ONNX_NAMESPACE::TensorProto_DataType_INT64, - ONNX_NAMESPACE::TensorProto_DataType_UINT32, - ONNX_NAMESPACE::TensorProto_DataType_UINT64, +inline bool GetWebNNOpType(const std::string& op_type, std::string& webnn_op_type) { + auto it = op_map.find(op_type); + // Returns false if the op_type is not listed in the op_map. + if (it == op_map.end()) { + return false; + } + webnn_op_type = it->second; + return true; +} + +static const InlinedHashMap onnx_to_webnn_data_type_map = { + {ONNX_NAMESPACE::TensorProto_DataType_BOOL, "uint8"}, + {ONNX_NAMESPACE::TensorProto_DataType_INT8, "int8"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT8, "uint8"}, + {ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, "float16"}, + {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, "float32"}, + {ONNX_NAMESPACE::TensorProto_DataType_INT32, "int32"}, + {ONNX_NAMESPACE::TensorProto_DataType_INT64, "int64"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT32, "uint32"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT64, "uint64"}, }; -bool IsSupportedDataType(const int32_t data_type, - const std::unordered_set& supported_data_types); +bool AreInputDataTypesSame(const std::string& op_type, + gsl::span input_types, + const logging::Logger& logger); +bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types); +bool IsDataTypeSupportedByOp(const std::string& onnx_op_type, + const int32_t onnx_data_type, + const emscripten::val& wnn_limits, + const std::string& webnn_input_output_name, + const std::string& onnx_input_output_name, + const logging::Logger& logger); bool GetBidirectionalBroadcastShape(std::vector& shape_a, std::vector& shape_b, diff --git a/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc index 626aaf5c71b74..781ddcb896155 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc @@ -21,8 +21,6 @@ class ActivationOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -94,44 +92,6 @@ bool ActivationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initi return true; } -bool ActivationOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types; - // WebNN relu op supports float32, float16, int32, int8 input data types. - if (op_type == "Relu") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - ONNX_NAMESPACE::TensorProto_DataType_INT32, - ONNX_NAMESPACE::TensorProto_DataType_INT8, - }; - // WebNN CPU backend does not support int32 data type for relu. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32); - } - } else { // Others only support float32 and float16. - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc index 05f3a742a3775..d61ae1a1f6be7 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc @@ -22,8 +22,6 @@ class ArgMaxMinOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -77,31 +75,6 @@ bool ArgMaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initia return true; } -bool ArgMaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support int64, uint64 input data types for argMax and argMin. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT64); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateArgMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc index fa535889299ea..8da255a288f17 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc @@ -38,9 +38,9 @@ bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { ORT_RETURN_IF_NOT( - IsOpSupported(model_builder.GetInitializerTensors(), node, model_builder.GetWebnnDeviceType(), logger), - "Unsupported operator ", - node.OpType()); + IsOpSupported(model_builder.GetInitializerTensors(), node, model_builder.GetWebnnDeviceType(), + model_builder.GetOpSupportLimits(), logger), + "Unsupported operator ", node.OpType()); ORT_RETURN_IF_ERROR(AddToModelBuilderImpl(model_builder, node, logger)); LOGS(logger, VERBOSE) << "Operator name: [" << node.Name() << "] type: [" << node.OpType() << "] was added"; @@ -50,8 +50,12 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& // Operator support related. bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const { - if (!HasSupportedInputs(node, device_type, logger)) + const WebnnDeviceType device_type, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + if (!HasSupportedInputs(node, wnn_limits, logger)) + return false; + + if (!HasSupportedOutputsImpl(node, wnn_limits, logger)) return false; // We do not support external initializers for now. @@ -64,7 +68,7 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons return IsOpSupportedImpl(initializers, node, device_type, logger); } -bool BaseOpBuilder::HasSupportedInputs(const Node& node, const WebnnDeviceType device_type, +bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]"); for (const auto* input : node.InputDefs()) { @@ -73,39 +77,33 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const WebnnDeviceType d } } - // WebNN CPU backend (TFLite) will enable float16 input data type soon, - // temporarily fallback float16 input data type for WebNN CPU. - if (device_type == WebnnDeviceType::CPU) { - const auto& input = *node.InputDefs()[0]; - - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) - return false; - } - - return HasSupportedInputsImpl(node, device_type, logger); + return HasSupportedInputsImpl(node, wnn_limits, logger); } bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, - const WebnnDeviceType /* device_type */, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { // We only check the type of input 0 by default, specific op builder can override this. const auto& input = *node.InputDefs()[0]; - + const auto& op_type = node.OpType(); int32_t input_type; if (!GetType(input, input_type, logger)) return false; - if (!IsSupportedDataType(input_type, webnn_supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << node.OpType() - << "] Input type: [" << input_type - << "] is not supported for now"; + return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "Input", logger); +} + +bool BaseOpBuilder::HasSupportedOutputsImpl(const Node& node, + const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + // We only check the type of output 0 by default, specific op builder can override this. + const auto& output = *node.OutputDefs()[0]; + const auto& op_type = node.OpType(); + int32_t output_type; + if (!GetType(output, output_type, logger)) return false; - } - return true; + return IsDataTypeSupportedByOp(op_type, output_type, wnn_limits, "output", "Output", logger); } bool BaseOpBuilder::HasSupportedOpSet(const Node& node, diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h index 85e38b668cee4..584455f62cb4e 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h @@ -28,7 +28,8 @@ class BaseOpBuilder : public IOpBuilder { // Operator support related. public: bool IsOpSupported(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const override; + const WebnnDeviceType device_type, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; protected: virtual bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& /* node */, @@ -36,8 +37,10 @@ class BaseOpBuilder : public IOpBuilder { return true; } - virtual bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + virtual bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; + virtual bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const; // ONNX Runtime only *guarantees* support for models stamped // with opset version 7 or above for opset domain 'ai.onnx'. @@ -50,7 +53,7 @@ class BaseOpBuilder : public IOpBuilder { private: bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; - bool HasSupportedInputs(const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const; + bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; }; } // namespace webnn diff --git a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc index 555de68cd60fe..af82a01b14de5 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc @@ -22,7 +22,7 @@ class BinaryOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -86,7 +86,7 @@ bool BinaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return true; } -bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, +bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -97,36 +97,14 @@ bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDevice !GetType(*input_defs[1], input1_type, logger)) return false; - std::unordered_set supported_data_types; - // WebNN prelu op only supports float32, float16, int32, int8 input data types. - if (op_type == "Prelu") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - ONNX_NAMESPACE::TensorProto_DataType_INT32, - ONNX_NAMESPACE::TensorProto_DataType_INT8, - }; - // WebNN CPU backend doesn't support int32 for prelu. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32); - } - } else { - supported_data_types = webnn_supported_data_types; - } - if (!IsSupportedDataType(input0_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; + std::array input_types{input0_type, input1_type}; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - if (input0_type != input1_type) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; - return false; - } - - return true; + std::string webnn_input_name = op_type == "PRelu" ? "input" : "a"; + std::string onnx_input_name = op_type == "PRelu" || op_type == "Pow" ? "X" : "A"; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, webnn_input_name, onnx_input_name, logger); } void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc index a08e1681a8464..3c4fc822f3d01 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc @@ -21,8 +21,8 @@ class CastOpBuilder : public BaseOpBuilder { // Operator support related. private: - bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; }; // Add operator related. @@ -80,26 +80,22 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } // Operator support related. +bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + int32_t input_type; -bool CastOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, - const Node& node, - const WebnnDeviceType device_type, - const logging::Logger& logger) const { - NodeAttrHelper helper(node); - // Check cast output type. - const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED); - - // WebNN CPU backend doesn't support casting to uint64 data type. - if (device_type == WebnnDeviceType::CPU && to_type == ONNX_NAMESPACE::TensorProto_DataType_UINT64) { - LOGS(logger, VERBOSE) << "Cast to uint64 is not supported for WebNN CPU backend."; + if (!GetType(*input_defs[0], input_type, logger)) return false; - } - if (!IsSupportedDataType(to_type, webnn_supported_data_types)) { - LOGS(logger, VERBOSE) << "WebNN doesn't support casting to type " << to_type << "."; + + if (!IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "input", logger)) return false; - } - return true; + NodeAttrHelper helper(node); + // Check cast to type. + const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED); + return IsDataTypeSupportedByOp(op_type, to_type, wnn_limits, "output", "to", logger); } void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc index b5c3206072d50..374143c886849 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc @@ -25,8 +25,6 @@ class ClipOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -94,33 +92,6 @@ bool ClipOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, }; } -bool ClipOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support int32, uint32, int64, uint64 input data types for clamp. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT64); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc index dedc76b80e978..48dd6f3beb020 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc @@ -19,6 +19,10 @@ class ConcatOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; }; // Add operator related. @@ -52,6 +56,30 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } +bool ConcatOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + int32_t input0_type; + + if (!GetType(*input_defs[0], input0_type, logger)) + return false; + + for (size_t i = 1; i < input_defs.size(); i++) { + int32_t input_type; + if (!GetType(*input_defs[i], input_type, logger)) { + return false; + } + + std::array input_types{input0_type, input_type}; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { + return false; + } + } + + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "inputs", "inputs", logger); +} + void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 76a8a178678df..35498c2e9b8b7 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -29,7 +29,7 @@ class ConvOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -397,7 +397,7 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } -bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, +bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -415,35 +415,18 @@ bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTy return false; } - std::unordered_set supported_data_types; - if (op_type == "Conv" || op_type == "ConvTranspose") { - // WebNN conv2d and convTranspose2d only support float32 and float16 input data types. - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - } else if (op_type == "ConvInteger") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_INT8, - ONNX_NAMESPACE::TensorProto_DataType_UINT8, - }; + InlinedVector input_types = {input0_type, input1_type}; + if (has_input2) { + input_types.push_back(input2_type); } - if (!IsSupportedDataType(input0_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; - return false; + if (has_input3) { + input_types.push_back(input3_type); } - - if (input0_type != input1_type || - (has_input2 && input0_type != input2_type) || - (has_input3 && input0_type != input3_type)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - return true; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); } void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc index 23233539d34c7..ae9fe3e3f3bd1 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc @@ -22,7 +22,7 @@ class GatherOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -69,29 +69,19 @@ bool GatherOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ return true; } -bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, +bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input = *node.InputDefs()[0]; + const auto& indices = *node.InputDefs()[1]; const auto& op_type = node.OpType(); int32_t input_type; - if (!GetType(input, input_type, logger)) + int32_t indices_type; + if (!GetType(input, input_type, logger) || + !GetType(indices, indices_type, logger)) return false; - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint32, uint64 input data types for gather. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; + return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger) && + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index bd452b118fe3e..30e024792ed42 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -25,7 +25,7 @@ class GemmOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -215,7 +215,7 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializer return true; } -bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, +bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -233,35 +233,18 @@ bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTy return false; } - std::unordered_set supported_data_types; - if (op_type == "Gemm" || op_type == "MatMul") { - // WebNN gemm and matmul only support float32 and float16 input data types. - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - } else if (op_type == "MatMulInteger") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_INT8, - ONNX_NAMESPACE::TensorProto_DataType_UINT8, - }; + InlinedVector input_types = {input0_type, input1_type}; + if (has_input2) { + input_types.push_back(input2_type); } - if (!IsSupportedDataType(input0_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; - return false; + if (has_input3) { + input_types.push_back(input3_type); } - - if (input0_type != input1_type || - (has_input2 && input0_type != input2_type) || - (has_input3 && input0_type != input3_type)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - return true; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "A", logger); } void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc index 23cc7f1b11459..c92fe7366d494 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc @@ -26,7 +26,7 @@ class GruOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -185,7 +185,7 @@ bool GruOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, c return true; } -bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, +bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -208,37 +208,21 @@ bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTyp return false; } - std::unordered_set supported_data_types; - if (device_type == WebnnDeviceType::CPU) { - // WebNN CPU backend only support float32 input data type. - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - }; - } else if (device_type == WebnnDeviceType::GPU) { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; + InlinedVector input_types = {input0_type, input1_type, input2_type}; + if (has_input3) { + input_types.push_back(input3_type); } - - if (!IsSupportedDataType(input0_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; - return false; + if (has_input4) { + input_types.push_back(input4_type); } - - if (input0_type != input1_type || - input0_type != input2_type || - (has_input3 && input0_type != input3_type) || - (has_input4 && input0_type != input4_type) || - (has_input5 && input0_type != input5_type)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; + if (has_input5) { + input_types.push_back(input5_type); + } + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - return true; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); } void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc index 23f3a938fee5e..ea7f70b4598e6 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc @@ -21,7 +21,7 @@ class LogicalOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -29,9 +29,14 @@ class LogicalOpBuilder : public BaseOpBuilder { Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { + const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); - emscripten::val input0 = model_builder.GetOperand(node.InputDefs()[0]->Name()); - emscripten::val input1 = model_builder.GetOperand(node.InputDefs()[1]->Name()); + emscripten::val input0 = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val input1 = emscripten::val::undefined(); + if (input_defs.size() > 1) { + input1 = model_builder.GetOperand(input_defs[1]->Name()); + } + emscripten::val output = emscripten::val::object(); emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); @@ -45,6 +50,8 @@ Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons output = model_builder.GetBuilder().call("lesser", input0, input1, options); } else if (op_type == "LessOrEqual") { output = model_builder.GetBuilder().call("lesserOrEqual", input0, input1, options); + } else if (op_type == "Not") { + output = model_builder.GetBuilder().call("logicalNot", input0, options); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "LogicalOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); @@ -61,7 +68,7 @@ bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali const auto& name = node.Name(); const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); - if (input_defs.size() < 2) { + if (input_defs.size() < 2 && op_type != "Not") { LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 2 inputs, actual: " << input_defs.size(); return false; @@ -69,31 +76,27 @@ bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali return true; } -bool LogicalOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, +bool LogicalOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; int32_t input1_type; - if (!GetType(*input_defs[0], input0_type, logger) || - !GetType(*input_defs[1], input1_type, logger)) - return false; - - if (!IsSupportedDataType(input0_type, webnn_supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; + if (!GetType(*input_defs[0], input0_type, logger)) return false; - } - if (input0_type != input1_type) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; - return false; + if (op_type != "Not") { + if (!GetType(*input_defs[1], input1_type, logger)) + return false; + std::array input_types{input0_type, input1_type}; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { + return false; + } } - return true; + std::string onnx_input_name = op_type == "Not" ? "X" : "A"; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", onnx_input_name, logger); } void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { @@ -107,6 +110,7 @@ void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& "GreaterOrEqual", "Less", "LessOrEqual", + "Not", }; op_registrations.builders.push_back(std::make_unique()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc index 5d88afda7b6a7..e111ca412c6e9 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc @@ -22,7 +22,7 @@ class MaxMinOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -87,31 +87,28 @@ bool MaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ return true; } -bool MaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, +bool MaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; - int32_t input1_type; - if (!GetType(*input_defs[0], input0_type, logger) || - !GetType(*input_defs[1], input1_type, logger)) + if (!GetType(*input_defs[0], input0_type, logger)) return false; - if (!IsSupportedDataType(input0_type, webnn_supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; - return false; - } + for (size_t i = 1; i < input_defs.size(); i++) { + int32_t input_type; + if (!GetType(*input_defs[i], input_type, logger)) { + return false; + } - if (input0_type != input1_type) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; - return false; + std::array input_types{input0_type, input_type}; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { + return false; + } } - return true; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "data_0", logger); } void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index 4d068baf35e72..a3c6b8fdcea9b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -25,7 +25,7 @@ class NormalizationOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -182,7 +182,7 @@ bool NormalizationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initi return true; } -bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, +bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -203,30 +203,21 @@ bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const Webn return false; } - // WebNN batchNormalization, instanceNormalization, layerNormalization - // only support float32 and float16 input data types. - std::unordered_set supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - - if (!IsSupportedDataType(input0_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; - return false; + std::vector input_types = {input0_type, input1_type}; + if (has_input2) { + input_types.push_back(input2_type); } - - if (input0_type != input1_type || - (has_input2 && input0_type != input2_type) || - (has_input3 && input0_type != input3_type) || - (has_input4 && input0_type != input4_type)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; + if (has_input3) { + input_types.push_back(input3_type); + } + if (has_input4) { + input_types.push_back(input4_type); + } + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - return true; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); } void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc index 071155a2fb372..d8373a45e4423 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc @@ -28,8 +28,6 @@ class PadOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -196,31 +194,6 @@ bool PadOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } // namespace webnn -bool PadOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint32, uint64 input data types for pad. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc index 3e6d4d9820e9a..93ad933d71c34 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc @@ -31,8 +31,6 @@ class ReductionOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -147,56 +145,6 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializ return true; } -bool ReductionOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types; - if (op_type == "ReduceL1" || op_type == "ReduceProd" || - op_type == "ReduceSum" || op_type == "ReduceSumSquare") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - ONNX_NAMESPACE::TensorProto_DataType_INT32, - ONNX_NAMESPACE::TensorProto_DataType_UINT32, - ONNX_NAMESPACE::TensorProto_DataType_INT64, - ONNX_NAMESPACE::TensorProto_DataType_UINT64, - }; - - if (device_type == WebnnDeviceType::CPU) { - // WebNN CPU backend doesn't support uint32 and uint64 for reduceL1, - // reduceProd, reduceSum and reduceSumSquare. - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - } else if (op_type == "ReduceL2" || op_type == "ReduceLogSum" || - op_type == "ReduceLogSumExp" || op_type == "ReduceMean") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - } else { // ReduceMax and ReduceMin - supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint32, uint64 for reduceMax and reduceMin. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - } - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc index 2218c858951d3..9dc79f4f52f46 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc @@ -35,8 +35,6 @@ class ResizeOpBuilder : public BaseOpBuilder { // Resize opset 10- is very different than Resize opset 11+, with many key attributes missing. // We only support Resize opset 11+ here. int GetMinSupportedOpSet(const Node& /* node */) const override { return 11; } - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const override; }; // Helper functions @@ -275,30 +273,6 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return true; } -bool ResizeOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - // WebNN resample2d op only supports float32 and float16 input data types. - std::unordered_set supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc index 0eb7dafdffe4d..6b56d2c740f40 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc @@ -18,11 +18,6 @@ class ShapeOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - - // Operator support related. - private: - bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const override; }; Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -69,28 +64,6 @@ Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -// Operator support related. - -bool ShapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, - const Node& node, - const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - int32_t output_type = ONNX_NAMESPACE::TensorProto_DataType_INT64; - if (!IsSupportedDataType(output_type, webnn_supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << node.OpType() - << "] Output type: [" << output_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateShapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc index bef13841c646c..3f0d633ac888b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc @@ -29,8 +29,6 @@ class SliceOpBuilder : public BaseOpBuilder { const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; // TODO: Support Slice opset < 10, which uses attributes for starts and ends. int GetMinSupportedOpSet(const Node& /* node */) const override { return 10; } - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -166,30 +164,6 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } -bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint64 input data type for slice. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc index 798cfabae65db..b1b737b114998 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc @@ -24,8 +24,6 @@ class SoftmaxOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const override; }; Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -63,30 +61,6 @@ bool SoftmaxOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali return true; } -bool SoftmaxOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - // WebNN softmax only supports float32 and float16 input data types. - std::unordered_set supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc index 2ed8330bf25be..4b6cf312074ba 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc @@ -18,7 +18,7 @@ class TernaryOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -46,7 +46,7 @@ Status TernaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons return Status::OK(); } -bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, +bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -59,27 +59,14 @@ bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDevic !GetType(*input_defs[2], input2_type, logger)) return false; - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint64 X, Y data type for where. - if (device_type == WebnnDeviceType::CPU && op_type == "Where") { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } // ONNX's condition data type is bool which is same as WebNN. // Only need to check X, Y data types. - if (!IsSupportedDataType(input1_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input1_type - << "] is not supported for now"; - return false; - } - - if (input1_type != input2_type) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input X, Y data types should be the same."; + std::array input_types{input1_type, input2_type}; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - return true; + return IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "trueValue", "X", logger); } void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc index 03c88ad9db88a..3a5e39f7f7a56 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc @@ -18,8 +18,6 @@ class TransposeOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -50,31 +48,6 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -bool TransposeOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint32, uint64 input data types for transpose. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc index 061404c8a9ce0..8e64e98445f03 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc @@ -18,8 +18,6 @@ class UnaryOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const override; }; // Add operator related. @@ -51,8 +49,6 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const output = model_builder.GetBuilder().call("log", input, options); } else if (op_type == "Neg") { output = model_builder.GetBuilder().call("neg", input, options); - } else if (op_type == "Not") { - output = model_builder.GetBuilder().call("logicalNot", input, options); } else if (op_type == "Reciprocal") { output = model_builder.GetBuilder().call("reciprocal", input, options); } else if (op_type == "Sin") { @@ -70,44 +66,6 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const return Status::OK(); } -bool UnaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types; - if (op_type == "Identity") { - supported_data_types = webnn_supported_data_types; - } else if (op_type == "Abs" || op_type == "Neg") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - ONNX_NAMESPACE::TensorProto_DataType_INT32, - ONNX_NAMESPACE::TensorProto_DataType_INT8, - }; - } else if (op_type == "Not") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_BOOL, - }; - } else { // Others only support float32, float16 input data types. - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - } - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; @@ -123,7 +81,6 @@ void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op "Identity", "Log", "Neg", - "Not", "Reciprocal", "Sin", "Sqrt", diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 44bec1fb6fd48..b58bf8233692e 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -21,12 +21,13 @@ namespace webnn { ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, const emscripten::val& context, const DataLayout preferred_layout, - const WebnnDeviceType wnn_device_type) + const WebnnDeviceType wnn_device_type, const emscripten::val& wnn_limits) : graph_viewer_(graph_viewer), logger_(logger), wnn_context_(context), preferred_layout_(preferred_layout), - wnn_device_type_(wnn_device_type) { + wnn_device_type_(wnn_device_type), + wnn_limits_(wnn_limits) { // Create WebNN MLGraphBuilder for each ModelBuilder, because MLGraphBuilder.build() // is only allowed to be called once. wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(context); @@ -102,7 +103,7 @@ Status ModelBuilder::RegisterInitializers() { desc.set("dimensions", emscripten::val::array(dims)); auto data_type = tensor.data_type(); emscripten::val operand = emscripten::val::object(); - if (IsSupportedDataType(data_type, webnn_supported_data_types)) { + if (IsSupportedDataType(data_type, wnn_limits_["constant"]["dataTypes"])) { ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type"); auto num_elements = SafeInt(Product(shape)); emscripten::val view = emscripten::val::undefined(); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index 2d686070cdcc1..256337baeba7e 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -23,7 +23,7 @@ class ModelBuilder { public: ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, const emscripten::val& context, const DataLayout preferred_layout, - const WebnnDeviceType wnn_device_type); + const WebnnDeviceType wnn_device_type, const emscripten::val& wnn_limits); ~ModelBuilder() = default; Status Compile(std::unique_ptr& model) ORT_MUST_USE_RESULT; @@ -35,6 +35,8 @@ class ModelBuilder { const emscripten::val& GetBuilder() const { return wnn_builder_; } const emscripten::val& GetContext() const { return wnn_context_; } const emscripten::val& GetOperand(const std::string& name) const { return wnn_operands_.at(name); } + const emscripten::val& GetOpSupportLimits() const { return wnn_limits_; } + void AddOperand(const std::string& name, const emscripten::val& operand); const emscripten::val& GetZeroConstant(const std::string& data_type); // Use the buffers to persist WebNN allocated data like transposed weight. @@ -66,6 +68,7 @@ class ModelBuilder { emscripten::val wnn_builder_ = emscripten::val::undefined(); DataLayout preferred_layout_; WebnnDeviceType wnn_device_type_; + emscripten::val wnn_limits_ = emscripten::val::undefined(); InlinedHashMap wnn_operands_; std::vector input_names_; std::vector output_names_; diff --git a/onnxruntime/core/providers/webnn/builders/op_builder.h b/onnxruntime/core/providers/webnn/builders/op_builder.h index 6ecc5d1068963..bb69a6a545597 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder.h +++ b/onnxruntime/core/providers/webnn/builders/op_builder.h @@ -29,7 +29,8 @@ class IOpBuilder { public: // Check if an operator is supported. virtual bool IsOpSupported(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const = 0; + const WebnnDeviceType device_type, const emscripten::val& wnn_limits, + const logging::Logger& logger) const = 0; }; } // namespace webnn diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 01761290f07e3..3dc1c7966ae41 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -25,7 +25,6 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateUnaryOpBuilder("Identity", op_registrations); CreateUnaryOpBuilder("Log", op_registrations); CreateUnaryOpBuilder("Neg", op_registrations); - CreateUnaryOpBuilder("Not", op_registrations); CreateUnaryOpBuilder("Reciprocal", op_registrations); CreateUnaryOpBuilder("Sin", op_registrations); CreateUnaryOpBuilder("Sqrt", op_registrations); @@ -118,6 +117,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateLogicalOpBuilder("GreaterOrEqual", op_registrations); CreateLogicalOpBuilder("Less", op_registrations); CreateLogicalOpBuilder("LessOrEqual", op_registrations); + CreateLogicalOpBuilder("Not", op_registrations); } { // Max/Min diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index b918daf838c99..b729623c5d3d8 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -21,10 +21,8 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f : IExecutionProvider{onnxruntime::kWebNNExecutionProvider} { // WebNN EP uses NHWC layout for CPU XNNPACK backend and NCHW for GPU DML backend. if (webnn_device_flags.compare("cpu") == 0) { - preferred_layout_ = DataLayout::NHWC; wnn_device_type_ = webnn::WebnnDeviceType::CPU; } else { - preferred_layout_ = DataLayout::NCHW; if (webnn_device_flags.compare("gpu") == 0) { wnn_device_type_ = webnn::WebnnDeviceType::GPU; } else if (webnn_device_flags.compare("npu") == 0) { @@ -38,6 +36,17 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f if (!wnn_context_.as()) { ORT_THROW("Failed to create WebNN context."); } + + // Retrieve the level of support for different WebNN operators. + // This varies across implementations and is obtained via the WebNN's opSupportLimits() function. + // https://www.w3.org/TR/webnn/#api-mlcontext-opsupportlimits + wnn_limits_ = wnn_context_.call("opSupportLimits"); + + if (wnn_limits_["preferredInputLayout"].as().compare("nhwc") == 0) { + preferred_layout_ = DataLayout::NHWC; + } else { + preferred_layout_ = DataLayout::NCHW; + } } WebNNExecutionProvider::~WebNNExecutionProvider() {} @@ -82,7 +91,7 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view ORT_THROW("Failed to create WebNN builder."); } - const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, logger); + const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, wnn_limits_, logger); wnn_builder = emscripten::val::undefined(); if (node_groups.empty()) { @@ -213,7 +222,7 @@ common::Status WebNNExecutionProvider::Compile(const std::vector model; ORT_RETURN_IF_ERROR(builder.Compile(model)); @@ -295,11 +304,6 @@ common::Status WebNNExecutionProvider::Compile(const std::vector Date: Sun, 15 Sep 2024 18:31:55 -0400 Subject: [PATCH 29/39] [java] Adding ability to load a model from a memory mapped byte buffer (#20062) ### Description Adds support for constructing an `OrtSession` from a `java.nio.ByteBuffer`. These buffers can be memory mapped from files which means there doesn't need to be copies of the model protobuf held in Java, reducing peak memory usage during session construction. ### Motivation and Context Reduces memory usage on model construction by not requiring as many copies on the Java side. Should help with #19599. --- .../java/ai/onnxruntime/OrtEnvironment.java | 49 ++++++++++++++++++- .../main/java/ai/onnxruntime/OrtSession.java | 35 +++++++++++++ .../main/native/ai_onnxruntime_OrtSession.c | 25 +++++++++- .../java/ai/onnxruntime/InferenceTest.java | 31 ++++++++++++ 4 files changed, 138 insertions(+), 2 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java index 26137e88478b5..8382ef06e26e5 100644 --- a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java +++ b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2024 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -7,6 +7,7 @@ import ai.onnxruntime.OrtSession.SessionOptions; import ai.onnxruntime.OrtTrainingSession.OrtCheckpointState; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.EnumSet; import java.util.Objects; import java.util.logging.Logger; @@ -236,6 +237,52 @@ OrtSession createSession(String modelPath, OrtAllocator allocator, SessionOption return new OrtSession(this, modelPath, allocator, options); } + /** + * Create a session using the specified {@link SessionOptions}, model and the default memory + * allocator. + * + * @param modelBuffer Byte buffer representing an ONNX model. Must be a direct byte buffer. + * @param options The session options. + * @return An {@link OrtSession} with the specified model. + * @throws OrtException If the model failed to parse, wasn't compatible or caused an error. + */ + public OrtSession createSession(ByteBuffer modelBuffer, SessionOptions options) + throws OrtException { + return createSession(modelBuffer, defaultAllocator, options); + } + + /** + * Create a session using the default {@link SessionOptions}, model and the default memory + * allocator. + * + * @param modelBuffer Byte buffer representing an ONNX model. Must be a direct byte buffer. + * @return An {@link OrtSession} with the specified model. + * @throws OrtException If the model failed to parse, wasn't compatible or caused an error. + */ + public OrtSession createSession(ByteBuffer modelBuffer) throws OrtException { + return createSession(modelBuffer, new OrtSession.SessionOptions()); + } + + /** + * Create a session using the specified {@link SessionOptions} and model buffer. + * + * @param modelBuffer Byte buffer representing an ONNX model. Must be a direct byte buffer. + * @param allocator The memory allocator to use. + * @param options The session options. + * @return An {@link OrtSession} with the specified model. + * @throws OrtException If the model failed to parse, wasn't compatible or caused an error. + */ + OrtSession createSession(ByteBuffer modelBuffer, OrtAllocator allocator, SessionOptions options) + throws OrtException { + Objects.requireNonNull(modelBuffer, "model array must not be null"); + if (modelBuffer.remaining() == 0) { + throw new OrtException("Invalid model buffer, no elements remaining."); + } else if (!modelBuffer.isDirect()) { + throw new OrtException("ByteBuffer is not direct."); + } + return new OrtSession(this, modelBuffer, allocator, options); + } + /** * Create a session using the specified {@link SessionOptions}, model and the default memory * allocator. diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index 8fe73ff69e169..f87cbc76ef141 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -11,6 +11,7 @@ import ai.onnxruntime.providers.OrtFlags; import ai.onnxruntime.providers.OrtTensorRTProviderOptions; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -94,6 +95,31 @@ public class OrtSession implements AutoCloseable { allocator); } + /** + * Creates a session reading the model from the supplied byte buffer. + * + *

Must be a direct byte buffer. + * + * @param env The environment. + * @param modelBuffer The model protobuf as a byte buffer. + * @param allocator The allocator to use. + * @param options Session configuration options. + * @throws OrtException If the model was corrupted or some other error occurred in native code. + */ + OrtSession( + OrtEnvironment env, ByteBuffer modelBuffer, OrtAllocator allocator, SessionOptions options) + throws OrtException { + this( + createSession( + OnnxRuntime.ortApiHandle, + env.getNativeHandle(), + modelBuffer, + modelBuffer.position(), + modelBuffer.remaining(), + options.getNativeHandle()), + allocator); + } + /** * Private constructor to build the Java object wrapped around a native session. * @@ -514,6 +540,15 @@ private static native long createSession( private static native long createSession( long apiHandle, long envHandle, byte[] modelArray, long optsHandle) throws OrtException; + private static native long createSession( + long apiHandle, + long envHandle, + ByteBuffer modelBuffer, + int bufferPos, + int bufferSize, + long optsHandle) + throws OrtException; + private native long getNumInputs(long apiHandle, long nativeHandle) throws OrtException; private native String[] getInputNames(long apiHandle, long nativeHandle, long allocatorHandle) diff --git a/java/src/main/native/ai_onnxruntime_OrtSession.c b/java/src/main/native/ai_onnxruntime_OrtSession.c index f4d5ab080cd31..ee8cdee659296 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2020, 2022 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2024 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ #include @@ -48,6 +48,29 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_la return (jlong)session; } +/* + * Class: ai_onnxruntime_OrtSession + * Method: createSession + * Signature: (JJLjava/nio/ByteBuffer;IIJ)J + */ +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_nio_ByteBuffer_2IIJ(JNIEnv* jniEnv, jclass jclazz, jlong apiHandle, jlong envHandle, jobject buffer, jint bufferPos, jint bufferSize, jlong optsHandle) { + (void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtEnv* env = (OrtEnv*)envHandle; + OrtSessionOptions* opts = (OrtSessionOptions*)optsHandle; + OrtSession* session = NULL; + + // Extract the buffer + char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, buffer); + // Increment by bufferPos bytes + bufferArr = bufferArr + bufferPos; + + // Create the session + checkOrtStatus(jniEnv, api, api->CreateSessionFromArray(env, bufferArr, bufferSize, opts, &session)); + + return (jlong)session; +} + /* * Class: ai_onnxruntime_OrtSession * Method: createSession diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 3340a2e5e9f3a..f76e1b3b20e19 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -20,10 +20,14 @@ import ai.onnxruntime.OrtSession.SessionOptions.OptLevel; import java.io.File; import java.io.IOException; +import java.io.RandomAccessFile; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; import java.nio.LongBuffer; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.channels.FileChannel.MapMode; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; @@ -338,6 +342,33 @@ public void partialInputsTest() throws OrtException { } } + @Test + public void createSessionFromByteBuffer() throws IOException, OrtException { + Path modelPath = TestHelpers.getResourcePath("/squeezenet.onnx"); + try (RandomAccessFile file = new RandomAccessFile(modelPath.toFile(), "r"); + FileChannel channel = file.getChannel()) { + MappedByteBuffer modelBuffer = channel.map(MapMode.READ_ONLY, 0, channel.size()); + try (OrtSession.SessionOptions options = new SessionOptions(); + OrtSession session = env.createSession(modelBuffer, options)) { + assertNotNull(session); + assertEquals(1, session.getNumInputs()); // 1 input node + Map inputInfoList = session.getInputInfo(); + assertNotNull(inputInfoList); + assertEquals(1, inputInfoList.size()); + NodeInfo input = inputInfoList.get("data_0"); + assertEquals("data_0", input.getName()); // input node name + assertTrue(input.getInfo() instanceof TensorInfo); + TensorInfo inputInfo = (TensorInfo) input.getInfo(); + assertEquals(OnnxJavaType.FLOAT, inputInfo.type); + int[] expectedInputDimensions = new int[] {1, 3, 224, 224}; + assertEquals(expectedInputDimensions.length, inputInfo.shape.length); + for (int i = 0; i < expectedInputDimensions.length; i++) { + assertEquals(expectedInputDimensions[i], inputInfo.shape[i]); + } + } + } + } + @Test public void createSessionFromByteArray() throws IOException, OrtException { Path modelPath = TestHelpers.getResourcePath("/squeezenet.onnx"); From 6d7235ba5ab995e42a0e251874e65e9d7eaa2997 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Sun, 15 Sep 2024 21:55:38 -0400 Subject: [PATCH 30/39] [Java] Exposing SessionOptions.SetDeterministicCompute (#18998) ### Description Exposes `SetDeterministicCompute` in Java, added to the C API by #18944. ### Motivation and Context Parity between C and Java APIs. --- .../main/java/ai/onnxruntime/OrtSession.java | 17 +++++++++++++++++ .../ai_onnxruntime_OrtSession_SessionOptions.c | 13 +++++++++++++ .../test/java/ai/onnxruntime/InferenceTest.java | 1 + 3 files changed, 31 insertions(+) diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index f87cbc76ef141..6d146d5857d3c 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -942,6 +942,20 @@ public void setSymbolicDimensionValue(String dimensionName, long dimensionValue) OnnxRuntime.ortApiHandle, nativeHandle, dimensionName, dimensionValue); } + /** + * Set whether to use deterministic compute. + * + *

Default is false. If set to true, this will enable deterministic compute for GPU kernels + * where possible. Note that this most likely will have a performance cost. + * + * @param value Should the compute be deterministic? + * @throws OrtException If there was an error in native code. + */ + public void setDeterministicCompute(boolean value) throws OrtException { + checkClosed(); + setDeterministicCompute(OnnxRuntime.ortApiHandle, nativeHandle, value); + } + /** * Disables the per session thread pools. Must be used in conjunction with an environment * containing global thread pools. @@ -1327,6 +1341,9 @@ private native void registerCustomOpsUsingFunction( private native void closeOptions(long apiHandle, long nativeHandle); + private native void setDeterministicCompute( + long apiHandle, long nativeHandle, boolean isDeterministic) throws OrtException; + private native void addFreeDimensionOverrideByName( long apiHandle, long nativeHandle, String dimensionName, long dimensionValue) throws OrtException; diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c index ff9348c299e90..ff6b7fa703e6e 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c @@ -259,6 +259,19 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setSes checkOrtStatus(jniEnv,api,api->SetSessionLogVerbosityLevel(options,logLevel)); } +/* + * Class: ai_onnxruntime_OrtSession_SessionOptions + * Method: setDeterministicCompute + * Signature: (JJZ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setDeterministicCompute + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jboolean isDeterministic) { + (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle; + checkOrtStatus(jniEnv,api,api->SetDeterministicCompute(options, isDeterministic)); +} + /* * Class: ai_onnxruntime_OrtSession_SessionOptions * Method: registerCustomOpLibrary diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index f76e1b3b20e19..11141a3a65a3e 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -1263,6 +1263,7 @@ public void testExtraSessionOptions() throws OrtException, IOException { options.setLoggerId("monkeys"); options.setSessionLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL); options.setSessionLogVerbosityLevel(5); + options.setDeterministicCompute(true); Map configEntries = options.getConfigEntries(); assertTrue(configEntries.isEmpty()); options.addConfigEntry("key", "value"); From 1a1669fe817232e7d19c6459da0fc610e0c74b0a Mon Sep 17 00:00:00 2001 From: George Wu Date: Mon, 16 Sep 2024 09:12:13 -0700 Subject: [PATCH 31/39] use node name in transpose optimizer when adding nodes rather than optype (#22084) patch from @john-dance "The main change is simple: Use the original node name rather than the original node op_type when creating new nodes. Here are my comments on the change: ------ The onnx runtime uses the op_type as the basis for a new node name, so a node claimed by QNN EP might be named Conv_token_1 with no relation to the original /conv1/Conv. This patch: 1. Adds OpName as a virtual function in NodeRef and implements it in ApiNode. 2. AddNode now takes an op_name and op_type and passes them both to CreateNodeHelper. 3. CreateNodeHelper uses the op_name rather than the op_type in GenerateNodeName 4. Direct calls to AddNode are modified to either use the NodeRef if available, or just repeat the op_type if not available. The result is that the new nodes are named something like /conv1/Conv_token_1, allowing a straight forward mapping back to the original model node (if they exist in the original graph)." --- .../onnx_transpose_optimization.cc | 18 +++++++++--------- .../transpose_optimization/optimizer_api.h | 6 +++++- .../ort_optimizer_api_impl.cc | 17 +++++++++++------ .../internal_testing/internal_testing_tests.cc | 4 ++-- 4 files changed, 27 insertions(+), 18 deletions(-) diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index df81367c5bbee..5d689a9d933e8 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -78,7 +78,7 @@ static std::unique_ptr MakeNode1Attr(api::GraphRef& graph, std::st std::string_view input, std::string_view attr_name, const std::vector& attr_val) { std::vector inputs{input}; - std::unique_ptr node = graph.AddNode(op_type, inputs, /*num_outputs*/ 1); + std::unique_ptr node = graph.AddNode(op_type, op_type, inputs, /*num_outputs*/ 1); node->SetAttributeInts(attr_name, attr_val); return node; } @@ -102,7 +102,7 @@ static std::unique_ptr MakeSqueezeOrUnsqueeze(int64_t opset, api:: std::vector inputs{input, axes_initializer}; - return graph.AddNode(op_type, inputs, /*num_outputs*/ 1); + return graph.AddNode(op_type, op_type, inputs, /*num_outputs*/ 1); } ///

@@ -136,7 +136,7 @@ static std::unique_ptr MakeQuantizeOp(api::GraphRef& graph, std::s std::optional block_size, std::optional output_dtype, std::optional saturate) { - std::unique_ptr node = graph.AddNode("QuantizeLinear", inputs, /* num_outputs */ 1, domain); + std::unique_ptr node = graph.AddNode("QuantizeLinear", "QuantizeLinear", inputs, /* num_outputs */ 1, domain); SetAttrIfNotDefault(*node, "axis", axis, 1); @@ -170,7 +170,7 @@ static std::unique_ptr MakeDequantizeOp(api::GraphRef& graph, std: std::vector inputs, std::optional axis, std::optional block_size) { - std::unique_ptr node = graph.AddNode("DequantizeLinear", inputs, /* num_outputs */ 1, domain); + std::unique_ptr node = graph.AddNode("DequantizeLinear", "DequantizeLinear", inputs, /* num_outputs */ 1, domain); SetAttrIfNotDefault(*node, "axis", axis, 1); @@ -1724,7 +1724,7 @@ static bool HandleShape(HandlerArgs& args) { // X -> Shape -> Y, Gather std::vector gather_inputs{"", perm_const}; - auto gather_ptr = args.ctx.graph.AddNode("Gather", gather_inputs, /*num_outputs*/ 1); + auto gather_ptr = args.ctx.graph.AddNode("Gather", "Gather", gather_inputs, /*num_outputs*/ 1); api::NodeRef& gather = *gather_ptr; gather.SetAttributeInt("axis", 0); @@ -1767,7 +1767,7 @@ static void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, con // inputs that would never be quantized. std::string_view gather_indices_const = AddInitializerInt64(graph, /*shape*/ {rank_int}, perm); std::vector gather_inputs{input_name, gather_indices_const}; - auto gather_ptr = graph.AddNode("Gather", gather_inputs, /*num_outputs*/ 1); + auto gather_ptr = graph.AddNode("Gather", "Gather", gather_inputs, /*num_outputs*/ 1); api::NodeRef& gather = *gather_ptr; std::string_view gather_output = gather.Outputs()[0]; graph.CopyValueInfo(input_name, gather_output); @@ -2215,7 +2215,7 @@ static bool HandleTile(HandlerArgs& args) { // Case 2: Repeats is computed. Insert Gather node. std::string_view perm_inv_const = AddInitializerInt64(args.ctx.graph, perm_shape, args.perm_inv); std::vector gather_inputs{repeats_inp, perm_inv_const}; - auto gather_node_ptr = args.ctx.graph.AddNode("Gather", gather_inputs, /*num_outputs*/ 1); + auto gather_node_ptr = args.ctx.graph.AddNode("Gather", "Gather", gather_inputs, /*num_outputs*/ 1); api::NodeRef& gather_node = *gather_node_ptr; std::string_view gather_output = gather_node.Outputs()[0]; args.ctx.graph.CopyValueInfo(repeats_inp, gather_output); @@ -2265,7 +2265,7 @@ static void RemoveCancelingTransposeNodes(HandlerArgs& args) { // Worst-case scenario: Both parent output and 2nd transpose/reshape output cannot be removed (both graph outputs) // despite computing the same value. Use an Identity op instead. std::vector single_empty_input{""}; - auto identity_ptr = args.ctx.graph.AddNode("Identity", single_empty_input, /*num_outputs*/ 1); + auto identity_ptr = args.ctx.graph.AddNode("Identity", "Identity", single_empty_input, /*num_outputs*/ 1); api::NodeRef& identity = *identity_ptr; args.ctx.graph.MoveOutput(args.node, 0, identity, 0); identity.SetInput(0, transpose_input); @@ -2297,7 +2297,7 @@ static bool HandleTransposeImpl(HandlerArgs& args, const std::vector& n // replace Reshape with Transpose to simplify the logic. // use the same input as the 1st Transpose, move the output from the Reshape to the new Transpose node, // and remove the Reshape node. - new_node = args.ctx.graph.AddNode("Transpose", {args.transpose.Inputs()[0]}, 1); + new_node = args.ctx.graph.AddNode("Transpose", "Transpose", {args.transpose.Inputs()[0]}, 1); args.ctx.graph.MoveOutput(args.node, 0, *new_node, 0); args.ctx.graph.RemoveNode(args.node); } else { diff --git a/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h b/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h index 211734f4bacc8..7122aec45e61a 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h +++ b/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h @@ -146,6 +146,9 @@ class ValueInfoRef { /// class NodeRef { public: + /// Node name + virtual std::string_view Name() const = 0; + /// Op computed by the node virtual std::string_view OpType() const = 0; @@ -361,6 +364,7 @@ class GraphRef { /// generated. Outputs of created node have unspecified shapes/dtypes. They will be populated afterwards using /// CopyValueInfo. /// + /// The new node's name /// The new node's op type /// Inputs for the node. "" for missing optional inputs. /// @@ -368,7 +372,7 @@ class GraphRef { /// /// The new node's domain. Empty string signifies default onnx domain. /// The new node - virtual std::unique_ptr AddNode(std::string_view op_type, const std::vector& inputs, + virtual std::unique_ptr AddNode(std::string_view name, std::string_view op_type, const std::vector& inputs, size_t num_outputs, std::string_view domain = /*kOnnxDomain*/ "") = 0; /// diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc index 33408474f92a6..f87df746234fa 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc @@ -80,6 +80,10 @@ class ApiNode final : public api::NodeRef { return node_; } + std::string_view Name() const override { + return node_.Name(); + } + std::string_view OpType() const override { return node_.OpType(); } @@ -134,7 +138,7 @@ class ApiGraph final : public api::GraphRef { std::unique_ptr GetNodeProducingOutput(std::string_view name) const override; void TransposeInitializer(std::string_view name, const std::vector& perm) override; void ReshapeInitializer(std::string_view name, const std::vector& shape) override; - std::unique_ptr AddNode(std::string_view op_type, const std::vector& inputs, + std::unique_ptr AddNode(std::string_view name, std::string_view op_type, const std::vector& inputs, size_t num_outputs = 1, std::string_view domain = "") override; std::unique_ptr CopyNode(const api::NodeRef& source_node, std::string_view op_type, @@ -621,11 +625,12 @@ void ApiGraph::ReshapeInitializer(std::string_view name, const std::vectorSetShape(new_shape); } -static Node& CreateNodeHelper(onnxruntime::Graph& graph, std::string_view op_type, +static Node& CreateNodeHelper(onnxruntime::Graph& graph, std::string_view op_name, std::string_view op_type, const std::vector& inputs, size_t num_outputs, std::string_view domain, int since_version, std::string_view node_ep) { const std::string op_type_str(op_type); - std::string name = graph.GenerateNodeName(op_type_str); + const std::string op_name_str(op_name); + std::string name = graph.GenerateNodeName(op_name_str); std::vector input_args; std::vector output_args; @@ -731,11 +736,11 @@ static int GetSinceVersionForNewOp(std::string_view op_type, std::string_view do return *since_version; } -std::unique_ptr ApiGraph::AddNode(std::string_view op_type, +std::unique_ptr ApiGraph::AddNode(std::string_view name, std::string_view op_type, const std::vector& inputs, size_t num_outputs, std::string_view domain) { int since_version = GetSinceVersionForNewOp(op_type, domain, graph_.DomainToVersionMap()); - Node& node = CreateNodeHelper(graph_, op_type, inputs, num_outputs, + Node& node = CreateNodeHelper(graph_, name, op_type, inputs, num_outputs, domain, since_version, new_node_ep_ != nullptr ? new_node_ep_ : ""); return std::make_unique(node, graph_); @@ -744,7 +749,7 @@ std::unique_ptr ApiGraph::AddNode(std::string_view op_type, std::unique_ptr ApiGraph::CopyNode(const api::NodeRef& source_node, std::string_view op_type, std::string_view domain, std::optional since_version) { const int new_node_since_version = since_version.has_value() ? *since_version : source_node.SinceVersion(); - Node& node = CreateNodeHelper(graph_, op_type, source_node.Inputs(), + Node& node = CreateNodeHelper(graph_, source_node.Name(), op_type, source_node.Inputs(), source_node.Outputs().size(), domain, new_node_since_version, source_node.GetExecutionProviderType()); diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc index 9f7be524daa34..67fb35d26e6dc 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc @@ -196,7 +196,7 @@ TEST(InternalTestingEP, TestMixOfStaticAndCompiledKernels) { // Error message should come from the Conv implementation with the statically registered kernel ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session.Run(feeds, output_names, &fetches), - "Non-zero status code returned while running Conv node. Name:'Conv' " + "Non-zero status code returned while running Conv node. Name:'_token_2' " "Status Message: TODO: add NHWC implementation here."); } @@ -242,7 +242,7 @@ TEST(InternalTestingEP, TestNhwcConversionOfStaticKernels) { std::vector fetches; ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session.Run(feeds, output_names, &fetches), - "Non-zero status code returned while running Conv node. Name:'Conv' " + "Non-zero status code returned while running Conv node. Name:'_token_2' " "Status Message: TODO: add NHWC implementation here."); }; From e93f14e00d09b0c62ba0869bc87f14ee5f1cf4c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erick=20Mu=C3=B1oz?= Date: Mon, 16 Sep 2024 10:20:06 -0600 Subject: [PATCH 32/39] Check partial conversion on FP16 to FP32 AVX Cast kernel (#22091) ### Description Added checks to convert partial vectors in the early stages of the FP16 to FP32 cast using AVX NE CONVERT ISA. ### Motivation and Context Avoid storing data in sections outside of the output buffer, these checks are missing on the [original PR](https://github.com/microsoft/onnxruntime/pull/21183). This fix prevents memory corruption when the output buffer has a size [n*16 + 1, n*16 + 7] with 0< n --- onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm | 4 +++- onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm b/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm index c7f6342c527bf..800863c77a230 100644 --- a/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm +++ b/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm @@ -54,7 +54,7 @@ HIGH_SELECTOR equ 00110001b LEAF_ENTRY MlasCastF16ToF32KernelAvx, _TEXT - test r8, r8 ; Check if we have any elements to convert + test r8, r8 ; Check if we have any elements to convert jz ExitRoutine cmp r8, 8 jb ConvertMaskedVectors @@ -80,6 +80,8 @@ Convert256Vectors: jz ExitRoutine ; If we are done, exit cmp r8, 16 ; If the vector is big enough, we go again jae Convert256Vectors + cmp r8, 8 ; Check if we have enough elements to convert + jb ConvertMaskedVectors diff --git a/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S b/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S index 1a70061460e50..a4d730fa513ab 100644 --- a/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S +++ b/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S @@ -51,8 +51,6 @@ FUNCTION_ENTRY MlasCastF16ToF32KernelAvx test rdx, rdx // Check if we have any elements to convert jz ExitRoutine - -AVX_NE_CONVERT: cmp rdx, 8 jb ConvertMaskedVectors cmp rdx, 16 @@ -75,6 +73,8 @@ Convert256Vectors: jz ExitRoutine // If we are done, exit cmp rdx, 16 // If the vector is big enough, we go again jae Convert256Vectors + cmp rdx, 8 // Check if we have enough elements to convert + jb ConvertMaskedVectors From 291a5352b27ded5714e5748b381f2efb88f28fb9 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 16 Sep 2024 10:56:22 -0700 Subject: [PATCH 33/39] [js/web] remove training release (#22103) ### Description Remove training from onnxruntime-web Following up of #22082 --- js/web/lib/backend-wasm-inference.ts | 5 - js/web/lib/backend-wasm-training.ts | 29 - js/web/lib/backend-wasm.ts | 2 + js/web/lib/index.ts | 4 +- js/web/lib/wasm/session-handler-training.ts | 198 ------ js/web/lib/wasm/wasm-core-impl.ts | 9 +- js/web/lib/wasm/wasm-training-core-impl.ts | 631 ------------------ js/web/lib/wasm/wasm-types.ts | 76 +-- js/web/lib/wasm/wasm-utils-import.ts | 16 +- js/web/package.json | 7 - js/web/script/build.ts | 13 +- js/web/script/pull-prebuilt-wasm-artifacts.ts | 2 - js/web/test/training/e2e/browser-test-wasm.js | 21 - js/web/test/training/e2e/common.js | 248 ------- js/web/test/training/e2e/data/model.onnx | 16 - js/web/test/training/e2e/karma.conf.js | 54 -- js/web/test/training/e2e/package.json | 14 - js/web/test/training/e2e/run.js | 143 ---- .../test/training/e2e/simple-http-server.js | 67 -- js/web/types.d.ts | 4 - 20 files changed, 15 insertions(+), 1544 deletions(-) delete mode 100644 js/web/lib/backend-wasm-inference.ts delete mode 100644 js/web/lib/backend-wasm-training.ts delete mode 100644 js/web/lib/wasm/session-handler-training.ts delete mode 100644 js/web/lib/wasm/wasm-training-core-impl.ts delete mode 100644 js/web/test/training/e2e/browser-test-wasm.js delete mode 100644 js/web/test/training/e2e/common.js delete mode 100644 js/web/test/training/e2e/data/model.onnx delete mode 100644 js/web/test/training/e2e/karma.conf.js delete mode 100644 js/web/test/training/e2e/package.json delete mode 100644 js/web/test/training/e2e/run.js delete mode 100644 js/web/test/training/e2e/simple-http-server.js diff --git a/js/web/lib/backend-wasm-inference.ts b/js/web/lib/backend-wasm-inference.ts deleted file mode 100644 index 7dfe7ee05a1d3..0000000000000 --- a/js/web/lib/backend-wasm-inference.ts +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import { OnnxruntimeWebAssemblyBackend } from './backend-wasm'; -export const wasmBackend = new OnnxruntimeWebAssemblyBackend(); diff --git a/js/web/lib/backend-wasm-training.ts b/js/web/lib/backend-wasm-training.ts deleted file mode 100644 index 7332b3f97eba0..0000000000000 --- a/js/web/lib/backend-wasm-training.ts +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import { InferenceSession, TrainingSessionHandler } from 'onnxruntime-common'; - -import { OnnxruntimeWebAssemblyBackend } from './backend-wasm'; -import { OnnxruntimeWebAssemblyTrainingSessionHandler } from './wasm/session-handler-training'; - -class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { - async createTrainingSessionHandler( - checkpointStateUriOrBuffer: string | Uint8Array, - trainModelUriOrBuffer: string | Uint8Array, - evalModelUriOrBuffer: string | Uint8Array, - optimizerModelUriOrBuffer: string | Uint8Array, - options: InferenceSession.SessionOptions, - ): Promise { - const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler(); - await handler.createTrainingSession( - checkpointStateUriOrBuffer, - trainModelUriOrBuffer, - evalModelUriOrBuffer, - optimizerModelUriOrBuffer, - options, - ); - return Promise.resolve(handler); - } -} - -export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend(); diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index 7bef538b26063..766937dc4c4cf 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -99,3 +99,5 @@ export class OnnxruntimeWebAssemblyBackend implements Backend { return Promise.resolve(handler); } } + +export const wasmBackend = new OnnxruntimeWebAssemblyBackend(); diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index 321394466b365..776c0d026bc97 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -20,9 +20,7 @@ if (!BUILD_DEFS.DISABLE_WEBGL) { } if (!BUILD_DEFS.DISABLE_WASM) { - const wasmBackend = BUILD_DEFS.DISABLE_TRAINING - ? require('./backend-wasm-inference').wasmBackend - : require('./backend-wasm-training').wasmBackend; + const wasmBackend = require('./backend-wasm').wasmBackend; if (!BUILD_DEFS.DISABLE_JSEP) { registerBackend('webgpu', wasmBackend, 5); registerBackend('webnn', wasmBackend, 5); diff --git a/js/web/lib/wasm/session-handler-training.ts b/js/web/lib/wasm/session-handler-training.ts deleted file mode 100644 index 8bbfb9cf06668..0000000000000 --- a/js/web/lib/wasm/session-handler-training.ts +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import { InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler } from 'onnxruntime-common'; - -import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages'; -import { decodeTensorMetadata, encodeTensorMetadata } from './session-handler-inference'; -import { copyFromExternalBuffer } from './wasm-core-impl'; -import { - createCheckpointHandle, - createTrainingSessionHandle, - getContiguousParameters, - getModelInputOutputNames, - getParametersSize, - lazyResetGrad, - loadParametersBuffer, - releaseTrainingSessionAndCheckpoint, - runEvalStep, - runOptimizerStep, - runTrainStep, -} from './wasm-training-core-impl'; - -export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { - private sessionId: number; - private checkpointId: number; - - inputNames: string[]; - outputNames: string[]; - - evalInputNames: string[] = []; - evalOutputNames: string[] = []; - - async uriOrBufferToHeap(uriOrBuffer: string | Uint8Array): Promise { - let buffer: Uint8Array; - if (typeof uriOrBuffer === 'string') { - const response = await fetch(uriOrBuffer); - const arrayBuffer = await response.arrayBuffer(); - buffer = new Uint8Array(arrayBuffer); - } else { - buffer = uriOrBuffer; - } - return copyFromExternalBuffer(buffer); - } - - async createTrainingSession( - checkpointStateUriOrBuffer: string | Uint8Array, - trainModelUriOrBuffer: string | Uint8Array, - evalModelUriOrBuffer: string | Uint8Array, - optimizerModelUriOrBuffer: string | Uint8Array, - options: InferenceSession.SessionOptions, - ) { - const checkpointData: SerializableInternalBuffer = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer); - const trainModelData: SerializableInternalBuffer = await this.uriOrBufferToHeap(trainModelUriOrBuffer); - // 0 is supposed to be the nullptr - let evalModelData: SerializableInternalBuffer = [0, 0]; - let optimizerModelData: SerializableInternalBuffer = [0, 0]; - - if (evalModelUriOrBuffer !== '') { - evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer); - } - if (optimizerModelUriOrBuffer !== '') { - optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer); - } - - this.checkpointId = createCheckpointHandle(checkpointData); - this.sessionId = createTrainingSessionHandle( - this.checkpointId, - trainModelData, - evalModelData, - optimizerModelData, - options, - ); - [this.inputNames, this.outputNames] = getModelInputOutputNames(this.sessionId, false); - if (evalModelUriOrBuffer !== '') { - [this.evalInputNames, this.evalOutputNames] = getModelInputOutputNames(this.sessionId, true); - } - } - - /** - * Helper method that converts a feeds or fetches datatype to two arrays, one of values and one that stores the - * corresponding name as a number referring to the index in the list of names provided. - * - * @param feeds meant to match either SessionHandler.FeedsType or SessionHandler.FetchesType - * @param names either inputNames or outputNames - * @returns a tuple of a list of values and a list of indices. - */ - convertMapIntoValuesArrayAndIndicesArray( - feeds: { [name: string]: T }, - names: string[], - mapFunc: (val: T, index: number) => U, - ): [T[], number[], U[]] { - const values: T[] = []; - const indices: number[] = []; - Object.entries(feeds).forEach((kvp) => { - const name = kvp[0]; - const tensor = kvp[1]; - const index = names.indexOf(name); - if (index === -1) { - throw new Error(`invalid input '${name}`); - } - values.push(tensor); - indices.push(index); - }); - - const uList = values.map(mapFunc); - return [values, indices, uList]; - } - - /** - * Helper method that converts the TensorMetadata that the wasm-core functions return to the - * SessionHandler.ReturnType. Any outputs in the provided outputArray that are falsy will be populated with the - * corresponding result. - * - * @param results used to populate the resultMap if there is no value for that outputName already - * @param outputArray used to populate the resultMap. If null or undefined, use the corresponding result from results - * @param outputIndices specifies which outputName the corresponding value for outputArray refers to. - * @returns a map of output names and OnnxValues. - */ - convertTensorMetadataToReturnType( - results: TensorMetadata[], - outputArray: Array, - outputIndices: number[], - ): SessionHandler.ReturnType { - const resultMap: SessionHandler.ReturnType = {}; - for (let i = 0; i < results.length; i++) { - resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]); - } - return resultMap; - } - - async lazyResetGrad(): Promise { - await lazyResetGrad(this.sessionId); - } - - async runTrainStep( - feeds: SessionHandler.FeedsType, - fetches: SessionHandler.FetchesType, - options: InferenceSession.RunOptions, - ): Promise { - const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray( - feeds, - this.inputNames, - (t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`), - ); - - const [outputArray, outputIndices, outputs] = this.convertMapIntoValuesArrayAndIndicesArray< - Tensor | null, - TensorMetadata | null - >(fetches, this.outputNames, (t, i): TensorMetadata | null => - t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null, - ); - - const results = await runTrainStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); - return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices); - } - - async runOptimizerStep(options: InferenceSession.RunOptions): Promise { - await runOptimizerStep(this.sessionId, options); - } - - async runEvalStep( - feeds: SessionHandler.FeedsType, - fetches: SessionHandler.FetchesType, - options: InferenceSession.RunOptions, - ): Promise { - const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray( - feeds, - this.evalInputNames, - (t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.evalInputNames[inputIndices[i]]}"`), - ); - - const [outputArray, outputIndices, outputs] = this.convertMapIntoValuesArrayAndIndicesArray< - Tensor | null, - TensorMetadata | null - >(fetches, this.evalOutputNames, (t, i): TensorMetadata | null => - t ? encodeTensorMetadata(t, () => `output "${this.evalOutputNames[outputIndices[i]]}"`) : null, - ); - - const results = await runEvalStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); - return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices); - } - - async getParametersSize(trainableOnly: boolean): Promise { - return getParametersSize(this.sessionId, trainableOnly); - } - - async loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise { - await loadParametersBuffer(this.sessionId, array, trainableOnly); - } - async getContiguousParameters(trainableOnly: boolean): Promise { - const tensorResult = await getContiguousParameters(this.sessionId, trainableOnly); - return decodeTensorMetadata(tensorResult); - } - - async dispose(): Promise { - return releaseTrainingSessionAndCheckpoint(this.checkpointId, this.sessionId); - } -} diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 6c4e28df62f23..ed001cfa90f59 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -41,8 +41,8 @@ import { loadFile } from './wasm-utils-load-file'; * Refer to web/lib/index.ts for the backend registration. * * 2. WebAssembly artifact initialization. - * This happens when any registered wasm backend is used for the first time (ie. `ort.InferenceSession.create()` or - * `ort.TrainingSession.create()` is called). In this step, onnxruntime-web does the followings: + * This happens when any registered wasm backend is used for the first time (ie. `ort.InferenceSession.create()` is + * called). In this step, onnxruntime-web does the followings: * - create a proxy worker and make sure the proxy worker is ready to receive messages, if proxy is enabled. * - perform feature detection, locate correct WebAssembly artifact path and call the Emscripten generated * JavaScript code to initialize the WebAssembly runtime. @@ -57,9 +57,8 @@ import { loadFile } from './wasm-utils-load-file'; * - logging level (ort.env.logLevel) and thread number (ort.env.wasm.numThreads) are set in this step. * * 4. Session initialization. - * This happens when `ort.InferenceSession.create()` or `ort.TrainingSession.create()` is called. Unlike the first 3 - * steps (they only called once), this step will be done for each session. In this step, onnxruntime-web does the - * followings: + * This happens when `ort.InferenceSession.create()` is called. Unlike the first 3 steps (they only called once), + * this step will be done for each session. In this step, onnxruntime-web does the followings: * If the parameter is a URL: * - download the model data from the URL. * - copy the model data to the WASM heap. (proxy: 'copy-from') diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts deleted file mode 100644 index 22cd6ec30732c..0000000000000 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ /dev/null @@ -1,631 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import { InferenceSession, Tensor } from 'onnxruntime-common'; - -import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages'; -import { setRunOptions } from './run-options'; -import { setSessionOptions } from './session-options'; -import { - dataLocationStringToEnum, - tensorDataTypeEnumToString, - tensorDataTypeStringToEnum, - tensorTypeToTypedArrayConstructor, -} from './wasm-common'; -import { prepareInputOutputTensor } from './wasm-core-impl'; -import { getInstance } from './wasm-factory'; -import { checkLastError } from './wasm-utils'; - -const NO_TRAIN_FUNCS_MSG = - "Built without training API's enabled. Use the onnxruntime-web/training import for training " + - 'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' + - 'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.'; - -/** - * Runs the checkLastError function which will throw an error, if the provided error code matches the specified - * pattern for an error code. - * @param errCode number to evaluated for if it's an error - * @param message message to pass into checkLastError - * @param checkNeqZero when true, treats not equal to zero as an error. - * When false, treats equal to zero as an error. - */ -const ifErrCodeCheckLastError = (errCode: number, message: string, checkNeqZero = true) => { - if (checkNeqZero && errCode !== 0) { - checkLastError(message); - } else if (!checkNeqZero && errCode === 0) { - checkLastError(message); - } -}; - -export const createCheckpointHandle = (checkpointData: SerializableInternalBuffer): number => { - const wasm = getInstance(); - - const [checkpointDataOffset, checkpointDataLength] = checkpointData; - let checkpointHandle = 0; - - try { - if (wasm._OrtTrainingLoadCheckpoint) { - checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointDataOffset, checkpointDataLength); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - - ifErrCodeCheckLastError(checkpointHandle, 'Error occurred when trying to create a CheckpointState', false); - return checkpointHandle; - } catch (e) { - if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) { - wasm._OrtTrainingReleaseCheckpoint(checkpointHandle); - } - throw e; - } finally { - // free buffer from wasm heap - wasm._OrtFree(checkpointData[0]); - } -}; - -const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolean): [number, number] => { - const wasm = getInstance(); - const stack = wasm.stackSave(); - try { - const dataOffset = wasm.stackAlloc(8); - if (wasm._OrtTrainingGetModelInputOutputCount) { - const errorCode = wasm._OrtTrainingGetModelInputOutputCount( - trainingSessionId, - dataOffset, - dataOffset + 4, - isEvalModel, - ); - ifErrCodeCheckLastError(errorCode, "Can't get session input/output count."); - return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } finally { - wasm.stackRestore(stack); - } -}; - -const getModelInputOutputNamesLoop = ( - trainingSessionId: number, - count: number, - isInput: boolean, - isEvalModel: boolean, -): string[] => { - const names = []; - const wasm = getInstance(); - - for (let i = 0; i < count; i++) { - if (wasm._OrtTrainingGetModelInputOutputName) { - const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel); - ifErrCodeCheckLastError(name, `Can't get input or output name -- is input: ${isInput}, index ${i}`, false); - - names.push(wasm.UTF8ToString(name)); - wasm._free(name); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } - return names; -}; - -export const getModelInputOutputNames = (trainingSessionId: number, isEvalModel: boolean): [string[], string[]] => { - let inputNames: string[] = []; - let outputNames: string[] = []; - - const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, isEvalModel); - - inputNames = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, isEvalModel); - outputNames = getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, isEvalModel); - - return [inputNames, outputNames]; -}; - -export const createTrainingSessionHandle = ( - checkpointHandle: number, - trainModelData: SerializableInternalBuffer, - evalModelData: SerializableInternalBuffer, - optimizerModelData: SerializableInternalBuffer, - options: InferenceSession.SessionOptions, -): number => { - const wasm = getInstance(); - - let trainingSessionHandle = 0; - let sessionOptionsHandle = 0; - let allocs: number[] = []; - - try { - [sessionOptionsHandle, allocs] = setSessionOptions(options); - if (wasm._OrtTrainingCreateSession) { - trainingSessionHandle = wasm._OrtTrainingCreateSession( - sessionOptionsHandle, - checkpointHandle, - trainModelData[0], - trainModelData[1], - evalModelData[0], - evalModelData[1], - optimizerModelData[0], - optimizerModelData[1], - ); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - - ifErrCodeCheckLastError(trainingSessionHandle, 'Error occurred when trying to create a TrainingSession', false); - return trainingSessionHandle; - } catch (e) { - if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) { - wasm._OrtTrainingReleaseSession(trainingSessionHandle); - } - throw e; - } finally { - wasm._free(trainModelData[0]); - wasm._free(evalModelData[0]); - wasm._free(optimizerModelData[0]); - - if (sessionOptionsHandle !== 0) { - wasm._OrtReleaseSessionOptions(sessionOptionsHandle); - } - allocs.forEach((alloc) => wasm._free(alloc)); - } -}; - -/** - * Prepares input and output tensors by creating the tensors in the WASM side then creates a list of the handles of the - * WASM tensors. - * - * @param trainingSessionId - * @param indices for each tensor, the index of the input or output name that the tensor corresponds with - * @param tensors list of TensorMetaData - * @param tensorHandles should pass in an empty list of numbers; modified in-place by this method & stores the resulting - * handles of the allocated tensors on the heap - * @param inputOutputAllocs modified in-place by this method - * @param indexAdd constant to add to the index that is passed to prepareInputOutputTensor - */ -const createAndAllocateTensors = ( - trainingSessionId: number, - indices: number[], - tensors: Array, - tensorHandles: number[], - inputOutputAllocs: number[], - indexAdd: number, -) => { - const count = indices.length; - - // creates the tensors - for (let i = 0; i < count; i++) { - prepareInputOutputTensor(tensors[i], tensorHandles, inputOutputAllocs, trainingSessionId, indexAdd + indices[i]); - } - - // moves to heap - const wasm = getInstance(); - const valuesOffset = wasm.stackAlloc(count * 4); - let valuesIndex = valuesOffset / 4; - for (let i = 0; i < count; i++) { - wasm.HEAPU32[valuesIndex++] = tensorHandles[i]; - } - - return valuesOffset; -}; - -/** - * Retrieves the information from the output tensor handles, copies to an array, and frees the WASM information - * associated with the tensor handle. - * - * @param outputValuesOffset - * @param outputCount - * @returns list of TensorMetadata retrieved from the output handles. - */ -const moveOutputToTensorMetadataArr = ( - outputValuesOffset: number, - outputCount: number, - outputTensorHandles: number[], - outputTensors: Array, -) => { - const wasm = getInstance(); - const output: TensorMetadata[] = []; - - for (let i = 0; i < outputCount; i++) { - const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; - if (tensor === outputTensorHandles[i]) { - // output tensor is pre-allocated. no need to copy data. - output.push(outputTensors[i]!); - continue; - } - - const beforeGetTensorDataStack = wasm.stackSave(); - // stack allocate 4 pointer value - const tensorDataOffset = wasm.stackAlloc(4 * 4); - - let type: Tensor.Type | undefined, - dataOffset = 0; - try { - const errorCode = wasm._OrtGetTensorData( - tensor, - tensorDataOffset, - tensorDataOffset + 4, - tensorDataOffset + 8, - tensorDataOffset + 12, - ); - ifErrCodeCheckLastError(errorCode, `Can't access output tensor data on index ${i}.`); - - let tensorDataIndex = tensorDataOffset / 4; - const dataType = wasm.HEAPU32[tensorDataIndex++]; - dataOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsLength = wasm.HEAPU32[tensorDataIndex++]; - const dims = []; - for (let i = 0; i < dimsLength; i++) { - dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); - } - wasm._OrtFree(dimsOffset); - - const size = dims.reduce((a, b) => a * b, 1); - type = tensorDataTypeEnumToString(dataType); - - if (type === 'string') { - const stringData: string[] = []; - let dataIndex = dataOffset / 4; - for (let i = 0; i < size; i++) { - const offset = wasm.HEAPU32[dataIndex++]; - const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; - stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); - } - output.push([type, dims, stringData, 'cpu']); - } else { - const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); - const data = new typedArrayConstructor(size); - new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set( - wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength), - ); - output.push([type, dims, data, 'cpu']); - } - } finally { - wasm.stackRestore(beforeGetTensorDataStack); - if (type === 'string' && dataOffset) { - wasm._free(dataOffset); - } - wasm._OrtReleaseTensor(tensor); - } - } - - return output; -}; - -export const lazyResetGrad = async (trainingSessionId: number): Promise => { - const wasm = getInstance(); - - if (wasm._OrtTrainingLazyResetGrad) { - const errorCode = wasm._OrtTrainingLazyResetGrad(trainingSessionId); - ifErrCodeCheckLastError(errorCode, "Can't call lazyResetGrad."); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } -}; - -export const runTrainStep = async ( - trainingSessionId: number, - inputIndices: number[], - inputTensors: TensorMetadata[], - outputIndices: number[], - outputTensors: Array, - options: InferenceSession.RunOptions, -): Promise => { - const wasm = getInstance(); - - const inputCount = inputIndices.length; - const outputCount = outputIndices.length; - - let runOptionsHandle = 0; - let runOptionsAllocs: number[] = []; - - const inputTensorHandles: number[] = []; - const outputTensorHandles: number[] = []; - const inputOutputAllocs: number[] = []; - - const beforeRunStack = wasm.stackSave(); - - try { - // prepare parameters by moving them to heap - [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); - - // handle inputs -- you don't want anything added to the index - const inputValuesOffset = createAndAllocateTensors( - trainingSessionId, - inputIndices, - inputTensors, - inputTensorHandles, - inputOutputAllocs, - 0, - ); - // handle outputs - // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor - const outputValuesOffset = createAndAllocateTensors( - trainingSessionId, - outputIndices, - outputTensors, - outputTensorHandles, - inputOutputAllocs, - inputCount, - ); - - if (wasm._OrtTrainingRunTrainStep) { - const errorCode = wasm._OrtTrainingRunTrainStep( - trainingSessionId, - inputValuesOffset, - inputCount, - outputValuesOffset, - outputCount, - runOptionsHandle, - ); - ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingRunTrainStep in the WebAssembly layer'); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - - return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors); - } finally { - wasm.stackRestore(beforeRunStack); - - inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); - outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); - inputOutputAllocs.forEach((p) => wasm._free(p)); - - if (runOptionsHandle !== 0) { - wasm._OrtReleaseRunOptions(runOptionsHandle); - } - runOptionsAllocs.forEach((p) => wasm._free(p)); - } -}; - -export const runOptimizerStep = async ( - trainingSessionId: number, - options: InferenceSession.RunOptions, -): Promise => { - const wasm = getInstance(); - - let runOptionsHandle = 0; - let runOptionsAllocs: number[] = []; - - try { - [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); - - if (wasm._OrtTrainingOptimizerStep) { - const errCode = wasm._OrtTrainingOptimizerStep(trainingSessionId, runOptionsHandle); - ifErrCodeCheckLastError(errCode, 'Failed to call OrtTrainingOptimizerStep in the WebAssembly layer'); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } finally { - if (runOptionsHandle !== 0) { - wasm._OrtReleaseRunOptions(runOptionsHandle); - } - runOptionsAllocs.forEach((p) => wasm._free(p)); - } -}; - -export const runEvalStep = async ( - trainingSessionId: number, - inputIndices: number[], - inputTensors: TensorMetadata[], - outputIndices: number[], - outputTensors: Array, - options: InferenceSession.RunOptions, -): Promise => { - const wasm = getInstance(); - - const inputCount = inputIndices.length; - const outputCount = outputIndices.length; - - let runOptionsHandle = 0; - let runOptionsAllocs: number[] = []; - - const inputTensorHandles: number[] = []; - const outputTensorHandles: number[] = []; - const inputOutputAllocs: number[] = []; - - const beforeRunStack = wasm.stackSave(); - - try { - // prepare parameters by moving them to heap - [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); - - // handle inputs -- you don't want anything added to the index - const inputValuesOffset = createAndAllocateTensors( - trainingSessionId, - inputIndices, - inputTensors, - inputTensorHandles, - inputOutputAllocs, - 0, - ); - // handle outputs - // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor - const outputValuesOffset = createAndAllocateTensors( - trainingSessionId, - outputIndices, - outputTensors, - outputTensorHandles, - inputOutputAllocs, - inputCount, - ); - - if (wasm._OrtTrainingEvalStep) { - const errorCode = wasm._OrtTrainingEvalStep( - trainingSessionId, - inputValuesOffset, - inputCount, - outputValuesOffset, - outputCount, - runOptionsHandle, - ); - - ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingEvalStep in the WebAssembly layer'); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - - return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors); - } finally { - wasm.stackRestore(beforeRunStack); - - inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); - outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); - inputOutputAllocs.forEach((p) => wasm._free(p)); - - if (runOptionsHandle !== 0) { - wasm._OrtReleaseRunOptions(runOptionsHandle); - } - runOptionsAllocs.forEach((p) => wasm._free(p)); - } -}; - -export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): number => { - const wasm = getInstance(); - const stack = wasm.stackSave(); - - try { - const sizeOffset = wasm.stackAlloc(4); - if (wasm._OrtTrainingGetParametersSize) { - const errorCode = wasm._OrtTrainingGetParametersSize(trainingSessionId, sizeOffset, trainableOnly); - ifErrCodeCheckLastError(errorCode, "Can't get parameters size"); - - return wasm.HEAP32[sizeOffset / 4]; - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } finally { - wasm.stackRestore(stack); - } -}; - -export const getContiguousParameters = async ( - trainingSessionId: number, - trainableOnly: boolean, -): Promise => { - const wasm = getInstance(); - const stack = wasm.stackSave(); - - const tensorTypeAsString = 'float32'; - const locationAsString = 'cpu'; - - const parametersSize = getParametersSize(trainingSessionId, trainableOnly); - let tensor = 0; - - // allocates a buffer of the correct size on the WASM heap - const paramsByteLength = 4 * parametersSize; - const paramsOffset = wasm._malloc(paramsByteLength); - - // handles the dimensions-related createTensor parameters - const dims = [parametersSize]; - - const dimsOffset = wasm.stackAlloc(4); - const dimsIndex = dimsOffset / 4; - wasm.HEAP32[dimsIndex] = parametersSize; - - try { - // wraps allocated array in a tensor - tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum(tensorTypeAsString), - paramsOffset, - paramsByteLength, - dimsOffset, - dims.length, - dataLocationStringToEnum(locationAsString), - ); - ifErrCodeCheckLastError( - tensor, - `Can't create tensor for getContiguousParameters. session=${trainingSessionId}.`, - false, - ); - - if (wasm._OrtTrainingCopyParametersToBuffer) { - const errCode = wasm._OrtTrainingCopyParametersToBuffer(trainingSessionId, tensor, parametersSize, trainableOnly); - ifErrCodeCheckLastError(errCode, "Can't get contiguous parameters."); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - - // copies from WASM memory to a JavaScript typed array, which is then put into a TensorMetadata object - const typedArrayConstructor = tensorTypeToTypedArrayConstructor(tensorTypeAsString); - const data = new typedArrayConstructor(parametersSize); - const output: TensorMetadata[] = []; - new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set( - wasm.HEAPU8.subarray(paramsOffset, paramsOffset + paramsByteLength), - ); - output.push([tensorTypeAsString, dims, data, locationAsString]); - if (output.length !== 1) { - throw new Error(`something unexpected happened in the getContiguousParameters function. Expected output length of - one, got ${output.length}`); - } else { - return output[0]; - } - } finally { - if (tensor !== 0) { - wasm._OrtReleaseTensor(tensor); - } - wasm._free(paramsOffset); - wasm._free(dimsOffset); - wasm.stackRestore(stack); - } -}; - -export const loadParametersBuffer = async ( - trainingSessionId: number, - buffer: Uint8Array, - trainableOnly: boolean, -): Promise => { - const wasm = getInstance(); - const stack = wasm.stackSave(); - - const tensorTypeAsString = 'float32'; - const locationAsString = 'cpu'; - - // allocates & copies JavaScript buffer to WASM heap - const bufferByteLength = buffer.length; - const bufferCount = bufferByteLength / 4; - const bufferOffset = wasm._malloc(bufferByteLength); - wasm.HEAPU8.set(buffer, bufferOffset); - - // allocates and handles moving dimensions information to WASM memory - const dimsOffset = wasm.stackAlloc(4); - wasm.HEAP32[dimsOffset / 4] = bufferCount; - const dimsLength = 1; - let tensor = 0; - - try { - tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum(tensorTypeAsString), - bufferOffset, - bufferByteLength, - dimsOffset, - dimsLength, - dataLocationStringToEnum(locationAsString), - ); - ifErrCodeCheckLastError(tensor, `Can't create tensor for input/output. session=${trainingSessionId}`, false); - - if (wasm._OrtTrainingCopyParametersFromBuffer) { - const errCode = wasm._OrtTrainingCopyParametersFromBuffer(trainingSessionId, tensor, bufferCount, trainableOnly); - ifErrCodeCheckLastError(errCode, "Can't copy buffer to parameters."); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } finally { - if (tensor !== 0) { - wasm._OrtReleaseTensor(tensor); - } - wasm.stackRestore(stack); - wasm._free(bufferOffset); - wasm._free(dimsOffset); - } -}; - -export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number): void => { - const wasm = getInstance(); - - if (wasm._OrtTrainingReleaseSession) { - wasm._OrtTrainingReleaseSession(sessionId); - } - if (wasm._OrtTrainingReleaseCheckpoint) { - wasm._OrtTrainingReleaseCheckpoint(checkpointId); - } -}; diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 70b6cceab0eef..828cd3cfd94fa 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -213,84 +213,10 @@ export interface OrtInferenceAPIs { _OrtEndProfiling(sessionHandle: number): number; } -export interface OrtTrainingAPIs { - _OrtTrainingLoadCheckpoint(dataOffset: number, dataLength: number): number; - - _OrtTrainingReleaseCheckpoint(checkpointHandle: number): void; - - _OrtTrainingCreateSession( - sessionOptionsHandle: number, - checkpointHandle: number, - trainOffset: number, - trainLength: number, - evalOffset: number, - evalLength: number, - optimizerOffset: number, - optimizerLength: number, - ): number; - - _OrtTrainingLazyResetGrad(trainingHandle: number): number; - - _OrtTrainingRunTrainStep( - trainingHandle: number, - inputsOffset: number, - inputCount: number, - outputsOffset: number, - outputCount: number, - runOptionsHandle: number, - ): number; - - _OrtTrainingOptimizerStep(trainingHandle: number, runOptionsHandle: number): number; - - _OrtTrainingEvalStep( - trainingHandle: number, - inputsOffset: number, - inputCount: number, - outputsOffset: number, - outputCount: number, - runOptionsHandle: number, - ): number; - - _OrtTrainingGetParametersSize(trainingHandle: number, paramSizeT: number, trainableOnly: boolean): number; - - _OrtTrainingCopyParametersToBuffer( - trainingHandle: number, - parametersBuffer: number, - parameterCount: number, - trainableOnly: boolean, - ): number; - - _OrtTrainingCopyParametersFromBuffer( - trainingHandle: number, - parametersBuffer: number, - parameterCount: number, - trainableOnly: boolean, - ): number; - - _OrtTrainingGetModelInputOutputCount( - trainingHandle: number, - inputCount: number, - outputCount: number, - isEvalModel: boolean, - ): number; - _OrtTrainingGetModelInputOutputName( - trainingHandle: number, - index: number, - isInput: boolean, - isEvalModel: boolean, - ): number; - - _OrtTrainingReleaseSession(trainingHandle: number): void; -} - /** * The interface of the WebAssembly module for ONNX Runtime, compiled from C++ source code by Emscripten. */ -export interface OrtWasmModule - extends EmscriptenModule, - OrtInferenceAPIs, - Partial, - Partial { +export interface OrtWasmModule extends EmscriptenModule, OrtInferenceAPIs, Partial { // #region emscripten functions stackSave(): number; stackRestore(stack: number): void; diff --git a/js/web/lib/wasm/wasm-utils-import.ts b/js/web/lib/wasm/wasm-utils-import.ts index 008b9b41b1592..bd9e0ce083ef0 100644 --- a/js/web/lib/wasm/wasm-utils-import.ts +++ b/js/web/lib/wasm/wasm-utils-import.ts @@ -135,11 +135,9 @@ const embeddedWasmModule: EmscriptenModuleFactory | undefined = BUILD_DEFS.IS_ESM && BUILD_DEFS.DISABLE_DYNAMIC_IMPORT ? // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires require( - !BUILD_DEFS.DISABLE_TRAINING - ? '../../dist/ort-training-wasm-simd-threaded.mjs' - : !BUILD_DEFS.DISABLE_JSEP - ? '../../dist/ort-wasm-simd-threaded.jsep.mjs' - : '../../dist/ort-wasm-simd-threaded.mjs', + !BUILD_DEFS.DISABLE_JSEP + ? '../../dist/ort-wasm-simd-threaded.jsep.mjs' + : '../../dist/ort-wasm-simd-threaded.mjs', ).default : undefined; @@ -163,11 +161,9 @@ export const importWasmModule = async ( if (BUILD_DEFS.DISABLE_DYNAMIC_IMPORT) { return [undefined, embeddedWasmModule!]; } else { - const wasmModuleFilename = !BUILD_DEFS.DISABLE_TRAINING - ? 'ort-training-wasm-simd-threaded.mjs' - : !BUILD_DEFS.DISABLE_JSEP - ? 'ort-wasm-simd-threaded.jsep.mjs' - : 'ort-wasm-simd-threaded.mjs'; + const wasmModuleFilename = !BUILD_DEFS.DISABLE_JSEP + ? 'ort-wasm-simd-threaded.jsep.mjs' + : 'ort-wasm-simd-threaded.mjs'; const wasmModuleUrl = urlOverride ?? normalizeUrl(wasmModuleFilename, prefixOverride); // need to preload if all of the following conditions are met: // 1. not in Node.js. diff --git a/js/web/package.json b/js/web/package.json index 94dd047915b05..d770499adada4 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -23,7 +23,6 @@ "build:doc": "node ./script/generate-webgl-operator-md && node ./script/generate-webgpu-operator-md", "pull:wasm": "node ./script/pull-prebuilt-wasm-artifacts", "test:e2e": "node ./test/e2e/run", - "test:training:e2e": "node ./test/training/e2e/run", "prebuild": "tsc -p . --noEmit && tsc -p lib/wasm/proxy-worker --noEmit", "build": "node ./script/build", "test": "tsc --build ../scripts && node ../scripts/prepare-onnx-node-tests && node ./script/test-runner-cli", @@ -101,12 +100,6 @@ "import": "./dist/ort.webgpu.bundle.min.mjs", "require": "./dist/ort.webgpu.min.js", "types": "./types.d.ts" - }, - "./training": { - "node": null, - "import": "./dist/ort.training.wasm.min.mjs", - "require": "./dist/ort.training.wasm.min.js", - "types": "./types.d.ts" } }, "types": "./types.d.ts", diff --git a/js/web/script/build.ts b/js/web/script/build.ts index 6d1b3bdb65068..408f9e00a5cbd 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -56,7 +56,6 @@ const DEFAULT_DEFINE = { 'BUILD_DEFS.DISABLE_JSEP': 'false', 'BUILD_DEFS.DISABLE_WASM': 'false', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'false', - 'BUILD_DEFS.DISABLE_TRAINING': 'true', 'BUILD_DEFS.DISABLE_DYNAMIC_IMPORT': 'false', 'BUILD_DEFS.IS_ESM': 'false', @@ -253,7 +252,7 @@ async function buildBundle(options: esbuild.BuildOptions) { * * The distribution code is split into multiple files: * - [output-name][.min].[m]js - * - ort[-training]-wasm-simd-threaded[.jsep].mjs + * - ort-wasm-simd-threaded[.jsep].mjs */ async function buildOrt({ isProduction = false, @@ -630,16 +629,6 @@ async function main() { 'BUILD_DEFS.DISABLE_WASM_PROXY': 'true', }, }); - // ort.training.wasm[.min].[m]js - await addAllWebBuildTasks({ - outputName: 'ort.training.wasm', - define: { - ...DEFAULT_DEFINE, - 'BUILD_DEFS.DISABLE_TRAINING': 'false', - 'BUILD_DEFS.DISABLE_JSEP': 'true', - 'BUILD_DEFS.DISABLE_WEBGL': 'true', - }, - }); } if (BUNDLE_MODE === 'dev' || BUNDLE_MODE === 'perf') { diff --git a/js/web/script/pull-prebuilt-wasm-artifacts.ts b/js/web/script/pull-prebuilt-wasm-artifacts.ts index b1b2fa26b2351..5b8b0d27c88db 100644 --- a/js/web/script/pull-prebuilt-wasm-artifacts.ts +++ b/js/web/script/pull-prebuilt-wasm-artifacts.ts @@ -149,11 +149,9 @@ downloadJson( void jszip.loadAsync(buffer).then((zip) => { extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.wasm', folderName); extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.jsep.wasm', folderName); - extractFile(zip, WASM_FOLDER, 'ort-training-wasm-simd-threaded.wasm', folderName); extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.mjs', folderName); extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.jsep.mjs', folderName); - extractFile(zip, WASM_FOLDER, 'ort-training-wasm-simd-threaded.mjs', folderName); }); }); }, diff --git a/js/web/test/training/e2e/browser-test-wasm.js b/js/web/test/training/e2e/browser-test-wasm.js deleted file mode 100644 index 05750ed149303..0000000000000 --- a/js/web/test/training/e2e/browser-test-wasm.js +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -'use strict'; - -describe('Browser E2E testing for training package', function () { - it('Check that training package encompasses inference', async function () { - ort.env.wasm.numThreads = 1; - await testInferenceFunction(ort, { executionProviders: ['wasm'] }); - }); - - it('Check training functionality, all options', async function () { - ort.env.wasm.numThreads = 1; - await testTrainingFunctionAll(ort, { executionProviders: ['wasm'] }); - }); - - it('Check training functionality, minimum options', async function () { - ort.env.wasm.numThreads = 1; - await testTrainingFunctionMin(ort, { executionProviders: ['wasm'] }); - }); -}); diff --git a/js/web/test/training/e2e/common.js b/js/web/test/training/e2e/common.js deleted file mode 100644 index 0574ae85aabd1..0000000000000 --- a/js/web/test/training/e2e/common.js +++ /dev/null @@ -1,248 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -'use strict'; - -const DATA_FOLDER = 'data/'; -const TRAININGDATA_TRAIN_MODEL = DATA_FOLDER + 'training_model.onnx'; -const TRAININGDATA_OPTIMIZER_MODEL = DATA_FOLDER + 'adamw.onnx'; -const TRAININGDATA_EVAL_MODEL = DATA_FOLDER + 'eval_model.onnx'; -const TRAININGDATA_CKPT = DATA_FOLDER + 'checkpoint.ckpt'; - -const trainingSessionAllOptions = { - checkpointState: TRAININGDATA_CKPT, - trainModel: TRAININGDATA_TRAIN_MODEL, - evalModel: TRAININGDATA_EVAL_MODEL, - optimizerModel: TRAININGDATA_OPTIMIZER_MODEL, -}; - -const trainingSessionMinOptions = { - checkpointState: TRAININGDATA_CKPT, - trainModel: TRAININGDATA_TRAIN_MODEL, -}; - -// ASSERT METHODS - -function assert(cond) { - if (!cond) throw new Error(); -} - -function assertStrictEquals(actual, expected) { - if (actual !== expected) { - let strRep = actual; - if (typeof actual === 'object') { - strRep = JSON.stringify(actual); - } - throw new Error(`expected: ${expected}; got: ${strRep}`); - } -} - -function assertTwoListsUnequal(list1, list2) { - if (list1.length !== list2.length) { - return; - } - for (let i = 0; i < list1.length; i++) { - if (list1[i] !== list2[i]) { - return; - } - } - throw new Error(`expected ${list1} and ${list2} to be unequal; got two equal lists`); -} - -// HELPER METHODS FOR TESTS - -function generateGaussianRandom(mean = 0, scale = 1) { - const u = 1 - Math.random(); - const v = Math.random(); - const z = Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v); - return z * scale + mean; -} - -function generateGaussianFloatArray(length) { - const array = new Float32Array(length); - - for (let i = 0; i < length; i++) { - array[i] = generateGaussianRandom(); - } - - return array; -} - -/** - * creates the TrainingSession and verifies that the input and output names of the training model loaded into the - * training session are correct. - * @param {} ort - * @param {*} createOptions - * @param {*} options - * @returns - */ -async function createTrainingSessionAndCheckTrainingModel(ort, createOptions, options) { - const trainingSession = await ort.TrainingSession.create(createOptions, options); - - assertStrictEquals(trainingSession.trainingInputNames[0], 'input-0'); - assertStrictEquals(trainingSession.trainingInputNames[1], 'labels'); - assertStrictEquals(trainingSession.trainingInputNames.length, 2); - assertStrictEquals(trainingSession.trainingOutputNames[0], 'onnx::loss::21273'); - assertStrictEquals(trainingSession.trainingOutputNames.length, 1); - return trainingSession; -} - -/** - * verifies that the eval input and output names associated with the eval model loaded into the given training session - * are correct. - */ -function checkEvalModel(trainingSession) { - assertStrictEquals(trainingSession.evalInputNames[0], 'input-0'); - assertStrictEquals(trainingSession.evalInputNames[1], 'labels'); - assertStrictEquals(trainingSession.evalInputNames.length, 2); - assertStrictEquals(trainingSession.evalOutputNames[0], 'onnx::loss::21273'); - assertStrictEquals(trainingSession.evalOutputNames.length, 1); -} - -/** - * Checks that accessing trainingSession.evalInputNames or trainingSession.evalOutputNames will throw an error if - * accessed - * @param {} trainingSession - */ -function checkNoEvalModel(trainingSession) { - try { - assertStrictEquals(trainingSession.evalInputNames, 'should have thrown an error upon accessing'); - } catch (error) { - assertStrictEquals(error.message, 'This training session has no evalModel loaded.'); - } - try { - assertStrictEquals(trainingSession.evalOutputNames, 'should have thrown an error upon accessing'); - } catch (error) { - assertStrictEquals(error.message, 'This training session has no evalModel loaded.'); - } -} - -/** - * runs the train step with the given inputs and checks that the tensor returned is of type float32 and has a length - * of 1 for the loss. - * @param {} trainingSession - * @param {*} feeds - * @returns - */ -var runTrainStepAndCheck = async function (trainingSession, feeds) { - const results = await trainingSession.runTrainStep(feeds); - assertStrictEquals(Object.keys(results).length, 1); - assertStrictEquals(results['onnx::loss::21273'].data.length, 1); - assertStrictEquals(results['onnx::loss::21273'].type, 'float32'); - return results; -}; - -var loadParametersBufferAndCheck = async function (trainingSession, paramsLength, constant, paramsBefore) { - // make a float32 array that is filled with the constant - const newParams = new Float32Array(paramsLength); - for (let i = 0; i < paramsLength; i++) { - newParams[i] = constant; - } - - const newParamsUint8 = new Uint8Array(newParams.buffer, newParams.byteOffset, newParams.byteLength); - - await trainingSession.loadParametersBuffer(newParamsUint8); - const paramsAfterLoad = await trainingSession.getContiguousParameters(); - - // check that the parameters have changed - assertTwoListsUnequal(paramsAfterLoad.data, paramsBefore.data); - assertStrictEquals(paramsAfterLoad.dims[0], paramsLength); - - // check that the parameters have changed to what they should be - for (let i = 0; i < paramsLength; i++) { - // round to the same number of digits (4 decimal places) - assertStrictEquals(paramsAfterLoad.data[i].toFixed(4), constant.toFixed(4)); - } - - return paramsAfterLoad; -}; - -// TESTS - -var testInferenceFunction = async function (ort, options) { - const session = await ort.InferenceSession.create('data/model.onnx', options || {}); - - const dataA = Float32Array.from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); - const dataB = Float32Array.from([10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]); - - const fetches = await session.run({ - a: new ort.Tensor('float32', dataA, [3, 4]), - b: new ort.Tensor('float32', dataB, [4, 3]), - }); - - const c = fetches.c; - - assert(c instanceof ort.Tensor); - assert(c.dims.length === 2 && c.dims[0] === 3 && c.dims[1] === 3); - assert(c.data[0] === 700); - assert(c.data[1] === 800); - assert(c.data[2] === 900); - assert(c.data[3] === 1580); - assert(c.data[4] === 1840); - assert(c.data[5] === 2100); - assert(c.data[6] === 2460); - assert(c.data[7] === 2880); - assert(c.data[8] === 3300); -}; - -var testTrainingFunctionMin = async function (ort, options) { - const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionMinOptions, options); - checkNoEvalModel(trainingSession); - const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]); - const labels = new ort.Tensor('int32', [2, 1], [2]); - const feeds = { 'input-0': input0, labels: labels }; - - // check getParametersSize - const paramsSize = await trainingSession.getParametersSize(); - assertStrictEquals(paramsSize, 397510); - - // check getContiguousParameters - const originalParams = await trainingSession.getContiguousParameters(); - assertStrictEquals(originalParams.dims.length, 1); - assertStrictEquals(originalParams.dims[0], 397510); - assertStrictEquals(originalParams.data[0], -0.025190064683556557); - assertStrictEquals(originalParams.data[2000], -0.034044936299324036); - - await runTrainStepAndCheck(trainingSession, feeds); - - await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, originalParams); -}; - -var testTrainingFunctionAll = async function (ort, options) { - const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionAllOptions, options); - checkEvalModel(trainingSession); - - const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]); - const labels = new ort.Tensor('int32', [2, 1], [2]); - let feeds = { 'input-0': input0, labels: labels }; - - // check getParametersSize - const paramsSize = await trainingSession.getParametersSize(); - assertStrictEquals(paramsSize, 397510); - - // check getContiguousParameters - const originalParams = await trainingSession.getContiguousParameters(); - assertStrictEquals(originalParams.dims.length, 1); - assertStrictEquals(originalParams.dims[0], 397510); - assertStrictEquals(originalParams.data[0], -0.025190064683556557); - assertStrictEquals(originalParams.data[2000], -0.034044936299324036); - - const results = await runTrainStepAndCheck(trainingSession, feeds); - - await trainingSession.runOptimizerStep(feeds); - feeds = { 'input-0': input0, labels: labels }; - // check getContiguousParameters after optimizerStep -- that the parameters have been updated - const optimizedParams = await trainingSession.getContiguousParameters(); - assertTwoListsUnequal(originalParams.data, optimizedParams.data); - - const results2 = await runTrainStepAndCheck(trainingSession, feeds); - - // check that loss decreased after optimizer step and training again - assert(results2['onnx::loss::21273'].data < results['onnx::loss::21273'].data); - - await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, optimizedParams); -}; - -if (typeof module === 'object') { - module.exports = [testInferenceFunction, testTrainingFunctionMin, testTrainingFunctionAll, testTest]; -} diff --git a/js/web/test/training/e2e/data/model.onnx b/js/web/test/training/e2e/data/model.onnx deleted file mode 100644 index 088124bd48624..0000000000000 --- a/js/web/test/training/e2e/data/model.onnx +++ /dev/null @@ -1,16 +0,0 @@ - backend-test:b - -a -bc"MatMultest_matmul_2dZ -a -  - -Z -b -  - -b -c -  - -B \ No newline at end of file diff --git a/js/web/test/training/e2e/karma.conf.js b/js/web/test/training/e2e/karma.conf.js deleted file mode 100644 index 74662b67676f7..0000000000000 --- a/js/web/test/training/e2e/karma.conf.js +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -'use strict'; - -const args = require('minimist')(process.argv.slice(2)); -const SELF_HOST = !!args['self-host']; -const ORT_MAIN = args['ort-main']; -const TEST_MAIN = args['test-main']; -if (typeof TEST_MAIN !== 'string') { - throw new Error('flag --test-main= is required'); -} -const USER_DATA = args['user-data']; -if (typeof USER_DATA !== 'string') { - throw new Error('flag --user-data= is required'); -} - -module.exports = function (config) { - const distPrefix = SELF_HOST ? './node_modules/onnxruntime-web/dist/' : 'http://localhost:8081/dist/'; - config.set({ - frameworks: ['mocha'], - files: [ - { pattern: distPrefix + ORT_MAIN }, - { pattern: './common.js' }, - { pattern: TEST_MAIN }, - { pattern: './node_modules/onnxruntime-web/dist/*.*', included: false, nocache: true }, - { pattern: './data/*', included: false }, - ], - plugins: [require('@chiragrupani/karma-chromium-edge-launcher'), ...config.plugins], - proxies: { - '/model.onnx': '/base/model.onnx', - '/data/': '/base/data/', - }, - client: { captureConsole: true, mocha: { expose: ['body'], timeout: 60000 } }, - reporters: ['mocha'], - captureTimeout: 120000, - reportSlowerThan: 100, - browserDisconnectTimeout: 600000, - browserNoActivityTimeout: 300000, - browserDisconnectTolerance: 0, - browserSocketTimeout: 60000, - hostname: 'localhost', - browsers: [], - customLaunchers: { - Chrome_default: { base: 'ChromeHeadless', chromeDataDir: USER_DATA }, - Chrome_no_threads: { - base: 'ChromeHeadless', - chromeDataDir: USER_DATA, - // TODO: no-thread flags - }, - Edge_default: { base: 'Edge', edgeDataDir: USER_DATA }, - }, - }); -}; diff --git a/js/web/test/training/e2e/package.json b/js/web/test/training/e2e/package.json deleted file mode 100644 index 5f11a27de6dfc..0000000000000 --- a/js/web/test/training/e2e/package.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "devDependencies": { - "@chiragrupani/karma-chromium-edge-launcher": "^2.2.2", - "fs-extra": "^11.1.0", - "globby": "^13.1.3", - "karma": "^6.4.1", - "karma-chrome-launcher": "^3.1.1", - "karma-mocha": "^2.0.1", - "karma-mocha-reporter": "^2.2.5", - "light-server": "^2.9.1", - "minimist": "^1.2.7", - "mocha": "^10.2.0" - } -} diff --git a/js/web/test/training/e2e/run.js b/js/web/test/training/e2e/run.js deleted file mode 100644 index d12bcc7aa66ed..0000000000000 --- a/js/web/test/training/e2e/run.js +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -'use strict'; - -const path = require('path'); -const fs = require('fs-extra'); -const { spawn } = require('child_process'); -const startServer = require('./simple-http-server'); -const minimist = require('minimist'); - -// copy whole folder to out-side of /js/ because we need to test in a folder that no `package.json` file -// exists in its parent folder. -// here we use /build/js/e2e-training/ for the test - -const TEST_E2E_SRC_FOLDER = __dirname; -const JS_ROOT_FOLDER = path.resolve(__dirname, '../../../..'); -const TEST_E2E_RUN_FOLDER = path.resolve(JS_ROOT_FOLDER, '../build/js/e2e-training'); -const NPM_CACHE_FOLDER = path.resolve(TEST_E2E_RUN_FOLDER, '../npm_cache'); -const CHROME_USER_DATA_FOLDER = path.resolve(TEST_E2E_RUN_FOLDER, '../user_data'); -fs.emptyDirSync(TEST_E2E_RUN_FOLDER); -fs.emptyDirSync(NPM_CACHE_FOLDER); -fs.emptyDirSync(CHROME_USER_DATA_FOLDER); -fs.copySync(TEST_E2E_SRC_FOLDER, TEST_E2E_RUN_FOLDER); - -// training data to copy -const ORT_ROOT_FOLDER = path.resolve(JS_ROOT_FOLDER, '..'); -const TRAINING_DATA_FOLDER = path.resolve(ORT_ROOT_FOLDER, 'onnxruntime/test/testdata/training_api'); -const TRAININGDATA_DEST = path.resolve(TEST_E2E_RUN_FOLDER, 'data'); - -// always use a new folder as user-data-dir -let nextUserDataDirId = 0; -function getNextUserDataDir() { - const dir = path.resolve(CHROME_USER_DATA_FOLDER, nextUserDataDirId.toString()); - nextUserDataDirId++; - fs.emptyDirSync(dir); - return dir; -} - -// commandline arguments -const BROWSER = minimist(process.argv.slice(2)).browser || 'Chrome_default'; - -async function main() { - // find packed package - const { globbySync } = await import('globby'); - - const ORT_COMMON_FOLDER = path.resolve(JS_ROOT_FOLDER, 'common'); - const ORT_COMMON_PACKED_FILEPATH_CANDIDATES = globbySync('onnxruntime-common-*.tgz', { cwd: ORT_COMMON_FOLDER }); - - const PACKAGES_TO_INSTALL = []; - - if (ORT_COMMON_PACKED_FILEPATH_CANDIDATES.length === 1) { - PACKAGES_TO_INSTALL.push(path.resolve(ORT_COMMON_FOLDER, ORT_COMMON_PACKED_FILEPATH_CANDIDATES[0])); - } else if (ORT_COMMON_PACKED_FILEPATH_CANDIDATES.length > 1) { - throw new Error('multiple packages found for onnxruntime-common.'); - } - - const ORT_WEB_FOLDER = path.resolve(JS_ROOT_FOLDER, 'web'); - const ORT_WEB_PACKED_FILEPATH_CANDIDATES = globbySync('onnxruntime-web-*.tgz', { cwd: ORT_WEB_FOLDER }); - if (ORT_WEB_PACKED_FILEPATH_CANDIDATES.length !== 1) { - throw new Error('cannot find exactly single package for onnxruntime-web.'); - } - PACKAGES_TO_INSTALL.push(path.resolve(ORT_WEB_FOLDER, ORT_WEB_PACKED_FILEPATH_CANDIDATES[0])); - - // we start here: - - // install dev dependencies - await runInShell(`npm install`); - - // npm install with "--cache" to install packed packages with an empty cache folder - await runInShell(`npm install --cache "${NPM_CACHE_FOLDER}" ${PACKAGES_TO_INSTALL.map((i) => `"${i}"`).join(' ')}`); - - // prepare training data - prepareTrainingDataByCopying(); - - console.log('==============================================================='); - console.log('Running self-hosted tests'); - console.log('==============================================================='); - // test cases with self-host (ort hosted in same origin) - await testAllBrowserCases({ hostInKarma: true }); - - console.log('==============================================================='); - console.log('Running not self-hosted tests'); - console.log('==============================================================='); - // test cases without self-host (ort hosted in cross origin) - const server = startServer(path.join(TEST_E2E_RUN_FOLDER, 'node_modules', 'onnxruntime-web'), 8081); - try { - await testAllBrowserCases({ hostInKarma: false }); - } finally { - // close the server after all tests - await server.close(); - } -} - -async function testAllBrowserCases({ hostInKarma }) { - await runKarma({ hostInKarma, main: './browser-test-wasm.js' }); -} - -async function runKarma({ hostInKarma, main, browser = BROWSER, ortMain = 'ort.training.wasm.min.js' }) { - console.log('==============================================================='); - console.log(`Running karma with the following binary: ${ortMain}`); - console.log('==============================================================='); - const selfHostFlag = hostInKarma ? '--self-host' : ''; - await runInShell( - `npx karma start --single-run --browsers ${browser} ${selfHostFlag} --ort-main=${ - ortMain - } --test-main=${main} --user-data=${getNextUserDataDir()}`, - ); -} - -async function runInShell(cmd) { - console.log('==============================================================='); - console.log(' Running command in shell:'); - console.log(' > ' + cmd); - console.log('==============================================================='); - let complete = false; - const childProcess = spawn(cmd, { shell: true, stdio: 'inherit', cwd: TEST_E2E_RUN_FOLDER }); - childProcess.on('close', function (code) { - if (code !== 0) { - process.exit(code); - } else { - complete = true; - } - }); - while (!complete) { - await delay(100); - } -} - -async function delay(ms) { - return new Promise(function (resolve) { - setTimeout(function () { - resolve(); - }, ms); - }); -} - -function prepareTrainingDataByCopying() { - fs.copySync(TRAINING_DATA_FOLDER, TRAININGDATA_DEST); - console.log(`Copied ${TRAINING_DATA_FOLDER} to ${TRAININGDATA_DEST}`); -} - -main(); diff --git a/js/web/test/training/e2e/simple-http-server.js b/js/web/test/training/e2e/simple-http-server.js deleted file mode 100644 index ef9cced681cc8..0000000000000 --- a/js/web/test/training/e2e/simple-http-server.js +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -'use strict'; - -// this is a simple HTTP server that enables CORS. -// following code is based on https://developer.mozilla.org/en-US/docs/Learn/Server-side/Node_server_without_framework - -const http = require('http'); -const fs = require('fs'); -const path = require('path'); - -const getRequestData = (url, dir) => { - const pathname = new URL(url, 'http://localhost').pathname; - - let filepath; - let mimeType; - if (pathname.startsWith('/test-wasm-path-override/') || pathname.startsWith('/dist/')) { - filepath = path.resolve(dir, pathname.substring(1)); - } else { - return null; - } - - if (filepath.endsWith('.wasm')) { - mimeType = 'application/wasm'; - } else if (filepath.endsWith('.js') || filepath.endsWith('.mjs')) { - mimeType = 'text/javascript'; - } else { - return null; - } - - return [filepath, mimeType]; -}; - -module.exports = function (dir, port) { - const server = http - .createServer(function (request, response) { - const url = request.url.replace(/\n|\r/g, ''); - console.log(`request ${url}`); - - const requestData = getRequestData(url, dir); - if (!request || !requestData) { - response.writeHead(404); - response.end('404'); - } else { - const [filePath, contentType] = requestData; - fs.readFile(path.resolve(dir, filePath), function (error, content) { - if (error) { - if (error.code == 'ENOENT') { - response.writeHead(404); - response.end('404'); - } else { - response.writeHead(500); - response.end('500'); - } - } else { - response.setHeader('access-control-allow-origin', '*'); - response.writeHead(200, { 'Content-Type': contentType }); - response.end(content, 'utf-8'); - } - }); - } - }) - .listen(port); - console.log(`Server running at http://localhost:${port}/`); - return server; -}; diff --git a/js/web/types.d.ts b/js/web/types.d.ts index 735b6a89a2a86..b82248c0c83b8 100644 --- a/js/web/types.d.ts +++ b/js/web/types.d.ts @@ -20,7 +20,3 @@ declare module 'onnxruntime-web/webgl' { declare module 'onnxruntime-web/webgpu' { export * from 'onnxruntime-web'; } - -declare module 'onnxruntime-web/training' { - export * from 'onnxruntime-web'; -} From 2db6b734f5e669593f2868a703c5d212df1bef03 Mon Sep 17 00:00:00 2001 From: Yang Gu Date: Tue, 17 Sep 2024 14:17:10 +0800 Subject: [PATCH 34/39] [js/webgpu] Fix issue to run model demucs (#22074) This is to fix issue #22031 to run model demucs. For conv-transpose, outputPadding.length could be 1, while spatialRank is 2. The fix is to append enough 0s to outputPadding. For conv, the issue is similar. kernelShape.length sometimes could be 1, while inputs[1].dims.length is 4. The fix is also to append enough 0s to kernelShape. --- js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts | 6 ++---- js/web/lib/wasm/jsep/webgpu/ops/conv.ts | 5 ++++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index ece2e1b7c7dcd..236f1b09a6c93 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -44,10 +44,8 @@ const calculateOutputShapeAndPads = ( ) => { const spatialRank = inputShape.length - 2; const updateOutputShape = outputShape.length === 0; - if (outputPadding.length === 0) { - for (let i = 0; i < spatialRank; ++i) { - outputPadding.push(0); - } + if (outputPadding.length < spatialRank) { + outputPadding.push(...Array(spatialRank - outputPadding.length).fill(0)); } const batchSize = inputShape[0]; const outChannels = kernelShape[isChannelLast ? 3 : 1] * group; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index fe37163a0cd08..de9f7bc8885ab 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -103,7 +103,10 @@ const validateInputs = (inputs: readonly TensorView[], attributes: ConvAttribute const getAdjustedConvAttributes = (attributes: T, inputs: readonly TensorView[]): T => { const kernelShape = attributes.kernelShape.slice(); - // if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims + // if kernelShape is not well specified in the attributes, infer it from the weight tensor dims + if (kernelShape.length < inputs[1].dims.length - 2) { + kernelShape.push(...Array(inputs[1].dims.length - 2 - kernelShape.length).fill(0)); + } for (let i = 2; i < inputs[1].dims.length; ++i) { if (kernelShape[i - 2] === 0) { kernelShape[i - 2] = inputs[1].dims[i]; From afd642a194b39138ad891e7bb2c8bca26d37b785 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Tue, 17 Sep 2024 14:17:46 +0800 Subject: [PATCH 35/39] [js/webgpu] Replace array with string in transpose perm (#21930) Perf test data(100000 times) Array: 12.599999997764826ms String: 1.6000000014901161ms Perf test case: ``` const permFunctionBodyArray = (rank: number, input: string): string => { const reverseFunc = []; reverseFunc.push(`fn perm(i: int) -> int { var a: int};`); for (let i = 0; i < rank; ++i) { reverseFunc.push(input); } reverseFunc.push('return a;}'); return reverseFunc.join('\n'); }; const permFunctionBodyString = (rank: number, input: string): string => { let reverseFunc= `fn perm(i: int}) -> int { var a: int;`; for (let i = 0; i < rank; ++i) { reverseFunc+=input; } reverseFunc+='return a;}'; return reverseFunc;//.join('\n'); }; const count = 100000; let start, end console.time('array'); start = performance.now(); for(let i =0 ; i < count; i ++) { permFunctionBodyArray(3, 'input'); } end = performance.now(); console.timeEnd('array'); console.log("Array: "+ (end-start)); console.time('string'); start = performance.now(); for(let i =0 ; i < count; i ++) { permFunctionBodyString(3, 'input'); } end = performance.now(); console.log("String: " +(end-start)); console.timeEnd('string'); ``` ### Description ### Motivation and Context --- js/web/lib/wasm/jsep/webgpu/ops/transpose.ts | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index ee877f8f0c3f2..1fd99d085e0ed 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -26,14 +26,12 @@ const getOutputShape = (inputShape: readonly number[], perm: number[]): readonly ShapeUtil.sortBasedOnPerm(inputShape, getAdjustedPerm(inputShape.length, perm)); const permFunctionBody = (perm: number[], rank: number, input: IndicesHelper, output: IndicesHelper): string => { - const reverseFunc = []; - reverseFunc.push(`fn perm(i: ${output.type.indices}) -> ${input.type.indices} { - var a: ${input.type.indices};`); + let reverseFunc = `fn perm(i: ${output.type.indices}) -> ${input.type.indices} { + var a: ${input.type.indices};`; for (let i = 0; i < rank; ++i) { - reverseFunc.push(input.indicesSet('a', perm[i], `i[${i}]`)); + reverseFunc += input.indicesSet('a', perm[i], `i[${i}]`); } - reverseFunc.push('return a;}'); - return reverseFunc.join('\n'); + return (reverseFunc += 'return a;}'); }; const squeezeShape = (shape: readonly number[], adjustedPerm: number[]): { newShape: number[]; newPerm: number[] } => { From 9786909ab5918c31119aaf39fe8b1c7b666c8962 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Tue, 17 Sep 2024 23:18:47 +0800 Subject: [PATCH 36/39] [WebNN EP] Support QuantizeLinear and DequantizeLinear ops (#22097) --- js/web/docs/webnn-operators.md | 2 + .../core/providers/webnn/builders/helper.h | 1 + .../webnn/builders/impl/conv_op_builder.cc | 4 +- .../impl/dequantizeLinear_op_builder.cc | 81 ---------- .../webnn/builders/impl/gemm_op_builder.cc | 4 +- .../webnn/builders/impl/qdq_op_builder.cc | 152 ++++++++++++++++++ .../providers/webnn/builders/model_builder.cc | 49 ++++-- .../providers/webnn/builders/model_builder.h | 2 +- .../webnn/builders/op_builder_factory.cc | 5 +- .../webnn/builders/op_builder_factory.h | 2 +- 10 files changed, 199 insertions(+), 103 deletions(-) delete mode 100644 onnxruntime/core/providers/webnn/builders/impl/dequantizeLinear_op_builder.cc create mode 100644 onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index 164096b4fda9a..6fd4f9af20432 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -25,6 +25,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | ConvTranspose | ai.onnx(7-10, 11+) | convTranspose2d | ✓ | ✓ | Only supports 3-D or 4-D input and 'W' (weight). WebNN CPU backend only supports default dilations and group | | Cos | ai.onnx(7+) | cos | ✓ | ✓ | | | Div | ai.onnx(7-12, 13, 14+) | div | ✓ | ✓ | | +| DequantizeLinear | ai.onnx(10-12, 13-18, 19-20, 21-22, 23+) | dequantizeLinear | ✗ | ✓ | | | Dropout | ai.onnx(7-9, 10-11, 12, 13-21, 22+) | identity | ✓ | ✓ | Only supports test mode | | Elu | ai.onnx(7+) | elu | ✓ | ✓ | WebNN CPU backend only supports 'alpha' value is 1.0 | | Equal | ai.onnx(7-10, 11-12, 13-18, 19+) | equal | ✓ | ✓ | | @@ -62,6 +63,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | Pad | ai.onnx(7-10, 11-12, 13-17, 18, 19-20, 21+) | pad | ✓ | ✓ | modes == 'wrap' is not supported | | Pow | ai.onnx(7-11, 12, 13-14, 15+) | pow | ✓ | ✓ | | | PRelu | ai.onnx(7-8, 9-15, 16+) | prelu | ✓ | ✓ | WebNN CPU backend restricts the last dimension of input and slope to be same (Chromium issue: https://issues.chromium.org/issues/335517470) | +| QuantizeLinear | ai.onnx(10-12, 13-18, 19-20, 21-22, 23+) | quantizeLinear | ✗ | ✓ | | | Reciprocal | ai.onnx(7-12, 13+) | reciprocal | ✓ | ✓ | | | ReduceL1 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL1 | ✓ | ✓ | Input 'axes' if present should be a constant | | ReduceL2 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL2 | ✓ | ✓ | Input 'axes' if present should be a constant | diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 257fcff9ef50c..dd4a8acc662ef 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -206,6 +206,7 @@ static const InlinedHashMap op_map = { {"Pad", "pad"}, {"Pow", "pow"}, {"PRelu", "prelu"}, + {"QuantizeLinear", "quantizeLinear"}, {"Reciprocal", "reciprocal"}, {"ReduceL1", "reduceL1"}, {"ReduceL2", "reduceL2"}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 35498c2e9b8b7..f03e5b90ff6db 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -311,12 +311,12 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (input_defs.size() >= 3) { x_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); } else { - x_zero_point = model_builder.GetZeroConstant("uint8"); + x_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8); } if (input_defs.size() >= 4) { w_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name()); } else { - w_zero_point = model_builder.GetZeroConstant("uint8"); + w_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8); } output = model_builder.GetBuilder().call("conv2dInteger", input, x_zero_point, filter, w_zero_point, options); diff --git a/onnxruntime/core/providers/webnn/builders/impl/dequantizeLinear_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/dequantizeLinear_op_builder.cc deleted file mode 100644 index 93a12a696cce1..0000000000000 --- a/onnxruntime/core/providers/webnn/builders/impl/dequantizeLinear_op_builder.cc +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Copyright (c) Intel Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/common/safeint.h" -#include "core/optimizer/initializer.h" -#include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" -#include "core/providers/webnn/builders/helper.h" -#include "core/providers/webnn/builders/model_builder.h" -#include "core/providers/webnn/builders/op_builder_factory.h" - -#include "core/providers/webnn/builders/impl/base_op_builder.h" - -namespace onnxruntime { -namespace webnn { - -class DequantizeLinearOpBuilder : public BaseOpBuilder { - // Add operator related. - private: - Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& logger) const override ORT_MUST_USE_RESULT; -}; - -Status DequantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, - const Node& node, - const logging::Logger& logger) const { - const auto& input_defs = node.InputDefs(); - emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); - emscripten::val scale = model_builder.GetOperand(input_defs[1]->Name()); - emscripten::val zero_point = emscripten::val::null(); - if (input_defs.size() == 3) { - zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); - } else { - zero_point = model_builder.GetZeroConstant("uint8"); - } - emscripten::val output; - std::vector input_shape; - std::vector scale_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape"); - ORT_RETURN_IF_NOT(GetShape(*input_defs[1], scale_shape, logger), "Cannot get scale shape"); - NodeAttrHelper helper(node); - int32_t axis = helper.Get("axis", 1); - // axis is valid for input shape greater than 1D. - if (input_shape.size() > 1) { - axis = static_cast(HandleNegativeAxis(axis, input_shape.size())); - } - // Insert ones before and after the axis dimension for broadcasting of 1D scale tensor. - if (1 == scale_shape.size() && 1 < input_shape.size()) { - std::vector target_shape{static_cast(input_shape[axis])}; - target_shape.insert(target_shape.begin(), axis, 1); - target_shape.insert(target_shape.end(), input_shape.size() - axis - 1, 1); - emscripten::val reshape_scale_options = emscripten::val::object(); - reshape_scale_options.set("label", node.Name() + "_reshape_scale"); - scale = model_builder.GetBuilder().call("reshape", - scale, - emscripten::val::array(target_shape), - reshape_scale_options); - emscripten::val reshape_zero_point_options = emscripten::val::object(); - reshape_zero_point_options.set("label", node.Name() + "_reshape_zero_point"); - zero_point = model_builder.GetBuilder().call("reshape", - zero_point, - emscripten::val::array(target_shape), - reshape_zero_point_options); - } - emscripten::val options = emscripten::val::object(); - options.set("label", node.Name()); - output = model_builder.GetBuilder().call("dequantizeLinear", input, scale, zero_point, options); - - model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); - - return Status::OK(); -} - -void CreateDequantizeLinearOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { - op_registrations.builders.push_back(std::make_unique()); - op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); -} - -} // namespace webnn -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index 30e024792ed42..1477530ce1894 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -113,12 +113,12 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (input_defs.size() >= 3) { a_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); } else { - a_zero_point = model_builder.GetZeroConstant("uint8"); + a_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8); } if (input_defs.size() >= 4) { b_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name()); } else { - b_zero_point = model_builder.GetZeroConstant("uint8"); + b_zero_point = model_builder.GetZeroConstant(ONNX_NAMESPACE::TensorProto_DataType_UINT8); } output = model_builder.GetBuilder().call("matmulInteger", a, diff --git a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc new file mode 100644 index 0000000000000..13dee667f6fd9 --- /dev/null +++ b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc @@ -0,0 +1,152 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Intel Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/optimizer/initializer.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/webnn/builders/helper.h" +#include "core/providers/webnn/builders/model_builder.h" +#include "core/providers/webnn/builders/op_builder_factory.h" + +#include "core/providers/webnn/builders/impl/base_op_builder.h" + +namespace onnxruntime { +namespace webnn { + +class QDQOpBuilder : public BaseOpBuilder { + // Add operator related. + private: + Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; +}; + +Status QDQOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const Node& node, + const logging::Logger& logger) const { + const auto& op_type = node.OpType(); + const auto& input_defs = node.InputDefs(); + const auto& output_defs = node.OutputDefs(); + + std::vector input_shape; + std::vector scale_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape"); + ORT_RETURN_IF_NOT(GetShape(*input_defs[1], scale_shape, logger), "Cannot get scale shape"); + int32_t input_type = 0; + int32_t output_type = 0; + int32_t zero_point_type = 0; + ORT_RETURN_IF_NOT(GetType(*input_defs[0], input_type, logger), "Cannot get input data type"); + ORT_RETURN_IF_NOT(GetType(*output_defs[0], output_type, logger), "Cannot get output data type"); + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val scale = model_builder.GetOperand(input_defs[1]->Name()); + + emscripten::val zero_point = emscripten::val::null(); + if (input_defs.size() == 3 && input_defs[2]->Exists()) { + zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); + } else { + // DequantizeLinear: x_zero_point's data type equals to input data type + // QuantizeLinear: x_zero_point's data type equals to output data type + zero_point_type = op_type == "DequantizeLinear" ? input_type : output_type; + zero_point = model_builder.GetZeroConstant(zero_point_type); + } + + emscripten::val output; + NodeAttrHelper helper(node); + int32_t axis = helper.Get("axis", 1); + int32_t block_size = helper.Get("block_size", 0); + // axis is valid for input shape greater than 1D. + if (input_shape.size() > 1) { + axis = static_cast(HandleNegativeAxis(axis, input_shape.size())); + } + // Insert ones before and after the axis dimension for broadcasting of 1D scale tensor. + if (1 == scale_shape.size() && 1 < input_shape.size()) { + std::vector target_shape{static_cast(input_shape[axis])}; + target_shape.insert(target_shape.begin(), axis, 1); + target_shape.insert(target_shape.end(), input_shape.size() - axis - 1, 1); + emscripten::val reshape_scale_options = emscripten::val::object(); + reshape_scale_options.set("label", node.Name() + "_reshape_scale"); + scale = model_builder.GetBuilder().call("reshape", + scale, + emscripten::val::array(target_shape), + reshape_scale_options); + emscripten::val reshape_zero_point_options = emscripten::val::object(); + reshape_zero_point_options.set("label", node.Name() + "_reshape_zero_point"); + zero_point = model_builder.GetBuilder().call("reshape", + zero_point, + emscripten::val::array(target_shape), + reshape_zero_point_options); + } + + // If block_size is specified, we need to expand the scale and zero_point tensors. + if (block_size > 1) { + emscripten::val concat_scale_inputs = emscripten::val::array(); + emscripten::val concat_zero_point_inputs = emscripten::val::array(); + for (int i = 0; i < block_size; i++) { + concat_scale_inputs.call("push", scale); + concat_zero_point_inputs.call("push", zero_point); + } + + emscripten::val concat_scale_options = emscripten::val::object(); + concat_scale_options.set("label", node.Name() + "_concat_scale"); + scale = model_builder.GetBuilder().call("concat", concat_scale_inputs, axis, concat_scale_options); + + emscripten::val concat_zero_point_options = emscripten::val::object(); + concat_zero_point_options.set("label", node.Name() + "_concat_zero_point"); + zero_point = model_builder.GetBuilder().call( + "concat", concat_zero_point_inputs, axis, concat_zero_point_options); + } + + emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); + std::string webnn_op_type; + ORT_RETURN_IF_NOT(GetWebNNOpType(op_type, webnn_op_type), "Cannot get WebNN op type"); + output = model_builder.GetBuilder().call(webnn_op_type.c_str(), input, scale, zero_point, options); + + model_builder.AddOperand(output_defs[0]->Name(), std::move(output)); + + return Status::OK(); +} + +bool QDQOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + int32_t input0_type = 0; // input data type + int32_t input1_type = 0; // x_scale data type + int32_t input2_type = 0; // x_zero_point data type + bool has_input2 = input_defs.size() > 2 && input_defs[2]->Exists(); + + if (!GetType(*input_defs[0], input0_type, logger) || + !GetType(*input_defs[1], input1_type, logger) || + (has_input2 && !GetType(*input_defs[2], input2_type, logger))) { + return false; + } + + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "x", logger) && + IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "scale", "x_scale", logger) && + (!has_input2 || IsDataTypeSupportedByOp(op_type, input2_type, wnn_limits, "zeroPoint", "x_zero_point", logger)); +} + +void CreateQDQOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + if (op_registrations.op_builder_map.count(op_type) > 0) + return; + + static std::vector op_types = + { + "DequantizeLinear", + "QuantizeLinear", + }; + + op_registrations.builders.push_back(std::make_unique()); + for (const auto& type : op_types) { + op_registrations.op_builder_map.emplace(type, op_registrations.builders.back().get()); + } +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index b58bf8233692e..f9f8264b234bb 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -354,27 +354,48 @@ void ModelBuilder::AddOperand(const std::string& name, const emscripten::val& op // https://webmachinelearning.github.io/webnn/#api-mlgraphbuilder-constant-value-type // BTW, the spec is discussing if the builer.constant(value, type) should be dropped at // https://github.com/webmachinelearning/webnn/issues/475. Fix me according to the spec decision. -const emscripten::val& ModelBuilder::GetZeroConstant(const std::string& data_type) { - std::string name = "webnn_zero_constant_" + data_type; +const emscripten::val& ModelBuilder::GetZeroConstant(const int32_t& data_type) { + std::string name = "webnn_zero_constant_" + std::to_string(data_type); // If the operand does not exist, create it. if (wnn_operands_.find(name) == wnn_operands_.end()) { emscripten::val desc = emscripten::val::object(); emscripten::val dims = emscripten::val::array(); desc.set("dimensions", dims); emscripten::val zero_buffer = emscripten::val::undefined(); - if (data_type == "uint8") { - if (!SetWebnnDataType(desc, ONNX_NAMESPACE::TensorProto_DataType_UINT8)) { - ORT_THROW("Unsupported data type: " + data_type); - } - zero_buffer = emscripten::val::global("Uint8Array").new_(1); - } else if (data_type == "float32") { - if (!SetWebnnDataType(desc, ONNX_NAMESPACE::TensorProto_DataType_FLOAT)) { - ORT_THROW("Unsupported data type: " + data_type); - } - zero_buffer = emscripten::val::global("Float32Array").new_(1); - } else { - ORT_THROW("Unsupported data type: " + data_type); + if (!SetWebnnDataType(desc, data_type)) { + ORT_THROW("Unsupported data type: " + std::to_string(data_type)); } + + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: + zero_buffer = emscripten::val::global("Uint8Array").new_(1); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + zero_buffer = emscripten::val::global("Int8Array").new_(1); + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + zero_buffer = emscripten::val::global("Uint16Array").new_(1); + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + zero_buffer = emscripten::val::global("Float32Array").new_(1); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + zero_buffer = emscripten::val::global("Int32Array").new_(1); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + zero_buffer = emscripten::val::global("BigInt64Array").new_(1); + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: + zero_buffer = emscripten::val::global("Uint32Array").new_(1); + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: + zero_buffer = emscripten::val::global("BigUint64Array").new_(1); + break; + default: + break; + } + emscripten::val zero_constant = wnn_builder_.call("constant", desc, zero_buffer); wnn_operands_.insert(std::make_pair(name, zero_constant)); } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index 256337baeba7e..13937933a0a9c 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -38,7 +38,7 @@ class ModelBuilder { const emscripten::val& GetOpSupportLimits() const { return wnn_limits_; } void AddOperand(const std::string& name, const emscripten::val& operand); - const emscripten::val& GetZeroConstant(const std::string& data_type); + const emscripten::val& GetZeroConstant(const int32_t& data_type); // Use the buffers to persist WebNN allocated data like transposed weight. // It ensures the validity during inference session. std::vector> mem_persist_buffers_; diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 3dc1c7966ae41..93a2b232a7d51 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -84,9 +84,10 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateDropoutOpBuilder("Dropout", op_registrations); } - { // Quantize/Dequantize + { // DequantizeLinear/QuantizeLinear/DynamicQuantizeLinear + CreateQDQOpBuilder("DequantizeLinear", op_registrations); + CreateQDQOpBuilder("QuantizeLinear", op_registrations); CreateDynamicQuantizeLinearOpBuilder("DynamicQuantizeLinear", op_registrations); - CreateDequantizeLinearOpBuilder("DequantizeLinear", op_registrations); } { // Expand diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h index b66218cc9a902..61fe6d936e9d1 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.h @@ -28,7 +28,6 @@ void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_ void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateDropoutOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateDynamicQuantizeLinearOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); -void CreateDequantizeLinearOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateFlattenOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); @@ -39,6 +38,7 @@ void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& o void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreatePoolOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateQDQOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateReshapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); From 6dcdc70aa7666325af57acd4ed2dd81353856e2e Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Tue, 17 Sep 2024 09:52:28 -0700 Subject: [PATCH 37/39] [TensorRT EP] Add supportsModelV2 (#22081) `supportsModel` is deprecated in TRT 10.1. Add `supportsModelV2 `but still keep `supportsModel` as we still need to support TRT 8.6 where `supportsModelV2 ` is not supported. --- .../tensorrt/tensorrt_execution_provider.cc | 45 ++++++++++++++++--- 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index a7daa98902afb..c3d010ac9fcd7 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2255,16 +2255,29 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect network_flags |= fp16_enable_ || int8_enable_ ? 0 : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); #endif network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); - auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); + auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); -#if defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4996) -#endif + +#if (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 1) || NV_TENSORRT_MAJOR > 10 + auto is_model_supported = trt_parser->supportsModelV2(string_buf.data(), string_buf.size(), model_path_); + + // Note: Calling getNbSubgraphs or getSubgraphNodes before calling supportsModelV2 results in undefined behavior. + auto num_subgraphs = trt_parser->getNbSubgraphs(); + parser_nodes_list.reserve(num_subgraphs); + + for (int64_t i = 0; i < num_subgraphs; ++i) { + int64_t subgraph_len = 0; + int64_t* nodes = trt_parser->getSubgraphNodes(i, subgraph_len); + parser_nodes_list.emplace_back(); + parser_nodes_list.back().first.reserve(subgraph_len); + for (int64_t j = 0; j < subgraph_len; ++j) { + parser_nodes_list.back().first.push_back(nodes[j]); + } + parser_nodes_list.back().second = is_model_supported ? true : false; + } +#else trt_parser->supportsModel(string_buf.data(), string_buf.size(), parser_nodes_list, model_path_); -#if defined(_MSC_VER) -#pragma warning(pop) #endif SubGraphCollection_t next_nodes_list; @@ -2272,6 +2285,24 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect next_nodes_list = GetSupportedList(parser_nodes_list, iterations, max_iterations, *graph_viewer, early_termination); for (size_t i = 0, end = next_nodes_list.size(); i < end; ++i) { for (size_t j = 0, end = next_nodes_list[i].first.size(); j < end; ++j) { + /* + * Convert the supported node list returning from onnx-tensorrt parser to the node list recognized by ORT TRT. + * + * TRT EP reconstructs the graph based on the nodes in group.first and feeds this graph (converts to model proto and to string buffer) to onnx-tensorrt parser. + * The node index in the list returning from onnx-tensorrt parser might not be the same as the node index in group.first. Therefore, TRT EP needs a node index mapping table here. + * + * The order of iterating the nodes in group.first and calling graph_build.AddNode() determines the node order in the newly constructed graph (see Graph::AllocateNode() in graph.cc), + * however, once the graph is converted to model proto, the node proto order in model proto (ex: onnx-tensorrt calls model.graph().node() to iterate NodeProto in ModelProto) is decided by topo sort. + * + * The topo sort list (i.e. subgraph_node_index) acts as the node index mapping table: + * subgraph_node_index[node index from onnx-tensorrt parser] = index in group.first + * + * In the past, TRT EP uses ORT's default reversed DFS topo sort which might end up with the sorting result not sequence of 0, 1, ... n-1, ex: the subgraph_node_index = [0,2,1,3,4]. + * With the change of using ORT's priority-based topo sort (node with lower node index outputs first) the sorting result is the sequence of 0, 1, ... n-1 for most of the cases, + * therefore subgraph_node_index as a mapping table is not needed anymore. + * + * TODO: Remove the subgraph_node_index + */ next_nodes_list[i].first[j] = group.first[subgraph_node_index[next_nodes_list[i].first[j]]]; } nodes_list_output.push_back(next_nodes_list[i]); From fa68ae2deffbf3d509d93b42ed48c092f38512fb Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Tue, 17 Sep 2024 10:07:30 -0700 Subject: [PATCH 38/39] Update pool to MacOS-13 (#17361) ### Description See https://github.com/microsoft/onnxruntime-extensions/pull/476 and https://github.com/actions/runner-images/issues/7671 ### Motivation and Context ### Current issue - [ ] For default xcode 15.2, that come with the MacOS-13, We Need to update the boost container header boost/container_hash/hash.hpp version to pass the build - [x] For xcode 14.2 The Build passed but the `Run React Native Detox Android e2e Test` Failed. Possible flaky test, https://github.com/microsoft/onnxruntime/pull/21969 - [x] For xcode 14.3.1 We encountered following issue in `Build React Native Detox iOS e2e Tests` ``` ld: file not found: /Applications/Xcode_14.3.1.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/arc/libarclite_iphonesimulator.a clang: error: linker command failed with exit code 1 (use -v to see invocation) ``` Applied following code to the eof in both ios/Podfile and fixed the issue ``` post_install do |installer| installer.generated_projects.each do |project| project.targets.each do |target| target.build_configurations.each do |config| config.build_settings['IPHONEOS_DEPLOYMENT_TARGET'] = '13.0' end end end end ``` - [x] https://github.com/facebook/react-native/issues/32483 Applying changes to ios/Pofile ``` pre_install do |installer| # Custom pre-install script or commands puts "Running pre-install script..." # Recommended fix for https://github.com/facebook/react-native/issues/32483 # from https://github.com/facebook/react-native/issues/32483#issuecomment-966784501 system("sed -i '' 's/typedef uint8_t clockid_t;//' \"${SRCROOT}/Pods/RCT-Folly/folly/portability/Time.h\"") end ``` - [ ] Detox environment setting up exceeded time out of 120000ms during iso e2e test ### dependent - [x] https://github.com/microsoft/onnxruntime/pull/21159 --------- Co-authored-by: Changming Sun --- js/react_native/e2e/.detoxrc.js | 3 ++- .../project.pbxproj | 12 ++++++----- js/react_native/e2e/ios/Podfile | 18 +++++++++++++++++ .../project.pbxproj | 15 ++++++++------ js/react_native/ios/Podfile | 19 ++++++++++++++++++ .../project.pbxproj | 20 +++++++++---------- ...t_full_apple_framework_build_settings.json | 2 +- ...ndroid-x86_64-crosscompile-ci-pipeline.yml | 12 +++++------ .../mac-coreml-ci-pipeline.yml | 6 ++---- .../azure-pipelines/mac-ios-ci-pipeline.yml | 5 ++--- .../mac-ios-packaging-pipeline.yml | 2 +- .../nodejs/templates/test_macos.yml | 2 +- .../nuget/templates/test_macos.yml | 2 +- .../azure-pipelines/post-merge-jobs.yml | 8 ++------ .../py-package-test-pipeline.yml | 2 +- .../templates/android-java-api-aar-test.yml | 8 +++----- .../azure-pipelines/templates/c-api-cpu.yml | 8 +++----- .../templates/mac-cpu-packaging-pipeline.yml | 4 ++-- .../templates/mac-cpu-packing-jobs.yml | 6 ++---- .../templates/py-packaging-stage.yml | 6 ++---- .../templates/react-native-ci.yml | 20 ++++++++++--------- .../stages/mac-ios-packaging-build-stage.yml | 6 ++---- .../templates/use-xcode-version.yml | 2 +- 23 files changed, 108 insertions(+), 80 deletions(-) diff --git a/js/react_native/e2e/.detoxrc.js b/js/react_native/e2e/.detoxrc.js index e24833a1d09c9..0792c3d528585 100644 --- a/js/react_native/e2e/.detoxrc.js +++ b/js/react_native/e2e/.detoxrc.js @@ -38,7 +38,8 @@ module.exports = { simulator: { type: 'ios.simulator', device: { - type: 'iPhone 13', + type: 'iPhone 14', + os: 'iOS 16.4', }, }, attached: { diff --git a/js/react_native/e2e/ios/OnnxruntimeModuleExample.xcodeproj/project.pbxproj b/js/react_native/e2e/ios/OnnxruntimeModuleExample.xcodeproj/project.pbxproj index f2e4a791eae57..b7c38547d67e1 100644 --- a/js/react_native/e2e/ios/OnnxruntimeModuleExample.xcodeproj/project.pbxproj +++ b/js/react_native/e2e/ios/OnnxruntimeModuleExample.xcodeproj/project.pbxproj @@ -3,7 +3,7 @@ archiveVersion = 1; classes = { }; - objectVersion = 46; + objectVersion = 54; objects = { /* Begin PBXBuildFile section */ @@ -322,6 +322,7 @@ DEVELOPMENT_TEAM = ""; ENABLE_BITCODE = NO; INFOPLIST_FILE = OnnxruntimeModuleExample/Info.plist; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; OTHER_LDFLAGS = ( "$(inherited)", @@ -344,6 +345,7 @@ CURRENT_PROJECT_VERSION = 1; DEVELOPMENT_TEAM = ""; INFOPLIST_FILE = OnnxruntimeModuleExample/Info.plist; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; OTHER_LDFLAGS = ( "$(inherited)", @@ -405,7 +407,7 @@ GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; "HEADER_SEARCH_PATHS[arch=*]" = ""; - IPHONEOS_DEPLOYMENT_TARGET = 14.4; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LD_RUNPATH_SEARCH_PATHS = "/usr/lib/swift $(inherited)"; LIBRARY_SEARCH_PATHS = ( "\"$(TOOLCHAIN_DIR)/usr/lib/swift/$(PLATFORM_NAME)\"", @@ -458,7 +460,7 @@ GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 14.4; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LD_RUNPATH_SEARCH_PATHS = "/usr/lib/swift $(inherited)"; LIBRARY_SEARCH_PATHS = ( "\"$(TOOLCHAIN_DIR)/usr/lib/swift/$(PLATFORM_NAME)\"", @@ -486,7 +488,7 @@ DEBUG_INFORMATION_FORMAT = dwarf; GCC_C_LANGUAGE_STANDARD = gnu11; GENERATE_INFOPLIST_FILE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 15.2; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; MARKETING_VERSION = 1.0; MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; MTL_FAST_MATH = YES; @@ -514,7 +516,7 @@ DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; GCC_C_LANGUAGE_STANDARD = gnu11; GENERATE_INFOPLIST_FILE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 15.2; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; MARKETING_VERSION = 1.0; MTL_FAST_MATH = YES; PRODUCT_BUNDLE_IDENTIFIER = ai.onnxruntime.reactnative.OnnxruntimeModuleExampleUITests; diff --git a/js/react_native/e2e/ios/Podfile b/js/react_native/e2e/ios/Podfile index d31a6f50221fb..0272c23092838 100644 --- a/js/react_native/e2e/ios/Podfile +++ b/js/react_native/e2e/ios/Podfile @@ -3,6 +3,15 @@ require_relative '../node_modules/@react-native-community/cli-platform-ios/nativ platform :ios, '13.0' +pre_install do |installer| + # Custom pre-install script or commands + puts "Running pre-install script..." + + # Recommended fix for https://github.com/facebook/react-native/issues/32483 + # from https://github.com/facebook/react-native/issues/32483#issuecomment-966784501 + system("sed -i '' 's/typedef uint8_t clockid_t;//' \"./Pods/RCT-Folly/folly/portability/Time.h\"") +end + target 'OnnxruntimeModuleExample' do config = use_native_modules! @@ -19,3 +28,12 @@ target 'OnnxruntimeModuleExample' do inherit! :search_paths end +post_install do |installer| + installer.generated_projects.each do |project| + project.targets.each do |target| + target.build_configurations.each do |config| + config.build_settings['IPHONEOS_DEPLOYMENT_TARGET'] = '13.0' + end + end + end +end \ No newline at end of file diff --git a/js/react_native/ios/OnnxruntimeModule.xcodeproj/project.pbxproj b/js/react_native/ios/OnnxruntimeModule.xcodeproj/project.pbxproj index 2a093b2b89c95..f11bc73687098 100644 --- a/js/react_native/ios/OnnxruntimeModule.xcodeproj/project.pbxproj +++ b/js/react_native/ios/OnnxruntimeModule.xcodeproj/project.pbxproj @@ -3,7 +3,7 @@ archiveVersion = 1; classes = { }; - objectVersion = 46; + objectVersion = 54; objects = { /* Begin PBXBuildFile section */ @@ -189,7 +189,8 @@ 58B511D31A9E6C8500147676 /* Project object */ = { isa = PBXProject; attributes = { - LastUpgradeCheck = 0920; + BuildIndependentTargetsInParallel = YES; + LastUpgradeCheck = 1540; ORGANIZATIONNAME = Facebook; TargetAttributes = { 58B511DA1A9E6C8500147676 = { @@ -420,7 +421,7 @@ GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; "HEADER_SEARCH_PATHS[arch=*]" = ""; - IPHONEOS_DEPLOYMENT_TARGET = 9.0; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LIBRARY_SEARCH_PATHS = ""; MTL_ENABLE_DEBUG_INFO = YES; ONLY_ACTIVE_ARCH = YES; @@ -465,7 +466,7 @@ GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; "HEADER_SEARCH_PATHS[arch=*]" = ""; - IPHONEOS_DEPLOYMENT_TARGET = 9.0; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LIBRARY_SEARCH_PATHS = ""; MTL_ENABLE_DEBUG_INFO = NO; SDKROOT = iphoneos; @@ -484,6 +485,7 @@ "$(SRCROOT)/../../react-native/React/**", ); "HEADER_SEARCH_PATHS[arch=*]" = "\"$(PODS_ROOT)/onnxruntime/onnxruntime.framework/Headers\""; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LIBRARY_SEARCH_PATHS = ( "$(inherited)", "$(PROJECT_DIR)", @@ -508,6 +510,7 @@ "$(SRCROOT)/../../react-native/React/**", ); "HEADER_SEARCH_PATHS[arch=*]" = "\"$(PODS_ROOT)/onnxruntime/onnxruntime.framework/Headers\""; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LIBRARY_SEARCH_PATHS = ( "$(inherited)", "$(PROJECT_DIR)", @@ -582,7 +585,7 @@ "$(PROJECT_DIR)", ); INFOPLIST_FILE = OnnxruntimeModuleTest/Info.plist; - IPHONEOS_DEPLOYMENT_TARGET = 14.4; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks @loader_path/Frameworks"; LIBRARY_SEARCH_PATHS = "$(inherited)"; MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; @@ -656,7 +659,7 @@ "$(PROJECT_DIR)", ); INFOPLIST_FILE = OnnxruntimeModuleTest/Info.plist; - IPHONEOS_DEPLOYMENT_TARGET = 14.4; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks @loader_path/Frameworks"; LIBRARY_SEARCH_PATHS = "$(inherited)"; MTL_FAST_MATH = YES; diff --git a/js/react_native/ios/Podfile b/js/react_native/ios/Podfile index 53c83948672a3..ad8a5ce3b2d5f 100644 --- a/js/react_native/ios/Podfile +++ b/js/react_native/ios/Podfile @@ -3,6 +3,15 @@ require_relative '../node_modules/@react-native-community/cli-platform-ios/nativ platform :ios, '13.0' +pre_install do |installer| + # Custom pre-install script or commands + puts "Running pre-install script..." + + # Recommended fix for https://github.com/facebook/react-native/issues/32483 + # from https://github.com/facebook/react-native/issues/32483#issuecomment-966784501 + system("sed -i '' 's/typedef uint8_t clockid_t;//' \"./Pods/RCT-Folly/folly/portability/Time.h\"") +end + def shared config = use_native_modules! @@ -29,3 +38,13 @@ end target 'OnnxruntimeModuleTest' do shared end + +post_install do |installer| + installer.generated_projects.each do |project| + project.targets.each do |target| + target.build_configurations.each do |config| + config.build_settings['IPHONEOS_DEPLOYMENT_TARGET'] = '13.0' + end + end + end +end \ No newline at end of file diff --git a/onnxruntime/test/platform/apple/apple_package_test/apple_package_test.xcodeproj/project.pbxproj b/onnxruntime/test/platform/apple/apple_package_test/apple_package_test.xcodeproj/project.pbxproj index eb7345be3770b..e9536bb59bbae 100644 --- a/onnxruntime/test/platform/apple/apple_package_test/apple_package_test.xcodeproj/project.pbxproj +++ b/onnxruntime/test/platform/apple/apple_package_test/apple_package_test.xcodeproj/project.pbxproj @@ -456,7 +456,7 @@ GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; IPHONEOS_DEPLOYMENT_TARGET = 13.0; - MACOSX_DEPLOYMENT_TARGET = 11.0; + MACOSX_DEPLOYMENT_TARGET = 13.3; MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; MTL_FAST_MATH = YES; ONLY_ACTIVE_ARCH = YES; @@ -510,7 +510,7 @@ GCC_WARN_UNUSED_FUNCTION = YES; GCC_WARN_UNUSED_VARIABLE = YES; IPHONEOS_DEPLOYMENT_TARGET = 13.0; - MACOSX_DEPLOYMENT_TARGET = 11.0; + MACOSX_DEPLOYMENT_TARGET = 13.3; MTL_ENABLE_DEBUG_INFO = NO; MTL_FAST_MATH = YES; SDKROOT = iphoneos; @@ -527,7 +527,7 @@ CODE_SIGNING_STYLE = Automatic; CODE_SIGN_ENTITLEMENTS = ios_package_test/ios_package_test.entitlements; INFOPLIST_FILE = ios_package_test/Info.plist; - IPHONEOS_DEPLOYMENT_TARGET = 14.0; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LD_RUNPATH_SEARCH_PATHS = ( "$(inherited)", "@executable_path/Frameworks", @@ -550,7 +550,7 @@ CODE_SIGNING_STYLE = Automatic; CODE_SIGN_ENTITLEMENTS = ios_package_test/ios_package_test.entitlements; INFOPLIST_FILE = ios_package_test/Info.plist; - IPHONEOS_DEPLOYMENT_TARGET = 14.0; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LD_RUNPATH_SEARCH_PATHS = ( "$(inherited)", "@executable_path/Frameworks", @@ -571,7 +571,7 @@ CODE_SIGN_STYLE = Automatic; CURRENT_PROJECT_VERSION = 1; GENERATE_INFOPLIST_FILE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 14.0; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LD_RUNPATH_SEARCH_PATHS = ( "$(inherited)", "@executable_path/Frameworks", @@ -593,7 +593,7 @@ CODE_SIGN_STYLE = Automatic; CURRENT_PROJECT_VERSION = 1; GENERATE_INFOPLIST_FILE = YES; - IPHONEOS_DEPLOYMENT_TARGET = 14.0; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LD_RUNPATH_SEARCH_PATHS = ( "$(inherited)", "@executable_path/Frameworks", @@ -631,7 +631,7 @@ "@executable_path/../Frameworks", ); LOCALIZATION_PREFERS_STRING_CATALOGS = YES; - MACOSX_DEPLOYMENT_TARGET = 11.0; + MACOSX_DEPLOYMENT_TARGET = 13.3; MARKETING_VERSION = 1.0; PRODUCT_BUNDLE_IDENTIFIER = "ai.onnxruntime.tests.macos-package-test"; PRODUCT_NAME = "$(TARGET_NAME)"; @@ -663,7 +663,7 @@ "@executable_path/../Frameworks", ); LOCALIZATION_PREFERS_STRING_CATALOGS = YES; - MACOSX_DEPLOYMENT_TARGET = 11.0; + MACOSX_DEPLOYMENT_TARGET = 13.3; MARKETING_VERSION = 1.0; PRODUCT_BUNDLE_IDENTIFIER = "ai.onnxruntime.tests.macos-package-test"; PRODUCT_NAME = "$(TARGET_NAME)"; @@ -684,7 +684,7 @@ GCC_C_LANGUAGE_STANDARD = gnu17; GENERATE_INFOPLIST_FILE = YES; LOCALIZATION_PREFERS_STRING_CATALOGS = YES; - MACOSX_DEPLOYMENT_TARGET = 11.0; + MACOSX_DEPLOYMENT_TARGET = 13.3; MARKETING_VERSION = 1.0; PRODUCT_BUNDLE_IDENTIFIER = "ai.onnxruntime.tests.macos-package-testUITests"; PRODUCT_NAME = "$(TARGET_NAME)"; @@ -706,7 +706,7 @@ GCC_C_LANGUAGE_STANDARD = gnu17; GENERATE_INFOPLIST_FILE = YES; LOCALIZATION_PREFERS_STRING_CATALOGS = YES; - MACOSX_DEPLOYMENT_TARGET = 11.0; + MACOSX_DEPLOYMENT_TARGET = 13.3; MARKETING_VERSION = 1.0; PRODUCT_BUNDLE_IDENTIFIER = "ai.onnxruntime.tests.macos-package-testUITests"; PRODUCT_NAME = "$(TARGET_NAME)"; diff --git a/tools/ci_build/github/apple/default_full_apple_framework_build_settings.json b/tools/ci_build/github/apple/default_full_apple_framework_build_settings.json index 84d7e355ed5b4..f2d034dbfc000 100644 --- a/tools/ci_build/github/apple/default_full_apple_framework_build_settings.json +++ b/tools/ci_build/github/apple/default_full_apple_framework_build_settings.json @@ -24,7 +24,7 @@ ], "macosx": [ "--macos=MacOSX", - "--apple_deploy_target=11.0" + "--apple_deploy_target=13.3" ], "iphoneos": [ "--ios", diff --git a/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml index 41ff365a65d49..a0ae0441f2824 100644 --- a/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml @@ -70,9 +70,9 @@ stages: versionSpec: $(pythonVersion) - task: JavaToolInstaller@0 - displayName: Use jdk 11 + displayName: Use jdk 17 inputs: - versionSpec: '11' + versionSpec: '17' jdkArchitectureOption: 'x64' jdkSourceOption: 'PreInstalled' @@ -131,9 +131,9 @@ stages: versionSpec: $(pythonVersion) - task: JavaToolInstaller@0 - displayName: Use jdk 11 + displayName: Use jdk 17 inputs: - versionSpec: '11' + versionSpec: '17' jdkArchitectureOption: 'x64' jdkSourceOption: 'PreInstalled' @@ -195,9 +195,9 @@ stages: versionSpec: $(pythonVersion) - task: JavaToolInstaller@0 - displayName: Use jdk 11 + displayName: Use jdk 17 inputs: - versionSpec: '11' + versionSpec: '17' jdkArchitectureOption: 'x64' jdkSourceOption: 'PreInstalled' diff --git a/tools/ci_build/github/azure-pipelines/mac-coreml-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-coreml-ci-pipeline.yml index c16adc6221ed0..a1059bb3c314e 100644 --- a/tools/ci_build/github/azure-pipelines/mac-coreml-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/mac-coreml-ci-pipeline.yml @@ -33,9 +33,9 @@ jobs: workspace: clean: all pool: - vmImage: 'macOS-latest' + vmImage: 'macOS-13' variables: - MACOSX_DEPLOYMENT_TARGET: '11.0' + MACOSX_DEPLOYMENT_TARGET: '13.3' TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] CCACHE_DIR: '$(Pipeline.Workspace)/ccache' timeoutInMinutes: 120 @@ -44,8 +44,6 @@ jobs: displayName: Install coreutils and ninja - template: templates/use-xcode-version.yml - parameters: - xcodeVersion: 14.2 - template: templates/mac-build-step-with-cache.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml index 74211bc5dbd7c..c61beb63b8b40 100644 --- a/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml @@ -31,7 +31,7 @@ pr: jobs: - job: iOS_CI_on_Mac pool: - vmImage: 'macOS-latest' + vmImage: 'macOS-13' variables: PROTO_CACHE_DIR: $(Pipeline.Workspace)/proto_ccache ORT_CACHE_DIR: $(Pipeline.Workspace)/ort_ccache @@ -39,8 +39,7 @@ jobs: timeoutInMinutes: 150 steps: - template: templates/use-xcode-version.yml - parameters: - xcodeVersion: 14.2 + - template: templates/mac-build-step-with-cache.yml parameters: WithCache: true diff --git a/tools/ci_build/github/azure-pipelines/mac-ios-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-ios-packaging-pipeline.yml index abd1004a830e0..c3176ec54fd79 100644 --- a/tools/ci_build/github/azure-pipelines/mac-ios-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/mac-ios-packaging-pipeline.yml @@ -54,7 +54,7 @@ stages: displayName: "Set common variables" pool: - vmImage: "macOS-latest" + vmImage: "macOS-13" timeoutInMinutes: 5 diff --git a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml index 53923e0b4432a..4518a168879a2 100644 --- a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml +++ b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml @@ -11,7 +11,7 @@ stages: clean: all timeoutInMinutes: 120 pool: - vmImage: 'macOS-latest' + vmImage: 'macOS-13' variables: - name: OnnxRuntimeBuildDirectory diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/test_macos.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/test_macos.yml index c977e17aada9d..07d21333270a8 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/test_macos.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/test_macos.yml @@ -11,7 +11,7 @@ stages: workspace: clean: all pool: - vmImage: 'macOS-latest' + vmImage: 'macOS-13' variables: - name: OnnxRuntimeBuildDirectory diff --git a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml index e13f1c20b37ce..3853bdbd1eb88 100644 --- a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml @@ -425,7 +425,7 @@ stages: - job: IosDynamicFramework timeoutInMinutes: 120 pool: - vmImage: "macOS-latest" + vmImage: "macOS-13" steps: - task: UsePythonVersion@0 @@ -435,8 +435,6 @@ stages: architecture: "x64" - template: templates/use-xcode-version.yml - parameters: - xcodeVersion: 14.2 - script: | pip install -r tools/ci_build/github/apple/ios_packaging/requirements.txt @@ -463,7 +461,7 @@ stages: - job: IosMinimalTrainingBuild timeoutInMinutes: 120 pool: - vmImage: "macOS-latest" + vmImage: "macOS-13" steps: - task: UsePythonVersion@0 @@ -473,8 +471,6 @@ stages: architecture: "x64" - template: templates/use-xcode-version.yml - parameters: - xcodeVersion: 14.2 - script: | pip install -r tools/ci_build/github/apple/ios_packaging/requirements.txt diff --git a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml index dced61670e34a..de2677ebc6594 100644 --- a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml @@ -29,7 +29,7 @@ stages: parameters: job_name: Test_MAC_Wheels machine_pool: - vmImage: 'macOS-latest' + vmImage: 'macOS-13' itemPattern: '*/*mac*x86_64.whl' - template: templates/py-package-smoking-test.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml index 6971488478ed0..7fa548240f865 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml @@ -19,9 +19,7 @@ jobs: workspace: clean: all pool: - # We need macOS-12 to run the Android emulator for now. - # https://github.com/actions/runner-images/issues/7671 - vmImage: 'macOS-12' + vmImage: 'macOS-13' variables: - name: runCodesignValidationInjection value: false @@ -39,9 +37,9 @@ jobs: targetPath: '$(Build.BinariesDirectory)/final-android-aar' - task: JavaToolInstaller@0 - displayName: Use jdk 11 + displayName: Use jdk 17 inputs: - versionSpec: '11' + versionSpec: '17' jdkArchitectureOption: 'x64' jdkSourceOption: 'PreInstalled' diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 8b4fe66465bb1..3e90a401d4deb 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -83,14 +83,12 @@ stages: workspace: clean: all pool: - vmImage: 'macOS-latest' + vmImage: 'macOS-13' timeoutInMinutes: 300 steps: - template: set-version-number-variables-step.yml - template: use-xcode-version.yml - parameters: - xcodeVersion: 14.2 - template: download-deps.yml @@ -782,7 +780,7 @@ stages: - template: ../nuget/templates/test_macos.yml parameters: - AgentPool : macOS-latest + AgentPool : macOS-13 ArtifactSuffix: 'CPU' - template: ../nodejs/templates/test_win.yml @@ -818,4 +816,4 @@ stages: OS: MacOS BuildId: ${{ parameters.BuildId }} SpecificArtifact: ${{ parameters.SpecificArtifact }} - PoolName: 'macOS-latest' + PoolName: 'macOS-13' diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml index 945fbb7c4a094..1b1080034948e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-pipeline.yml @@ -69,9 +69,9 @@ stages: - job: MacOS_C_API_Package_Publish pool: ${{ if eq(parameters.DoESRP, true)}}: - vmImage: 'macOS-12' + vmImage: 'macOS-13' ${{ else }}: - vmImage: 'macOS-latest' + vmImage: 'macOS-13' steps: - checkout: none - template: flex-downloadPipelineArtifact.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml index 01ec3b5a2f8ca..3b661d9eb2dc6 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml @@ -31,13 +31,13 @@ jobs: workspace: clean: all variables: - MACOSX_DEPLOYMENT_TARGET: '11.0' + MACOSX_DEPLOYMENT_TARGET: '13.3' ALLOW_RELEASED_ONNX_OPSET_ONLY: ${{ parameters.AllowReleasedOpsetOnly }} TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] PROTO_CACHE_DIR: $(Pipeline.Workspace)/ccache_proto ORT_CACHE_DIR: $(Pipeline.Workspace)/ccache_ort pool: - vmImage: 'macOS-latest' + vmImage: 'macOS-13' timeoutInMinutes: 300 steps: - checkout: self @@ -61,8 +61,6 @@ jobs: - template: set-version-number-variables-step.yml - template: use-xcode-version.yml - parameters: - xcodeVersion: 14.2 - template: download-deps.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml index 2701852f4601d..451e5ad2d44c7 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml @@ -398,9 +398,9 @@ stages: workspace: clean: all pool: - vmImage: 'macOS-latest' + vmImage: 'macOS-13' variables: - MACOSX_DEPLOYMENT_TARGET: '11.0' + MACOSX_DEPLOYMENT_TARGET: '13.3' strategy: matrix: Python38: @@ -424,8 +424,6 @@ stages: versionSpec: $(PythonVersion) - template: use-xcode-version.yml - parameters: - xcodeVersion: 14.2 - template: download-deps.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml index 46dc867113a2e..5fea265d59392 100644 --- a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml @@ -41,7 +41,7 @@ stages: - job: Build_Ios_Pod_For_React_Native pool: - vmImage: 'macOS-12' + vmImage: 'macOS-13' timeoutInMinutes: 90 @@ -52,7 +52,6 @@ stages: steps: - template: use-xcode-version.yml - - task: UsePythonVersion@0 displayName: Use python 3.9 inputs: @@ -99,9 +98,7 @@ stages: jobs: - job: ReactNative_CI pool: - # We need macOS-12 to run the Android emulator for now. - # https://github.com/actions/runner-images/issues/7671 - vmImage: 'macOS-12' + vmImage: 'macOS-13' variables: runCodesignValidationInjection: false timeoutInMinutes: 90 @@ -109,7 +106,7 @@ stages: - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 displayName: Clean Agent Directories condition: always() - + - template: use-xcode-version.yml - task: UsePythonVersion@0 displayName: Use python 3.9 inputs: @@ -118,9 +115,9 @@ stages: architecture: "x64" - task: JavaToolInstaller@0 - displayName: Use jdk 11 + displayName: Use jdk 17 inputs: - versionSpec: '11' + versionSpec: '17' jdkArchitectureOption: 'x64' jdkSourceOption: 'PreInstalled' @@ -304,7 +301,7 @@ stages: scheme: 'OnnxruntimeModuleTest' packageApp: false destinationPlatformOption: 'iOS' - destinationSimulators: 'iPhone 13,OS=latest' + destinationSimulators: 'iPhone 14,OS=16.4' workingDirectory: '$(Build.SourcesDirectory)/js/react_native/ios' xcprettyArgs: '--output build/reports/test-results.xml' publishJUnitResults: true @@ -319,6 +316,11 @@ stages: condition: succeededOrFailed() displayName: Publish React Native iOS Instrumented Test Results + - script: | + xcrun simctl list devices + displayName: List iOS Simulators + continueOnError: true + - script: | JEST_JUNIT_OUTPUT_FILE=$(Build.SourcesDirectory)/js/react_native/e2e/ios-test-results.xml \ detox test --record-logs all \ diff --git a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml index d43f277739d99..e27de27036130 100644 --- a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml @@ -15,10 +15,10 @@ stages: displayName: "Build iOS package for variant: ${{ parameters.packageVariant}}" pool: - vmImage: "macOS-latest" + vmImage: "macOS-13" variables: - xcodeVersion: "14.2" + xcodeVersion: "14.3.1" ortPodVersion: $[stageDependencies.IosPackaging_SetCommonVariables.j.outputs['SetCommonVariables.ORT_POD_VERSION']] ${{ if eq(parameters.packageVariant, 'Full') }}: @@ -62,8 +62,6 @@ stages: architecture: "x64" - template: ../use-xcode-version.yml - parameters: - xcodeVersion: ${{ variables.xcodeVersion }} - template: ../install-appcenter.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/use-xcode-version.yml b/tools/ci_build/github/azure-pipelines/templates/use-xcode-version.yml index ec4398fe31fc5..2cf698aefa8bd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/use-xcode-version.yml +++ b/tools/ci_build/github/azure-pipelines/templates/use-xcode-version.yml @@ -3,7 +3,7 @@ parameters: - name: xcodeVersion type: string - default: "14.2" + default: "14.3.1" steps: - bash: | From b94ba09e4fea03288d48f41380d25499cb9b2a7a Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Wed, 18 Sep 2024 01:12:16 +0800 Subject: [PATCH 39/39] Upgrade XNNPACK to latest version (#22012) ### Description Update XNNPack to latest version (Sep 4) - Some op outputs are changed, channel or stride paras are moved into reshape func. e.g. https://github.com/google/XNNPACK/commit/96962a602d56dc73b345b5b42aabf7a594eceab9 - input params of xnnpack's resize related function are changed a lot - KleidiAI is added as a dependency in ARM64 - The latest XNNPACK includes 2 static libs microkernels-prod and xnnpack. Without microkernels-prod, it throws the exception of Undefined symbols. - Add ORT_TARGET_PROCESSOR to get the real processor target in CMake --- cgmanifests/generated/cgmanifest.json | 2 +- cmake/deps.txt | 4 +- cmake/external/xnnpack.cmake | 62 ++++++++++++++++++- .../xnnpack/AddEmscriptenAndIosSupport.patch | 34 +++++----- .../core/providers/xnnpack/math/softmax.cc | 25 ++++---- .../core/providers/xnnpack/math/softmax.h | 1 + .../core/providers/xnnpack/nn/average_pool.cc | 8 +-- .../core/providers/xnnpack/nn/max_pool.cc | 5 +- .../core/providers/xnnpack/tensor/resize.cc | 13 ++-- .../core/providers/xnnpack/xnnpack_kernel.h | 4 +- .../templates/download-deps.yml | 4 +- 11 files changed, 108 insertions(+), 54 deletions(-) diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index f8589598c7571..654099958b21b 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -146,7 +146,7 @@ "component": { "type": "git", "git": { - "commitHash": "0da379fc4808f9601faef392352018c741c0f297", + "commitHash": "309b75c9e56e0a674bf78d59872ce131f814dfb6", "repositoryUrl": "https://github.com/google/XNNPACK.git" }, "comments": "googlexnnpack" diff --git a/cmake/deps.txt b/cmake/deps.txt index 597c051b5f477..342184bda2f0e 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -29,7 +29,8 @@ fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34 google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.8.5.zip;cd47d3d272faf353600c8cc2fdec2b52d6f69177 google_nsync;https://github.com/google/nsync/archive/refs/tags/1.26.0.zip;5e7c00ef6bf5b787386fc040067903ec774e2752 googletest;https://github.com/google/googletest/archive/refs/tags/v1.15.0.zip;9d2d0af8d77ac726ea55d44a8fa727ec98311349 -googlexnnpack;https://github.com/google/XNNPACK/archive/0da379fc4808f9601faef392352018c741c0f297.zip;663883491e380b628e0a5b162b5f2658032fae73 +#xnnpack 2024.09.04 +googlexnnpack;https://github.com/google/XNNPACK/archive/309b75c9e56e0a674bf78d59872ce131f814dfb6.zip;39FA5259EAEACE0547284B63D5CEDC4F05553F5A json;https://github.com/nlohmann/json/archive/refs/tags/v3.10.5.zip;f257f8dc27c5b8c085dc887b40cddd18ae1f725c microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 @@ -60,3 +61,4 @@ composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/arch directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.5.2.zip;11071a47594b20f00af09aad83e0d5203ccf6029 dawn;https://github.com/google/dawn/archive/511eb80847afe6bded34ec491a38d5d78ba2d604.zip;c493f5aca5586f6634e25d0121c85df71189fb99 +kleidiai;https://gitlab.arm.com/kleidi/kleidiai/-/archive/v0.2.0/kleidiai-v0.2.0.zip;B1E3173992FD91F20DB904AB77D6E901778C2681 diff --git a/cmake/external/xnnpack.cmake b/cmake/external/xnnpack.cmake index 41f02ce6f22bc..9519e4e6a7796 100644 --- a/cmake/external/xnnpack.cmake +++ b/cmake/external/xnnpack.cmake @@ -5,6 +5,8 @@ set(FP16_BUILD_TESTS OFF CACHE INTERNAL "") set(FP16_BUILD_BENCHMARKS OFF CACHE INTERNAL "") set(PTHREADPOOL_BUILD_TESTS OFF CACHE INTERNAL "") set(PTHREADPOOL_BUILD_BENCHMARKS OFF CACHE INTERNAL "") +set(KLEIDIAI_BUILD_TESTS OFF CACHE INTERNAL "") +set(KLEIDIAI_BUILD_BENCHMARK OFF CACHE INTERNAL "") if(CMAKE_SYSTEM_PROCESSOR MATCHES "^riscv64.*") set(XNNPACK_USE_SYSTEM_LIBS OFF) @@ -30,6 +32,60 @@ set(FXDIV_SOURCE_DIR ${fxdiv_SOURCE_DIR}) FetchContent_Declare(pthreadpool URL ${DEP_URL_pthreadpool} URL_HASH SHA1=${DEP_SHA1_pthreadpool}) onnxruntime_fetchcontent_makeavailable(pthreadpool) +# --- Determine target processor +# Why ORT_TARGET_PROCESSOR is only for XNNPACK +# So far, only Onnxruntime + XNNPack only allow one target processor. +# And we support Mac universal package, so, +# CMAKE_OSX_ARCHITECTURES_COUNT greater than 1 is allowed in other places. +IF(CMAKE_OSX_ARCHITECTURES) + LIST(LENGTH CMAKE_OSX_ARCHITECTURES CMAKE_OSX_ARCHITECTURES_COUNT) + IF(CMAKE_OSX_ARCHITECTURES_COUNT GREATER 1) + MESSAGE(STATUS "Building ONNX Runtime with XNNPACK and multiple OSX architectures is not supported. Got:(${CMAKE_OSX_ARCHITECTURES}). " + "Please specify a single architecture in CMAKE_OSX_ARCHITECTURES and re-configure. ") + ENDIF() + IF(NOT CMAKE_OSX_ARCHITECTURES MATCHES "^(x86_64|arm64|arm64e|arm64_32)$") + MESSAGE(FATAL_ERROR "Unrecognized CMAKE_OSX_ARCHITECTURES value \"${CMAKE_OSX_ARCHITECTURES}\"") + ENDIF() + SET(ORT_TARGET_PROCESSOR "${CMAKE_OSX_ARCHITECTURES}") + ADD_COMPILE_OPTIONS("-Wno-shorten-64-to-32") +ELSEIF(CMAKE_GENERATOR MATCHES "^Visual Studio " AND CMAKE_GENERATOR_PLATFORM) + IF(CMAKE_GENERATOR_PLATFORM MATCHES "^Win32") + SET(ORT_TARGET_PROCESSOR "x86") + ELSEIF(CMAKE_GENERATOR_PLATFORM MATCHES "^x64") + SET(ORT_TARGET_PROCESSOR "x86_64") + ELSEIF(CMAKE_GENERATOR_PLATFORM MATCHES "^ARM64") + SET(ORT_TARGET_PROCESSOR "arm64") + ELSEIF(CMAKE_GENERATOR_PLATFORM MATCHES "^ARM64EC") + SET(ORT_TARGET_PROCESSOR "arm64") + ELSE() + MESSAGE(FATAL_ERROR "Unsupported Visual Studio architecture \"${CMAKE_GENERATOR_PLATFORM}\"") + ENDIF() +ELSEIF(CMAKE_SYSTEM_PROCESSOR MATCHES "^i[3-7]86$") + SET(ORT_TARGET_PROCESSOR "x86") +ELSEIF(CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64") + SET(ORT_TARGET_PROCESSOR "x86_64") +ELSEIF(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv[5-8]") + SET(ORT_TARGET_PROCESSOR "arm") +ELSEIF(CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") + SET(ORT_TARGET_PROCESSOR "arm64") +ELSEIF(CMAKE_SYSTEM_PROCESSOR STREQUAL "ppc64le") + SET(ORT_TARGET_PROCESSOR "ppc64") +ELSEIF(NOT ORT_TARGET_PROCESSOR MATCHES "^(x86(_64)?|arm64|riscv(32|64|128)|Hexagon|ppc64)$") + SET(ORT_TARGET_PROCESSOR "${CMAKE_SYSTEM_PROCESSOR}") +ELSE() + MESSAGE(FATAL_ERROR "Unrecognized CMAKE_SYSTEM_PROCESSOR value \"${CMAKE_SYSTEM_PROCESSOR}\"") +ENDIF() +MESSAGE(STATUS "Building for ORT_TARGET_PROCESSOR: ${ORT_TARGET_PROCESSOR}") + +# KleidiAI is only used in Arm64 platform and not supported by MSVC, the details can be seen in +# https://github.com/google/XNNPACK/blob/3b3f7b8a6668f6ab3b6ce33b9f1d1fce971549d1/CMakeLists.txt#L206C82-L206C117 +if(ORT_TARGET_PROCESSOR MATCHES "^arm64.*" AND NOT CMAKE_C_COMPILER_ID STREQUAL "MSVC") + FetchContent_Declare(kleidiai URL ${DEP_URL_kleidiai} URL_HASH SHA1=${DEP_SHA1_kleidiai}) + onnxruntime_fetchcontent_makeavailable(kleidiai) + set(KLEIDIAI_SOURCE_DIR ${kleidiai_SOURCE_DIR}) +endif() + + FetchContent_Declare(googlexnnpack URL ${DEP_URL_googlexnnpack} URL_HASH SHA1=${DEP_SHA1_googlexnnpack} PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/xnnpack/AddEmscriptenAndIosSupport.patch ) @@ -37,8 +93,10 @@ onnxruntime_fetchcontent_makeavailable(googlexnnpack) set(XNNPACK_DIR ${googlexnnpack_SOURCE_DIR}) set(XNNPACK_INCLUDE_DIR ${XNNPACK_DIR}/include) -set(onnxruntime_EXTERNAL_LIBRARIES_XNNPACK XNNPACK pthreadpool) - +set(onnxruntime_EXTERNAL_LIBRARIES_XNNPACK XNNPACK microkernels-prod pthreadpool) +if(ORT_TARGET_PROCESSOR MATCHES "^arm64.*" AND NOT CMAKE_C_COMPILER_ID STREQUAL "MSVC") + list(APPEND onnxruntime_EXTERNAL_LIBRARIES_XNNPACK kleidiai) +endif() # the XNNPACK CMake setup doesn't include the WASM kernels so we have to manually set those up if(CMAKE_SYSTEM_NAME STREQUAL "Emscripten") diff --git a/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch b/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch index 736fffb1e384c..3abf2d3afec42 100644 --- a/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch +++ b/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch @@ -1,8 +1,8 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index dba9b4687..a4345898d 100755 +index 1ff85b538..c3ef2183f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -122,7 +122,7 @@ ENDIF() +@@ -253,7 +253,7 @@ ENDIF() # ---[ Build flags IF(NOT CMAKE_SYSTEM_NAME) MESSAGE(FATAL_ERROR "CMAKE_SYSTEM_NAME not defined") @@ -11,29 +11,27 @@ index dba9b4687..a4345898d 100755 MESSAGE(FATAL_ERROR "Unrecognized CMAKE_SYSTEM_NAME value \"${CMAKE_SYSTEM_NAME}\"") ENDIF() IF(CMAKE_SYSTEM_NAME MATCHES "Windows") -@@ -534,7 +534,12 @@ IF(XNNPACK_BUILD_LIBRARY) - TARGET_LINK_LIBRARIES(operator-utils PRIVATE logging) - TARGET_LINK_LIBRARIES(post-operation PRIVATE logging) - TARGET_LINK_LIBRARIES(subgraph PRIVATE allocator logging memory mutex operators operator-run) -- TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection jit logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing post-operation microkernels-prod subgraph) +@@ -763,7 +763,12 @@ IF(XNNPACK_BUILD_LIBRARY) + TARGET_LINK_LIBRARIES(operator-run PRIVATE xnnpack-base logging) + TARGET_LINK_LIBRARIES(operator-utils PRIVATE xnnpack-base logging) + TARGET_LINK_LIBRARIES(subgraph PRIVATE xnnpack-base allocator logging memory mutex operators operator-run) +- TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing microkernels-prod subgraph) + IF(CMAKE_SYSTEM_NAME STREQUAL "Emscripten") -+ # omit microkernels-prod as the list is manually created by ORT in cmake/external/xnnpack.cmake -+ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection jit logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing post-operation subgraph) ++ # omit microkernels-prod as the list is manually created by ORT in cmake/external/xnnpack.cmake ++ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing subgraph) + ELSE() -+ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection jit logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing post-operation microkernels-prod subgraph) -+ ENDIF() ++ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE allocator cache hardware-config indirection logging memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing microkernels-prod subgraph) ++ ENDIF() + TARGET_LINK_LIBRARIES(XNNPACK PUBLIC xnnpack-base) SET_TARGET_PROPERTIES(XNNPACK PROPERTIES C_EXTENSIONS YES) ENDIF() - IF(NOT MSVC) -@@ -543,8 +548,9 @@ ENDIF() +@@ -772,7 +777,8 @@ IF(NOT MSVC) + ENDIF() IF(XNNPACK_TARGET_PROCESSOR STREQUAL "arm") SET_PROPERTY(SOURCE ${ALL_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -marm ") - SET_PROPERTY(SOURCE ${PROD_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -marm ") - SET_PROPERTY(SOURCE ${ALL_ARMSIMD32_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv6 -mfpu=vfp -munaligned-access ") -- SET_PROPERTY(SOURCE ${PROD_ARMSIMD32_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv6 -mfpu=vfp -munaligned-access ") + # set this to armv7-a to workaround build issue. we don't target armv6 so it shouldn't matter -+ SET_PROPERTY(SOURCE ${ALL_ARMSIMD32_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=vfp -munaligned-access ") -+ SET_PROPERTY(SOURCE ${PROD_ARMSIMD32_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=vfp -munaligned-access ") ++ SET_PROPERTY(SOURCE ${ALL_ARMSIMD32_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=vfp -munaligned-access ") SET_PROPERTY(SOURCE ${ALL_NEON_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=neon ") - SET_PROPERTY(SOURCE ${PROD_NEON_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=neon ") SET_PROPERTY(SOURCE ${ALL_NEONFP16_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -march=armv7-a -mfpu=neon-fp16 ") + # GCC requires -mfp16-format=ieee to define __fp16 type, but Clang doesn't support this option at all. diff --git a/onnxruntime/core/providers/xnnpack/math/softmax.cc b/onnxruntime/core/providers/xnnpack/math/softmax.cc index 87440b7814176..43e3ac193de5d 100644 --- a/onnxruntime/core/providers/xnnpack/math/softmax.cc +++ b/onnxruntime/core/providers/xnnpack/math/softmax.cc @@ -166,24 +166,21 @@ Softmax::Softmax(const OpKernelInfo& info) : XnnpackKernel{info} { if (op_type_ == OpComputeType::op_compute_type_qu8) { // the order of input tensor, x,x_scale, x_zp, y_scale, y_zp OpQuantParam quant_param = ParseQuantParamForOp(info, x_dtype, 1); - xstatus = xnn_create_softmax_nc_qu8(channels, - channels, - channels, - quant_param[0].first[0], // x_scale - quant_param[1].second, // y_zp - quant_param[1].first[0], // y_scale - 0, // flags, - &p); + xstatus = xnn_create_softmax_nc_qu8( + quant_param[0].first[0], // x_scale, input scale + quant_param[1].second, // y_zp, output zero point + quant_param[1].first[0], // y_scale, output scale + 0, // flags, + &p); } else if (op_type_ == OpComputeType::op_compute_type_fp32) { - xstatus = xnn_create_softmax_nc_f32(channels, - channels, - channels, - 0, // flags, - &p); + xstatus = xnn_create_softmax_nc_f32( + 0, // flags, + &p); } ORT_ENFORCE(xstatus == xnn_status_success, "xnn_create_softmax_nc_", OpTypeToString(op_type_), " failed. Status:", xstatus); + channel_dim_ = channels; op0_.reset(p); } @@ -205,7 +202,7 @@ Status Softmax::Compute(OpKernelContext* ctx) const { auto reshape_fn = op_type_ == OpComputeType::op_compute_type_qu8 ? xnn_reshape_softmax_nc_qu8 : xnn_reshape_softmax_nc_f32; - status = reshape_fn(op0_.get(), N, threadpool); + status = reshape_fn(op0_.get(), channel_dim_, channel_dim_, channel_dim_, N, threadpool); if (status != xnn_status_success) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_softmax_nc_", OpTypeToString(op_type_), diff --git a/onnxruntime/core/providers/xnnpack/math/softmax.h b/onnxruntime/core/providers/xnnpack/math/softmax.h index 8c6fba6c822a1..9a8055ff34a57 100644 --- a/onnxruntime/core/providers/xnnpack/math/softmax.h +++ b/onnxruntime/core/providers/xnnpack/math/softmax.h @@ -23,6 +23,7 @@ class Softmax final : public XnnpackKernel { int opset_; OpComputeType op_type_ = OpComputeType::op_compute_type_invalid; XnnpackOperator op0_; + int64_t channel_dim_; }; } // namespace xnnpack } // namespace onnxruntime diff --git a/onnxruntime/core/providers/xnnpack/nn/average_pool.cc b/onnxruntime/core/providers/xnnpack/nn/average_pool.cc index 58c209a13cd0c..b31b5a94899bf 100644 --- a/onnxruntime/core/providers/xnnpack/nn/average_pool.cc +++ b/onnxruntime/core/providers/xnnpack/nn/average_pool.cc @@ -15,7 +15,6 @@ namespace onnxruntime { namespace xnnpack { namespace { Status CreateXnnpackKernel(const PoolAttributes& pool_attrs, - int64_t C, const std::optional>& clip_min_max, struct xnn_operator*& p, const OpQuantParam& quant_param, @@ -42,7 +41,6 @@ Status CreateXnnpackKernel(const PoolAttributes& pool_attrs, input_padding_bottom, input_padding_left, pooling_height, pooling_width, stride_height, stride_width, - C, C, C, // channels, input_pixel_stride, output_pixel_stride foutput_min, foutput_max, flags, &p); } else if (avgpool_type == OpComputeType::op_compute_type_qu8) { const float output_scale = quant_param[1].first[0]; @@ -53,7 +51,6 @@ Status CreateXnnpackKernel(const PoolAttributes& pool_attrs, input_padding_bottom, input_padding_left, pooling_height, pooling_width, stride_height, stride_width, - C, C, C, // channels, input_pixel_stride, output_pixel_stride quant_param[0].second, quant_param[0].first[0], quant_param[1].second, @@ -209,7 +206,7 @@ AveragePool::AveragePool(const OpKernelInfo& info) ORT_THROW("unsupported AveragePool in XnnpackEP, we have FLOAT|UINT8, but got ", stype); } struct xnn_operator* p; - auto ret = CreateXnnpackKernel(pool_attrs_, C, clip_min_max_, p, + auto ret = CreateXnnpackKernel(pool_attrs_, clip_min_max_, p, quant_param, avgpool_type_); ORT_ENFORCE(ret.IsOK(), ret.ErrorMessage()); op0_.reset(p); @@ -222,6 +219,7 @@ Status AveragePool::Compute(OpKernelContext* context) const { int64_t N = X_shape[0]; int64_t H = X_shape[1]; int64_t W = X_shape[2]; + int64_t C = X_shape[3]; // set the N dim to the correct value TensorShapeVector output_dims{output_dims_}; @@ -247,7 +245,7 @@ Status AveragePool::Compute(OpKernelContext* context) const { ? xnn_reshape_average_pooling2d_nhwc_f32 : xnn_reshape_average_pooling2d_nhwc_qu8; - auto status = reshape_fn(op0_.get(), N, H, W, + auto status = reshape_fn(op0_.get(), N, H, W, C, C, C, &workspace_size, &workspace_alignment, /*output_height_out=*/nullptr, /*output_width_out=*/nullptr, threadpool); diff --git a/onnxruntime/core/providers/xnnpack/nn/max_pool.cc b/onnxruntime/core/providers/xnnpack/nn/max_pool.cc index 2ef9f97f77b14..0f0b827974f66 100644 --- a/onnxruntime/core/providers/xnnpack/nn/max_pool.cc +++ b/onnxruntime/core/providers/xnnpack/nn/max_pool.cc @@ -172,7 +172,6 @@ MaxPool::MaxPool(const OpKernelInfo& info) pooling_height, pooling_width, stride_height, stride_width, dilation_height, dilation_width, - C, C, C, // channels, input_pixel_stride, output_pixel_stride foutput_min, foutput_max, flags, &p); } else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { maxpool_type_ = OpComputeType::op_compute_type_qu8; @@ -183,7 +182,6 @@ MaxPool::MaxPool(const OpKernelInfo& info) pooling_height, pooling_width, stride_height, stride_width, dilation_height, dilation_width, - C, C, C, // channels, input_pixel_stride, output_pixel_stride output_min, output_max, flags, &p); } else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_INT8) { maxpool_type_ = OpComputeType::op_compute_type_qs8; @@ -194,7 +192,6 @@ MaxPool::MaxPool(const OpKernelInfo& info) pooling_height, pooling_width, stride_height, stride_width, dilation_height, dilation_width, - C, C, C, // channels, input_pixel_stride, output_pixel_stride output_min, output_max, flags, &p); } else { auto stype = DataTypeImpl::ToString(DataTypeImpl::TypeFromProto(*X_arg.TypeAsProto())); @@ -213,6 +210,7 @@ Status MaxPool::Compute(OpKernelContext* context) const { int64_t N = X_shape[0]; int64_t H = X_shape[1]; int64_t W = X_shape[2]; + int64_t C = X_shape[3]; // set the N dim to the correct value TensorShapeVector output_dims{output_dims_}; @@ -234,6 +232,7 @@ Status MaxPool::Compute(OpKernelContext* context) const { } auto status = reshape_fn(op0_.get(), N, H, W, + C, C, C, // channels, input_pixel_stride, output_pixel_stride /*output_height_out=*/nullptr, /*output_width_out=*/nullptr, threadpool); if (status != xnn_status_success) { diff --git a/onnxruntime/core/providers/xnnpack/tensor/resize.cc b/onnxruntime/core/providers/xnnpack/tensor/resize.cc index cf874796ba169..db5648d5d6e54 100644 --- a/onnxruntime/core/providers/xnnpack/tensor/resize.cc +++ b/onnxruntime/core/providers/xnnpack/tensor/resize.cc @@ -214,8 +214,6 @@ Resize::Resize(const OpKernelInfo& info) : UpsampleBase(info), XnnpackKernel{inf } } - int64_t channels = x_shape->dim(3).dim_value(); - uint32_t flags = 0; ORT_ENFORCE(mode_ == UpsampleMode::LINEAR, "only support bilinear resize"); if (coordinate_transform_mode_ == ResizeCoordinateTransformationMode::ALIGN_CORNERS) { @@ -227,12 +225,14 @@ Resize::Resize(const OpKernelInfo& info) : UpsampleBase(info), XnnpackKernel{inf xnn_status xstatus = xnn_status_invalid_state; struct xnn_operator* p = nullptr; + auto out_h = output_dims_[1]; + auto out_w = output_dims_[2]; if (op_type_ == OpComputeType::op_compute_type_fp32) { - xstatus = xnn_create_resize_bilinear2d_nhwc_f32(channels, channels, channels, flags, &p); + xstatus = xnn_create_resize_bilinear2d_nhwc_f32(out_h, out_w, flags, &p); } else if (op_type_ == OpComputeType::op_compute_type_qu8) { - xstatus = xnn_create_resize_bilinear2d_nhwc_u8(channels, channels, channels, flags, &p); + xstatus = xnn_create_resize_bilinear2d_nhwc_u8(out_h, out_w, flags, &p); } else { - xstatus = xnn_create_resize_bilinear2d_nhwc_s8(channels, channels, channels, flags, &p); + xstatus = xnn_create_resize_bilinear2d_nhwc_s8(out_h, out_w, flags, &p); } ORT_ENFORCE(xstatus == xnn_status_success, "xnn_create_resize_bilinear2d_nhwc_", OpTypeToString(op_type_), " failed. Status:", @@ -248,6 +248,7 @@ Status Resize::ComputeInternal(OpKernelContext* ctx, const Tensor* input, auto N = X_shape[0]; auto H = X_shape[1]; auto W = X_shape[2]; + auto C = X_shape[3]; Tensor* output = ctx->Output(0, TensorShape(output_dims)); pthreadpool_t threadpool = GetThreadPool(); @@ -266,7 +267,7 @@ Status Resize::ComputeInternal(OpKernelContext* ctx, const Tensor* input, reshape_fn = xnn_reshape_resize_bilinear2d_nhwc_s8; } - auto status = reshape_fn(op0_.get(), N, H, W, output_dims[1], output_dims[2], + auto status = reshape_fn(op0_.get(), N, H, W, C, C, C, &workspace_size, &workspace_alignment, threadpool); if (status != xnn_status_success) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_resize_bilinear2d_nhwc_", OpTypeToString(op_type_), diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_kernel.h b/onnxruntime/core/providers/xnnpack/xnnpack_kernel.h index 0978a88288114..31512586be19d 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_kernel.h +++ b/onnxruntime/core/providers/xnnpack/xnnpack_kernel.h @@ -48,7 +48,7 @@ class XnnpackKernel : public OpKernel { // auto_code_cache.reset(&code_cache_); #endif // status = xnn_init_weights_cache(&weights_cache_); - xnn_weights_cache_t weights_cache = nullptr; + xnn_weights_cache_t weights_cache_provider = nullptr; status = xnn_create_weights_cache(&weights_cache, 0); ORT_ENFORCE(status == xnn_status_success, "Failed to create XNNPACK weights cache"); auto_weights_cache.reset(weights_cache); @@ -57,7 +57,7 @@ class XnnpackKernel : public OpKernel { } // std::unique_ptr auto_code_cache; - std::unique_ptr auto_weights_cache; + std::unique_ptr auto_weights_cache; // private: // #if defined(XNN_CACHE_ENABLE) && XNN_PLATFORM_JIT diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index 3c74c70ed102d..cbba1cb8ba8bd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.181 + version: 1.0.184 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.181 + version: 1.0.184 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here.