From 82c8e56911900bf1e892345124c85b475787cb54 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 4 Apr 2025 00:24:03 -0700 Subject: [PATCH 01/18] upgrade action shellcheck to v1.30.0 (#24304) ### Description ### Motivation and Context --- .github/workflows/lint.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 8d966d358de01..21baab0fd191c 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -26,7 +26,7 @@ jobs: level: info filter_mode: diff_context - name: shellcheck # Static check shell scripts - uses: reviewdog/action-shellcheck@v1 + uses: reviewdog/action-shellcheck@v1.30.0 with: github_token: ${{ secrets.github_token }} reporter: github-pr-check From 1cb53d0054774db37ed500e43988497c66599c11 Mon Sep 17 00:00:00 2001 From: minfhong-quic Date: Fri, 4 Apr 2025 23:29:54 +0800 Subject: [PATCH 02/18] [QNN-EP] Fix ONNX context model helper. (#24271) ### Description Fix the bug where the QNN EP generates an ONNX model with EP Context and fails to run. ### Motivation and Context When generating an ONNX model with QNN EP context where the input is scalar, the shape is not set, resulting in a null pointer and causing the subsequent run to fail. --- onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 93b2acb5b002c..26642459a6863 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -77,6 +77,7 @@ Status CreateNodeArgs(const std::vector& names, const OnnxTensorInfo& tensor_info = tensor_info_table.at(name); std::unique_ptr tensor_type = Factory::Create(); tensor_type->mutable_tensor_type()->set_elem_type(tensor_info.data_type_); + tensor_type->mutable_tensor_type()->mutable_shape(); for (size_t j = 0; j < tensor_info.shape_.size(); ++j) { tensor_type->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(tensor_info.shape_[j]); } From 318cc87f23d3b28a8b3d7617433f0a1cfd60e458 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 4 Apr 2025 10:36:08 -0700 Subject: [PATCH 03/18] [WebGPU] fix Pad cache key (#24305) ### Description Fix cache key of Pad operator --- .github/workflows/windows_webgpu.yml | 1 - onnxruntime/core/providers/webgpu/tensor/pad.cc | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/windows_webgpu.yml b/.github/workflows/windows_webgpu.yml index e8cea1c5805a3..8b3b8a2fcde54 100644 --- a/.github/workflows/windows_webgpu.yml +++ b/.github/workflows/windows_webgpu.yml @@ -126,7 +126,6 @@ jobs: dir ${{ github.workspace }}\RelWithDebInfo\RelWithDebInfo\onnxruntime_test_all_stderr.log - name: Validate shader keys - continue-on-error: true uses: ./.github/actions/webgpu-validate-shader-key with: log_file_path: ${{ github.workspace }}\RelWithDebInfo\RelWithDebInfo\onnxruntime_test_all_stderr.log diff --git a/onnxruntime/core/providers/webgpu/tensor/pad.cc b/onnxruntime/core/providers/webgpu/tensor/pad.cc index cb019892b006f..f24578a145aae 100644 --- a/onnxruntime/core/providers/webgpu/tensor/pad.cc +++ b/onnxruntime/core/providers/webgpu/tensor/pad.cc @@ -168,9 +168,9 @@ Status Pad::ComputeInternal(ComputeContext& context) const { PadProgram program{mode_, dim_value_zero, is_float16}; if (!dim_value_zero) { - program.AddInput({input_tensor, ProgramTensorMetadataDependency::TypeAndRank}); + program.AddInput({input_tensor, ProgramTensorMetadataDependency::Rank}); } - program.AddOutput({output_tensor, ProgramTensorMetadataDependency::Rank}) + program.AddOutput({output_tensor, ProgramTensorMetadataDependency::TypeAndRank}) .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) .CacheHint(std::to_string(static_cast(mode_)), dim_value_zero) .AddUniformVariables({{gsl::span(lower_pads.data(), lower_pads.size())}, {output_size}, {value_uint32}}); From 56f101839603f2c875234b8baca4dc0cf5f8f973 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 4 Apr 2025 11:30:11 -0700 Subject: [PATCH 04/18] Bump vite from 6.2.4 to 6.2.5 in /js/web/test/e2e/exports/testcases/vite-default (#24312) Bumps [vite](https://github.com/vitejs/vite/tree/HEAD/packages/vite) from 6.2.4 to 6.2.5.
Release notes

Sourced from vite's releases.

v6.2.5

Please refer to CHANGELOG.md for details.

Changelog

Sourced from vite's changelog.

6.2.5 (2025-04-03)

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=vite&package-manager=npm_and_yarn&previous-version=6.2.4&new-version=6.2.5)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .../e2e/exports/testcases/vite-default/package-lock.json | 8 ++++---- .../test/e2e/exports/testcases/vite-default/package.json | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/js/web/test/e2e/exports/testcases/vite-default/package-lock.json b/js/web/test/e2e/exports/testcases/vite-default/package-lock.json index 9e4730a407d57..708e458748b3a 100644 --- a/js/web/test/e2e/exports/testcases/vite-default/package-lock.json +++ b/js/web/test/e2e/exports/testcases/vite-default/package-lock.json @@ -12,7 +12,7 @@ }, "devDependencies": { "@vitejs/plugin-vue": "^5.2.1", - "vite": "^6.2.4" + "vite": "^6.2.5" } }, "node_modules/@babel/helper-string-parser": { @@ -1069,9 +1069,9 @@ } }, "node_modules/vite": { - "version": "6.2.4", - "resolved": "https://registry.npmjs.org/vite/-/vite-6.2.4.tgz", - "integrity": "sha512-veHMSew8CcRzhL5o8ONjy8gkfmFJAd5Ac16oxBUjlwgX3Gq2Wqr+qNC3TjPIpy7TPV/KporLga5GT9HqdrCizw==", + "version": "6.2.5", + "resolved": "https://registry.npmjs.org/vite/-/vite-6.2.5.tgz", + "integrity": "sha512-j023J/hCAa4pRIUH6J9HemwYfjB5llR2Ps0CWeikOtdR8+pAURAk0DoJC5/mm9kd+UgdnIy7d6HE4EAvlYhPhA==", "dev": true, "license": "MIT", "dependencies": { diff --git a/js/web/test/e2e/exports/testcases/vite-default/package.json b/js/web/test/e2e/exports/testcases/vite-default/package.json index e06733f917e3f..904db7a41de9c 100644 --- a/js/web/test/e2e/exports/testcases/vite-default/package.json +++ b/js/web/test/e2e/exports/testcases/vite-default/package.json @@ -13,6 +13,6 @@ }, "devDependencies": { "@vitejs/plugin-vue": "^5.2.1", - "vite": "^6.2.4" + "vite": "^6.2.5" } } From 2e94c5a40f4a5dbd4627fa98ad5e410db171ba2e Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 4 Apr 2025 12:32:47 -0700 Subject: [PATCH 05/18] [WebGPU] fix cache key of AttentionProbs/VxAttentionScore (#24309) ### Description fix the cache inconsistency of program AttentionProbs/VxAttentionScore `n_reps` is already in uniforms so do not use it from hardcoded. --- .../webgpu-validate-shader-key/action.yml | 2 +- .github/workflows/windows-web-ci-workflow.yml | 1 - onnxruntime/contrib_ops/webgpu/bert/attention.cc | 16 ++++++++-------- onnxruntime/contrib_ops/webgpu/bert/attention.h | 10 ++++------ 4 files changed, 13 insertions(+), 16 deletions(-) diff --git a/.github/actions/webgpu-validate-shader-key/action.yml b/.github/actions/webgpu-validate-shader-key/action.yml index 7b341d38ea906..86406a2e91877 100644 --- a/.github/actions/webgpu-validate-shader-key/action.yml +++ b/.github/actions/webgpu-validate-shader-key/action.yml @@ -22,7 +22,7 @@ runs: working-directory: ${{ github.action_path }} - name: Validate shader keys (native log) - if: ${{ !inputs.is_chromium_log != 'true' }} + if: ${{ inputs.is_chromium_log != 'true' }} shell: cmd run: | node validate-shader-key.js < "${{ inputs.log_file_path }}" diff --git a/.github/workflows/windows-web-ci-workflow.yml b/.github/workflows/windows-web-ci-workflow.yml index ce0e5167eb0a0..57f687d8502ff 100644 --- a/.github/workflows/windows-web-ci-workflow.yml +++ b/.github/workflows/windows-web-ci-workflow.yml @@ -200,7 +200,6 @@ jobs: - name: Validate shader keys - WebGPU EP if: ${{ inputs.run_webgpu_tests == true && inputs.build_config == 'Debug' }} - continue-on-error: true uses: ./.github/actions/webgpu-validate-shader-key with: log_file_path: ${{ runner.temp }}\web\test\07\chrome_debug.log diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 0d4afc8c13f4b..abea94d2e0b50 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -108,9 +108,9 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { std::ostringstream oss; InitVarStub(oss, seqlen_k_); shader.MainFunctionBody() << oss.str(); - shader.MainFunctionBody() << "let kOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.kv_sequence_length * uniforms.K;\n"; + shader.MainFunctionBody() << "let kOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.kv_sequence_length * uniforms.K;\n"; if (has_present_key_) { - shader.MainFunctionBody() << "let presentKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.present_sequence_length * uniforms.K;\n"; + shader.MainFunctionBody() << "let presentKeyOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.present_sequence_length * uniforms.K;\n"; } shader.MainFunctionBody() << "var value = f32_val_t(0);\n" @@ -123,7 +123,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { if ((feed_past_key_ && has_present_key_) || (past_present_share_buffer_ && !is_first_prompt_)) { shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n" - << " let pastKeyOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.past_sequence_length * uniforms.K;\n" + << " let pastKeyOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.past_sequence_length * uniforms.K;\n" << " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" << " } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" << " tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n" @@ -181,7 +181,7 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1); AttentionProbsProgram program{"AttentionProbs", feed_past_key, has_present_key, has_attention_bias, tile_size, - components, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; + components, parameters.is_first_prompt_, seqlen_k, parameters.past_present_share_buffer_}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, components}, {K, ProgramTensorMetadataDependency::TypeAndRank, components}}); if (feed_past_key) { @@ -331,9 +331,9 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { std::ostringstream oss; InitVarStub(oss, seqlen_k_); shader.MainFunctionBody() << oss.str(); - shader.MainFunctionBody() << "let vOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.kv_sequence_length + n;\n"; + shader.MainFunctionBody() << "let vOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.N * uniforms.kv_sequence_length + n;\n"; if (has_present_value_) { - shader.MainFunctionBody() << "let presentValueOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.present_sequence_length + n;\n"; + shader.MainFunctionBody() << "let presentValueOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.N * uniforms.present_sequence_length + n;\n"; } shader.MainFunctionBody() << "var value = output_value_t(0);\n" @@ -346,7 +346,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { if ((feed_past_value_ && has_present_value_) || (past_present_share_buffer_ && !is_first_prompt_)) { shader.MainFunctionBody() << " if (w + local_id.y < past_sequence_length) {\n" - << " let pastValueOffset = (workgroup_id.z / " << n_reps_ << ") * uniforms.N * uniforms.past_sequence_length + n;\n" + << " let pastValueOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.N * uniforms.past_sequence_length + n;\n" << " tileK[idx] = " << (past_present_share_buffer_ ? "present_value" : "past_value") << "[pastValueOffset + (w + local_id.y) * uniforms.N];\n" << " } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" << " tileK[idx] = v[vOffset + (w + local_id.y - past_sequence_length) * uniforms.N];\n" @@ -400,7 +400,7 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int const int components = parameters.v_head_size_ % 4 == 0 ? 4 : (parameters.v_head_size_ % 2 == 0 ? 2 : 1); constexpr int tile_size = 12; int tile_n_size = tile_size * components; - VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, parameters.n_reps, seqlen_k, parameters.past_present_share_buffer_}; + VxAttentionScoreProgram program{"VxAttentionScore", feed_past_value, has_present_value, tile_size, parameters.is_first_prompt_, seqlen_k, parameters.past_present_share_buffer_}; program.AddInputs({{probs, ProgramTensorMetadataDependency::TypeAndRank}, {V, ProgramTensorMetadataDependency::TypeAndRank, components}}); if (feed_past_value) { diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index 164ea72b07d9d..6123d2c47add1 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -34,8 +34,8 @@ class TransferBSDToBNSHProgram final : public Program class AttentionProbsProgram final : public Program { public: AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key, - bool has_attention_bias, int tile_size, int components, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) - : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { + bool has_attention_bias, int tile_size, int components, bool is_first_prompt, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) + : Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -60,7 +60,6 @@ class AttentionProbsProgram final : public Program { bool has_attention_bias_; int tile_size_; int components_; - int n_reps_; const Tensor* seqlen_k_; bool past_present_share_buffer_; bool is_first_prompt_; @@ -90,8 +89,8 @@ class InPlaceSoftmaxProgram final : public Program { class VxAttentionScoreProgram final : public Program { public: - VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, int n_reps = 1, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) - : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), n_reps_(n_reps), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { + VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false) + : Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -114,7 +113,6 @@ class VxAttentionScoreProgram final : public Program { bool feed_past_value_; bool has_present_value_; int tile_size_; - int n_reps_; const Tensor* seqlen_k_; bool past_present_share_buffer_; bool is_first_prompt_; From e944379e519fdda2c782e80d4d960483c975f604 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Fri, 4 Apr 2025 13:15:04 -0700 Subject: [PATCH 06/18] Support Gemma3 with Clip fused attention (#24280) ### Description Essentially, the vision model is traced differently (this time it's without mask.), and the input indices of op.Add and op.MatMul can be different. Also, fp16 and fp32 need different tracing patterns (op.Cast). 1. Add another traced pattern to CLIP attention to cover no attention_mask case 2. Accept different index of input on op.Add and op.MatMul (be more general) 3. fp16 and fp32 shows different pattern (op.Cast after op.Softmax) 4. Refactor test_fastgelu.py to cover torch.onnx.export(..., dynamo=True) 5. Add gemma3 vision attention (SigLip) test to cover both fp16 and fp32 ### Motivation and Context To optimize Gemma3 multi-modal model, the changes are needed. https://huggingface.co/google/gemma-3-4b-it NOTE: some related follow-ups (upstream optimizations to onnxscript-optimizer): https://github.com/microsoft/onnxscript/issues/2158 https://github.com/microsoft/onnxscript/issues/2156 --- .../transformers/fusion_attention_clip.py | 83 +++++-- .../tools/transformers/fusion_fastgelu.py | 17 +- .../models/gemma3-vision-attention_fp16.onnx | Bin 0 -> 5397 bytes .../models/gemma3-vision-attention_fp32.onnx | Bin 0 -> 8780 bytes .../python/transformers/test_gelu_fusions.py | 71 +++--- .../python/transformers/test_gemma3_vision.py | 216 ++++++++++++++++++ .../github/linux/python/requirements.txt | 3 +- .../github/windows/python/requirements.txt | 1 + .../transformers-test/requirements.txt | 1 + 9 files changed, 336 insertions(+), 56 deletions(-) create mode 100644 onnxruntime/test/python/transformers/test_data/models/gemma3-vision-attention_fp16.onnx create mode 100644 onnxruntime/test/python/transformers/test_data/models/gemma3-vision-attention_fp32.onnx create mode 100644 onnxruntime/test/python/transformers/test_gemma3_vision.py diff --git a/onnxruntime/python/tools/transformers/fusion_attention_clip.py b/onnxruntime/python/tools/transformers/fusion_attention_clip.py index 63bf6410f86c3..fe93f5cd358bf 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_clip.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_clip.py @@ -126,7 +126,10 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): if node_before_layer_norm is None: continue child = self.model.find_first_child_by_type( - node_before_layer_norm, "LayerNormalization", input_name_to_nodes, False + node_before_layer_norm, + "LayerNormalization", + input_name_to_nodes, + False, ) if child is None: continue @@ -146,19 +149,26 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): qkv_nodes = self.model.match_parent_path( normalize_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], - [1, 1, 0, 0, 0], + [1, None, 0, 0, 0], ) if qkv_nodes is None: logger.debug("fuse_attention: failed to match qkv path") return - - reshape_qkv, transpose_qkv, matmul_qkv = qkv_nodes[2], qkv_nodes[3], qkv_nodes[-1] + reshape_qkv, transpose_qkv, matmul_qkv = ( + qkv_nodes[2], + qkv_nodes[3], + qkv_nodes[-1], + ) v_nodes = self.model.match_parent_path( - matmul_qkv, ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, None] + matmul_qkv, + ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], + [1, 0, 0, 0, None], ) if v_nodes is None: - v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 1]) + v_nodes = self.model.match_parent_path( + matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None] + ) if v_nodes is None: logger.debug("fuse_attention: failed to match v path") return @@ -182,17 +192,30 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ) if qk_nodes is None: qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0]) - if qk_nodes is None: - qk_nodes = self.model.match_parent_path( - matmul_qkv, ["Cast", "Cast", "Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0, 0, 0] - ) - if qk_nodes is None: - logger.debug("fuse_attention: failed to match qk path") - return - else: - add_mask = qk_nodes[3] - else: + if qk_nodes is not None: add_mask = qk_nodes[1] + else: + # If attention mask is not used, we can still match the qk path. + qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0]) + if qk_nodes is None: + # Cast nodes are added in the model for fp16. + qk_nodes = self.model.match_parent_path( + matmul_qkv, + ["Cast", "Cast", "Softmax", "Add", "Mul", "MatMul"], + [0, 0, 0, 0, 0, 0], + ) + if qk_nodes is not None: + add_mask = qk_nodes[3] + else: + # If attention mask is not used, we can still match the qk path. + qk_nodes = self.model.match_parent_path( + matmul_qkv, + ["Cast", "Cast", "Softmax", "Mul", "MatMul"], + [0, 0, 0, 0, 0], + ) + if qk_nodes is None: + logger.debug("fuse_attention: failed to match qk path") + return else: assert len(add_mask_indices) == 1 causal_mask_input_index = 1 - add_mask_indices[0] @@ -201,10 +224,14 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): matmul_qk = qk_nodes[-1] q_nodes = self.model.match_parent_path( - matmul_qk, ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"], [0, 0, 0, 0, None, None] + matmul_qk, + ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"], + [0, 0, 0, 0, None, None], ) if q_nodes is None: - q_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, 1]) + q_nodes = self.model.match_parent_path( + matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, None] + ) if q_nodes is None: logger.debug("fuse_attention: failed to match q path") return @@ -216,10 +243,14 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): add_q, matmul_q = q_nodes[-2], q_nodes[-1] k_nodes = self.model.match_parent_path( - matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, 0, None] + matmul_qk, + ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], + [1, 0, 0, 0, 0, None], ) if k_nodes is None: - k_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 1]) + k_nodes = self.model.match_parent_path( + matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None] + ) if k_nodes is None: logger.debug("fuse_attention: failed to match k path") return @@ -242,7 +273,17 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # 4D Add after Q x K' add_qk_nodes = self.model.match_parent_path( add_mask, - ["Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze", "Reshape", "Reshape", "Cast"], + [ + "Where", + "Sub", + "Cast", + "Expand", + "Unsqueeze", + "Unsqueeze", + "Reshape", + "Reshape", + "Cast", + ], [1, 2, 1, 0, 0, 0, 0, 0, 0], ) if add_qk_nodes is not None: diff --git a/onnxruntime/python/tools/transformers/fusion_fastgelu.py b/onnxruntime/python/tools/transformers/fusion_fastgelu.py index 210f10e2eadd4..728bd03244758 100644 --- a/onnxruntime/python/tools/transformers/fusion_fastgelu.py +++ b/onnxruntime/python/tools/transformers/fusion_fastgelu.py @@ -177,13 +177,12 @@ def fuse_2(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict return mul_after_mul_half = children[0] + # root_node could be None when root_input is graph input root_node = self.model.get_parent( mul_after_mul_half, 0 if mul_after_mul_half.input[1] == mul_half.output[0] else 1, output_name_to_node, ) - if root_node is None: - return mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node) if mul_before_tanh is None: @@ -197,7 +196,13 @@ def fuse_2(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict if add_before_tanh is None: return - mul_after_pow = self.model.match_parent(add_before_tanh, "Mul", None, output_name_to_node, exclude=[root_node]) + mul_after_pow = self.model.match_parent( + add_before_tanh, + "Mul", + None, + output_name_to_node, + exclude=[root_node] if root_node else [], + ) if mul_after_pow is None: return @@ -212,7 +217,9 @@ def fuse_2(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict if not self.model.has_constant_input(pow, 3.0): return - if pow.input[0] != root_node.output[0]: + root_input = mul_after_mul_half.input[0 if mul_after_mul_half.input[1] == mul_half.output[0] else 1] + + if pow.input[0] != root_input: return subgraph_nodes = [ @@ -236,7 +243,7 @@ def fuse_2(self, tanh_node, input_name_to_nodes: dict, output_name_to_node: dict self.nodes_to_remove.extend(subgraph_nodes) fused_node = helper.make_node( "FastGelu", - inputs=[root_node.output[0]], + inputs=[root_input], outputs=mul_after_mul_half.output, name=self.model.create_node_name("FastGelu"), ) diff --git a/onnxruntime/test/python/transformers/test_data/models/gemma3-vision-attention_fp16.onnx b/onnxruntime/test/python/transformers/test_data/models/gemma3-vision-attention_fp16.onnx new file mode 100644 index 0000000000000000000000000000000000000000..49e805169ee452556aa477b3958d35b8fe437e7d GIT binary patch literal 5397 zcmeHL-A)rh6z-NhOKMdCDhwu zuC2RXK3&!-6|FS)_Z94c(h7bVi)shDvAx5gUtDe)#6w__L*2zxpz`?)>wB5+qQ1-_Q?mH9v4>ypTD$DP|!Qd|GCcfPBr|*i`)FT9070@6gRYe?0apVeb;I*^6t#Rfyb~L@^FqdoL z))9NgHYMD5b#0M4K3{Y6hPG?aCW4K%cOUWM%7=}00iV7G+!5Ud!wdHJD7C_M>G*Iv%vl7Ws{JS> zVo#Y9kP6x^XcTfnldO1loKQhhf%yzFVS~wGp5E0c7j;oe9j+&%c1-z7Tv>{%bj3Y~ zffG_M2NM(LXF_IJp1- literal 0 HcmV?d00001 diff --git a/onnxruntime/test/python/transformers/test_data/models/gemma3-vision-attention_fp32.onnx b/onnxruntime/test/python/transformers/test_data/models/gemma3-vision-attention_fp32.onnx new file mode 100644 index 0000000000000000000000000000000000000000..7fca335f1373165e1aeb56d17d0a0c896e7305b7 GIT binary patch literal 8780 zcmeI2-%is|9LLu&NB3jmJWjx=gv=0=UDW+y5{XJ6dch)pB5>o)N!iNQu01X7DWHjo zFW{w@MjpV2;Kt`LzJZV6X;;{aXQB}k63?=ev_0qd@4L@QoBb*Snz9KwbbUf~OH-(; z5^vXbEzR``>55XVD&>Vw&tMOv!{D1SiE^On?GA-ecDSRf9s*N+t+|*mtHQO-uI{`D z&ezx}E*B?a7gSS!qf*@_rRy-NIi7CXKEo$s;LwHvChF zTtHr>lqj}OvE$lCR6Yx!nz%Aw7$cM?^-Wx>7t_5%_$V<4;gVP#9038v1B9Yx z`=a7XVKSy7UTU?LS($$$F9J)|32wWp(!^Q33HNH##8zZki zNUXz(5aeVDWg#g`Y$Vyp7oWfbr`uMj?KV4#cIenHRl?34_VhMUsqbpa3UPe8>1a)* ztCJRj?ak*e@ygo9_7*d*|AD!~TEDs%G+R+>%_iDT6JpIsZib6cJ~za=Ph%;Lg(eoX^?w9afLe{d3`v{>pa)0@3qREztJo65dx6qQIuk}r;KSx1wC*y0ePlLj$Jze zDu4=%dyomc{f_NvLl@e$5b3>P%s9zsY qal~NU_Fit_vM?1T_%mXGUHC!**RIbPf48u$#5O1&>*W+Aa_9%9Xq&YF literal 0 HcmV?d00001 diff --git a/onnxruntime/test/python/transformers/test_gelu_fusions.py b/onnxruntime/test/python/transformers/test_gelu_fusions.py index 94b969ad5377d..11ae1401ff8ed 100644 --- a/onnxruntime/test/python/transformers/test_gelu_fusions.py +++ b/onnxruntime/test/python/transformers/test_gelu_fusions.py @@ -3,6 +3,7 @@ import unittest import torch +from parameterized import parameterized from parity_utilities import find_transformers_source if find_transformers_source(): @@ -43,16 +44,6 @@ def forward(self, x): return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) -test_cases = [ - ("huggingface", "Gelu", HuggingfaceGelu), - ("huggingface", "FastGelu", HuggingfaceFastGelu), - ("huggingface", "QuickGelu", HuggingfaceQuickGelu), - ("huggingface", "FastGelu", HuggingfaceTorchGeluTanh), - ("megatron", "Gelu", MegatronGelu), - ("megatron", "FastGelu", MegatronFastGelu), -] - - class TestGeluFusions(unittest.TestCase): def verify_node_count(self, bert_model, expected_node_count, test_name): for op_type, count in expected_node_count.items(): @@ -62,25 +53,47 @@ def verify_node_count(self, bert_model, expected_node_count, test_name): print(f"{op}: {len(bert_model.get_nodes_by_op_type(op))} expected={counter}") self.assertEqual(len(bert_model.get_nodes_by_op_type(op_type)), count) - def test_fusions(self): - for test_case in test_cases: - source, operator, model_class = test_case - model = model_class() - dummy_input = torch.ones(3, dtype=torch.float32) - test_name = f"{operator}_{source}" - onnx_path = f"{test_name}.onnx" - torch.onnx.export( - model, - (dummy_input), - onnx_path, - input_names=["input"], - output_names=["output"], - ) - optimizer = optimize_model(onnx_path, "bert") - # optimizer.save_model_to_file(f"{operator}_{source}_opt.onnx") - os.remove(onnx_path) - expected_node_count = {operator: 1} - self.verify_node_count(optimizer, expected_node_count, test_name) + @parameterized.expand( + [ + (("huggingface", "Gelu", HuggingfaceGelu), True), + (("huggingface", "FastGelu", HuggingfaceFastGelu), True), + (("huggingface", "QuickGelu", HuggingfaceQuickGelu), True), + (("huggingface", "FastGelu", HuggingfaceTorchGeluTanh), True), + (("megatron", "Gelu", MegatronGelu), True), + (("megatron", "FastGelu", MegatronFastGelu), True), + (("huggingface", "Gelu", HuggingfaceGelu), False), + (("huggingface", "FastGelu", HuggingfaceFastGelu), False), + (("huggingface", "QuickGelu", HuggingfaceQuickGelu), False), + (("huggingface", "FastGelu", HuggingfaceTorchGeluTanh), False), + (("megatron", "Gelu", MegatronGelu), False), + (("megatron", "FastGelu", MegatronFastGelu), False), + ] + ) + def test_fusions(self, test_case, dynamo): + source, operator, model_class = test_case + model = model_class() + dummy_input = torch.ones(3, dtype=torch.float32) + test_name = f"{operator}_{source}" + onnx_path = f"{test_name}.onnx" + torch.onnx.export( + model, + (dummy_input,), + onnx_path, + input_names=["input"], + output_names=["output"], + dynamo=dynamo, + optimize=True, # Only meaningful when dynamo is True + ) + optimizer = optimize_model(onnx_path, "bert") + # optimizer.save_model_to_file(f"{operator}_{source}_opt.onnx") + os.remove(onnx_path) + # Remove the associated .data file (dynamo) + data_path = onnx_path + ".data" + if os.path.exists(data_path): + os.remove(data_path) + expected_node_count = {operator: 1} + + self.verify_node_count(optimizer, expected_node_count, test_name) if __name__ == "__main__": diff --git a/onnxruntime/test/python/transformers/test_gemma3_vision.py b/onnxruntime/test/python/transformers/test_gemma3_vision.py new file mode 100644 index 0000000000000..4727d2c8030d2 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_gemma3_vision.py @@ -0,0 +1,216 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import unittest + +import onnx +import torch +from parameterized import parameterized +from parity_utilities import find_transformers_source + +if find_transformers_source(): + from dynamo_onnx_helper import DynamoOnnxHelper + from fusion_options import FusionOptions + from onnx_model import OnnxModel + from optimizer import optimize_model +else: + from onnxruntime.transformers.dynamo_onnx_helper import DynamoOnnxHelper + from onnxruntime.transformers.fusion_options import FusionOptions + from onnxruntime.transformers.onnx_model import OnnxModel + from onnxruntime.transformers.optimizer import optimize_model + + +# https://github.com/huggingface/transformers/blob/af9b2eaa54c150741f298d6db939af6328e1dc38/src/transformers/models/siglip/modeling_siglip.py#L363 +class SiglipAttention(torch.nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ + def __init__(self): + super().__init__() + self.embed_dim = 20 + self.num_heads = 2 + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + + self.k_proj = torch.nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = torch.nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = torch.nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = torch.nn.Linear(self.embed_dim, self.embed_dim) + + self.k_proj.weight.data.fill_(1) + self.v_proj.weight.data.fill_(1) + self.q_proj.weight.data.fill_(1) + self.out_proj.weight.data.fill_(1) + self.k_proj.bias.data.fill_(1) + self.v_proj.bias.data.fill_(1) + self.q_proj.bias.data.fill_(1) + self.out_proj.bias.data.fill_(1) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + output_attentions: bool | None = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + + if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + raise ValueError( + f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class Gemma3VSIGLIPAttentionAndLayerNorm(torch.nn.Module): + def __init__(self): + super().__init__() + self.attn = SiglipAttention() + self.ln = torch.nn.LayerNorm(20, eps=1e-05) + + def forward(self, x): + # SkipLayerNorm ------+ + # | | + # Attention | + # | | + # MatMul | + # | | + # SkipLayerNorm ------+ + + # SkipLayerNorm + x = x + x + x = self.ln(x) + residual = x + + # Attention + MatMul + x, _ = self.attn(x) + + # SkipLayerNorm + x = residual + x + x = self.ln(x) + return x + + +class TestFusion(unittest.TestCase): + def verify_fusion(self, optimized_model, expected_model_filename): + optimized_model.topological_sort(is_deterministic=True) + + expected_model_path = os.path.join(os.path.dirname(__file__), "test_data", "models", expected_model_filename) + expected_model = OnnxModel(onnx.load(expected_model_path)) + expected_model.topological_sort(is_deterministic=True) + + nodes = optimized_model.model.graph.node + self.assertEqual(len(nodes), len(expected_model.model.graph.node)) + + for i in range(len(nodes)): + self.assertEqual(nodes[i], expected_model.model.graph.node[i]) + + for expected_initializer in expected_model.model.graph.initializer: + self.assertTrue( + OnnxModel.has_same_value( + optimized_model.get_initializer(expected_initializer.name), + expected_initializer, + ) + ) + + def export(self, model, inputs) -> onnx.ModelProto: + with torch.no_grad(): + onnx_program = torch.onnx.export( + model, + args=inputs, + # f=os.path.join(os.path.dirname(__file__), "export.onnx"), + dynamo=True, + optimize=True, + ) + return onnx_program.model_proto # type: ignore + + def tearDown(self): + paths = [ + os.path.join(os.path.dirname(__file__), "export.onnx"), + os.path.join(os.path.dirname(__file__), "export.onnx.data"), + ] + for path in paths: + if os.path.exists(path): + os.remove(path) + + @parameterized.expand( + [ + (torch.float32, "gemma3-vision-attention_fp32.onnx"), + (torch.float16, "gemma3-vision-attention_fp16.onnx"), + ] + ) + def test_gemma3_vision_attention(self, dtype, model_name): + model = Gemma3VSIGLIPAttentionAndLayerNorm().eval().to(dtype) + inputs = (torch.randn(1, 2, 20, dtype=dtype),) + original_model = self.export(model, inputs) + + # TODO(titaiwang): Upstream these processings to onnxscript pass + onnx_model_wrapper = DynamoOnnxHelper(original_model) + onnx_model_wrapper.convert_constants_to_initializers() + onnx_model_wrapper.clear_metadata() + model_path = os.path.join(os.path.dirname(__file__), "export.onnx") + onnx_model_wrapper.model.save_model_to_file( + model_path, + use_external_data_format=True, + all_tensors_to_one_file=True, + convert_attribute=True, + ) + + options = FusionOptions("clip") + optimized_model = optimize_model( + model_path, + model_type="clip", + num_heads=2, + hidden_size=20, + optimization_options=options, + opt_level=0, + ) + self.verify_fusion(optimized_model, model_name) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/ci_build/github/linux/python/requirements.txt b/tools/ci_build/github/linux/python/requirements.txt index e51cfb38f57a3..1a580b848a55a 100644 --- a/tools/ci_build/github/linux/python/requirements.txt +++ b/tools/ci_build/github/linux/python/requirements.txt @@ -7,4 +7,5 @@ onnx==1.17.0 ; python_version < '3.13' protobuf==4.21.12 sympy==1.12 flatbuffers -psutil \ No newline at end of file +psutil +onnxscript==0.2.3 diff --git a/tools/ci_build/github/windows/python/requirements.txt b/tools/ci_build/github/windows/python/requirements.txt index 200b9c2e50288..2b222c4b1d4a4 100644 --- a/tools/ci_build/github/windows/python/requirements.txt +++ b/tools/ci_build/github/windows/python/requirements.txt @@ -8,3 +8,4 @@ protobuf==4.21.12 sympy==1.12 flatbuffers psutil +onnxscript==0.2.3 diff --git a/tools/ci_build/requirements/transformers-test/requirements.txt b/tools/ci_build/requirements/transformers-test/requirements.txt index 0fb37e3a1550a..47286c364a90f 100644 --- a/tools/ci_build/requirements/transformers-test/requirements.txt +++ b/tools/ci_build/requirements/transformers-test/requirements.txt @@ -11,3 +11,4 @@ parameterized>=0.8.1 sentencepiece psutil einops +onnxscript==0.2.3 From 11fda2adb93a48404fa29e090f651b5825f95f0e Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 4 Apr 2025 13:40:17 -0700 Subject: [PATCH 07/18] Update packaging pipeline for Nodejs binding (#24301) ### Description Update packaging pipeline for Nodejs binding. This change updates the pipeline to perform all Node.js binding builds, including: - Windows x64 ( CPU, DML, WebGPU ) - Windows arm64 ( CPU, DML, WebGPU ) - Linux x64 ( CPU, CUDA, TensorRT, WebGPU ) - Linux arm64 ( CPU ) - MacOS x64 ( CPU, CoreML, WebGPU ) - MacOS arm64 ( CPU, CoreML, WebGPU ) #### Dependencies The Node.js binding depends on the Nuget package from the same build. Because NPM has a size limit so we cannot fit libonnxruntime_provider_cuda.so into it. The Node.js binding works in a way that an installation script will try to download the Nuget package of the corresponding version. --- .../c-api-noopenmp-packaging-pipelines.yml | 23 +++ .../stages/nodejs-linux-packaging-stage.yml | 57 ++++++ .../stages/nodejs-win-packaging-stage.yml | 192 ++++++++++++++++++ .../stages/nuget-combine-cuda-stage.yml | 5 + .../nuget-linux-cuda-packaging-stage.yml | 11 +- .../azure-pipelines/templates/c-api-cpu.yml | 177 +++++++--------- .../linux-cpu-packaging-pipeline.yml | 2 +- .../github/linux/build_nodejs_package.sh | 6 + 8 files changed, 362 insertions(+), 111 deletions(-) create mode 100644 tools/ci_build/github/azure-pipelines/stages/nodejs-linux-packaging-stage.yml create mode 100644 tools/ci_build/github/azure-pipelines/stages/nodejs-win-packaging-stage.yml create mode 100755 tools/ci_build/github/linux/build_nodejs_package.sh diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 722e7696ba738..8f1189b05858c 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -147,6 +147,29 @@ extends: SpecificArtifact: ${{ parameters.SpecificArtifact }} BuildId: ${{ parameters.BuildId }} + - template: stages/nodejs-win-packaging-stage.yml + parameters: + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} + ArtifactName: 'drop-onnxruntime-nodejs-win-x64' + StageName: 'Windows_Nodejs_Packaging_x64' + BuildCommand: --skip_submodule_sync --build_shared_lib --enable_onnx_tests --enable_wcos --use_telemetry --use_dml --use_webgpu --build_nodejs --cmake_generator "Visual Studio 17 2022" + BuildArch: 'x64' + EnvSetupScript: 'setup_env.bat' + sln_platform: 'x64' + DoEsrp: ${{ parameters.DoEsrp }} + PublishWebGpuBuildTools: true + + - template: stages/nodejs-win-packaging-stage.yml + parameters: + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} + ArtifactName: 'drop-onnxruntime-nodejs-win-arm64' + StageName: 'Windows_Nodejs_Packaging_arm64' + BuildCommand: --arm64 --skip_submodule_sync --build_shared_lib --enable_onnx_tests --enable_wcos --use_telemetry --use_dml --use_webgpu --build_nodejs --cmake_generator "Visual Studio 17 2022" + BuildArch: 'x64' + EnvSetupScript: 'setup_env.bat' + sln_platform: 'arm64' + DoEsrp: ${{ parameters.DoEsrp }} + DependsOnStageName: Windows_Nodejs_Packaging_x64 - template: nuget/templates/dml-vs-2022.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/stages/nodejs-linux-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nodejs-linux-packaging-stage.yml new file mode 100644 index 0000000000000..e1247565d8f5b --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/stages/nodejs-linux-packaging-stage.yml @@ -0,0 +1,57 @@ +parameters: +- name: CudaVersion + type: string + default: '12.2' + +stages: +- stage: Linux_Nodejs_Packaging_x64 + dependsOn: [] + jobs: + - job: Linux_Nodejs_Packaging_x64 + dependsOn: [] + workspace: + clean: all + timeoutInMinutes: 180 + pool: + name: 'onnxruntime-Ubuntu2204-AMD-CPU' + os: linux + variables: + - template: ../templates/common-variables.yml + - name: CUDA_VERSION_MAJOR + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: '11' + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: '12' + - name: CUDA_VERSION + value: ${{ parameters.CudaVersion }} + - name: linux_trt_version + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: ${{ variables.linux_trt_version_cuda11 }} + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: ${{ variables.linux_trt_version_cuda12 }} + steps: + - checkout: self + clean: true + submodules: recursive + - template: ../templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/inference/x86_64/default/cuda${{ variables.CUDA_VERSION_MAJOR }}/Dockerfile + Context: tools/ci_build/github/linux/docker/inference/x86_64/default/cuda${{ variables.CUDA_VERSION_MAJOR }} + DockerBuildArgs: " + --build-arg TRT_VERSION=${{ variables.linux_trt_version }} + --build-arg BUILD_UID=$( id -u ) + " + Repository: onnxruntimecuda${{ variables.CUDA_VERSION_MAJOR }}xtrt86build + - template: ../templates/set-version-number-variables-step.yml + + - script: $(Build.SourcesDirectory)/tools/ci_build/github/linux/build_nodejs_package.sh + workingDirectory: $(Build.SourcesDirectory) + displayName: 'Build Node.js binding Package' + + - template: ../templates/nodejs-artifacts-package-and-publish-steps-posix.yml + parameters: + arch: 'x64' + os: 'linux' + artifactName: 'drop-onnxruntime-nodejs-linux-x64' + + - template: ../templates/clean-agent-build-directory-step.yml diff --git a/tools/ci_build/github/azure-pipelines/stages/nodejs-win-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nodejs-win-packaging-stage.yml new file mode 100644 index 0000000000000..73e650eb07992 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/stages/nodejs-win-packaging-stage.yml @@ -0,0 +1,192 @@ +parameters: + BuildCommand: '' + StageName: 'Windows_Nodejs_Packaging' + ArtifactName: 'drop-onnxruntime-nodejs-win' + DoEsrp: 'false' + BuildArch: 'x64' # Optional. Options: x86, x64 + sln_platform: 'x64' # Options: Win32, x64, arm, arm64 + AgentDemands: [] + BuildConfigurations: ['RelWithDebInfo'] # Options: Debug, RelWithDebInfo + EnableLto: true + # Controls whether unreleased onnx opsets are allowed. Default is set to 1 + AllowReleasedOpsetOnly: '0' + IsReleaseBuild: false + PublishWebGpuBuildTools: false + WebGpuBuildToolsArtifactName: 'Windows_WebGPU_BuildTools_x64' + DependsOnStageName: '' + +stages: +- stage: ${{ parameters.StageName }} + dependsOn: + - Setup + - ${{ if ne(parameters.DependsOnStageName, '') }}: + - ${{ parameters.DependsOnStageName }} + + jobs: + - job: ${{ parameters.StageName }} + timeoutInMinutes: 200 + strategy: + maxParallel: 2 + matrix: + ${{ each BuildConfiguration in parameters.BuildConfigurations }}: + ${{ BuildConfiguration }}: + BuildConfig: ${{ BuildConfiguration }} + workspace: + clean: all + pool: + name: onnxruntime-Win-CPU-2022 + demands: ${{ parameters.AgentDemands }} + variables: + buildDirectory: '$(Build.BinariesDirectory)' + OnnxRuntimeBuildDirectory: '$(Build.BinariesDirectory)' + runCodesignValidationInjection: ${{ parameters. DoEsrp}} #For the others, code sign is in a separated job + DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true + ALLOW_RELEASED_ONNX_OPSET_ONLY: ${{ parameters.AllowReleasedOpsetOnly }} + BuildDate : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] + BuildTime : $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] + BuildCommandExtra: '' + ${{ if eq(parameters.EnableLto, true) }}: + build_py_lto_flag: --enable_lto + + steps: + - checkout: self + clean: true + submodules: none + + - powershell: | + if($env:TELEMETRYGUID) + { + $length = $env:TELEMETRYGUID.length + $fileContent = "#define TraceLoggingOptionMicrosoftTelemetry() \ + TraceLoggingOptionGroup("+$env:TELEMETRYGUID.substring(1, $length-2)+")" + New-Item -Path "$(Build.SourcesDirectory)\include\onnxruntime\core\platform\windows\TraceLoggingConfigPrivate.h" -ItemType "file" -Value "$fileContent" -Force + Write-Output "Enabling TELEMETRY" + } + displayName: 'Create TraceLoggingConfigPrivate.h For WinML Telemetry' + env: + TELEMETRYGUID: $(TELEMETRYGUID) + + - task: NodeTool@0 + inputs: + versionSpec: '20.x' + + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.12' + addToPath: true + architecture: ${{ parameters.BuildArch }} + + # need to set PROCESSOR_ARCHITECTURE so the x86 SDK is installed correctly + - task: UseDotNet@2 + inputs: + version: 8.x + env: + PROCESSOR_ARCHITECTURE: ${{ parameters.BuildArch }} + + - task: BatchScript@1 + displayName: 'Setup VS2022 env vars' + inputs: + filename: 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat' + arguments: ${{ parameters.BuildArch }} + modifyEnvironment: true + + - ${{ if and(ne(parameters.WebGpuBuildToolsArtifactName, ''), eq(parameters.sln_platform, 'arm64')) }}: + - task: DownloadPipelineArtifact@2 + displayName: 'Download WebGPU build tools from x64 build' + inputs: + artifactName: '${{ parameters.WebGpuBuildToolsArtifactName }}' + targetPath: '$(Build.BinariesDirectory)\${{ parameters.WebGpuBuildToolsArtifactName }}' + - script: | + @echo ##vso[task.setvariable variable=LLVM_TABLEGEN_PATH]$(Build.BinariesDirectory)\${{ parameters.WebGpuBuildToolsArtifactName }}\llvm-tblgen.exe + @echo ##vso[task.setvariable variable=CLANG_TABLEGEN_PATH]$(Build.BinariesDirectory)\${{ parameters.WebGpuBuildToolsArtifactName }}\clang-tblgen.exe + displayName: 'Set tablegen paths' + - powershell: | + Write-Host "Using LLVM_TABLEGEN_PATH: $(LLVM_TABLEGEN_PATH)" + Write-Host "Using CLANG_TABLEGEN_PATH: $(CLANG_TABLEGEN_PATH)" + Write-Host "##vso[task.setvariable variable=BuildCommandExtra]--cmake_extra_defines LLVM_TABLEGEN=$(LLVM_TABLEGEN_PATH) CLANG_TABLEGEN=$(CLANG_TABLEGEN_PATH)" + displayName: 'Set build flags for WebGPU cross-compilation' + + - powershell: | + python tools\ci_build\build.py --build_dir $(Build.BinariesDirectory) ${{ parameters.BuildCommand }} $(BuildCommandExtra) --use_binskim_compliant_compile_flags --parallel --build --update --config $(BuildConfig) --msbuild_extra_options IncludeMobileTargets=false ${{ variables.build_py_lto_flag }} + + - ${{ if notIn(parameters['sln_platform'], 'Win32', 'x64') }}: + # Use cross-compiled protoc + - script: | + @echo ##vso[task.setvariable variable=ProtocDirectory]$(Build.BinariesDirectory)\installed\bin + + # The Configuration variable is required to build C# + - script: | + @echo ##vso[task.setvariable variable=Configuration]$(BuildConfig) + displayName: 'Set Configuration variable' + + # Node.js Publish + - task: BatchScript@1 + displayName: 'Setup VS env vars' + inputs: + filename: 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat' + arguments: ${{ parameters.BuildArch }} + modifyEnvironment: true + - task: CopyFiles@2 + displayName: 'Copy DirectML binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\${{ parameters.sln_platform }}' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig)' + Contents: 'DirectML.dll' + TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\${{ parameters.sln_platform }}' + - powershell: | + $dxcZipUrl = "https://github.com/microsoft/DirectXShaderCompiler/releases/download/v1.8.2502/dxc_2025_02_20.zip" + $dxcZipPath = "$(Build.BinariesDirectory)\dxc.zip" + $dxcExtractPath = "$(Build.BinariesDirectory)\dxc_extracted" + $targetArch = "${{ parameters.sln_platform }}" + + # Download the DXC package + Write-Host "Downloading DXC release from $dxcZipUrl" + Invoke-WebRequest -Uri $dxcZipUrl -OutFile $dxcZipPath + + # Create extraction directory + if (-not (Test-Path $dxcExtractPath)) { + New-Item -Path $dxcExtractPath -ItemType Directory -Force + } + + # Extract the zip file + Write-Host "Extracting DXC package to $dxcExtractPath" + Expand-Archive -Path $dxcZipPath -DestinationPath $dxcExtractPath -Force + + # Copy the necessary DLLs to the target directory + $sourcePath = Join-Path $dxcExtractPath "bin\$targetArch" + $targetPath = "$(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\$targetArch" + + Write-Host "Copying dxil.dll and dxcompiler.dll from $sourcePath to $targetPath" + Copy-Item -Path "$sourcePath\dxil.dll" -Destination $targetPath -Force + Copy-Item -Path "$sourcePath\dxcompiler.dll" -Destination $targetPath -Force + + Write-Host "DXC DLLs successfully copied to the target directory" + displayName: 'Download and Copy DXC Binaries' + - template: ../templates/win-esrp-dll.yml + parameters: + FolderPath: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\${{ parameters.sln_platform }}' + DisplayName: 'ESRP - Sign Node.js binding binaries' + DoEsrp: ${{ parameters.DoEsrp }} + Pattern: '*.dll,*.node' + + - script: | + del /Q $(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\${{ parameters.sln_platform }}\CodeSignSummary-*.* + call npm pack + copy $(Build.SourcesDirectory)\js\node\onnxruntime-*.tgz $(Build.ArtifactStagingDirectory) + workingDirectory: '$(Build.SourcesDirectory)\js\node' + displayName: 'Create NPM Package' + + - task: 1ES.PublishPipelineArtifact@1 + inputs: + targetPath: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\${{ parameters.sln_platform }}' + artifactName: ${{ parameters.ArtifactName }} + + - ${{ if and(eq(parameters.PublishWebGpuBuildTools, true), eq(parameters.sln_platform, 'x64')) }}: + - script: | + mkdir $(Build.ArtifactStagingDirectory)\${{ parameters.WebGpuBuildToolsArtifactName }} + copy $(Build.BinariesDirectory)\$(BuildConfig)\_deps\dawn-build\third_party\dxc\RelWithDebInfo\bin\llvm-tblgen.exe $(Build.ArtifactStagingDirectory)\${{ parameters.WebGpuBuildToolsArtifactName }} + copy $(Build.BinariesDirectory)\$(BuildConfig)\_deps\dawn-build\third_party\dxc\RelWithDebInfo\bin\clang-tblgen.exe $(Build.ArtifactStagingDirectory)\${{ parameters.WebGpuBuildToolsArtifactName }} + displayName: 'Copy WebGPU build tools' + - task: 1ES.PublishPipelineArtifact@1 + inputs: + targetPath: '$(Build.ArtifactStagingDirectory)\${{ parameters.WebGpuBuildToolsArtifactName }}' + artifactName: ${{ parameters.WebGpuBuildToolsArtifactName }} diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml index 893bf3f1ec394..a4fe78a7088e3 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml @@ -39,6 +39,11 @@ stages: buildJava: ${{ parameters.buildJava }} buildNodejs: ${{ parameters.buildNodejs }} +- ${{ if eq(parameters.buildNodejs, 'true') }}: + - template: nodejs-linux-packaging-stage.yml + parameters: + CudaVersion: ${{ parameters.CudaVersion }} + - template: nuget-win-cuda-packaging-stage.yml parameters: RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml index 8560817331475..e36fe98fe0ac2 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml @@ -32,9 +32,7 @@ stages: parameters: Dockerfile: tools/ci_build/github/linux/docker/inference/x86_64/default/cuda${{ variables.CUDA_VERSION_MAJOR }}/Dockerfile Context: tools/ci_build/github/linux/docker/inference/x86_64/default/cuda${{ variables.CUDA_VERSION_MAJOR }} - DockerBuildArgs: " - --build-arg BUILD_UID=$( id -u ) - " + DockerBuildArgs: " --build-arg BUILD_UID=$( id -u )" Repository: onnxruntimecuda${{ variables.CUDA_VERSION_MAJOR }}build - script: $(Build.SourcesDirectory)/tools/ci_build/github/linux/build_cuda_c_api_package.sh @@ -113,13 +111,6 @@ stages: nativeLibraryName: 'libonnxruntime4j_jni.so' is1ES: true - - ${{ if eq(parameters.buildNodejs, 'true') }}: - - template: ../templates/nodejs-artifacts-package-and-publish-steps-posix.yml - parameters: - arch: 'x64' - os: 'linux' - artifactName: 'drop-onnxruntime-nodejs-linux-x64-tensorrt' - - template: ../templates/c-api-artifacts-package-and-publish-steps-posix.yml parameters: buildConfig: 'Release' diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index bb789edc1cf21..9b1d7b705e741 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -182,10 +182,10 @@ stages: buildArch: x64 msbuildPlatform: arm64 packageName: arm64 - buildparameter: --build_nodejs --arm64 ${{ parameters.AdditionalBuildFlags }} ${{ parameters.AdditionalWinBuildFlags}} + buildparameter: --arm64 ${{ parameters.AdditionalBuildFlags }} ${{ parameters.AdditionalWinBuildFlags}} runTests: false buildJava: false - buildNodejs: true + buildNodejs: false - template: win-ci.yml parameters: @@ -194,10 +194,10 @@ stages: buildArch: x64 msbuildPlatform: x64 packageName: x64 - buildparameter: --build_java --build_nodejs ${{ parameters.AdditionalBuildFlags }} ${{ parameters.AdditionalWinBuildFlags}} + buildparameter: --build_java ${{ parameters.AdditionalBuildFlags }} ${{ parameters.AdditionalWinBuildFlags}} runTests: ${{ parameters.RunOnnxRuntimeTests }} buildJava: true - buildNodejs: true + buildNodejs: false - stage: Jar_Packaging dependsOn: @@ -506,14 +506,11 @@ stages: - stage: Nodejs_Packaging dependsOn: - - Windows_CI_GPU_DML_Dev - - Windows_CI_GPU_DML_Dev_arm64 + - Windows_Nodejs_Packaging_x64 + - Windows_Nodejs_Packaging_arm64 + - Linux_Nodejs_Packaging_x64 - Linux_C_API_Packaging_CPU - - Linux_C_API_Packaging_GPU - MacOS_C_API_Package_Publish - - Windows_Packaging_CPU_x86_${{ parameters.BuildVariant }} - - Windows_Packaging_CPU_x64_${{ parameters.BuildVariant }} - - Windows_Packaging_CPU_arm64_${{ parameters.BuildVariant }} condition: succeeded() jobs: - job: Nodejs_Packaging @@ -544,74 +541,78 @@ stages: # Node.js binding artifacts preparation # # This stage prepares Node.js binding artifacts for publishing. The artifacts support the following platforms: - # - Windows x64 with DML support - # - Windows arm64 with DML support - # - Linux x64 with TensorRT support + # - Windows x64 (CPU, DML, WebGPU) + # - Windows arm64 (CPU, DML, WebGPU) + # - Linux x64 (CPU, CUDA, TensorRT, WebGPU) # - Linux arm64 (CPU only) - # - macOS x64 (CPU only) - # - macOS arm64 (CPU only) + # - macOS x64 (CPU, CoreML, WebGPU) + # - macOS arm64 (CPU, CoreML, WebGPU) + # + # File manifest: + # - Windows x64 (CPU, DML, WebGPU): + # dependency: Windows_Nodejs_Packaging_x64 (drop-onnxruntime-nodejs-win-x64) + # files: + # - onnxruntime_binding.node + # - onnxruntime.dll + # - DirectML.dll + # - dxil.dll + # - dxcompiler.dll + # + # - Windows arm64 (CPU, DML, WebGPU): + # dependency: Windows_Nodejs_Packaging_arm64 (drop-onnxruntime-nodejs-win-arm64) + # files: + # - onnxruntime_binding.node + # - onnxruntime.dll + # - DirectML.dll + # - dxil.dll + # - dxcompiler.dll # - # ORT Node.js binding artifacts contain 2 parts: - # 1. ONNX Runtime native shared libraries and their dependencies - # - Windows (x64, arm64): - # - onnxruntime.dll - # - DirectML.dll - # - Linux (x64, arm64): - # - libonnxruntime.so{.version} - # - libonnxruntime_providers_shared.so - # - libonnxruntime_providers_{provider}.so - # - macOS (x64, arm64): - # - libonnxruntime.dylib - # 2. ONNX Runtime Node.js binding - # - onnxruntime_binding.node + # - Linux x64 (CPU, CUDA, TensorRT, WebGPU): + # dependency: Linux_Nodejs_Packaging_x64 (drop-onnxruntime-nodejs-linux-x64) + # files: + # - onnxruntime_binding.node + # - libonnxruntime.so.1 + # - libonnxruntime_providers_shared.so + # - libonnxruntime_providers_cuda.so + # - libonnxruntime_providers_tensorrt.so # - # For windows platform, the artifact is named as 'onnxruntime-nodejs-win-x64-dml' for x64, and - # 'onnxruntime-nodejs-win-arm64-dml' for arm64. Each artifact contains both (1) and (2). + # - Linux arm64 (CPU only): + # dependency: Linux_C_API_Packaging_CPU_aarch64 (drop-onnxruntime-nodejs-linux-aarch64) + # files: + # - onnxruntime_binding.node + # - libonnxruntime.so.1 # - # For Linux and macOS platforms, (1) and (2) are packed into separate artifacts. - # The following artifacts contain (1): - # - onnxruntime-osx - # - onnxruntime-linux-x64-tensorrt - # - onnxruntime-linux-aarch64 - # The following artifacts contain (2): - # - drop-onnxruntime-nodejs-linux-x64-tensorrt - # - drop-onnxruntime-nodejs-linux-aarch64 - # - drop-onnxruntime-nodejs-osx-x86_64 - # - drop-onnxruntime-nodejs-osx-arm64 + # - macOS x64 (CPU, CoreML, WebGPU): + # dependency: MacOS_C_API_Packaging_CPU_x86_64 (drop-onnxruntime-nodejs-osx-x86_64) + # files: + # - onnxruntime_binding.node + # - libonnxruntime.{version}.dylib # - # All binary artifacts will eventually be put into folder before packaging 'onnxruntime-node': + # - macOS arm64 (CPU, CoreML, WebGPU): + # dependency: MacOS_C_API_Packaging_CPU_arm64 (drop-onnxruntime-nodejs-osx-arm64) + # files: + # - onnxruntime_binding.node + # - libonnxruntime.{version}.dylib + # + # The following files will be excluded from the further packaging because they are too large to be included in the + # NPM package: + # - linux/x64/libonnxruntime_providers_cuda.so + # + # Rest binary artifacts will eventually be put into folder before packaging 'onnxruntime-node': # $(Build.SourcesDirectory)\js\node\bin\napi-v3\{os}\{cpu_arch}\ # # {os} is one of 'win32', 'darwin', 'linux' and {cpu_arch} is one of 'x64', 'arm64'. - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - NuGet (OSX)' - inputs: - artifactName: 'onnxruntime-osx' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' - - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - NuGet (Linux x64)' - inputs: - artifactName: 'onnxruntime-linux-x64-tensorrt' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' - - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - NuGet (Linux aarch64)' - inputs: - artifactName: 'onnxruntime-linux-aarch64' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' - - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - Nodejs (Win x64)' inputs: - artifactName: 'drop-onnxruntime-nodejs-win-x64-dml' + artifactName: 'drop-onnxruntime-nodejs-win-x64' targetPath: '$(Build.BinariesDirectory)/nodejs-artifacts/win32/x64/' - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - Nodejs (Win ARM64)' inputs: - artifactName: 'drop-onnxruntime-nodejs-win-arm64-dml' + artifactName: 'drop-onnxruntime-nodejs-win-arm64' targetPath: '$(Build.BinariesDirectory)/nodejs-artifacts/win32/arm64/' - task: DownloadPipelineArtifact@0 @@ -629,7 +630,7 @@ stages: - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - Nodejs (Linux x64)' inputs: - artifactName: 'drop-onnxruntime-nodejs-linux-x64-tensorrt' + artifactName: 'drop-onnxruntime-nodejs-linux-x64' targetPath: '$(Build.BinariesDirectory)/nodejs-artifacts/linux/x64/' - task: DownloadPipelineArtifact@0 @@ -638,15 +639,9 @@ stages: artifactName: 'drop-onnxruntime-nodejs-linux-aarch64' targetPath: '$(Build.BinariesDirectory)/nodejs-artifacts/linux/arm64/' - - task: PowerShell@2 - displayName: 'PowerShell Script' - inputs: - targetType: filePath - filePath: $(Build.SourcesDirectory)\tools\ci_build\github\windows\extract_nuget_files.ps1 - - script: | - dir - workingDirectory: '$(Build.BinariesDirectory)/nuget-artifact' + dir /S + workingDirectory: '$(Build.BinariesDirectory)/nodejs-artifacts' displayName: 'List artifacts' - script: | @@ -683,61 +678,43 @@ stages: TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\arm64' # Node.js binding linux/x64 - - task: CopyFiles@2 - displayName: 'Copy nuget binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\linux\x64\' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-linux-x64-tensorrt\lib' - Contents: | - libonnxruntime.so.1 - libonnxruntime_providers_shared.so - TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\linux\x64' - task: CopyFiles@2 displayName: 'Copy nodejs binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\linux\x64\' inputs: SourceFolder: '$(Build.BinariesDirectory)\nodejs-artifacts\linux\x64' - Contents: '*.node' + Contents: | + libonnxruntime.so.1 + *.node TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\linux\x64' # Node.js binding linux/arm64 - - task: CopyFiles@2 - displayName: 'Copy nuget binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\linux\arm64\' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-linux-aarch64\lib' - Contents: 'libonnxruntime.so.1' - TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\linux\arm64' - task: CopyFiles@2 displayName: 'Copy nodejs binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\linux\arm64\' inputs: SourceFolder: '$(Build.BinariesDirectory)\nodejs-artifacts\linux\arm64' - Contents: '*.node' + Contents: | + libonnxruntime.so.1 + *.node TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\linux\arm64' # Node.js binding darwin/x64 - - task: CopyFiles@2 - displayName: 'Copy nuget binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\darwin\x64\' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-osx-x86_64\lib' - Contents: 'libonnxruntime.*.dylib' - TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\darwin\x64' - task: CopyFiles@2 displayName: 'Copy nodejs binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\darwin\x64\' inputs: SourceFolder: '$(Build.BinariesDirectory)\nodejs-artifacts\darwin\x64' - Contents: '*.node' + Contents: | + libonnxruntime.*.dylib + *.node TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\darwin\x64' # Node.js binding darwin/arm64 - - task: CopyFiles@2 - displayName: 'Copy nuget binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\darwin\arm64\' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-osx-arm64\lib' - Contents: 'libonnxruntime.*.dylib' - TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\darwin\arm64' - task: CopyFiles@2 displayName: 'Copy nodejs binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\darwin\arm64\' inputs: SourceFolder: '$(Build.BinariesDirectory)\nodejs-artifacts\darwin\arm64' - Contents: '*.node' + Contents: | + libonnxruntime.*.dylib + *.node TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\darwin\arm64' - task: PowerShell@2 diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml index 7ac2e3a8addb6..fb1c63e1f8a24 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml @@ -34,7 +34,7 @@ stages: PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' ArtifactNamePrefix: ${{ parameters.ArtifactNamePrefix }} PackageJava: ${{ parameters.PackageJava }} - PackageNodeJS: ${{ parameters.PackageNodeJS }} + PackageNodeJS: false - template: c-api-linux-cpu.yml parameters: diff --git a/tools/ci_build/github/linux/build_nodejs_package.sh b/tools/ci_build/github/linux/build_nodejs_package.sh new file mode 100755 index 0000000000000..29ee91a122e39 --- /dev/null +++ b/tools/ci_build/github/linux/build_nodejs_package.sh @@ -0,0 +1,6 @@ +#!/bin/bash +set -e -x +mkdir -p $HOME/.onnx +docker run -e SYSTEM_COLLECTIONURI --rm --volume /data/onnx:/data/onnx:ro --volume $BUILD_SOURCESDIRECTORY:/onnxruntime_src --volume $BUILD_BINARIESDIRECTORY:/build \ +--volume /data/models:/build/models:ro --volume $HOME/.onnx:/home/onnxruntimedev/.onnx -e NIGHTLY_BUILD onnxruntimecuda${CUDA_VERSION_MAJOR}xtrt86build \ +/bin/bash -c "/usr/bin/python3 /onnxruntime_src/tools/ci_build/build.py --build_dir /build --config Release --skip_tests --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib --build_nodejs --use_webgpu --use_tensorrt --cuda_version=$CUDA_VERSION --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr --tensorrt_home=/usr --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60-real;70-real;75-real;80-real;90' --use_vcpkg --use_vcpkg_ms_internal_asset_cache && cd /build/Release && make install DESTDIR=/build/installed" From a4976e33ec9538c5a881f082c090a774af4a7a44 Mon Sep 17 00:00:00 2001 From: Sushanth Rajasankar <44513542+sushraja-msft@users.noreply.github.com> Date: Fri, 4 Apr 2025 15:43:38 -0700 Subject: [PATCH 08/18] Add support for uint8_t as data type for GatherBlockQuantized (#24239) ### Description This change adds support for GatherBlockQuantized to use uin8_t as data's type with the same semantics as MatMulNBits. Zero_Points and Gather Axis other than 0 are not yet supported, in order to keep the change scoped. ### Motivation and Context With the newer llama models like Phi4 trained with shared embeddings, the weights of the lm_head matrix and the embeddings table are exactly the same. These embeddings are huge, unquantized embeddings are 1.2GB in Phi4 mini instruct, at int4 quantization the weights are still 300MB. We can go a step further and have these two ops the lm_head matmulnbits and GatherBlockQuantized share the same weights, that would save 300MB on the model size. The two things that hinder that are the shape expectations for GatherBlockQuantized and the data type supported for data in GatherBlockQuantized. The shape can be solved via a simple reshape op, but the data type needs code changes and that is what this change does. Here is Phi4 modified with shared weights between lm_head and matmulnbits, this model is just 2.1GB on disk. image --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/ContribOperators.md | 5 +- docs/OperatorKernels.md | 2 +- .../contrib_ops/cpu/cpu_contrib_kernels.cc | 4 + .../quantization/gather_block_quantized.cc | 45 +++++++++-- .../core/graph/contrib_ops/contrib_defs.cc | 21 ++++- .../gather_block_quantized_op_test.cc | 81 ++++++++++++++++++- 6 files changed, 142 insertions(+), 16 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index f582abca34706..0308b5c79c508 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2039,10 +2039,11 @@ This version of the operator has been available since version 1 of the 'com.micr 1. Input `data` is a constant. It is quantized block-wise along attribute `quantize_axis` with block size specified by attribute `block_size`. `block_size must` be a power of 2 and not smaller than 16, like 16, 32, 64, 128, .. 2. Input `data`'s scale and zero point are specified by input `scales` and `zero_points`. `scales` and `zero_points` are also constants. - If `zero_points` is not provided, 0 is the zero point. + If `zero_points` is not provided, 0 is the zero point except when data is uint8 type then the default zero point is 8. 3. During the op execution, `data` and `indices` are first used to generate the quantized output. Then, `scales` and `zero_points` are used to dequantize the output. 4. The `output` and `scales` have the same type. The `data` and `zero_points` have the same type. + 5. For uint8 data, the `gather_axis` must be 0. #### Version @@ -2082,7 +2083,7 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T1 : tensor(int4), tensor(uint4)
+
T1 : tensor(int4), tensor(uint4), tensor(uint8)
Constrain quantized types.
T2 : tensor(float), tensor(float16), tensor(bfloat16)
Constrain dequantized types.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 60d9e8e747eeb..a20333e2340c4 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -515,7 +515,7 @@ Do not modify directly.* |FusedConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*in* Z:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |FusedGemm|*in* A:**T**
*in* B:**T**
*in* C:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| -|GatherBlockQuantized|*in* data:**T1**
*in* indices:**Tind**
*in* scales:**T2**
*in* zero_points:**T1**
*out* output:**T2**|1+|**T1** = tensor(int4), tensor(uint4)
**T2** = tensor(float), tensor(float16)
**Tind** = tensor(int32), tensor(int64)| +|GatherBlockQuantized|*in* data:**T1**
*in* indices:**Tind**
*in* scales:**T2**
*in* zero_points:**T1**
*out* output:**T2**|1+|**T1** = tensor(int4), tensor(uint4), tensor(uint8)
**T2** = tensor(float), tensor(float16)
**Tind** = tensor(int32), tensor(int64)| |GatherND|*in* data:**T**
*in* indices:**Tind**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)| diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 345b5e793a764..1a737f3a9d251 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -38,6 +38,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Fused class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MatMulNBits); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, int32_t, GatherBlockQuantized); +class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, int64_t, GatherBlockQuantized); class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, int32_t, GatherBlockQuantized); class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, int64_t, GatherBlockQuantized); class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Int4x2, int32_t, GatherBlockQuantized); @@ -318,6 +320,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc b/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc index 5935663f114a3..b83164d806ffc 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc @@ -16,6 +16,21 @@ namespace onnxruntime { namespace contrib { +namespace { +template +int32_t GetDataElement(const T1* data_ptr, int64_t data_idx) { + return static_cast(data_ptr[data_idx >> 1].GetElem(narrow(data_idx & 1))); +} + +template <> +int32_t GetDataElement(const uint8_t* data_ptr, int64_t data_idx) { + const uint8_t data_val_u8 = data_ptr[data_idx >> 1]; + // Weights are stored as (nibble2)(nibble1) in uint8_t. + auto data_val = static_cast((data_idx & 1) ? ((data_val_u8 >> 4) & 0x0F) : (data_val_u8 & 0x0F)); + return data_val; +} +} // namespace + template class GatherBlockQuantized : public OpKernel { public: @@ -98,6 +113,12 @@ Status GatherBlockQuantized::PrepareForCompute(OpKernelContext* contex for (int64_t i = p.gather_axis + 1; i < static_cast(data_rank); ++i) shape.push_back(data_shape[narrow(i)]); + // When data is stored as uint8_t, each element has two int4 values. + // The shape in the onnx model reflects that by having the last dimension be half the number of values. + // Ex: For a true data size of 2000x3072, the onnx model would have data of shape 2000x1536. + // However the outputs still need to be of size 2000x3072. Therefore we x2 the last dimension here. + uint32_t components = (std::is_same_v) ? 2 : 1; + shape[shape.size() - 1] = shape.back() * components; p.output_tensor = context->Output(0, TensorShape(std::move(shape))); // validate quantization parameters @@ -106,7 +127,7 @@ Status GatherBlockQuantized::PrepareForCompute(OpKernelContext* contex "data and scales must have the same rank."); for (size_t i = 0; i < data_shape.NumDimensions(); ++i) { ORT_RETURN_IF_NOT(i == static_cast(p.quantize_axis) - ? (data_shape[i] + block_size_ - 1) / block_size_ == scales_shape[i] + ? (data_shape[i] * components + block_size_ - 1) / block_size_ == scales_shape[i] : data_shape[i] == scales_shape[i], "data and scales do not match shapes."); } @@ -165,16 +186,22 @@ Status GatherBlockQuantized::CopyDataAndDequantize(const T1* data_ptr, int64_t output_idx = output_idx_base; int64_t data_idx = data_idx_base; for (int64_t i = 0; i < gather_block; ++i, ++output_idx, ++data_idx) { - auto data_val = static_cast(data_ptr[data_idx >> 1].GetElem(narrow(data_idx & 1))); + auto data_val = GetDataElement(data_ptr, data_idx); int64_t x = data_idx / quantize_full_block; int64_t y = data_idx % quantize_full_block / quantize_N; int64_t z = data_idx % quantize_N; int64_t scale_idx = x * scale_full_block + y / block_size_ * quantize_N + z; auto scale_val = static_cast(scales_ptr[scale_idx]); - auto zp_val = static_cast(zero_points_ptr - ? zero_points_ptr[scale_idx >> 1].GetElem(narrow(scale_idx & 1)) - : 0); + int32_t zp_val; + if constexpr (std::is_same_v) { + // The default zero point for uint8 weights as stored by MatMulNBits op is 8. + zp_val = 8; + } else { + zp_val = static_cast(zero_points_ptr + ? zero_points_ptr[scale_idx >> 1].GetElem(narrow(scale_idx & 1)) + : 0); + } output_ptr[output_idx] = static_cast(static_cast(data_val - zp_val) * scale_val); } @@ -205,7 +232,7 @@ template Status GatherBlockQuantized::Compute(OpKernelContext* context) const { Prepare p; ORT_RETURN_IF_ERROR(PrepareForCompute(context, p)); - + auto components = (std::is_same_v) ? 2 : 1; const auto& data_shape = p.data_tensor->Shape(); // re-shape the data tensor to [gather_M, gather_axis_dim, gather_block] // re-shape the indices tensor to [gather_N] @@ -215,7 +242,7 @@ Status GatherBlockQuantized::Compute(OpKernelContext* context) const { // 2> block is picked from data based on value from indices: axis_i = indices[blk_i % gather_N], // 3> get the corresponding block in data tensor: data_blk = data[blk_i / gather_N, axis_i, :], // 4> pick the element from the block: value_i = data_blk[blk_ele_i] - const int64_t gather_block = data_shape.SizeFromDimension(SafeInt(p.gather_axis) + 1); + const int64_t gather_block = data_shape.SizeFromDimension(SafeInt(p.gather_axis) + 1) * components; const int64_t gather_axis_dim = data_shape[narrow(p.gather_axis)]; const int64_t gather_M = data_shape.SizeToDimension(narrow(p.gather_axis)); const int64_t gather_N = p.indices_tensor->Shape().Size(); @@ -229,7 +256,7 @@ Status GatherBlockQuantized::Compute(OpKernelContext* context) const { // data_i % (quantize_axis_dim * quantize_N) / quantize_N, // data_i % quantize_N) // 4> get scale index: (x, y / block_size_, z) - const int64_t quantize_axis_dim = data_shape[narrow(p.quantize_axis)]; + const int64_t quantize_axis_dim = data_shape[narrow(p.quantize_axis)] * components; const int64_t quantize_N = data_shape.SizeFromDimension(SafeInt(p.quantize_axis) + 1); concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); @@ -273,6 +300,8 @@ Status GatherBlockQuantized::Compute(OpKernelContext* context) const { .TypeConstraint("Tind", DataTypeImpl::GetTensorType()), \ GatherBlockQuantized); +REGISTER_GATHERBLOCKQUANTIZED(uint8_t, int32_t); +REGISTER_GATHERBLOCKQUANTIZED(uint8_t, int64_t); REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, int32_t); REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, int64_t); REGISTER_GATHERBLOCKQUANTIZED(Int4x2, int32_t); diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 7b4a45ce8aa0f..d87688a62040c 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3571,10 +3571,11 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h 1. Input `data` is a constant. It is quantized block-wise along attribute `quantize_axis` with block size specified by attribute `block_size`. `block_size must` be a power of 2 and not smaller than 16, like 16, 32, 64, 128, .. 2. Input `data`'s scale and zero point are specified by input `scales` and `zero_points`. `scales` and `zero_points` are also constants. - If `zero_points` is not provided, 0 is the zero point. + If `zero_points` is not provided, 0 is the zero point except when data is uint8 type then the default zero point is 8. 3. During the op execution, `data` and `indices` are first used to generate the quantized output. Then, `scales` and `zero_points` are used to dequantize the output. 4. The `output` and `scales` have the same type. The `data` and `zero_points` have the same type. + 5. For uint8 data, the `gather_axis` must be 0. )DOC"; ONNX_CONTRIB_OPERATOR_SCHEMA(GatherBlockQuantized) @@ -3602,7 +3603,7 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h .Input(2, "scales", "quantization scale", "T2") .Input(3, "zero_points", "quantization zero points", "T1", OpSchema::Optional) .Output(0, "output", "Dequantized output tensor of rank q + (r - 1).", "T2") - .TypeConstraint("T1", {"tensor(int4)", "tensor(uint4)"}, "Constrain quantized types.") + .TypeConstraint("T1", {"tensor(int4)", "tensor(uint4)", "tensor(uint8)"}, "Constrain quantized types.") .TypeConstraint("T2", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain dequantized types.") .TypeConstraint("Tind", {"tensor(int32)", "tensor(int64)"}, "Constrain indices to integer types.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { @@ -3637,14 +3638,19 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h gather_axis = (gather_axis + r) % r; quantize_axis = (quantize_axis + r) % r; + if ((ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8) && gather_axis != 0) { + fail_shape_inference("gather_axis must be 0, for uint8 data"); + } + if (scales_shape.dim_size() != r) { fail_shape_inference("scales must have the same rank as data"); } + uint32_t components = ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8 ? 2 : 1; for (int i = 0; i < r; ++i) { if (!data_shape.dim(i).has_dim_value() || !scales_shape.dim(i).has_dim_value() || - (i == quantize_axis && (data_shape.dim(i).dim_value() + block_size - 1) / block_size != scales_shape.dim(i).dim_value()) || + (i == quantize_axis && (data_shape.dim(i).dim_value() * components + block_size - 1) / block_size != scales_shape.dim(i).dim_value()) || (i != quantize_axis && data_shape.dim(i).dim_value() != scales_shape.dim(i).dim_value())) { fail_shape_inference("data shape and scales shape do not match"); } @@ -3652,6 +3658,10 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h // validate zero point shape if (ctx.hasInput(3)) { + if (ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8) { + fail_type_inference("zero_points are not supported for uint8_t data type"); + } + if (!hasInputShape(ctx, 3)) { fail_shape_inference("zero_points shape must be known"); } @@ -3675,12 +3685,15 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape(); } for (int i = 0; i < out_rank; ++i) { + // For uint8_t data type the last dimension needs to be expanded back to actual dimension, + // because the data 2 int4s are stored packed in a single uint8_t. + auto last_dimension_components = (i == out_rank - 1) ? components : 1; *ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape()->add_dim() = (i < gather_axis) ? data_shape.dim(i) : (i >= gather_axis && i < gather_axis + q) ? indices_shape.dim(i - gather_axis) - : data_shape.dim(i - q + 1); + : data_shape.dim(i - q + 1) * last_dimension_components; } }); diff --git a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc index c4536fc56a22f..0dfe194e893e2 100644 --- a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc +++ b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc @@ -15,6 +15,27 @@ namespace onnxruntime { namespace test { +// When uint8_t data type is used GatherBlockQuantize applies MatMulNBit's conventions for storing the data. +// That is when no zero points are specified a default zero point of 8 is used. This convertor hence +// compensates for that by adding 8 to the data values, so that the outputs match the results that +// we be seen with non uint8_t data types. +template +void PackDataForUint8TypeIfNecessary(std::vector& data, std::vector& data_shape) { + if (!std::is_same_v) { + return; + } + // For uint8_t, we need to pack each pair of values (after adding 8) into a single uint8_t + std::vector packed_data; + for (size_t i = 0; i < data.size(); i += 2) { + int low_nibble = (data[i] + 8) & 0xF; + int high_nibble = ((i + 1) < data.size()) ? ((data[i + 1] + 8) & 0xF) : 0; + int packed = (high_nibble << 4) | low_nibble; + packed_data.push_back(packed); + } + data = packed_data; + data_shape[data_shape.size() - 1] = (data_shape[data_shape.size() - 1] + 1) / 2; +} + // Combinations: types, gather_axis, quantize_axis, block_size, indices, scale shape vs data shape template void RunGatherBlockQuantized(const std::vector& data, @@ -96,6 +117,7 @@ void Test_Fail_WithZeroPoints(int64_t gather_axis, 4, 5, 6, 7, -4, -3, -2, -1}; std::vector data_shape = {2, 3, 4}; + PackDataForUint8TypeIfNecessary(data, data_shape); std::vector indices = {1}; std::vector indices_shape = {1}; std::vector scales = {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}; @@ -123,7 +145,6 @@ void Test_Fail_WithZeroPoints(int64_t gather_axis, TEST(GatherBlockQuantizedOpTest, UnsupportedTypes) { Test_Fail_WithZeroPoints(0, 2, 16); - Test_Fail_WithZeroPoints(0, 2, 16); Test_Fail_WithZeroPoints(0, 2, 16); Test_Fail_WithZeroPoints(0, 2, 16); Test_Fail_WithZeroPoints(0, 2, 16); @@ -134,21 +155,70 @@ TEST(GatherBlockQuantizedOpTest, UnsupportedTypes) { Test_Fail_WithZeroPoints(0, 2, 16); Test_Fail_WithZeroPoints(0, 2, 16); Test_Fail_WithZeroPoints(0, 2, 16); + Test_Fail_WithZeroPoints(0, 2, 16); +} + +template +void Test_Fail_WithoutZeroPoints(int64_t gather_axis, + int64_t quantize_axis, + int64_t block_size) { + std::vector data = {-8, -7, -6, -5, + -4, -3, -2, -1, + 0, 1, 2, 3, + 4, 5, 6, 7, + 4, 5, 6, 7, + -4, -3, -2, -1}; + std::vector data_shape = {2, 3, 4}; + PackDataForUint8TypeIfNecessary(data, data_shape); + std::vector indices = {1}; + std::vector indices_shape = {1}; + std::vector scales = {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}; + std::vector scales_shape = {2, 3, 1}; + std::vector output = {8.f, 10.f, 12.f, 14.f, + 3.f, 4.f, 5.f, 6.f, + -6.f, -4.f, -2.f, 0.f}; + std::vector output_shape = {1, 3, 4}; + + RunGatherBlockQuantized(ToType(data), + data_shape, + ToType(indices), + indices_shape, + ToType(scales), + scales_shape, + {}, + gather_axis, + quantize_axis, + block_size, + ToType(output), + output_shape, + OpTester::ExpectResult::kExpectFailure); +} + +TEST(GatherBlockQuantizedOpTest, UnsupportedUInt8DataType) { + // T1 uint8_t with zero points is not yet supported. + Test_Fail_WithZeroPoints(0, 2, 16); + Test_Fail_WithZeroPoints(0, 2, 16); + // Gather on axis other than 0 is not supported with uint8_t + Test_Fail_WithoutZeroPoints(1, 2, 16); + Test_Fail_WithoutZeroPoints(1, 2, 16); } TEST(GatherBlockQuantizedOpTest, InvalidBlockSize) { Test_Fail_WithZeroPoints(0, 2, 8); Test_Fail_WithZeroPoints(0, 2, 17); + Test_Fail_WithZeroPoints(0, 2, 17); } TEST(GatherBlockQuantizedOpTest, InvalidGatherAxis) { Test_Fail_WithZeroPoints(3, 2, 16); Test_Fail_WithZeroPoints(-4, 2, 16); + Test_Fail_WithZeroPoints(-4, 2, 16); } TEST(GatherBlockQuantizedOpTest, InvalidQuantizeAxis) { Test_Fail_WithZeroPoints(0, 3, 16); Test_Fail_WithZeroPoints(0, -4, 16); + Test_Fail_WithZeroPoints(0, -4, 16); } template @@ -160,6 +230,7 @@ void Test_ShapeMismatch_WithZeroPoints() { 4, 5, 6, 7, -4, -3, -2, -1}; std::vector data_shape = {2, 3, 4}; + PackDataForUint8TypeIfNecessary(data, data_shape); std::vector indices = {1}; std::vector indices_shape = {1}; std::vector scales = {1.0f, 2.0f, 1.0f, 2.0f}; @@ -188,6 +259,7 @@ void Test_ShapeMismatch_WithZeroPoints() { TEST(GatherBlockQuantizedOpTest, ShapeMismatch) { Test_ShapeMismatch_WithZeroPoints(); Test_ShapeMismatch_WithZeroPoints(); + Test_ShapeMismatch_WithZeroPoints(); } template @@ -199,6 +271,7 @@ void Test_InvalidIndices_WithZeroPoints() { 4, 5, 6, 7, -4, -3, -2, -1}; std::vector data_shape = {2, 3, 4}; + PackDataForUint8TypeIfNecessary(data, data_shape); std::vector indices = {2}; std::vector indices_shape = {1}; std::vector scales = {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}; @@ -227,6 +300,7 @@ void Test_InvalidIndices_WithZeroPoints() { TEST(GatherBlockQuantizedOpTest, InvalidIndices) { Test_InvalidIndices_WithZeroPoints(); Test_InvalidIndices_WithZeroPoints(); + Test_InvalidIndices_WithZeroPoints(); } template @@ -298,6 +372,7 @@ void Test_GatherAxis0_NoZeroPoints() { 4, 5, 6, 7, -4, -3, -2, -1}; std::vector data_shape = {2, 3, 4}; + PackDataForUint8TypeIfNecessary(data, data_shape); std::vector indices = {1}; std::vector indices_shape = {1}; std::vector scales = {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}; @@ -340,6 +415,10 @@ TEST(GatherBlockQuantizedOpTest, GatherAxis0NoZeroPoints) { Test_GatherAxis0_NoZeroPoints(); Test_GatherAxis0_NoZeroPoints(); Test_GatherAxis0_NoZeroPoints(); + Test_GatherAxis0_NoZeroPoints(); + Test_GatherAxis0_NoZeroPoints(); + Test_GatherAxis0_NoZeroPoints(); + Test_GatherAxis0_NoZeroPoints(); } template From 9102aaee3db864a4b38a4f7ddca3437f1aef0efd Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Fri, 4 Apr 2025 16:46:38 -0700 Subject: [PATCH 09/18] [Native WebGPU] Add Conv, ConTranspose and FusedConv (#24186) ### Description Add Conv, ConvTranspose, and FusedConv to the WebGPU execution provider. ### Motivation and Context Required for operator coverage. --- onnxruntime/contrib_ops/webgpu/fused_conv.cc | 33 ++ .../webgpu/webgpu_contrib_kernels.cc | 2 +- .../core/optimizer/conv_activation_fusion.cc | 2 +- .../core/optimizer/graph_transformer_utils.cc | 26 +- .../core/providers/webgpu/math/matmul.cc | 72 +++-- .../core/providers/webgpu/math/matmul.h | 12 +- .../providers/webgpu/math/matmul_packed.cc | 291 +++++++++++------- .../providers/webgpu/math/matmul_packed.h | 41 ++- .../providers/webgpu/nn/activation_util.cc | 25 ++ .../providers/webgpu/nn/activation_util.h | 15 + onnxruntime/core/providers/webgpu/nn/conv.cc | 273 ++++++++++++++++ onnxruntime/core/providers/webgpu/nn/conv.h | 34 ++ .../providers/webgpu/nn/conv2d_mm_webgpu.cc | 232 ++++++++++++++ .../providers/webgpu/nn/conv2d_mm_webgpu.h | 61 ++++ .../webgpu/nn/conv_backprop_webgpu.cc | 191 ++++++++++++ .../webgpu/nn/conv_backprop_webgpu.h | 49 +++ .../providers/webgpu/nn/conv_transpose.cc | 132 ++++++++ .../core/providers/webgpu/nn/conv_transpose.h | 27 ++ .../core/providers/webgpu/nn/conv_utils.cc | 22 ++ .../core/providers/webgpu/nn/conv_utils.h | 15 + .../core/providers/webgpu/nn/fuse_utils.cc | 79 +++++ .../core/providers/webgpu/nn/fuse_utils.h | 51 +++ .../core/providers/webgpu/nn/grouped_conv.cc | 93 ++++++ .../core/providers/webgpu/nn/grouped_conv.h | 36 +++ .../webgpu/webgpu_execution_provider.cc | 30 +- .../core/providers/webgpu/webgpu_utils.cc | 24 ++ .../core/providers/webgpu/webgpu_utils.h | 5 + onnxruntime/core/session/inference_session.cc | 7 +- .../test/contrib_ops/fused_conv_test.cc | 19 +- .../test/optimizer/graph_transform_test.cc | 19 ++ .../test/providers/cpu/nn/conv_op_test.cc | 8 +- .../cpu/nn/conv_transpose_op_test.cc | 4 +- 32 files changed, 1747 insertions(+), 183 deletions(-) create mode 100644 onnxruntime/contrib_ops/webgpu/fused_conv.cc create mode 100644 onnxruntime/core/providers/webgpu/nn/activation_util.cc create mode 100644 onnxruntime/core/providers/webgpu/nn/activation_util.h create mode 100644 onnxruntime/core/providers/webgpu/nn/conv.cc create mode 100644 onnxruntime/core/providers/webgpu/nn/conv.h create mode 100644 onnxruntime/core/providers/webgpu/nn/conv2d_mm_webgpu.cc create mode 100644 onnxruntime/core/providers/webgpu/nn/conv2d_mm_webgpu.h create mode 100644 onnxruntime/core/providers/webgpu/nn/conv_backprop_webgpu.cc create mode 100644 onnxruntime/core/providers/webgpu/nn/conv_backprop_webgpu.h create mode 100644 onnxruntime/core/providers/webgpu/nn/conv_transpose.cc create mode 100644 onnxruntime/core/providers/webgpu/nn/conv_transpose.h create mode 100644 onnxruntime/core/providers/webgpu/nn/conv_utils.cc create mode 100644 onnxruntime/core/providers/webgpu/nn/conv_utils.h create mode 100644 onnxruntime/core/providers/webgpu/nn/fuse_utils.cc create mode 100644 onnxruntime/core/providers/webgpu/nn/fuse_utils.h create mode 100644 onnxruntime/core/providers/webgpu/nn/grouped_conv.cc create mode 100644 onnxruntime/core/providers/webgpu/nn/grouped_conv.h create mode 100644 onnxruntime/core/providers/webgpu/webgpu_utils.cc diff --git a/onnxruntime/contrib_ops/webgpu/fused_conv.cc b/onnxruntime/contrib_ops/webgpu/fused_conv.cc new file mode 100644 index 0000000000000..e6b7ac3ec24d4 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/fused_conv.cc @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/nn/conv.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "core/providers/webgpu/nn/fuse_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { +using onnxruntime::webgpu::Conv; +template +class FusedConv final : public Conv { + public: + FusedConv(const OpKernelInfo& info) : Conv(info) { + ORT_ENFORCE(GetFusedActivationAttr(info, Conv::activation_).IsOK()); + } +}; + +ONNX_OPERATOR_KERNEL_EX( + FusedConv, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", onnxruntime::webgpu::WebGpuSupportedFloatTypes()), + FusedConv); + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 6e63ba3a0caa4..4136477a1d88c 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -40,7 +40,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index ea9d8605e2417..71c8667a89b1d 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -121,7 +121,7 @@ class ConvActivationSelector : public NodeSelector { if (!graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Relu", {6, 13, 14})) { return std::nullopt; } - } else if (node_ep.empty() || node_ep == kCpuExecutionProvider || node_ep == kJsExecutionProvider) { + } else if (node_ep.empty() || node_ep == kCpuExecutionProvider || node_ep == kJsExecutionProvider || node_ep == kWebGpuExecutionProvider) { if (!is_supported_non_cuda_rocm_ep_activation(*next_node) && !graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "HardSigmoid", {6})) { return std::nullopt; diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 9684394da0520..eae2a464cef7e 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -296,17 +296,19 @@ InlinedVector> GenerateTransformers( onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider, onnxruntime::kDmlExecutionProvider}; - const InlinedHashSet cpu_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider, - onnxruntime::kRocmExecutionProvider, - onnxruntime::kAclExecutionProvider, - onnxruntime::kArmNNExecutionProvider, - onnxruntime::kJsExecutionProvider}; - const InlinedHashSet cpu_cuda_rocm_acl_armnn_js_eps = {onnxruntime::kCpuExecutionProvider, - onnxruntime::kCudaExecutionProvider, - onnxruntime::kRocmExecutionProvider, - onnxruntime::kAclExecutionProvider, - onnxruntime::kArmNNExecutionProvider, - onnxruntime::kJsExecutionProvider}; + const InlinedHashSet cpu_rocm_acl_armnn_js_webgpu_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kRocmExecutionProvider, + onnxruntime::kAclExecutionProvider, + onnxruntime::kArmNNExecutionProvider, + onnxruntime::kJsExecutionProvider, + onnxruntime::kWebGpuExecutionProvider}; + const InlinedHashSet cpu_cuda_rocm_acl_armnn_js_webgpu_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kCudaExecutionProvider, + onnxruntime::kRocmExecutionProvider, + onnxruntime::kAclExecutionProvider, + onnxruntime::kArmNNExecutionProvider, + onnxruntime::kJsExecutionProvider, + onnxruntime::kWebGpuExecutionProvider}; const InlinedHashSet cpu_dml_acl_eps = {onnxruntime::kCpuExecutionProvider, onnxruntime::kDmlExecutionProvider, onnxruntime::kAclExecutionProvider}; @@ -338,7 +340,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_dml_acl_eps)); transformers.emplace_back(std::make_unique(cpu_acl_eps)); - transformers.emplace_back(std::make_unique(cpu_rocm_acl_armnn_js_eps)); + transformers.emplace_back(std::make_unique(cpu_rocm_acl_armnn_js_webgpu_eps)); transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps, level)); transformers.emplace_back(std::make_unique(cpu_acl_cuda_dml_rocm_eps, level)); diff --git a/onnxruntime/core/providers/webgpu/math/matmul.cc b/onnxruntime/core/providers/webgpu/math/matmul.cc index 9b447d5fdb59a..cdd3909874e7f 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul.cc @@ -6,8 +6,9 @@ #include "core/providers/cpu/tensor/utils.h" #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" - +#include "core/providers/webgpu/nn/fuse_utils.h" #include "core/providers/webgpu/data_transfer.h" + namespace onnxruntime { namespace webgpu { @@ -54,11 +55,12 @@ Status MatMulNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const { std::string process_bias; if (has_bias_) { shader.AddInput("bias", ShaderUsage::UseUniform); - process_bias = "value += output_value_t(bias[row + i]);"; + process_bias = is_channels_last_ ? "value += output_value_t(bias[col])" : "value += output_value_t(bias[row + i]);"; } + std::string apply_activation = GetActivationSnippet(activation_, "output_value_t", "output_element_t"); const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | - ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); const auto& batch_dims = shader.AddIndices("batch_dims"); int a_components = a.NumComponents(); @@ -90,6 +92,7 @@ Status MatMulNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const { << "for (var i = 0u; i < " << output_number_ << "u; i++) {\n" << " var value = values[i];\n" << process_bias << "\n" + << apply_activation << "\n" << " let cur_indices = output_indices_t(batch, row + i, col/ " << components << ");\n" << " let offset = " << output.IndicesToOffset("cur_indices") << ";\n" << output.SetByOffset("offset", "value") @@ -127,7 +130,7 @@ Status MatMul::ComputeInternal(ComputeContext& context) const { const int64_t a_rows = a->Shape().NumDimensions() > 1 ? a->Shape()[a->Shape().NumDimensions() - 2] : 1; TensorShape output_shape_shader({batch_size, a_rows, helper.N() / components}); - MatMulNaiveProgram program{output_rank, output_number, has_bias}; + MatMulNaiveProgram program{Activation(), output_rank, output_number, has_bias}; program .CacheHint(std::to_string(components), std::to_string(a_components), std::to_string(output_number)) @@ -147,11 +150,32 @@ Status MatMul::ComputeInternal(ComputeContext& context) const { return context.RunProgram(program); } - int64_t batchA = a->Shape().SizeToDimension(a->Shape().NumDimensions() - 2); - int64_t batchB = b->Shape().SizeToDimension(b->Shape().NumDimensions() - 2); + std::vector inputs(has_bias ? 3 : 2); + inputs[0] = a; + inputs[1] = b; + if (has_bias) { + const auto* bias = context.Input(2); + inputs.push_back(bias); + } + auto program = CreateMatMulProgram(Activation(), inputs, output_tensor, false); + + return context.RunProgram(program); +} + +MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector& inputs, Tensor* output_tensor, bool is_channels_last, + const TensorShape& input_a_reshape, + const TensorShape& input_b_reshape) { + const auto* a = inputs[0]; + const auto* b = inputs[1]; + bool has_bias = inputs.size() > 2; + TensorShape a_shape = input_a_reshape.NumDimensions() > 0 ? input_a_reshape : a->Shape(); + TensorShape b_shape = input_b_reshape.NumDimensions() > 0 ? input_b_reshape : b->Shape(); + + MatMulComputeHelper helper; + ORT_THROW_IF_ERROR(helper.Compute(a_shape, b_shape)); + int64_t batchA = a_shape.SizeToDimension(a_shape.NumDimensions() - 2); + int64_t batchB = b_shape.SizeToDimension(b_shape.NumDimensions() - 2); - TensorShape a_shape = a->Shape(); - TensorShape b_shape = b->Shape(); TensorShape output_shape = helper.OutputShape(); const int64_t dim_output_outer = output_shape[output_shape.NumDimensions() - 2]; @@ -184,9 +208,9 @@ Status MatMul::ComputeInternal(ComputeContext& context) const { const int64_t batch_size = outer_dims.Size(); // Get dimensions for matrix multiplication from TensorShape - const int32_t dim_a_outer = narrow(a_shape[a_shape.NumDimensions() - 2]); // left matrix second dimension - const int32_t dim_inner = narrow(a_shape[a_shape.NumDimensions() - 1]); // left matrix first dimension - const int32_t dim_b_outer = narrow(b_shape[b_shape.NumDimensions() - 1]); // right matrix first dimension + const uint32_t dim_a_outer = narrow(a_shape[a_shape.NumDimensions() - 2]); // left matrix second dimension + const uint32_t dim_inner = narrow(a_shape[a_shape.NumDimensions() - 1]); // left matrix first dimension + const uint32_t dim_b_outer = narrow(b_shape[b_shape.NumDimensions() - 1]); // right matrix first dimension const bool is_vec4 = dim_inner % 4 == 0 && dim_b_outer % 4 == 0; @@ -194,34 +218,36 @@ Status MatMul::ComputeInternal(ComputeContext& context) const { ? InlinedVector({4, 1, 1}) : InlinedVector({4, 4, 1}); - const uint32_t dispatch_x = narrow((dim_b_outer + MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0] - 1) / - (MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0])); - const uint32_t dispatch_y = narrow((dim_a_outer + MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1] - 1) / - (MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1])); - const uint32_t dispatch_z = narrow((static_cast(batch_size) + MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) / - (MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2])); + const uint32_t dispatch_x = narrow((dim_b_outer + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0] - 1) / + (MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X * elements_per_thread[0])); + const uint32_t dispatch_y = narrow((dim_a_outer + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1] - 1) / + (MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y * elements_per_thread[1])); + const uint32_t dispatch_z = narrow((static_cast(batch_size) + MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2] - 1) / + (MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z * elements_per_thread[2])); const int components = is_vec4 ? 4 : 1; const TensorShape a_shape_temp = CreateMatMulIntermediateShape(outer_dims_a, dim_a_outer, dim_inner, components); const TensorShape b_shape_temp = CreateMatMulIntermediateShape(outer_dims_b, dim_inner, dim_b_outer, components); const TensorShape output_shape_temp = TensorShape({batch_size, dim_a_outer, dim_b_outer / components}); - MatMulProgram program{has_bias, is_vec4, elements_per_thread}; + MatMulProgram program{activation, has_bias, is_vec4, elements_per_thread, is_channels_last}; program - .CacheHint(absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4)) + .CacheHint(activation.ToString(), absl::StrJoin(elements_per_thread, "-"), std::to_string(is_vec4)) .AddInputs({{a, ProgramTensorMetadataDependency::TypeAndRank, a_shape_temp, components}, {b, ProgramTensorMetadataDependency::TypeAndRank, b_shape_temp, components}}) .AddOutputs({{output_tensor, ProgramTensorMetadataDependency::Rank, output_shape_temp, components}}) .AddUniformVariables({{dim_a_outer}, {dim_b_outer}, {dim_inner}}) .AddIndices(outer_dims) .SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z) - .SetWorkgroupSize(MATMUL_PACKED_WORKGROUP_SIZE_X, MATMUL_PACKED_WORKGROUP_SIZE_Y, MATMUL_PACKED_WORKGROUP_SIZE_Z); + .SetWorkgroupSize(MatMul::MATMUL_PACKED_WORKGROUP_SIZE_X, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Y, MatMul::MATMUL_PACKED_WORKGROUP_SIZE_Z); if (has_bias) { - const auto* bias = context.Input(2); - program.AddInput({bias, ProgramTensorMetadataDependency::Rank, 1}); + auto bias_components = is_channels_last ? components : 1; + const auto* bias = inputs[2]; + TensorShape reduced_bias_shape = ReduceShapeByComponents(bias->Shape(), bias_components); + program.AddInput({bias, ProgramTensorMetadataDependency::Rank, reduced_bias_shape, bias_components}); } - return context.RunProgram(program); + return program; } } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/math/matmul.h b/onnxruntime/core/providers/webgpu/math/matmul.h index 789e824383189..91216d8e25eec 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul.h +++ b/onnxruntime/core/providers/webgpu/math/matmul.h @@ -9,16 +9,20 @@ #include "core/providers/webgpu/math/matmul_utils.h" #include "core/providers/webgpu/math/matmul_packed.h" #include "core/providers/webgpu/webgpu_utils.h" +#include "core/providers/webgpu/nn/fuse_utils.h" namespace onnxruntime { namespace webgpu { +MatMulProgram CreateMatMulProgram(const Activation& activation, std::vector& inputs, Tensor* output, bool is_channels_last, + const TensorShape& input_a_reshape = TensorShape(), + const TensorShape& input_b_reshape = TensorShape()); + class MatMul final : public WebGpuKernel { public: MatMul(const OpKernelInfo& info) : WebGpuKernel{info} {} Status ComputeInternal(ComputeContext& context) const override; - constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_X = 8; constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_Y = 8; constexpr static uint32_t MATMUL_PACKED_WORKGROUP_SIZE_Z = 1; @@ -26,8 +30,8 @@ class MatMul final : public WebGpuKernel { class MatMulNaiveProgram final : public Program { public: - MatMulNaiveProgram(const size_t output_rank, int64_t output_number, bool has_bias) - : Program{"MatMulNaive"}, output_rank_(output_rank), output_number_(output_number), has_bias_{has_bias} { + MatMulNaiveProgram(const Activation& activation, const size_t output_rank, int64_t output_number, bool has_bias, bool is_channels_last = false) + : Program{"MatMulNaive"}, activation_(activation), output_rank_(output_rank), output_number_(output_number), has_bias_{has_bias}, is_channels_last_(is_channels_last) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -38,9 +42,11 @@ class MatMulNaiveProgram final : public Program { {"K", ProgramUniformVariableDataType::Uint32}); private: + const Activation& activation_; const size_t output_rank_; const int64_t output_number_; const bool has_bias_; + const bool is_channels_last_; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc index 2e5cff923f442..36510eec0cd3b 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.cc +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.cc @@ -5,7 +5,7 @@ #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" #include "core/providers/webgpu/webgpu_utils.h" - +#include namespace onnxruntime { namespace webgpu { @@ -13,7 +13,8 @@ void MatMulProgram::MatMulReadWriteFnSource(ShaderHelper& shader, const ShaderVariableHelper& a, const ShaderVariableHelper& b, const ShaderVariableHelper& output, - const ShaderIndicesHelper& batch_dims) const { + const ShaderIndicesHelper& batch_dims, + std::string activation_snippet) const { int components = is_vec4_ ? 4 : 1; const std::string data_type = "a_element_t"; const std::string type_string = MakeScalarOrVectorType(components, data_type); @@ -23,7 +24,7 @@ void MatMulProgram::MatMulReadWriteFnSource(ShaderHelper& shader, << "fn mm_readA(batch: i32, row: i32, colIn: i32, batch_indices: batch_dims_indices_t) -> " << type_string << " {\n" << " var value = " << type_string << "(0.0);\n" << " let col = colIn * " << components << ";\n" - << " if(row < uniforms.dim_a_outer && col < uniforms.dim_inner) {\n" + << " if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_inner)) {\n" << " var a_indices: a_indices_t;\n" << ConvertOutputBatchIndicesToInputBatchIndices("a", a, a.Rank() - 2, batch_dims.Rank(), "batch_indices") << a.IndicesSet("a_indices", a.Rank() - 2, "u32(row)") << "\n" @@ -38,7 +39,7 @@ void MatMulProgram::MatMulReadWriteFnSource(ShaderHelper& shader, << "fn mm_readB(batch: i32, row: i32, colIn: i32, batch_indices: batch_dims_indices_t) -> " << type_string << " {\n" << " var value = " << type_string << "(0.0);\n" << " let col = colIn * " << components << ";\n" - << " if(row < uniforms.dim_inner && col < uniforms.dim_b_outer) {\n" + << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_b_outer)) {\n" << " var b_indices: b_indices_t;\n" << ConvertOutputBatchIndicesToInputBatchIndices("b", b, b.Rank() - 2, batch_dims.Rank(), "batch_indices") << b.IndicesSet("b_indices", b.Rank() - 2, "u32(row)") << "\n" @@ -52,14 +53,16 @@ void MatMulProgram::MatMulReadWriteFnSource(ShaderHelper& shader, shader.AdditionalImplementation() << "fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: " << type_string << ") {\n" << " let col = colIn * " << components << ";\n" - << " if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) {\n" + << " if (row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_b_outer)) {\n" << " var value = valueIn;\n" << " let coords = vec3(batch, row, colIn);\n"; if (has_bias_) { - shader.AdditionalImplementation() << " value = value + " << type_string << "(bias[row]);\n"; + shader.AdditionalImplementation() << " value = value + " << (is_channels_last_ ? "bias[colIn]" : type_string + "(bias[row])") << ";\n"; } + shader.AdditionalImplementation() << " " << activation_snippet << "\n"; + shader.AdditionalImplementation() << output.SetByIndices("vec3(coords)", "value") << "\n" << " }\n" @@ -67,29 +70,36 @@ void MatMulProgram::MatMulReadWriteFnSource(ShaderHelper& shader, } Status MatMulProgram::MakeMatMulPackedVec4Source(ShaderHelper& shader, - const ShaderIndicesHelper& batch_dims, const InlinedVector& elements_per_thread, uint32_t workgroup_size_x, - uint32_t workgroup_size_y) { + uint32_t workgroup_size_y, + const std::string& data_type, + const ShaderIndicesHelper* batch_dims, + bool transpose_a, + uint32_t tile_inner, + bool split_k, + uint32_t splitted_dim_inner) { + ORT_UNUSED_PARAMETER(split_k); + ORT_UNUSED_PARAMETER(splitted_dim_inner); + std::string write_data_to_sub_a_vec4_snippet = + transpose_a ? std::string("mm_Asub[inputRow][inputCol] = mm_readA(batch, kStart + inputRow, globalRowStart / innerElementSize + inputCol") + (batch_dims ? ", batchIndices" : "") + ");\n" + : std::string("mm_Asub[inputRow][inputCol] = mm_readA(batch, globalRow + innerRow, kStart / innerElementSize + inputCol") + (batch_dims ? ", batchIndices" : "") + ");\n"; // elements per thread const auto elements_per_thread_x = elements_per_thread[0]; const auto elements_per_thread_y = elements_per_thread[1]; - const decltype(elements_per_thread_x) tile_inner = 32; const auto tile_a_outer = workgroup_size_y * elements_per_thread_y; const auto tile_b_outer = workgroup_size_x * elements_per_thread_x; - const auto tile_a_width = tile_inner; - - const auto tile_a_height = tile_a_outer; + const auto tile_a_width = transpose_a ? tile_a_outer : tile_inner; + const auto tile_a_height = transpose_a ? tile_inner : tile_a_outer; const auto inner_elements_size = tile_a_width / workgroup_size_x; const auto row_per_thread_b = tile_inner / workgroup_size_y; - const std::string data_type = "a_element_t"; - - if (!((inner_elements_size == 3 || inner_elements_size == 4) && - tile_a_width % workgroup_size_x == 0 && - tile_inner % workgroup_size_y == 0 && - elements_per_thread_x == 4)) { + if (!((transpose_a && inner_elements_size == 4 && elements_per_thread[1] == 4) || + (!transpose_a && (inner_elements_size == 3 || inner_elements_size == 4))) && + tile_a_width % workgroup_size_x == 0 && + tile_inner % workgroup_size_y == 0 && + elements_per_thread_x == 4) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid matrix multiplication configuration inner_elements_size: ", inner_elements_size, " must be 3 or 4. tile_a_width: ", tile_a_width, " must be divisible by WorkgroupSizeX: ", @@ -112,7 +122,7 @@ Status MatMulProgram::MakeMatMulPackedVec4Source(ShaderHelper& shader, << " let globalRow = i32(global_id.y) * rowPerThread;\n" << " let globalCol = i32(global_id.x);\n" << " let batch = i32(global_id.z);\n" - << " let batchIndices = " << batch_dims.OffsetToIndices("u32(batch)") << ";\n" + << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "") << " let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n" << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n" << " var kStart = 0;\n" @@ -121,14 +131,14 @@ Status MatMulProgram::MakeMatMulPackedVec4Source(ShaderHelper& shader, // Loop over shared dimension. shader.MainFunctionBody() << " let tileRowB = localRow * " << row_per_thread_b << ";\n" - << " for (var t = 0; t < num_tiles; t = t + 1) {\n"; + << " for (var t = 0; t < i32(num_tiles); t = t + 1) {\n"; // Load one tile of A into local memory. shader.MainFunctionBody() << " for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {\n" << " let inputRow = tileRow + innerRow;\n" << " let inputCol = tileCol;\n" - << " mm_Asub[inputRow][inputCol] = mm_readA(batch, globalRow + innerRow, kStart / innerElementSize + inputCol, batchIndices);\n" + << " " << write_data_to_sub_a_vec4_snippet << " }\n"; // Load one tile of B into local memory. @@ -136,7 +146,7 @@ Status MatMulProgram::MakeMatMulPackedVec4Source(ShaderHelper& shader, << " for (var innerRow = 0; innerRow < " << row_per_thread_b << "; innerRow = innerRow + 1) {\n" << " let inputRow = tileRowB + innerRow;\n" << " let inputCol = tileCol;\n" - << " mm_Bsub[inputRow][inputCol] = mm_readB(batch, kStart + inputRow, globalCol, batchIndices);\n" + << " mm_Bsub[inputRow][inputCol] = mm_readB(batch, kStart + inputRow, globalCol" << (nullptr != batch_dims ? ", batchIndices" : "") << ");\n" << " }\n" << " kStart = kStart + tileInner;\n" << " workgroupBarrier();\n"; @@ -152,15 +162,29 @@ Status MatMulProgram::MakeMatMulPackedVec4Source(ShaderHelper& shader, shader.MainFunctionBody() << " let BCached3 = mm_Bsub[k * innerElementSize + 3][tileCol];\n"; } - shader.MainFunctionBody() - << " for (var i = 0; i < rowPerThread; i = i + 1) {\n" - << " let ACached = mm_Asub[tileRow + i][k];\n" - << " acc[i] = BCached0 * ACached.x + acc[i];\n" - << " acc[i] = BCached1 * ACached.y + acc[i];\n" - << " acc[i] = BCached2 * ACached.z + acc[i];\n" - << " " << (inner_elements_size == 3 ? "" : "acc[i] = BCached3 * ACached.w + acc[i];") << "\n" - << " }\n"; - + if (transpose_a) { + shader.MainFunctionBody() + << " let Acached0 = mm_Asub[k * innerElementSize][localRow];\n" + << " let Acached1 = mm_Asub[k * innerElementSize + 1][localRow];\n" + << " let Acached2 = mm_Asub[k * innerElementSize + 2][localRow];\n" + << (inner_elements_size == 3 ? "" : " let Acached3 = mm_Asub[k * innerElementSize + 3][localRow];\n") + << " for (var i = 0; i < rowPerThread; i = i + 1) {\n" + << " let ACached = mm_Asub[tileCol][i];\n" + << " acc[i] = BCached0 * ACached0[i] + acc[i];\n" + << " acc[i] = BCached1 * ACached1[i] + acc[i];\n" + << " acc[i] = BCached2 * ACached2[i] + acc[i];\n" + << " " << (inner_elements_size == 3 ? "" : "acc[i] = BCached3 * ACached3[i] + acc[i];") << "\n" + << " }\n"; + } else { + shader.MainFunctionBody() + << " for (var i = 0; i < rowPerThread; i = i + 1) {\n" + << " let ACached = mm_Asub[tileRow + i][k];\n" + << " acc[i] = BCached0 * ACached.x + acc[i];\n" + << " acc[i] = BCached1 * ACached.y + acc[i];\n" + << " acc[i] = BCached2 * ACached.z + acc[i];\n" + << " " << (inner_elements_size == 3 ? "" : "acc[i] = BCached3 * ACached.w + acc[i];") << "\n" + << " }\n"; + } shader.MainFunctionBody() << " workgroupBarrier();\n" << " }\n"; // main for loop @@ -174,13 +198,22 @@ Status MatMulProgram::MakeMatMulPackedVec4Source(ShaderHelper& shader, return Status::OK(); } -Status MatMulProgram::MakeMatMulPackedSource(ShaderHelper& shader, const ShaderIndicesHelper& batch_dims, +Status MatMulProgram::MakeMatMulPackedSource(ShaderHelper& shader, const InlinedVector& elements_per_thread, uint32_t workgroup_size_x, - uint32_t workgroup_size_y) { + uint32_t workgroup_size_y, + const std::string& data_type, + const ShaderIndicesHelper* batch_dims, + bool transpose_a, + uint32_t tile_inner, + bool split_k, + uint32_t splitted_dim_inner, + bool sequentially_access_by_threads) { + ORT_UNUSED_PARAMETER(split_k); + ORT_UNUSED_PARAMETER(splitted_dim_inner); + const auto elements_per_thread_x = elements_per_thread[0]; const auto elements_per_thread_y = elements_per_thread[1]; - const decltype(elements_per_thread_x) tile_inner = 32; const auto tile_a_outer = workgroup_size_y * elements_per_thread_y; const auto tile_b_outer = workgroup_size_x * elements_per_thread_x; @@ -194,12 +227,11 @@ Status MatMulProgram::MakeMatMulPackedSource(ShaderHelper& shader, const ShaderI ", tile_inner: ", tile_inner, " must be divisible by WorkgroupSizeY: ", workgroup_size_y); } - const std::string data_type = "a_element_t"; - const auto row_per_thread_a = tile_a_height / workgroup_size_y; const auto col_per_thread_a = tile_a_width / workgroup_size_x; const auto row_per_thread_b = tile_inner / workgroup_size_y; - + std::string write_data_to_sub_a_snippet = transpose_a ? std::string("mm_Asub[inputRow][inputCol] = mm_readA(batch, kStart + inputRow, globalRowStart + inputCol") + (batch_dims ? ", batchIndices" : "") + ");\n" + : std::string("mm_Asub[inputRow][inputCol] = mm_readA(batch, globalRowStart + inputRow, kStart + inputCol") + (batch_dims ? ", batchIndices" : "") + ");\n"; shader.AdditionalImplementation() << "var mm_Asub: array, " << tile_a_height << ">;\n" << "var mm_Bsub: array, " << tile_inner << ">;\n" @@ -208,93 +240,142 @@ Status MatMulProgram::MakeMatMulPackedSource(ShaderHelper& shader, const ShaderI << "const tileInner = " << tile_inner << ";\n"; shader.MainFunctionBody() << " let batch = i32(global_id.z);\n" - << " let batchIndices = " << batch_dims.OffsetToIndices("u32(batch)") << ";\n" + << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices("u32(batch)") + ";\n" : "") << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n" << " var kStart = 0;\n" << " var acc: array, rowPerThread>;\n"; - shader.MainFunctionBody() - << "let tileRow = i32(local_id.y) * rowPerThread;\n" - << "let tileCol = i32(local_id.x) * colPerThread;\n" - << "let globalRow = i32(global_id.y) * rowPerThread;\n" - << "let globalCol = i32(global_id.x) * colPerThread;\n" - << "let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n" - << "let tileRowA = i32(local_id.y) * " << row_per_thread_a << ";\n" - << "let tileColA = i32(local_id.x) * " << col_per_thread_a << ";\n" - << "let tileRowB = i32(local_id.y) * " << row_per_thread_b << ";\n"; - - // Loop over shared dimension. - shader.MainFunctionBody() - << "for (var t = 0; t < num_tiles; t = t + 1) {\n"; - - // Load one tile of A into local memory. - shader.MainFunctionBody() - << " for (var innerRow = 0; innerRow < " << row_per_thread_a << "; innerRow = innerRow + 1) {\n" - << " for (var innerCol = 0; innerCol < " << col_per_thread_a << "; innerCol = innerCol + 1) {\n" - << " let inputRow = tileRowA + innerRow;\n" - << " let inputCol = tileColA + innerCol;\n" - << " mm_Asub[inputRow][inputCol] = mm_readA(batch, globalRowStart + inputRow, kStart + inputCol, batchIndices);\n" - << " }\n" - << " }\n"; - - // Load one tile of B into local memory. - shader.MainFunctionBody() - << " for (var innerRow = 0; innerRow < " << row_per_thread_b << "; innerRow = innerRow + 1) {\n" - << " for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) {\n" - << " let inputRow = tileRowB + innerRow;\n" - << " let inputCol = tileCol + innerCol;\n" - << " mm_Bsub[inputRow][inputCol] = mm_readB(batch, kStart + inputRow, globalCol + innerCol, batchIndices);\n" - << " }\n" - << " }\n" - << " kStart = kStart + tileInner;\n" - << " workgroupBarrier();\n"; - - // Compute acc values for a single thread. - shader.MainFunctionBody() - << "var BCached: array<" << data_type << ", colPerThread>;\n" - << " for (var k = 0; k < tileInner; k = k + 1) {\n" - << " for (var inner = 0; inner < colPerThread; inner = inner + 1) {\n" - << " BCached[inner] = mm_Bsub[k][tileCol + inner];\n" - << " }\n" - << " for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {\n" - << " let ACached = mm_Asub[tileRow + innerRow][k];\n" - << " for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) {\n" - << " acc[innerRow][innerCol] = acc[innerRow][innerCol] + ACached * BCached[innerCol];\n" - << " }\n" - << " }\n" - << " }\n" - << " workgroupBarrier();\n" - << "}\n"; - - // Write the results to the output buffer - shader.MainFunctionBody() - << "for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {\n" - << " for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) {\n" - << " mm_write(batch, globalRow + innerRow, globalCol + innerCol, acc[innerRow][innerCol]);\n" - << " }\n" - << "}\n"; - + if (sequentially_access_by_threads) { + shader.MainFunctionBody() << "let localRow = i32(local_id.y);\n" + << "let localCol = i32(local_id.x);\n" + << "let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n" + << "let globalColStart = i32(workgroup_id.x) * " << tile_b_outer << ";\n" + << "\n" + << "// Loop over shared dimension.\n" + << "for (var t = 0; t < i32(num_tiles); t = t + 1) {\n" + << " // Load one tile of A into local memory.\n" + << " for (var inputRow = localRow; inputRow < " << tile_a_height << "; inputRow = inputRow + " << workgroup_size_y << ") {\n" + << " for (var inputCol = localCol; inputCol < " << tile_a_width << "; inputCol = inputCol + " << workgroup_size_x << ") {\n" + << " " << write_data_to_sub_a_snippet << "\n" + << " }\n" + << " }\n" + << " // Load one tile of B into local memory.\n" + << " for (var inputRow = localRow; inputRow < " << tile_inner << "; inputRow = inputRow + " << workgroup_size_y << ") {\n" + << " for (var inputCol = localCol; inputCol < " << tile_b_outer << "; inputCol = inputCol + " << workgroup_size_x << ") {\n" + << " mm_Bsub[inputRow][inputCol] = mm_readB(batch,\n" + << " kStart + inputRow,\n" + << " globalColStart + inputCol" << (batch_dims ? ", batchIndices" : "") << ");\n " + << " }\n" + << " }\n" + << " kStart = kStart + tileInner;\n" + << " workgroupBarrier();\n" + << "\n" + << " // Compute acc values for a single thread.\n" + << " var BCached : array<" << data_type << ", colPerThread>;\n" + << " for (var k = 0; k < tileInner; k = k + 1) {\n" + << " for (var inner = 0; inner < colPerThread; inner = inner + 1) {\n" + << " BCached[inner] = mm_Bsub[k][localCol + inner * " << workgroup_size_x << "];\n" + << " }\n" + << " for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {\n" + << " let ACached = " << (transpose_a ? "mm_Asub[k][localCol + innerRow * " + std::to_string(workgroup_size_y) + "];" : "mm_Asub[localRow + innerRow * " + std::to_string(workgroup_size_y) + "][k];") << "\n" + << " for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) {\n" + << " acc[innerRow][innerCol] = acc[innerRow][innerCol] +\n" + << " ACached * BCached[innerCol];\n" + << " }\n" + << " }\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n" + << "for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {\n" + << " let gRow = globalRowStart + localRow + innerRow * " << workgroup_size_y << ";\n" + << " for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) {\n" + << " let gCol = globalColStart + localCol + innerCol * " << workgroup_size_x << ";\n" + << " mm_write(batch, gRow, gCol, acc[innerRow][innerCol]);\n" + << " }\n" + << "}\n"; + } else { + shader.MainFunctionBody() + << "let tileRow = i32(local_id.y) * rowPerThread;\n" + << "let tileCol = i32(local_id.x) * colPerThread;\n" + << "let globalRow = i32(global_id.y) * rowPerThread;\n" + << "let globalCol = i32(global_id.x) * colPerThread;\n" + << "let globalRowStart = i32(workgroup_id.y) * " << tile_a_outer << ";\n" + << "let tileRowA = i32(local_id.y) * " << row_per_thread_a << ";\n" + << "let tileColA = i32(local_id.x) * " << col_per_thread_a << ";\n" + << "let tileRowB = i32(local_id.y) * " << row_per_thread_b << ";\n"; + + // Loop over shared dimension. + shader.MainFunctionBody() + << "for (var t = 0; t < i32(num_tiles); t = t + 1) {\n"; + + // Load one tile of A into local memory. + shader.MainFunctionBody() + << " for (var innerRow = 0; innerRow < i32(" << row_per_thread_a << "); innerRow = innerRow + 1) {\n" + << " for (var innerCol = 0; innerCol < i32(" << col_per_thread_a << "); innerCol = innerCol + 1) {\n" + << " let inputRow = tileRowA + innerRow;\n" + << " let inputCol = tileColA + innerCol;\n" + << " " << write_data_to_sub_a_snippet << "\n" + << " }\n" + << " }\n"; + + // Load one tile of B into local memory. + shader.MainFunctionBody() + << " for (var innerRow = 0; innerRow < i32(" << row_per_thread_b << "); innerRow = innerRow + 1) {\n" + << " for (var innerCol = 0; innerCol < i32(colPerThread); innerCol = innerCol + 1) {\n" + << " let inputRow = tileRowB + innerRow;\n" + << " let inputCol = tileCol + innerCol;\n" + << " mm_Bsub[inputRow][inputCol] = mm_readB(batch, kStart + inputRow, globalCol + innerCol" << (nullptr != batch_dims ? ", batchIndices" : "") << ");\n" + << " }\n" + << " }\n" + << " kStart = kStart + tileInner;\n" + << " workgroupBarrier();\n"; + + // Compute acc values for a single thread. + shader.MainFunctionBody() + << "var BCached: array<" << data_type << ", colPerThread>;\n" + << " for (var k = 0; k < tileInner; k = k + 1) {\n" + << " for (var inner = 0; inner < i32(colPerThread); inner = inner + 1) {\n" + << " BCached[inner] = mm_Bsub[k][tileCol + inner];\n" + << " }\n" + << " for (var innerRow = 0; innerRow < i32(rowPerThread); innerRow = innerRow + 1) {\n" + << " let ACached = mm_Asub[tileRow + innerRow][k];\n" + << " for (var innerCol = 0; innerCol < i32(colPerThread); innerCol = innerCol + 1) {\n" + << " acc[innerRow][innerCol] = acc[innerRow][innerCol] + ACached * BCached[innerCol];\n" + << " }\n" + << " }\n" + << " }\n" + << " workgroupBarrier();\n" + << "}\n"; + + // Write the results to the output buffer + shader.MainFunctionBody() + << "for (var innerRow = 0; innerRow < i32(rowPerThread); innerRow = innerRow + 1) {\n" + << " for (var innerCol = 0; innerCol < i32(colPerThread); innerCol = innerCol + 1) {\n" + << " mm_write(batch, globalRow + innerRow, globalCol + innerCol, acc[innerRow][innerCol]);\n" + << " }\n" + << "}\n"; + } return Status::OK(); } Status MatMulProgram::GenerateShaderCode(ShaderHelper& shader) const { const auto& a = shader.AddInput("a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); const auto& b = shader.AddInput("b", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); - const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); const auto& batch_dims = shader.AddIndices("batch_dims", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); if (has_bias_) { shader.AddInput("bias", ShaderUsage::UseUniform); } - + std::string apply_activation = GetActivationSnippet(activation_, "output_value_t", "output_element_t"); // declare the read and write functions - MatMulReadWriteFnSource(shader, a, b, output, batch_dims); - + MatMulReadWriteFnSource(shader, a, b, output, batch_dims, apply_activation); + std::string data_type = "a_element_t"; // generate the main function if (is_vec4_) { - ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, batch_dims, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY())); + ORT_RETURN_IF_ERROR(MakeMatMulPackedVec4Source(shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, &batch_dims)); } else { - ORT_RETURN_IF_ERROR(MakeMatMulPackedSource(shader, batch_dims, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY())); + ORT_RETURN_IF_ERROR(MakeMatMulPackedSource(shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, &batch_dims)); } return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/math/matmul_packed.h b/onnxruntime/core/providers/webgpu/math/matmul_packed.h index ea76468944066..d3a68ff8a57fa 100644 --- a/onnxruntime/core/providers/webgpu/math/matmul_packed.h +++ b/onnxruntime/core/providers/webgpu/math/matmul_packed.h @@ -7,38 +7,53 @@ #include "core/providers/webgpu/program.h" #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/math/matmul_utils.h" +#include "core/providers/webgpu/nn/fuse_utils.h" namespace onnxruntime { namespace webgpu { class MatMulProgram final : public Program { public: - MatMulProgram(bool bias, bool is_vec4, const gsl::span& elements_per_thread) : Program{"MatMul"}, - has_bias_{bias}, - is_vec4_{is_vec4}, - elements_per_thread_(elements_per_thread.begin(), elements_per_thread.end()) {} + MatMulProgram(const Activation& activation, bool bias, bool is_vec4, const gsl::span& elements_per_thread, bool is_channels_last = false) : Program{"MatMul"}, + activation_(activation), + has_bias_{bias}, + is_vec4_{is_vec4}, + elements_per_thread_(elements_per_thread.begin(), elements_per_thread.end()), + is_channels_last_(is_channels_last) {} Status GenerateShaderCode(ShaderHelper& sh) const override; - WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Int32}, - {"dim_b_outer", ProgramUniformVariableDataType::Int32}, - {"dim_inner", ProgramUniformVariableDataType::Int32}); + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"dim_a_outer", ProgramUniformVariableDataType::Uint32}, + {"dim_b_outer", ProgramUniformVariableDataType::Uint32}, + {"dim_inner", ProgramUniformVariableDataType::Uint32}); static Status MakeMatMulPackedVec4Source(ShaderHelper& shader, - const ShaderIndicesHelper& batch_dims, const InlinedVector& elements_per_thread, uint32_t workgroup_size_x, - uint32_t workgroup_size_y); + uint32_t workgroup_size_y, + const std::string& data_type, + const ShaderIndicesHelper* batch_dims, + bool transpose_a = false, + uint32_t tile_inner = 32, + bool split_k = false, + uint32_t splitted_dim_inner = 32); static Status MakeMatMulPackedSource(ShaderHelper& shader, - const ShaderIndicesHelper& batch_dims, const InlinedVector& elements_per_thread, uint32_t workgroup_size_x, - uint32_t workgroup_size_y); + uint32_t workgroup_size_y, + const std::string& data_type, + const ShaderIndicesHelper* batch_dims, + bool transpose_a = false, + uint32_t tile_inner = 32, + bool split_k = false, + uint32_t splitted_dim_inner = 32, + bool sequentially_access_by_threads = false); private: + const Activation& activation_; const bool has_bias_; const bool is_vec4_; const InlinedVector elements_per_thread_; - - void MatMulReadWriteFnSource(ShaderHelper& shader, const ShaderVariableHelper& a, const ShaderVariableHelper& b, const ShaderVariableHelper& output, const ShaderIndicesHelper& batch_dims) const; + bool is_channels_last_ = false; + void MatMulReadWriteFnSource(ShaderHelper& shader, const ShaderVariableHelper& a, const ShaderVariableHelper& b, const ShaderVariableHelper& output, const ShaderIndicesHelper& batch_dims, std::string apply_activation) const; }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/nn/activation_util.cc b/onnxruntime/core/providers/webgpu/nn/activation_util.cc new file mode 100644 index 0000000000000..b5c31d98cda93 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/activation_util.cc @@ -0,0 +1,25 @@ +#include "core/providers/webgpu/nn/activation_util.h" +#include "core/common/common.h" +namespace onnxruntime { +namespace webgpu { +std::string TypeSnippet(uint32_t component, std::string data_type) { + switch (component) { + case 1: + return data_type; + case 2: + return "vec2<" + data_type + ">"; + case 3: + return "vec3<" + data_type + ">"; + case 4: + return "vec4<" + data_type + ">"; + default: + ORT_THROW("Component ", component, " is not supported."); + } +} + +std::string BiasSnippet(bool has_bias) { + return has_bias ? "value = value + getBiasByOutputCoords(coords);" : ""; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/activation_util.h b/onnxruntime/core/providers/webgpu/nn/activation_util.h new file mode 100644 index 0000000000000..1c9fd93e35384 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/activation_util.h @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +namespace onnxruntime { +namespace webgpu { + +extern std::string TypeSnippet(uint32_t component, std::string data_type); +extern std::string BiasSnippet(bool has_bias); + +} // namespace webgpu +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/nn/conv.cc b/onnxruntime/core/providers/webgpu/nn/conv.cc new file mode 100644 index 0000000000000..0edad3eebe2ea --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/conv.cc @@ -0,0 +1,273 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/providers/webgpu/nn/conv.h" +#include "core/providers/webgpu/nn/conv2d_mm_webgpu.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/tensor/transpose.h" +#include "core/providers/webgpu/nn/grouped_conv.h" +#include "core/providers/webgpu/webgpu_utils.h" +#include "core/providers/webgpu/math/matmul.h" +namespace onnxruntime { +namespace webgpu { + +Status TransposeKernel(ComputeContext& context, const Tensor* kernel, const TensorShape& kernel_shape, Tensor* transposed_kernel, const InlinedVector& perm) { + // Transpose weights + auto rank = kernel_shape.NumDimensions(); + TensorShapeVector transposed_kernel_shape_vector(rank); + for (size_t i = 0; i < rank; ++i) { + transposed_kernel_shape_vector[i] = kernel_shape[perm[i]]; + } + uint32_t output_size = onnxruntime::narrow(kernel_shape.Size()); + TensorShape transposed_kernel_shape(transposed_kernel_shape_vector); + *transposed_kernel = context.CreateGPUTensor(kernel->DataType(), transposed_kernel_shape); + bool use_shared = false; + TransposeProgram program{perm, use_shared}; + program + .CacheHint(absl::StrJoin(perm, "-")) + .AddInput({kernel, ProgramTensorMetadataDependency::TypeAndRank, kernel_shape, 1}) + .AddOutput({transposed_kernel, ProgramTensorMetadataDependency::TypeAndRank}) + .AddUniformVariable({output_size}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); + return context.RunProgram(program); +} + +template +Status Conv::ComputeInternal(ComputeContext& context) const { + bool has_bias = context.InputCount() > 2; + const auto* input = context.Input(0); + const auto* kernel = context.Input(1); + const auto* bias = has_bias ? context.Input(2) : nullptr; + TensorShape input_shape = input->Shape(); + TensorShape kernel_shape = kernel->Shape(); + ConvAttributes::ConvPadVector local_pads(conv_attrs_.pads.begin(), conv_attrs_.pads.end()); + TensorShapeVector local_dilations(conv_attrs_.dilations.begin(), conv_attrs_.dilations.end()); + TensorShapeVector local_strides(conv_attrs_.strides.begin(), conv_attrs_.strides.end()); + TensorShapeVector kernel_spacial_shape_vector; + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(kernel_shape, kernel_spacial_shape_vector, false)); + if (local_pads.empty()) { + local_pads.resize(kernel_spacial_shape_vector.size() * 2, 0); + } + if (local_dilations.empty()) { + local_dilations.resize(kernel_spacial_shape_vector.size(), 1); + } + if (local_strides.empty()) { + local_strides.resize(kernel_spacial_shape_vector.size(), 1); + } + TensorShapeVector input_shape_vector = input_shape.AsShapeVector(); + auto batch = input_shape[0]; + TensorShapeVector output_shape_vector = {batch}; + TensorShape input_spacial_shape = is_channels_last ? TensorShape(TensorShapeVector(std::next(input_shape_vector.begin()), std::prev(input_shape_vector.end()))) : input_shape.Slice(2); + ORT_RETURN_IF_ERROR(conv_attrs_.InferPadsAndOutputShape(input_spacial_shape, kernel_spacial_shape_vector, local_strides, local_dilations, local_pads, output_shape_vector)); + auto output_channels = kernel_shape[0]; + if (is_channels_last) { + output_shape_vector.push_back(output_channels); + } else { + output_shape_vector.insert(output_shape_vector.begin() + 1, output_channels); + } + auto output_shape = TensorShape(output_shape_vector); + auto* output = context.Output(0, output_shape); + std::vector strides; + std::vector pads; + std::vector dilations; + auto transform_dim = [](int64_t dim) { return static_cast(dim); }; + std::transform(local_pads.begin(), local_pads.end(), std::back_inserter(pads), transform_dim); + std::transform(local_strides.begin(), local_strides.end(), std::back_inserter(strides), transform_dim); + std::transform(local_dilations.begin(), local_dilations.end(), std::back_inserter(dilations), transform_dim); + auto rank = input_shape.NumDimensions(); + const InlinedVector perm = {2, 3, 1, 0}; + if (rank > 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Only Conv1d and Conv2d are supported."); + } else if (rank == 4) { + // Conv2D + } else if (rank == 3) { + // Conv1D + TensorShapeVector kernel_shape_vector = kernel_shape.AsShapeVector(); + input_shape_vector.insert(input_shape_vector.begin() + (is_channels_last ? 1 : 2), 1, 1); + output_shape_vector.insert(output_shape_vector.begin() + (is_channels_last ? 1 : 2), 1, 1); + kernel_shape_vector.insert(kernel_shape_vector.begin() + 2, 1); + input_shape = TensorShape(input_shape_vector); + kernel_shape = TensorShape(kernel_shape_vector); + pads.insert(pads.begin(), 0); + pads.insert(pads.begin() + 2, 0); + strides.insert(strides.begin(), 1); + dilations.insert(dilations.begin(), 1); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input and kernel tensors must have at least 3 dimensions"); + } + std::vector inputs(has_bias ? 3 : 2); + inputs[0] = input; + inputs[1] = kernel; + if (has_bias) { + inputs[2] = bias; + } + std::vector modified_input_output_shapes = {input_shape, kernel_shape}; + if (has_bias) { + modified_input_output_shapes.push_back(bias->Shape()); + } + modified_input_output_shapes.push_back(TensorShape(output_shape_vector)); + uint32_t auto_pad_adjust = conv_attrs_.auto_pad == AutoPadType::SAME_LOWER ? 1 : 0; + auto pad0 = conv_attrs_.auto_pad == AutoPadType::NOTSET ? pads[0] : (pads[0] + pads[2] + auto_pad_adjust) / 2; + auto pad1 = conv_attrs_.auto_pad == AutoPadType::NOTSET ? pads[1] : (pads[1] + pads[3] + auto_pad_adjust) / 2; + std::vector updated_pads{pad0, pad1}; + if (conv_attrs_.group > 1) { + Tensor transposed_kernel; + if (is_channels_last) { + ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm)); + inputs[1] = &transposed_kernel; + modified_input_output_shapes[1] = transposed_kernel.Shape(); + } + auto output_channels_per_group = output_channels / conv_attrs_.group; + auto components = static_cast(is_channels_last && output_channels_per_group >= 4 ? GetMaxComponents(output_channels) : 1); + auto output_size = output_shape.Size() / components; + GroupedConvProgram program(activation_, has_bias, is_channels_last); + auto reduced_kernel_shape = ReduceShapeByComponents(modified_input_output_shapes[1], components); + auto reduced_output_shape = ReduceShapeByComponents(modified_input_output_shapes[has_bias ? 3 : 2], components); + program.CacheHint(activation_.ToString(), std::to_string(components), std::to_string(is_channels_last)) + .AddInput({inputs[0], ProgramTensorMetadataDependency::TypeAndRank, modified_input_output_shapes[0], 1}) + .AddInput({inputs[1], ProgramTensorMetadataDependency::TypeAndRank, reduced_kernel_shape, components}) + .AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, reduced_output_shape, components}) + .AddUniformVariables({{static_cast(output_size)}, {dilations}, {strides}, {updated_pads}, {static_cast(output_channels_per_group)}, {static_cast(components)}}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); + if (has_bias) { + auto reduced_bias_shape = ReduceShapeByComponents(modified_input_output_shapes[2], components); + program.AddInput({inputs[2], ProgramTensorMetadataDependency::TypeAndRank, reduced_bias_shape, components}); + } + return context.RunProgram(program); + } + const auto input_height = input_shape[is_channels_last ? 1 : 2]; + const auto input_width = input_shape[is_channels_last ? 2 : 3]; + const auto input_channels = input_shape[is_channels_last ? 3 : 1]; + const auto kernel_height = kernel_shape[2]; + const auto kernel_width = kernel_shape[3]; + const auto output_height = output_shape_vector[is_channels_last ? 1 : 2]; + const auto output_width = output_shape_vector[is_channels_last ? 2 : 3]; + + const auto same_size = is_channels_last && input_height == kernel_height && input_width == kernel_width && pads[0] == 0 && pads[1] == 0; + if (same_size || (kernel_height == 1 && kernel_width == 1 && pads[0] == 0 && pads[1] == 0 && strides[0] == 1 && strides[1] == 1)) { + Tensor transposed_kernel; + TensorShape input_reshape; + TensorShape kernel_reshape; + TensorShape matmul_output_shape; + std::vector matmul_inputs; + std::vector matmul_input_reshapes; + if (is_channels_last) { + // Transpose weights + + ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm)); + inputs[1] = &transposed_kernel; + if (same_size) { + const auto shared_dim = input_height * input_width * input_channels; + input_reshape = TensorShape({1, batch, shared_dim}); + kernel_reshape = TensorShape({1, shared_dim, output_channels}); + matmul_output_shape = TensorShape({1, batch, output_channels}); + } else { + input_reshape = TensorShape({batch, input_height * input_width, input_channels}); + kernel_reshape = TensorShape({1, input_channels, output_channels}); + matmul_output_shape = TensorShape({batch, output_height * output_width, output_channels}); + } + matmul_inputs.push_back(input); + matmul_inputs.push_back(&transposed_kernel); + matmul_input_reshapes.push_back(input_reshape); + matmul_input_reshapes.push_back(kernel_reshape); + } else { + input_reshape = TensorShape({batch, input_channels, input_height * input_width}); + kernel_reshape = TensorShape({1, output_channels, input_channels}); + matmul_output_shape = TensorShape({batch, output_channels, output_height * output_width}); + matmul_inputs.push_back(kernel); + matmul_inputs.push_back(input); + matmul_input_reshapes.push_back(kernel_reshape); + matmul_input_reshapes.push_back(input_reshape); + } + if (has_bias) { + matmul_inputs.push_back(bias); + } + auto N = matmul_output_shape[2]; + auto matmul_first_input_numdims = matmul_input_reshapes[0].NumDimensions(); + auto K = matmul_input_reshapes[0].GetDims()[matmul_first_input_numdims - 1]; + if (N < 8 && K < 8) { + const auto components = GetMaxComponents(N); + const auto a_components = GetMaxComponents(K); + const auto output_number = GetMaxComponents(output_shape[1]); + uint32_t output_size = static_cast(output_shape.Size() / components / output_number); + const size_t output_rank = matmul_output_shape.NumDimensions(); + TensorShape outer_dims = output_rank > 2 ? matmul_output_shape.Slice(0, output_rank - 2) : TensorShape({}); + MatMulNaiveProgram program(activation_, output_rank, output_number, has_bias); + program + .CacheHint(std::to_string(components), std::to_string(a_components), std::to_string(output_number)) + .AddInputs({{matmul_inputs[0], ProgramTensorMetadataDependency::TypeAndRank, ReduceShapeByComponents(matmul_input_reshapes[0], a_components), int(a_components)}, + {matmul_inputs[1], ProgramTensorMetadataDependency::TypeAndRank, ReduceShapeByComponents(matmul_input_reshapes[1], components), int(components)}}); + if (has_bias) { + program.AddInput({bias, ProgramTensorMetadataDependency::Rank, bias->Shape(), components}); + } + program + .AddOutputs({{output, ProgramTensorMetadataDependency::None, ReduceShapeByComponents(matmul_output_shape, components), int(components)}}) + .SetDispatchGroupSize(static_cast((output_size + 63) / 64)) + .AddIndices(outer_dims) + .AddUniformVariables({{output_size}, {static_cast(matmul_output_shape[1])}, {static_cast(matmul_output_shape[2])}, {static_cast(K)}}); + return context.RunProgram(program); + } else { + MatMulProgram program = CreateMatMulProgram(activation_, matmul_inputs, output, is_channels_last, matmul_input_reshapes[0], matmul_input_reshapes[1]); + return context.RunProgram(program); + } + } + const bool sequentially_access_by_threads = true; + // Transpose weights + Tensor transposed_kernel; + ORT_RETURN_IF_ERROR(TransposeKernel(context, kernel, kernel_shape, &transposed_kernel, perm)); + auto dim_a_outer = static_cast(is_channels_last ? output_height * output_width : output_channels); + auto dim_b_outer = static_cast(is_channels_last ? output_channels : output_height * output_width); + auto dim_inner = static_cast(kernel_height * kernel_width * input_channels); + inputs[1] = &transposed_kernel; + TensorShape transposed_kernel_shape = transposed_kernel.Shape(); + modified_input_output_shapes[1] = transposed_kernel.Shape(); + Conv2dMMProgram conv2d_mm_program = CreateConv2dMMProgram(activation_, inputs, pads, strides, dilations, output, dim_a_outer, dim_b_outer, dim_inner, is_channels_last, sequentially_access_by_threads, modified_input_output_shapes); + return context.RunProgram(conv2d_mm_program); +} + +// Explicit template instantiation for FusedConv +template class Conv; +template class Conv; +template class Conv; +template class Conv; + +#define WEBGPU_ONNX_CONV_OPERATOR_KERNEL(VERSION_FROM) \ + ONNX_OPERATOR_KERNEL_EX( \ + Conv, \ + kMSInternalNHWCDomain, \ + VERSION_FROM, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), \ + Conv); \ + \ + ONNX_OPERATOR_KERNEL_EX( \ + Conv, \ + kOnnxDomain, \ + VERSION_FROM, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), \ + Conv); + +#define WEBGPU_ONNX_CONV_OPERATOR_VERSIONED_KERNEL(VERSION_FROM, VERSION_TO) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + Conv, \ + kOnnxDomain, \ + VERSION_FROM, VERSION_TO, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), \ + Conv); \ + \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + Conv, \ + kMSInternalNHWCDomain, \ + VERSION_FROM, VERSION_TO, \ + kWebGpuExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), \ + Conv); + +WEBGPU_ONNX_CONV_OPERATOR_VERSIONED_KERNEL(1, 10) +WEBGPU_ONNX_CONV_OPERATOR_VERSIONED_KERNEL(11, 21) +WEBGPU_ONNX_CONV_OPERATOR_KERNEL(22) + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/conv.h b/onnxruntime/core/providers/webgpu/nn/conv.h new file mode 100644 index 0000000000000..cafaa272c0613 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/conv.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/optional.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/cpu/nn/conv_attributes.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/nn/fuse_utils.h" + +namespace onnxruntime { +namespace webgpu { + +template +class Conv : public WebGpuKernel { + public: + Conv(const OpKernelInfo& info) : WebGpuKernel(info), conv_attrs_(info) { + if (is_fused) { + ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK()); + } + } + Status ComputeInternal(ComputeContext& context) const override; + + protected: + ConvAttributes conv_attrs_; + Activation activation_; +}; + +Status TransposeKernel(ComputeContext& context, const Tensor* kernel, const TensorShape& kernel_shape, Tensor* transposed_kernel, const InlinedVector& perm); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/conv2d_mm_webgpu.cc b/onnxruntime/core/providers/webgpu/nn/conv2d_mm_webgpu.cc new file mode 100644 index 0000000000000..24e49304cf532 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/conv2d_mm_webgpu.cc @@ -0,0 +1,232 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include +#include +#include +#include +#include "core/providers/webgpu/nn/conv2d_mm_webgpu.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/nn/activation_util.h" +#include "core/providers/webgpu/math/matmul_packed.h" +#include "core/providers/webgpu/nn/conv_utils.h" +#include "core/providers/webgpu/nn/fuse_utils.h" +#include "core/providers/webgpu/webgpu_utils.h" + +namespace onnxruntime { +namespace webgpu { +std::string Conv2dMMProgram::Conv2dCommonSnippet(const ShaderVariableHelper& x, const ShaderVariableHelper& w, const Activation& activation, std::string data_type, uint32_t inner_element_size_x, uint32_t inner_element_size_w, uint32_t inner_element_size) const { + auto get_x_snippet = [&](int32_t inner_element_size) -> std::string { + switch (inner_element_size) { + case 1: + return "resData = " + x.GetByOffset("xIndex") + ";"; + case 3: + return "resData = vec3(" + x.GetByOffset("xIndex") + ", " + x.GetByOffset("xIndex + 1") + ", " + x.GetByOffset("xIndex + 2") + ");"; + case 4: + return "resData = " + x.GetByOffset("xIndex") + ";\n "; + default: + ORT_THROW("inner_element_size", inner_element_size, " is not supported."); + } + }; + auto get_w_snippet = [&](int32_t inner_element_size) -> std::string { + switch (inner_element_size) { + case 1: + return "return " + w.GetByOffset("row * i32(uniforms.w_shape[3]) + colIn") + ";\n"; + case 4: + return "return " + w.GetByOffset("row * i32(uniforms.w_shape[3]) + colIn") + ";\n"; + default: + ORT_THROW("inner_element_size ", inner_element_size, " is not supported."); + } + }; + const std::string coord_a_snippet = is_channels_last_ ? "let coord = vec4(batch, xRow, xCol, xCh / " + std::to_string(inner_element_size_x == 3 ? 4 : inner_element_size_x) + ");" : "let coord = vec4(batch, xCh, xRow, xCol);"; + const std::string coord_res_snippet = is_channels_last_ ? "let coords = vec4(batch, row / outWidth, row % outWidth, col / " + std::to_string(inner_element_size) + ");" : "let coords = vec4(batch, row, col / outWidth, col % outWidth);"; + + const std::string xHeight = is_channels_last_ ? "i32(uniforms.x_shape[1])" : "i32(uniforms.x_shape[2])"; + const std::string xWidth = is_channels_last_ ? "i32(uniforms.x_shape[2])" : "i32(uniforms.x_shape[3])"; + const std::string row = is_channels_last_ ? "row" : "col"; + const std::string col = is_channels_last_ ? "col" : "row"; + std::stringstream read_x_snippet; + read_x_snippet + << "let inChannels = i32(uniforms.w_shape[2]);\n" + << "let outWidth = " << (is_channels_last_ ? "i32(uniforms.result_shape[2])" : "i32(uniforms.result_shape[3])") << ";\n" + << "let outRow = " << row << " / outWidth;\n " + << "let outCol = " << row << " % outWidth;\n" + << "let WRow = " << col << " / (i32(uniforms.w_shape[1]) * inChannels);\n" + << "let WCol = " << col << " / inChannels % i32(uniforms.w_shape[1]);\n" + << "let xRow = outRow * i32(uniforms.strides[0]) + i32(uniforms.dilations[0]) * WRow - i32(uniforms.pads[0]);\n" + << "let xCol = outCol * i32(uniforms.strides[1]) + i32(uniforms.dilations[1]) * WCol - i32(uniforms.pads[1]);\n" + << "let xCh = " << col << " % inChannels;\n" + << "var resData = " << TypeSnippet(inner_element_size_x, data_type) << "(0.0);\n " + << "// The bounds checking is always needed since we use it to pad zero for\n" + << "// the \" same \" padding type.\n" + << "if (xRow >= 0 && xRow < " << xHeight << " && xCol >= 0 && xCol < " << xWidth << ") {\n" + << " " << coord_a_snippet << "\n" + << " let xIndex = getIndexFromCoords4D(coord, vec4(uniforms.x_shape));\n" + << " " << get_x_snippet(inner_element_size_x) + << "}\n" + << "return resData;"; + std::stringstream sample_x; + if (is_channels_last_) { + if (fit_a_outer_ && fit_inner_) { + sample_x << "let col = colIn * " << inner_element_size_x << ";\n" + << read_x_snippet.str(); + } else { + sample_x << "let col = colIn * " << inner_element_size_x << ";\n" + << "if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_inner)) {\n" + << " " << read_x_snippet.str() << "\n" + << "}\n" + << "return " << TypeSnippet(inner_element_size_x, data_type) << "(0.0);\n"; + } + } else { + if (fit_inner_ && fit_b_outer_) { + sample_x << "let col = colIn * " << inner_element_size_x << ";\n" + << read_x_snippet.str(); + } else { + sample_x << "let col = colIn * " << inner_element_size_x << ";\n" + << "if (row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_b_outer)) {\n" + << " " << read_x_snippet.str() << "\n" + << "}\n" + << "return " << TypeSnippet(inner_element_size_x, data_type) << "(0.0);\n"; + } + } + std::stringstream sample_w; + if (is_channels_last_) { + if (fit_inner_ && fit_b_outer_) { + sample_w << get_w_snippet(inner_element_size_w); + } else { + sample_w << "let col = colIn * " << inner_element_size_w << ";\n" + << "if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_b_outer)) {\n" + << " " << get_w_snippet(inner_element_size_w) << "\n" + << "}\n" + << "return " << TypeSnippet(inner_element_size_w, data_type) << "(0.0);\n"; + } + } else { + sample_w << "let col = colIn * " << inner_element_size_w << ";\n" + << "if (row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_b_outer)) {\n" + << " " << get_w_snippet(inner_element_size_w) << "\n" + << "}\n" + << "return " << TypeSnippet(inner_element_size_w, data_type) << "(0.0);\n"; + } + const std::string res_type = TypeSnippet(inner_element_size, data_type); + const std::string a_type = is_channels_last_ ? TypeSnippet(inner_element_size_x, data_type) : TypeSnippet(inner_element_size_w, data_type); + const std::string b_type = is_channels_last_ ? TypeSnippet(inner_element_size_w, data_type) : TypeSnippet(inner_element_size_x, data_type); + const std::string apply_activation = GetActivationSnippet(activation, res_type, data_type); + std::stringstream user_code; + user_code << "fn mm_readA(batch : i32, row : i32, colIn : i32) -> " << a_type << " {\n" + << (is_channels_last_ ? sample_x.str() : sample_w.str()) + << "}\n" + << "\n" + << "fn mm_readB(batch : i32, row : i32, colIn : i32) -> " << b_type << " {\n" + << (is_channels_last_ ? sample_w.str() : sample_x.str()) + << "}\n" + << "\n" + << "fn mm_write(batch : i32, row : i32, colIn : i32, valueIn : " << res_type << ") {\n" + << " let col = colIn * " << inner_element_size << ";\n" + << " if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_b_outer)) {\n" + << " var value = valueIn;\n" + << " let outWidth = " << (is_channels_last_ ? " i32(uniforms.result_shape[2]) " : " i32(uniforms.result_shape[3]) ") << ";\n" + << " " << coord_res_snippet << "\n" + << " " << BiasSnippet(has_bias_) << "\n" + << " " << apply_activation << "\n" + << " setOutputAtCoords(coords[0], coords[1], coords[2], coords[3], value);\n" + << " }\n" + << "}\n"; + return user_code.str(); +} + +Status Conv2dMMProgram::GenerateShaderCode(ShaderHelper& shader) const { + std::stringstream declaration_functions; + declaration_functions << "fn setOutputAtIndex(flatIndex : i32, value : " << (is_vec4_ ? "vec4" : "x_element_t") << ") {\n" + << " result[flatIndex] = " << (is_vec4_ ? "vec4" : "x_element_t") << "(value);\n" + << "}\n" + << "fn setOutputAtCoords(d0 : i32, d1 : i32, d2 : i32, d3 : i32, value : " << (is_vec4_ ? "vec4" : "x_element_t") << "){\n" + << " let flatIndex = getOutputIndexFromCoords(vec4(d0, d1, d2, d3));\n" + << " setOutputAtIndex(flatIndex, value);\n" + << "}\n"; + const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& w = shader.AddInput("w", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + std::vector inputs = {&x, &w}; + ORT_IGNORE_RETURN_VALUE(shader.AddOutput("result", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride | ShaderUsage::UseIndicesTypeAlias)); + if (has_bias_) { + const auto& bias = shader.AddInput("bias", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + inputs.push_back(&bias); + declaration_functions << "fn getBiasByOutputCoords(coords : vec4) -> bias_value_t {" << "\n" + << " return bias[" << (is_channels_last_ ? "coords.w" : "coords.y") << "];\n" + << "}"; + } + shader.AdditionalImplementation() + << UtilFunctions("uniforms.result_stride") + << declaration_functions.str() + << Conv2dCommonSnippet(x, w, activation_, "x_element_t", element_size_[0], element_size_[1], element_size_[2]); + std::string data_type = "x_element_t"; + return is_vec4_ ? MatMulProgram::MakeMatMulPackedVec4Source(shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, /* batch_dims = */ nullptr, /* transpose_a = */ !is_channels_last_, tile_inner_) : MatMulProgram::MakeMatMulPackedSource(shader, elements_per_thread_, WorkgroupSizeX(), WorkgroupSizeY(), data_type, /* batch_dims = */ nullptr, false, tile_inner_, false, 0, sequentially_access_by_threads_); +} + +Conv2dMMProgram CreateConv2dMMProgram(const Activation& activation, const std::vector& inputs, const std::vector& pads, const std::vector& strides, const std::vector& dilations, Tensor* output, uint32_t dim_a_outer, uint32_t dim_b_outer, uint32_t dim_inner, bool is_channels_last, bool sequentially_access_by_threads, const std::vector& input_output_shapes) { + const auto* input = inputs[0]; + const auto* weight = inputs[1]; + bool has_bias = inputs.size() > 2; + const auto* bias = has_bias ? inputs[2] : nullptr; + const auto& input_shape = input_output_shapes[0]; + auto in_channels = is_channels_last ? input_shape[3] : input_shape[1]; + const auto& output_shape = has_bias ? input_output_shapes[3] : input_output_shapes[2]; + auto batch_size = output_shape[0]; + const auto output_width = is_channels_last ? output_shape[2] : output_shape[3]; + const auto output_height = is_channels_last ? output_shape[1] : output_shape[2]; + const auto output_channels = is_channels_last ? output_shape[3] : output_shape[1]; + // TODO: enable vec4 for NCHW + const bool is_vec4 = is_channels_last && (in_channels % 4 == 0 || in_channels % 3 == 0) && output_channels % 4 == 0; + + // TODO: fine tune size + const auto dispatch_x = is_channels_last ? output_channels : output_width * output_height; + const auto dispatch_y = is_channels_last ? output_width * output_height : output_channels; + std::vector workgroup_size = {8, 8, 1}; + InlinedVector elements_per_thread = {4, static_cast(dim_a_outer <= 8 ? 1 : 4), 1}; + auto integer_ceil = [](int64_t a, int64_t b) -> int64_t { return (a + b - 1) / b; }; + + const std::vector dispatch = { + static_cast(integer_ceil(integer_ceil(dispatch_x, workgroup_size[0]), elements_per_thread[0])), + static_cast(integer_ceil(integer_ceil(dispatch_y, workgroup_size[1]), elements_per_thread[1])), + static_cast(integer_ceil(integer_ceil(batch_size, workgroup_size[2]), elements_per_thread[2])), + }; + + uint32_t inner_element_size = is_vec4 ? (is_channels_last && in_channels % 4 != 0 ? 3 : 4) : 1; + auto tile_a_outer = static_cast(workgroup_size[1] * elements_per_thread[1]); + auto tile_b_outer = static_cast(workgroup_size[0] * elements_per_thread[0]); + auto tile_inner = std::max(workgroup_size[0] * inner_element_size, workgroup_size[1]); + bool fit_a_outer = dim_a_outer % tile_a_outer == 0; + bool fit_b_outer = dim_b_outer % tile_b_outer == 0; + bool fit_inner = dim_inner % tile_inner == 0; + std::vector element_size = {is_vec4 ? inner_element_size : 1, static_cast(is_vec4 ? 4 : 1), static_cast(is_vec4 ? 4 : 1)}; + const auto components = is_vec4 ? 4 : 1; + const auto input_components = static_cast(inner_element_size == 3 ? 1 : inner_element_size); + Conv2dMMProgram program(activation, tile_inner, fit_a_outer, fit_b_outer, fit_inner, is_channels_last, is_vec4, has_bias, std::move(element_size), std::move(elements_per_thread), sequentially_access_by_threads); + TensorShape reduced_input_shape = ReduceShapeByComponents(input_output_shapes[0], input_components); + TensorShape reduced_weight_shape = ReduceShapeByComponents(input_output_shapes[1], components); + TensorShape reduced_output_shape = ReduceShapeByComponents(input_output_shapes[has_bias ? 3 : 2], components); + program.AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank, reduced_input_shape, input_components}, {weight, ProgramTensorMetadataDependency::TypeAndRank, reduced_weight_shape, components}}); + if (has_bias) { + TensorShape reduced_bias_shape = ReduceShapeByComponents(input_output_shapes[2], components); + program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, reduced_bias_shape, components}); + } + const auto stringify = [](const std::vector& vec) -> std::string { + std::ostringstream oss; + std::transform(vec.begin(), vec.end(), std::ostream_iterator(oss, ","), [](uint32_t i) { return std::to_string(i); }); + return oss.str(); + }; + program.CacheHint(activation.ToString(), stringify({inner_element_size, static_cast(is_vec4 ? 1 : 0), fit_a_outer, fit_b_outer, fit_inner, tile_a_outer, tile_a_outer, tile_inner, static_cast(components)})) + .AddOutput({output, ProgramTensorMetadataDependency::TypeAndRank, reduced_output_shape, components}) + .SetDispatchGroupSize(dispatch[0], dispatch[1], dispatch[2]) + .SetWorkgroupSize(workgroup_size[0], workgroup_size[1], workgroup_size[2]) + .AddUniformVariables({{static_cast(dim_a_outer)}, + {static_cast(dim_b_outer)}, + {static_cast(dim_inner)}, + {pads}, + {strides}, + {dilations}}); + + return program; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/conv2d_mm_webgpu.h b/onnxruntime/core/providers/webgpu/nn/conv2d_mm_webgpu.h new file mode 100644 index 0000000000000..0087d11db179d --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/conv2d_mm_webgpu.h @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include "core/common/inlined_containers.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" +#include "core/framework/tensor_shape.h" +#include "core/framework/tensor.h" +#include "core/providers/webgpu/nn/fuse_utils.h" +#include "core/providers/webgpu/shader_helper.h" + +namespace onnxruntime { +namespace webgpu { +class Conv2dMMProgram final : public Program { + public: + Conv2dMMProgram(const Activation& activation, uint32_t tile_inner, bool fit_a_outer, bool fit_b_outer, bool fit_inner, bool is_channels_last, bool is_vec4, bool has_bias, std::vector&& element_size, InlinedVector&& elements_per_thread, bool sequentially_access_by_threads) : Program("Conv2dMM"), + activation_(activation), + tile_inner_(tile_inner), + fit_a_outer_(fit_a_outer), + fit_b_outer_(fit_b_outer), + fit_inner_(fit_inner), + is_channels_last_(is_channels_last), + is_vec4_(is_vec4), + has_bias_(has_bias), + element_size_(std::move(element_size)), + elements_per_thread_(std::move(elements_per_thread)), + sequentially_access_by_threads_(sequentially_access_by_threads) {} + + std::string Conv2dCommonSnippet(const ShaderVariableHelper& x, const ShaderVariableHelper& w, const Activation& activation, std::string data_type, uint32_t inner_element_size_x = 4, uint32_t inner_element_size_w = 4, uint32_t inner_element_size = 4) const; + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"dim_a_outer", ProgramUniformVariableDataType::Uint32}, + {"dim_b_outer", ProgramUniformVariableDataType::Uint32}, + {"dim_inner", ProgramUniformVariableDataType::Uint32}, + {"pads", ProgramUniformVariableDataType::Uint32}, + {"strides", ProgramUniformVariableDataType::Uint32}, + {"dilations", ProgramUniformVariableDataType::Uint32}); + + private: + const Activation& activation_; + uint32_t tile_inner_; + bool fit_a_outer_; + bool fit_b_outer_; + bool fit_inner_; + bool is_channels_last_; + bool is_vec4_; + bool has_bias_; + std::vector element_size_; + InlinedVector elements_per_thread_; + bool sequentially_access_by_threads_; +}; + +Conv2dMMProgram CreateConv2dMMProgram(const Activation& activation, const std::vector& inputs, const std::vector& pads, const std::vector& strides, const std::vector& dilations, Tensor* output, uint32_t dim_a_outer, uint32_t dim_b_outer, uint32_t dim_inner, bool is_channels_last, bool sequentially_access_by_threads, const std::vector& modified_input_output_shapes); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/conv_backprop_webgpu.cc b/onnxruntime/core/providers/webgpu/nn/conv_backprop_webgpu.cc new file mode 100644 index 0000000000000..aa3ef5b96db54 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/conv_backprop_webgpu.cc @@ -0,0 +1,191 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include +#include +#include +#include "core/common/inlined_containers.h" +#include "core/providers/webgpu/nn/conv_backprop_webgpu.h" +#include "core/providers/webgpu/webgpu_utils.h" +namespace onnxruntime { +namespace webgpu { + +Status ConvTranspose2DProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& dy = shader.AddInput("dy", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + const auto& w = shader.AddInput("w", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + if (has_bias_) { + shader.AddInput("bias"); + } + auto row_dim = is_channels_last_ ? 1 : 2; + auto col_dim = is_channels_last_ ? 2 : 3; + auto channel_dim = is_channels_last_ ? 3 : 1; + auto calculate_result = [&]() -> std::string { + std::stringstream ss; + if (pack_input_as4_) { + if (a_components_ == 4) { + ss << "let xValue = " << dy.GetByOffset("x_offset") << ";\n" + << "let wValue = " << w.GetByOffset("w_offset") << ";\n" + << "dotProd = dotProd + dot(xValue, wValue);\n" + << "x_offset += 1;\n" + << "w_offset += 1;\n"; + } else if (a_components_ == 2) { + ss << "let xValue = vec4(" << dy.GetByOffset("x_offset") << ", " << dy.GetByOffset("x_offset + 1") << ");\n" + << "let wValue = vec4(" << w.GetByOffset("w_offset") << ", " << w.GetByOffset("w_offset + 1u") << ");\n" + << "dotProd = dotProd + dot(xValue, wValue);\n" + << "x_offset += 2;\n" + << "w_offset += 2;\n"; + } else if (a_components_ == 1) { + ss << "let xValue = vec4(" << dy.GetByOffset("x_offset") << ", " << dy.GetByOffset("x_offset + 1u") << ", " << dy.GetByOffset("x_offset + 2u") << ", " << dy.GetByOffset("x_offset + 3u") << ");\n" + << "let wValue = vec4(" << w.GetByOffset("x_offset") << ", " << w.GetByOffset("x_offset + 1u") << ", " << w.GetByOffset("x_offset + 2u") << ", " << w.GetByOffset("x_offset + 3u") << ");\n" + << "dotProd = dotProd + dot(xValue, wValue);\n" + << "x_offset += 4;\n" + << "w_offset += 4;\n"; + } + } else { + if (is_channels_last_) { + ss << "let xValue = " << dy.GetByIndices("dy_indices_t(batch, idyR, idyC, inputChannel / " + std::to_string(a_components_)) << ");\n"; + } else { + ss << "let xValue = " << dy.GetByIndices("dy_indices_t(batch, inputChannel, idyR, idyC)") << ";\n"; + } + if (a_components_ == 1) { + ss << "let wValue = " << w.GetByIndices("w_indices_t(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)") << ";\n" + << "dotProd = dotProd + xValue * wValue;\n"; + } else if (a_components_ == b_components_ && components_ == 1) { + ss << "let wValue = " << w.GetByIndices("w_indices_t(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel)") << ";\n" + << "dotProd = dotProd + dot(xValue, wValue);\n"; + } else { + for (uint32_t i = 0; i < a_components_; ++i) { + ss << "let w_indices" << i << " = w_indices_t(u32(wRPerm), u32(wCPerm), inputChannel + d2 + " << i << ", wOutChannel);\n " + << "let w_offset" << i << " = " << w.IndicesToOffset("w_indices" + std::to_string(i)) << ";\n" + << "let wValue" << i << " = " << w.GetByIndices("w_indices" + std::to_string(i)) << ";\n" + << "dotProd = dotProd + xValue[" << i << "] * wValue" << i << ";\n"; + } + } + } + return ss.str(); + }; + auto calculate_remainder = [&]() -> std::string { + std::stringstream ss; + if (input_channels_remainder_ > 0) { + ORT_ENFORCE(pack_input_as4_, "Invalid input_channels_remainder: ", input_channels_remainder_); + if (a_components_ == 1) { + for (uint32_t i = 0; i < input_channels_remainder_; ++i) { + ss << "dotProd = dotProd + " << dy.GetByOffset("x_offset + " + std::to_string(i)) << ";\n"; + } + } else if (a_components_ == 2) { + if (input_channels_remainder_ != 2) { + ORT_THROW("Invalid input_channels_remainder: ", input_channels_remainder_); + } + ss << "let xValue = " << dy.GetByOffset("x_offset") << ";\n" + << "let wValue = " << w.GetByOffset("w_offset") << ";\n" + << "dotProd = dotProd + dot(xValue, wValue);\n"; + } + } + return ss.str(); + }; + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << "let outputIndices = " << output.OffsetToIndices("global_idx") << ";\n" + << "let batch = " << output.IndicesGet("outputIndices", 0) << ";\n" + << "let d1 = " << output.IndicesGet("outputIndices", channel_dim) << ";\n" + << "let r = " << output.IndicesGet("outputIndices", row_dim) << ";\n" + << "let c = " << output.IndicesGet("outputIndices", col_dim) << ";\n" + << "let dyCorner = vec2(i32(r), i32(c)) - vec2(uniforms.pads);\n" + << "let dyRCorner = dyCorner.x;\n" + << "let dyCCorner = dyCorner.y;\n" + << "let groupId = d1 / (uniforms.output_channels_per_group / " << components_ << ");\n" + << "let wOutChannel = d1 - groupId * (uniforms.output_channels_per_group / " << components_ << ");\n" + << "// Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).\n" + << "// ? = to be determined. : = across all values in that axis.\n" + << "var dotProd = output_value_t(0.0);\n" + << "var wR: u32 = 0;\n" + << "if (uniforms.dilations.x == 1) {\n" + << " // Minimum wR >= 0 that satisfies (dyRCorner + wR) % (uniforms.strides.x) == 0\n" + << " wR = u32(((dyRCorner + i32(uniforms.strides.x) - 1) / i32(uniforms.strides.x)) * i32(uniforms.strides.x) - dyRCorner);\n" + << "}\n" + << "for (; wR < uniforms.effective_filter_dims.x; wR = wR + 1) {\n" + << " if (wR % uniforms.dilations.x != 0) {\n" + << " continue;\n" + << " }\n" + << " let dyR = (dy_element_t(dyRCorner) + dy_element_t(wR)) / dy_element_t(uniforms.strides[0]);\n" + << " let wRPerm = uniforms.filter_dims.x - 1 - wR / uniforms.dilations.x;\n" + << " if (dyR < 0.0 || dyR >= dy_element_t(uniforms.dy_shape[" << row_dim << "]) || fract(dyR) > 0.0 || wRPerm < 0) {\n" + << " continue;\n" + << " }\n" + << " let idyR: u32 = u32(dyR);\n" + << " var wC: u32 = 0;\n" + << " if (uniforms.dilations.y == 1) {\n" + << " // Minimum wC >= 0 that satisfies (dyCCorner + wC) % (uniforms.strides.y) == 0\n" + << " wC = u32(((dyCCorner + i32(uniforms.strides.y) - 1) / i32(uniforms.strides.y)) * i32(uniforms.strides.y) - dyCCorner);\n" + << " }\n" + << " for (; wC < uniforms.effective_filter_dims.y; wC = wC + 1) {\n" + << " if (wC % uniforms.dilations.y != 0) {" + << " continue;\n" + << " }\n" + << " let dyC = (dy_element_t(dyCCorner) + dy_element_t(wC)) / dy_element_t(uniforms.strides.y);\n" + << " let wCPerm = uniforms.filter_dims.y - 1 - wC / uniforms.dilations.y;\n" + << " if (dyC < 0.0 || dyC >= dy_element_t(uniforms.dy_shape[" << col_dim << "]) ||\n" + << " fract(dyC) > 0.0 || wCPerm < 0) {\n" + << " continue;\n" + << " }\n" + << " let idyC: u32 = u32(dyC);\n" + << " var inputChannel = groupId * uniforms.input_channels_per_group;\n"; + if (pack_input_as4_) { + shader.MainFunctionBody() << " let dy_indices = dy_indices_t(batch, idyR, idyC, inputChannel);\n" + << " let w_indices = w_indices_t(u32(wRPerm), u32(wCPerm), inputChannel, wOutChannel);\n" + << " var x_offset = " << dy.IndicesToOffset("dy_indices") << ";\n" + << " var w_offset = " << w.IndicesToOffset("w_indices") << ";\n"; + } + + shader.MainFunctionBody() << " for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group_int; d2 = d2 + " << (pack_input_as4_ ? 4 : a_components_) << ") {\n" + << " " << calculate_result() << "\n" + << " inputChannel = inputChannel + " << (pack_input_as4_ ? 4 : 1) << ";\n" + << " }\n" + << " " << calculate_remainder() << "\n" + << " wC = wC + uniforms.strides.y - 1;\n" + << " }\n" + << " wR = wR + uniforms.strides.x - 1;\n" + << "}\n" + << "let value = dotProd" << (has_bias_ ? " + bias[d1]" : "") << ";\n" + << output.SetByOffset("global_idx", "value") << "\n"; + return Status::OK(); +} + +ConvTranspose2DProgram CreateConvTranspose2DProgram(const std::vector& inputs, const std::vector& pads, const std::vector& strides, const std::vector& dilations, Tensor* output, bool is_channels_last, const std::vector& modified_input_output_shapes, uint32_t groups) { + bool has_bias = inputs.size() > 2; + const auto* input = inputs[0]; + const auto* weight = inputs[1]; + const auto& input_shape = modified_input_output_shapes[0]; + const auto& weight_shape = modified_input_output_shapes[1]; + const auto& output_shape = modified_input_output_shapes[has_bias ? 3 : 2]; + auto input_channels_per_group = weight_shape[2] / groups; + auto output_channels_per_group = weight_shape[3]; + auto a_components = is_channels_last ? GetMaxComponents(input_channels_per_group) : 1; + bool pack_input_as4 = is_channels_last && output_channels_per_group == 1 && input_channels_per_group >= 4; + auto input_channels_per_group_int = pack_input_as4 ? ((input_channels_per_group + 3) / 4) * 4 : (input_channels_per_group / a_components) * a_components; + auto input_channels_remainder = input_channels_per_group - input_channels_per_group_int; + auto components = is_channels_last ? GetMaxComponents(output_channels_per_group) : 1; + auto b_components = is_channels_last ? (output_channels_per_group == 1 ? a_components : components) : 1; + TensorShape reduced_input_shape = ReduceShapeByComponents(input_shape, a_components); + TensorShape reduced_weight_shape = ReduceShapeByComponents(weight_shape, b_components); + TensorShape reduced_output_shape = ReduceShapeByComponents(output_shape, components); + auto output_size = reduced_output_shape.Size(); + std::vector kernel_dims = {static_cast(weight_shape[0]), static_cast(weight_shape[1])}; + std::vector effective_kernel_dims = {kernel_dims[0] + ((dilations[0] <= 1) ? 0 : ((kernel_dims[0] - 1) * (dilations[0] - 1))), kernel_dims[1] + ((dilations[1] <= 1) ? 0 : ((kernel_dims[1] - 1) * (dilations[1] - 1)))}; + std::vector local_pads = {effective_kernel_dims[0] - 1 - pads[0], effective_kernel_dims[1] - 1 - pads[1]}; + ConvTranspose2DProgram program(is_channels_last, has_bias, components, a_components, b_components, uint32_t(input_channels_remainder), pack_input_as4); + program.AddInputs({{input, ProgramTensorMetadataDependency::TypeAndRank, reduced_input_shape, a_components}, {weight, ProgramTensorMetadataDependency::TypeAndRank, reduced_weight_shape, b_components}}); + if (has_bias) { + const auto* bias = inputs[2]; + const auto& bias_shape = modified_input_output_shapes[2]; + TensorShape reduced_bias_shape = ReduceShapeByComponents(bias_shape, components); + program.AddInput({bias, ProgramTensorMetadataDependency::TypeAndRank, reduced_bias_shape, components}); + } + program.AddOutput({output, ProgramTensorMetadataDependency::Rank, reduced_output_shape, components}) + .AddUniformVariables({{static_cast(output_size)}, {strides}, {kernel_dims}, {dilations}, {effective_kernel_dims}, {local_pads}, {static_cast(input_channels_per_group_int)}, {static_cast(input_channels_per_group)}, {static_cast(output_channels_per_group)}}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); + + return program; +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/conv_backprop_webgpu.h b/onnxruntime/core/providers/webgpu/nn/conv_backprop_webgpu.h new file mode 100644 index 0000000000000..6c784e4825a65 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/conv_backprop_webgpu.h @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "core/common/inlined_containers.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" +#include "core/framework/tensor_shape.h" +#include "core/framework/tensor.h" + +namespace onnxruntime { +namespace webgpu { + +class ConvTranspose2DProgram : public Program { + public: + ConvTranspose2DProgram(bool is_channels_last, bool has_bias, uint32_t components, uint32_t a_components, uint32_t b_components, uint32_t input_channels_remainder, bool pack_input_as4) : Program("ConvTranspose2D"), is_channels_last_(is_channels_last), has_bias_(has_bias), components_(components), a_components_(a_components), b_components_(b_components), input_channels_remainder_(input_channels_remainder), pack_input_as4_(pack_input_as4) { + } + + Status GenerateShaderCode(ShaderHelper& sh) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"output_size", ProgramUniformVariableDataType::Uint32}, + {"strides", ProgramUniformVariableDataType::Uint32}, + {"filter_dims", ProgramUniformVariableDataType::Uint32}, + {"dilations", ProgramUniformVariableDataType::Uint32}, + {"effective_filter_dims", ProgramUniformVariableDataType::Uint32}, + {"pads", ProgramUniformVariableDataType::Uint32}, + {"input_channels_per_group_int", ProgramUniformVariableDataType::Uint32}, + {"input_channels_per_group", ProgramUniformVariableDataType::Uint32}, + {"output_channels_per_group", ProgramUniformVariableDataType::Uint32}); + + private: + bool is_channels_last_; + bool has_bias_; + uint32_t components_; + uint32_t a_components_; + uint32_t b_components_; + uint32_t input_channels_remainder_; + bool pack_input_as4_; +}; + +ConvTranspose2DProgram CreateConvTranspose2DProgram(const std::vector& inputs, const std::vector& pads, const std::vector& strides, const std::vector& dilations, Tensor* output, bool is_channels_last, const std::vector& modified_input_output_shapes, uint32_t groups); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/conv_transpose.cc b/onnxruntime/core/providers/webgpu/nn/conv_transpose.cc new file mode 100644 index 0000000000000..9cd290ef56013 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/conv_transpose.cc @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "conv.h" +#include "conv_transpose.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/cpu/nn/conv_attributes.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "core/providers/webgpu/tensor/transpose.h" +#include "core/providers/webgpu/nn/conv_backprop_webgpu.h" + +namespace onnxruntime { +namespace webgpu { +// kernel shape is the spacial dims of the filter. +// ie. filter shape with batch and channel. kernel shape dimension is 2 less than the filter dimension + +template +Status ConvTranspose::ComputeInternal(ComputeContext& context) const { + const auto* input = context.Input(0); + const auto* filter = context.Input(1); + TensorShape input_shape = input->Shape(); + TensorShape filter_shape = filter->Shape(); + const InlinedVector perm = {2, 3, 0, 1}; + TensorShapeVector local_output_padding(conv_transpose_attrs_.output_padding.begin(), conv_transpose_attrs_.output_padding.end()); + ConvAttributes::ConvPadVector local_pads(conv_transpose_attrs_.pads.begin(), conv_transpose_attrs_.pads.end()); + TensorShapeVector local_dilations(conv_transpose_attrs_.dilations.begin(), conv_transpose_attrs_.dilations.end()); + TensorShapeVector local_strides(conv_transpose_attrs_.strides.begin(), conv_transpose_attrs_.strides.end()); + TensorShapeVector kernel_shape_vector; + auto rank = input_shape.NumDimensions(); + TensorShape input_spacial_shape = input_shape.Slice(is_channels_last ? 1 : 2, is_channels_last ? rank - 1 : rank); + local_pads.reserve(2 * (input_spacial_shape.NumDimensions())); + ORT_RETURN_IF_ERROR(conv_transpose_attrs_.ComputeKernelShape(filter_shape, kernel_shape_vector, false)); + if (local_output_padding.empty()) { + local_output_padding.resize(kernel_shape_vector.size(), 0); + } + if (local_pads.empty()) { + local_pads.resize(kernel_shape_vector.size() * 2, 0); + } + if (local_dilations.empty()) { + local_dilations.resize(kernel_shape_vector.size(), 1); + } + if (local_strides.empty()) { + local_strides.resize(kernel_shape_vector.size(), 1); + } + auto group = conv_transpose_attrs_.group; + auto num_output_channels = group * filter_shape[1]; + auto batch_size = input_shape[0]; + TensorShapeVector output_shape_vector; + conv_transpose_attrs_.ComputePadsAndOutputShape(input_spacial_shape, num_output_channels, kernel_shape_vector, local_strides, local_dilations, local_output_padding, batch_size, &local_pads, &output_shape_vector, is_channels_last); + TensorShape computed_output_shape(output_shape_vector); + std::vector strides; + std::vector pads; + std::vector dilations; + auto transform_dim = [](int64_t dim) { return static_cast(dim); }; + std::transform(local_pads.begin(), local_pads.end(), std::back_inserter(pads), transform_dim); + std::transform(local_strides.begin(), local_strides.end(), std::back_inserter(strides), transform_dim); + std::transform(local_dilations.begin(), local_dilations.end(), std::back_inserter(dilations), transform_dim); + + bool has_bias = context.InputCount() > 2; + const auto* bias = has_bias ? context.Input(2) : nullptr; + if (input_shape.NumDimensions() == 3 && filter_shape.NumDimensions() == 3) { + // ConvTranspose1D + TensorShapeVector input_shape_vector = input_shape.AsShapeVector(); + TensorShapeVector filter_shape_vector = filter_shape.AsShapeVector(); + input_shape_vector.insert(input_shape_vector.begin() + (is_channels_last ? 1 : 2), 1, 1); + output_shape_vector.insert(output_shape_vector.begin() + (is_channels_last ? 1 : 2), 1, 1); + filter_shape_vector.insert(filter_shape_vector.begin() + 2, 1); + input_shape = TensorShape(input_shape_vector); + filter_shape = TensorShape(filter_shape_vector); + pads.insert(pads.begin(), 0); + pads.insert(pads.begin() + 2, 0); + strides.insert(strides.begin(), 1); + dilations.insert(dilations.begin(), 1); + } + if (input_shape.NumDimensions() > 4 || filter_shape.NumDimensions() > 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Only Conv2d or Conv1d are supported."); + } else if (input_shape.NumDimensions() < 2 || filter_shape.NumDimensions() < 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input and kernel tensors must have at least 3 dimensions"); + } + // Transpose weights + Tensor transposed_filter; + ORT_RETURN_IF_ERROR(TransposeKernel(context, filter, filter_shape, &transposed_filter, perm)); + TensorShape output_shape(output_shape_vector); + TensorShape transposed_filter_shape = transposed_filter.Shape(); + std::vector inputs = {input, &transposed_filter}; + std::vector input_output_shapes = {input_shape, transposed_filter_shape}; + if (has_bias) { + inputs.push_back(bias); + input_output_shapes.push_back(bias->Shape()); + } + uint32_t auto_pad_adjust = conv_transpose_attrs_.auto_pad == AutoPadType::SAME_LOWER ? 1 : 0; + auto pad0 = conv_transpose_attrs_.auto_pad == AutoPadType::NOTSET ? pads[0] : (pads[0] + pads[2] + auto_pad_adjust) / 2; + auto pad1 = conv_transpose_attrs_.auto_pad == AutoPadType::NOTSET ? pads[1] : (pads[1] + pads[3] + auto_pad_adjust) / 2; + Tensor* output = context.Output(0, computed_output_shape); + input_output_shapes.push_back(output_shape); + auto program = CreateConvTranspose2DProgram(inputs, {pad0, pad1}, strides, dilations, output, is_channels_last, input_output_shapes, static_cast(conv_transpose_attrs_.group)); + return context.RunProgram(program); +} + +ONNX_OPERATOR_KERNEL_EX( + ConvTranspose, + kMSInternalNHWCDomain, + 11, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + ConvTranspose); + +ONNX_OPERATOR_KERNEL_EX( + ConvTranspose, + kOnnxDomain, + 11, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + ConvTranspose); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + ConvTranspose, + kMSInternalNHWCDomain, + 1, 10, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + ConvTranspose); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + ConvTranspose, + kOnnxDomain, + 1, 10, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + ConvTranspose); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/conv_transpose.h b/onnxruntime/core/providers/webgpu/nn/conv_transpose.h new file mode 100644 index 0000000000000..a97b3f5947303 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/conv_transpose.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/common.h" + +#include "core/providers/cpu/nn/conv_transpose_attributes.h" +#include "core/providers/webgpu/webgpu_kernel.h" +namespace onnxruntime { +namespace webgpu { + +template +class ConvTranspose final : public WebGpuKernel { + public: + ConvTranspose(const OpKernelInfo& info) : WebGpuKernel(info), conv_transpose_attrs_(info) { + } + Status ComputeInternal(ComputeContext& context) const override; + + protected: + ConvTransposeAttributes conv_transpose_attrs_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/conv_utils.cc b/onnxruntime/core/providers/webgpu/nn/conv_utils.cc new file mode 100644 index 0000000000000..233662c10bfb8 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/conv_utils.cc @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/nn/conv_utils.h" +namespace onnxruntime { +namespace webgpu { +std::string UtilFunctions(std::string stride_string) { + std::stringstream ss; + ss << "fn getIndexFromCoords3D(coords : vec3, shape : vec3) -> i32 {\n" + << " return dot(coords, vec3(shape.y * shape.z, shape.z, 1));\n" + << "}\n" + << "fn getIndexFromCoords4D(coords : vec4, shape : vec4) -> i32 {\n" + << " return dot(coords, vec4(shape.y * shape.z * shape.w, shape.z * shape.w, shape.w, 1));\n" + << "}\n" + << "fn getOutputIndexFromCoords(coords : vec4) -> i32 {\n" + << " return dot(coords, vec4(i32(" << stride_string << ".x), i32(" << stride_string << ".y), i32(" << stride_string << ".z), 1));\n" + << "}\n"; + return ss.str(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/conv_utils.h b/onnxruntime/core/providers/webgpu/nn/conv_utils.h new file mode 100644 index 0000000000000..ad8aa868ff7f0 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/conv_utils.h @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +namespace onnxruntime { +namespace webgpu { + +std::string UtilFunctions(std::string stride_string); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/fuse_utils.cc b/onnxruntime/core/providers/webgpu/nn/fuse_utils.cc new file mode 100644 index 0000000000000..38db604695a54 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/fuse_utils.cc @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/nn/fuse_utils.h" +#include +namespace onnxruntime { +namespace webgpu { + +Status GetFusedActivationAttr(const OpKernelInfo& info, Activation& activation) { + activation.activation_kind_ = ActivationKind::None; + + std::string activation_type; + if (info.GetAttr("activation", &activation_type).IsOK()) { + if (activation_type == "Relu") { + activation.activation_kind_ = ActivationKind::Relu; + } else if (activation_type == "Tanh") { + activation.activation_kind_ = ActivationKind::Tanh; + } else if (activation_type == "Sigmoid") { + activation.activation_kind_ = ActivationKind::Sigmoid; + } else { + // The remaining activation types have additional parameters to be pulled out. + size_t activation_params_count; + if (activation_type == "LeakyRelu") { + activation.activation_kind_ = ActivationKind::LeakyRelu; + activation_params_count = 1; + } else if (activation_type == "Clip") { + activation.activation_kind_ = ActivationKind::Clip; + activation_params_count = 2; + } else if (activation_type == "HardSigmoid") { + activation.activation_kind_ = ActivationKind::HardSigmoid; + activation_params_count = 2; + } else { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "unimplemented activation: " + activation_type); + } + + std::vector activation_params; + common::Status status = info.GetAttrs("activation_params", activation_params); + if (!status.IsOK()) { + return status; + } else if (activation_params_count != activation_params.size()) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "activation_params count mismatch"); + } + for (size_t i = 0; i < activation_params_count; i++) { + activation.activation_params_.values_[i] = activation_params[i]; + } + } + } + + return Status::OK(); +} + +std::string GetActivationSnippet(const Activation& activation, std::string value_type, std::string base_type) { + std::string snippet; + auto base_type_cast = [base_type](float value) -> std::string { + return base_type + "(" + std::to_string(value) + ")"; + }; + auto value_type_cast = [base_type_cast, value_type](float f) -> std::string { + return value_type + "(" + base_type_cast(f) + ")"; + }; + switch (activation.activation_kind_) { + case ActivationKind::Relu: + return "value = max(value, " + value_type_cast(0.0) + ");"; + case ActivationKind::Sigmoid: + return "value = " + value_type_cast(1.0) + " / (" + value_type_cast(1.0) + " + exp(-value));"; + case ActivationKind::Clip: + return "value = clamp(value, " + value_type_cast(activation.activation_params_.Clip.minimum_) + ", " + value_type_cast(activation.activation_params_.Clip.maximum_) + ");"; + case ActivationKind::HardSigmoid: + return "value = clamp(" + value_type_cast(activation.activation_params_.HardSigmoid.alpha_) + " * value + " + value_type_cast(activation.activation_params_.HardSigmoid.beta_) + ", 0.0" + ", 1.0" + ");"; + case ActivationKind::LeakyRelu: + return "value = select(" + base_type_cast(activation.activation_params_.LeakyRelu.alpha_) + " * value, value, value >= " + value_type_cast(0.0) + ");"; + case ActivationKind::Tanh: + return "value = tanh(value);"; + default: + return ""; + } +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/fuse_utils.h b/onnxruntime/core/providers/webgpu/nn/fuse_utils.h new file mode 100644 index 0000000000000..f5d2585bb9b45 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/fuse_utils.h @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include +#include "core/providers/webgpu/webgpu_kernel.h" + +#pragma once +namespace onnxruntime { +namespace webgpu { +enum class ActivationKind { + None, + Relu, + Sigmoid, + Clip, + HardSigmoid, + LeakyRelu, + Tanh +}; + +using Activation = struct Activation { + std::string ToString() const { + std::stringstream oss; + oss << "ActivationKind: " << static_cast(activation_kind_) << ";"; + oss << "ActivationParams: " << activation_params_.values_[0] << ";"; + oss << "ActivationParams: " << activation_params_.values_[1] << ";"; + return oss.str(); + } + using ActivationParameters = union ActivationParameters { + struct { + float alpha_; + } LeakyRelu; + struct { + float minimum_; + float maximum_; + } Clip; + struct { + float alpha_; + float beta_; + } HardSigmoid; + float values_[2]; + }; + ActivationParameters activation_params_ = {}; + ActivationKind activation_kind_ = ActivationKind::None; +}; + +Status GetFusedActivationAttr(const OpKernelInfo& info, Activation& activation); +std::string GetActivationSnippet(const Activation& activation, std::string value_type, std::string base_type); +// Status AppendActivationUniformsData(const Activation& activation, std::vector& variables); +// Status AppendActivationUniforms(const Activation& activation, std::vector& data); + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/grouped_conv.cc b/onnxruntime/core/providers/webgpu/nn/grouped_conv.cc new file mode 100644 index 0000000000000..4dc0b82cdd7eb --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/grouped_conv.cc @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include "core/providers/webgpu/nn/grouped_conv.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/shader_variable.h" +#include "core/providers/webgpu/nn/fuse_utils.h" + +namespace onnxruntime { +namespace webgpu { + +std::string CanculateResult(const ShaderVariableHelper& x, const ShaderVariableHelper& w, bool is_channels_last) { + std::stringstream ss; + if (is_channels_last) { + ss << "for (var wHeight: u32 = 0u; wHeight < uniforms.w_shape[0]; wHeight++) {\n" + << " let xHeight = xRCCorner.x + wHeight * uniforms.dilations[0];\n" + << " if (xHeight < 0u || xHeight >= uniforms.x_shape[1]) {\n" + << " continue;\n" + << " }\n" + << "" + << " for (var wWidth: u32 = 0u; wWidth < uniforms.w_shape[1]; wWidth++) {\n" + << " let xWidth = xRCCorner.y + wWidth * uniforms.dilations[1];\n" + << " if (xWidth < 0u || xWidth >= uniforms.x_shape[2]) {\n" + << " continue;\n" + << " }\n" + << "" + << " for (var wInChannel: u32 = 0u; wInChannel < uniforms.w_shape[2]; wInChannel++) {\n" + << " let input_channel = in_channel_offset + wInChannel;\n" + << " let x_indices = x_indices_t(batch, xHeight, xWidth, input_channel);\n" + << " let w_indices = w_indices_t(wHeight, wWidth, wInChannel, output_channel);\n" + << " let xVal = " << x.GetByIndices("x_indices") << ";\n" + << " let wVal = " << w.GetByIndices("w_indices") << ";\n" + << " value += xVal * wVal;\n" + << " }\n" + << " }\n" + << "}\n"; + } else { + ss << "for (var wInChannel: u32 = 0u; wInChannel < uniforms.w_shape[1]; wInChannel++) {\n" + << " let input_channel = in_channel_offset + wInChannel;\n" + << " for (var wHeight: u32 = 0u; wHeight < uniforms.w_shape[2]; wHeight++) {\n" + << " let xHeight = xRCCorner.x + wHeight * uniforms.dilations[0];\n" + << "" + << " if (xHeight < 0u || xHeight >= uniforms.x_shape[2]) {\n" + << " continue;\n" + << " }\n" + << "" + << " for (var wWidth: u32 = 0u; wWidth < uniforms.w_shape[3]; wWidth++) {\n" + << " let xWidth = xRCCorner.y + wWidth * uniforms.dilations[1];\n" + << " if (xWidth < 0u || xWidth >= uniforms.x_shape[3]) {\n" + << " continue;\n" + << " }\n" + << "" + << " let x_indices = x_indices_t(batch, input_channel, xHeight, xWidth);\n" + << " let w_indices = w_indices_t(output_channel, wInChannel, wHeight, wWidth);\n" + << " let xVal = " << x.GetByIndices("x_indices") << ";\n" + << " let wVal = " << w.GetByIndices("w_indices") << ";\n" + << " value += xVal * wVal;\n" + << " }\n" + << " }\n" + << "}\n"; + } + return ss.str(); +} + +Status GroupedConvProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& x = shader.AddInput("x", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseIndicesTypeAlias); + const auto& w = shader.AddInput("w", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseIndicesTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); + std::string apply_activation = GetActivationSnippet(activation_, "output_value_t", "output_element_t"); + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << "let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << "let batch: u32 = output_indices[0];\n" + << "let output_channel: u32 = " << output.IndicesGet("output_indices", is_channels_last_ ? "3" : "1") << ";\n" + << "let xRCCorner_x: u32 = " << output.IndicesGet("output_indices", is_channels_last_ ? "1" : "2") << ";\n" + << "let xRCCorner_y: u32 = " << output.IndicesGet("output_indices", is_channels_last_ ? "2" : "3") << ";\n" + << "let xRCCorner: vec2 = vec2(xRCCorner_x, xRCCorner_y) * uniforms.strides - uniforms.pads;\n" + << "let group_id = output_channel * uniforms.components / uniforms.output_channels_per_group;\n" + << "let in_channel_offset = group_id * " << w.IndicesGet("uniforms.w_shape", is_channels_last_ ? 2 : 1) << ";\n" + << "var value: output_value_t = output_value_t(0);\n" + << CanculateResult(x, w, is_channels_last_); + if (has_bias_) { + const auto& b = shader.AddInput("b", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); + shader.MainFunctionBody() << "value += " + b.GetByIndices("output_channel") + ";\n"; + } + shader.MainFunctionBody() << apply_activation << "\n"; + shader.MainFunctionBody() << output.SetByOffset("global_idx", "value"); + return Status::OK(); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/grouped_conv.h b/onnxruntime/core/providers/webgpu/nn/grouped_conv.h new file mode 100644 index 0000000000000..d09f9679eecf5 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/nn/grouped_conv.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/optional.h" +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/cpu/nn/conv_attributes.h" +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/nn/fuse_utils.h" + +namespace onnxruntime { +namespace webgpu { + +class GroupedConvProgram final : public Program { + public: + GroupedConvProgram(const Activation& activation, bool has_bias, bool is_channels_last) : Program("GroupedConv"), activation_(activation), has_bias_(has_bias), is_channels_last_(is_channels_last) { + } + Status GenerateShaderCode(ShaderHelper& shader) const override; + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES( + {"output_size", ProgramUniformVariableDataType::Uint32}, + {"dilations", ProgramUniformVariableDataType::Uint32}, + {"strides", ProgramUniformVariableDataType::Uint32}, + {"pads", ProgramUniformVariableDataType::Uint32}, + {"output_channels_per_group", ProgramUniformVariableDataType::Uint32}, + {"components", ProgramUniformVariableDataType::Uint32}); + + private: + const Activation& activation_; + bool has_bias_; + bool is_channels_last_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 2427bf62cc658..eb65e998c81c5 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -250,6 +250,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, Transpose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 20, Transpose); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Transpose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, Transpose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, DepthToSpace); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, DepthToSpace); @@ -257,9 +259,11 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInt class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 13, DepthToSpace); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Conv); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, Conv); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 21, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 22, Conv); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 1, 10, Conv); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, Conv); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 11, 21, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 22, Conv); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ConvTranspose); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, ConvTranspose); @@ -578,21 +582,25 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.cc b/onnxruntime/core/providers/webgpu/webgpu_utils.cc new file mode 100644 index 0000000000000..9b16767475c0c --- /dev/null +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.cc @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/providers/webgpu/webgpu_utils.h" +namespace onnxruntime { +namespace webgpu { + +TensorShape ReduceShapeByComponents(const TensorShape& shape, int64_t components) { + // Reduce the last dimensions by components creating a new tensor shape. + TensorShapeVector shape_vector = shape.AsShapeVector(); + auto reduce_index = shape_vector.size() - 1; + // Find the last dimension that is divisible by components. + while (shape_vector[reduce_index] % components != 0 && reduce_index > 0) { + ORT_ENFORCE(components % shape_vector[reduce_index] == 0, "The components must divide dims"); + components /= shape_vector[reduce_index]; + shape_vector[reduce_index] = 1; + reduce_index--; + } + ORT_ENFORCE(reduce_index >= 0 && shape_vector[reduce_index] % components == 0, "The last non-unit dimension of the input shape must be divisible by the number of components."); + shape_vector[reduce_index] /= components; + return TensorShape(shape_vector); +} + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_utils.h b/onnxruntime/core/providers/webgpu/webgpu_utils.h index 5f6f18f34b7f5..e02d9266e8a0e 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_utils.h +++ b/onnxruntime/core/providers/webgpu/webgpu_utils.h @@ -1,8 +1,11 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #pragma once #include +#include "core/common/common.h" +#include "core/framework/tensor_shape.h" namespace onnxruntime { namespace webgpu { @@ -44,5 +47,7 @@ inline std::string MakeScalarOrVectorType(int components, std::string_view data_ } } +TensorShape ReduceShapeByComponents(const TensorShape& shape, int64_t components); + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index e5ea562ce3535..92eaa68667f0e 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -153,7 +153,7 @@ static bool HasMemcpyNodes(const Graph& graph) { return false; } -static bool AreAllComputeNodesAssignedToCudaOrJsOrDmlEp(const Graph& graph) { +static bool AreAllComputeNodesAssignedToCudaOrJsOrDmlEpWebGpuEp(const Graph& graph) { bool nodes_on_cpu_and_cuda_and_js_and_dml_eps_only = true; for (const auto& node : graph.Nodes()) { @@ -164,6 +164,7 @@ static bool AreAllComputeNodesAssignedToCudaOrJsOrDmlEp(const Graph& graph) { !(node_provider == kCudaExecutionProvider || node_provider == kRocmExecutionProvider || node_provider == kJsExecutionProvider || + node_provider == kWebGpuExecutionProvider || node_provider == kDmlExecutionProvider) && node_provider != kCpuExecutionProvider) { nodes_on_cpu_and_cuda_and_js_and_dml_eps_only = false; @@ -2041,6 +2042,7 @@ common::Status InferenceSession::Initialize() { onnxruntime::kCudaExecutionProvider, onnxruntime::kRocmExecutionProvider, onnxruntime::kJsExecutionProvider, + onnxruntime::kWebGpuExecutionProvider, onnxruntime::kDmlExecutionProvider}; for (auto& it : graph_support_ep_list) { @@ -2063,12 +2065,13 @@ common::Status InferenceSession::Initialize() { if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0 || strcmp(target_ep->Type().c_str(), onnxruntime::kRocmExecutionProvider) == 0 || strcmp(target_ep->Type().c_str(), onnxruntime::kJsExecutionProvider) == 0 || + strcmp(target_ep->Type().c_str(), onnxruntime::kWebGpuExecutionProvider) == 0 || strcmp(target_ep->Type().c_str(), onnxruntime::kDmlExecutionProvider) == 0) { // Ensure that all nodes have been partitioned to CUDA/JS or CPU EP && there are no memcpy nodes // The reasoning behind this logic is that certain shape nodes will be forced onto CPU // and as long as there are no memcpy nodes this is confirmation that no compute nodes have been placed on the CPU EP // which is all we care about. - if (!AreAllComputeNodesAssignedToCudaOrJsOrDmlEp(graph)) { + if (!AreAllComputeNodesAssignedToCudaOrJsOrDmlEpWebGpuEp(graph)) { LOGS(*session_logger_, ERROR) << "This session cannot use the graph capture feature as requested by the user " << " as all compute graph nodes have not been partitioned to the " << target_ep->Type(); diff --git a/onnxruntime/test/contrib_ops/fused_conv_test.cc b/onnxruntime/test/contrib_ops/fused_conv_test.cc index e6fe0ec0e45a3..0dd69a49972e8 100644 --- a/onnxruntime/test/contrib_ops/fused_conv_test.cc +++ b/onnxruntime/test/contrib_ops/fused_conv_test.cc @@ -33,14 +33,16 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes, bool disable_cpu = false, bool disable_cuda = false, bool disable_rocm = false, + bool disable_webgpu = false, bool use_float16 = false, bool weight_is_initializer = false) { bool enable_cuda = HasCudaEnvironment(0) && !use_float16 && !disable_cuda; // Only ROCm EP supports float16. bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()) && !disable_rocm; + bool enable_webgpu = (nullptr != DefaultWebGpuExecutionProvider().get()) && !disable_webgpu; bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()) && !use_float16 && !disable_cpu; - if (enable_cuda || enable_rocm || enable_cpu) { + if (enable_cuda || enable_rocm || enable_cpu || enable_webgpu) { OpTester test("FusedConv", 1, onnxruntime::kMSDomain); test.AddAttribute("group", attributes.group); test.AddAttribute("kernel_shape", attributes.kernel_shape); @@ -96,6 +98,10 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes, execution_providers.push_back(DefaultRocmExecutionProvider()); } + if (enable_webgpu) { + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + } + if (enable_cpu) { execution_providers.push_back(DefaultCpuExecutionProvider()); } @@ -110,15 +116,16 @@ void RunConvOp(const ConvOpAndTestAttributes& attributes, const vector& expected_output_shape, bool disable_cpu = false, bool disable_cuda = false, - bool disable_rocm = false) { + bool disable_rocm = false, + bool disable_webgpu = false) { bool weight_is_initializer = false; bool use_float16 = false; TestConvOp(attributes, inputs, input_shapes, expected_output, expected_output_shape, - disable_cpu, disable_cuda, disable_rocm, use_float16, weight_is_initializer); + disable_cpu, disable_cuda, disable_rocm, disable_webgpu, use_float16, weight_is_initializer); use_float16 = true; TestConvOp(attributes, inputs, input_shapes, expected_output, expected_output_shape, - disable_cpu, disable_cuda, disable_rocm, use_float16, weight_is_initializer); + disable_cpu, disable_cuda, disable_rocm, disable_webgpu, use_float16, weight_is_initializer); } TEST(FusedConvTest, Conv2D_HardSigmoid) { @@ -139,7 +146,7 @@ TEST(FusedConvTest, Conv2D_HardSigmoid) { vector W_shape = {2, 1, 2, 2}; vector Y_shape = {1, 2, 2, 2}; auto expected_vals = {0.8f, 0.9f, 1.0f, 1.0f, 0.2f, 0.1f, 0.0f, 0.0f}; - RunConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, false, true, true); + RunConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, false, true, true, true); } TEST(FusedConvTest, Conv2D_Relu) { @@ -233,7 +240,7 @@ TEST(FusedConvTest, Cpu_Conv2D_Bias_Z_Relu) { vector Z = {-1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f}; vector Z_shape = {1, 2, 2, 2}; auto expected_vals = {12.0f, 17.0f, 25.0f, 29.0f, 11.0f, 15.0f, 23.0f, 28.0f}; - RunConvOp(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape, false, true, true); + RunConvOp(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape, false, true, true, true); } #endif diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index a66964de17c72..a33b3148014f1 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -2204,6 +2204,10 @@ TEST_F(GraphTransformationTests, FuseCudaConvAddReluIdentity) { for (auto& node : p_model->MainGraph().Nodes()) { node.SetExecutionProviderType(kJsExecutionProvider); } +#elif defined(USE_WEBGPU) + for (auto& node : p_model->MainGraph().Nodes()) { + node.SetExecutionProviderType(kWebGpuExecutionProvider); + } #else for (auto& node : p_model->MainGraph().Nodes()) { node.SetExecutionProviderType(kCpuExecutionProvider); @@ -2232,6 +2236,10 @@ TEST_F(GraphTransformationTests, FuseCudaConvAdd) { for (auto& node : p_model->MainGraph().Nodes()) { node.SetExecutionProviderType(kJsExecutionProvider); } +#elif defined(USE_WEBGPU) + for (auto& node : p_model->MainGraph().Nodes()) { + node.SetExecutionProviderType(kWebGpuExecutionProvider); + } #else for (auto& node : p_model->MainGraph().Nodes()) { node.SetExecutionProviderType(kCpuExecutionProvider); @@ -2330,6 +2338,10 @@ TEST_F(GraphTransformationTests, FuseConvActivation) { for (auto& node : p_model->MainGraph().Nodes()) { node.SetExecutionProviderType(kJsExecutionProvider); } +#elif defined(USE_WEBGPU) + for (auto& node : p_model->MainGraph().Nodes()) { + node.SetExecutionProviderType(kWebGpuExecutionProvider); + } #else for (auto& node : p_model->MainGraph().Nodes()) { node.SetExecutionProviderType(kCpuExecutionProvider); @@ -2351,6 +2363,13 @@ TEST_F(GraphTransformationTests, FuseConvActivation) { } else { ASSERT_TRUE(op_to_count_after_fusion[model.second] == 0); } +#elif defined(USE_WEBGPU) + std::set webgpu_supported = {"Relu", "Clip", "Sigmoid", "Tanh", "LeakyRelu", "HardSigmoid"}; + if (webgpu_supported.find(model.second) == webgpu_supported.end()) { + ASSERT_EQ(op_to_count_before_fusion[model.second], op_to_count_after_fusion[model.second]); + } else { + ASSERT_TRUE(op_to_count_after_fusion[model.second] == 0); + } #else ASSERT_TRUE(op_to_count_after_fusion[model.second] == 0); #endif diff --git a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc index a3a3dd939cbf0..06434d5b59ec6 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc @@ -489,7 +489,7 @@ TEST(ConvTest, Conv3D_1) { vector{1, 1, 1}, // kernel_shape vector{0, 0, 0, 0, 0, 0}, // pads vector{1, 1, 1}, // strides - {} // excluded EPs + {kWebGpuExecutionProvider} // excluded EPs }; vector X = {-0.43337246775627136f, -0.48385289311408997f, -0.30954962968826294f, @@ -526,7 +526,7 @@ TEST(ConvTest, Conv3D_2) { vector{1, 1, 1}, // kernel_shape vector{2, 2, 2, 2, 2, 2}, // pads vector{2, 2, 2}, // strides - {} // excluded EPs + {kWebGpuExecutionProvider} // excluded EPs }; vector X = {0.010772407054901123f, -0.43806642293930054f, 0.455391526222229f, -0.28657248616218567f, @@ -569,7 +569,7 @@ TEST(ConvTest, Conv3D_Bias) { vector{2, 2, 2}, // kernel_shape vector{2, 2, 2, 2, 2, 2}, // pads vector{2, 2, 2}, // strides - {} // excluded EPs + {kWebGpuExecutionProvider} // excluded EPs }; vector X = {0.46796226501464844f, -0.4613912105560303f, 0.33512794971466064f, -0.4010460674762726f, @@ -916,7 +916,7 @@ TEST(ConvTest, ConvDimWithZero) { vector{1, 1}, // kernel_shape vector{0, 0, 0, 0}, // pads vector{1, 1}, // strides - {} // excluded EPs + {kWebGpuExecutionProvider} // excluded EPs }; vector X = vector(); diff --git a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc index 83b27f10fe04f..198fa07ae4ed0 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc @@ -933,7 +933,7 @@ TEST(ConvTransposeTest, DimWithZero) { TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, - kAclExecutionProvider, kQnnExecutionProvider}); + kAclExecutionProvider, kQnnExecutionProvider, kWebGpuExecutionProvider}); } TEST(ConvTransposeTest, ConvTranspose_3D) { @@ -1068,7 +1068,7 @@ TEST(ConvTransposeTest, ConvTranspose_3D) { TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kCudaExecutionProvider, - kCudaNHWCExecutionProvider, kQnnExecutionProvider}); + kCudaNHWCExecutionProvider, kQnnExecutionProvider, kWebGpuExecutionProvider}); } TEST(ConvTransposeTest, ConvTranspose_1D_AsymmetricPads) { From a7e62d6390ae06265326dc44e74753ffa633cd69 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 4 Apr 2025 16:59:58 -0700 Subject: [PATCH 10/18] [webgpu][dawn API optimization] reduce number of calls to wgpuDeviceGetQueue (#24313) ### Description This PR is one of a series of changes for optimization of Dawn API usage. See #24281 Optimizes the usage of wgpuDeviceGetQueue. --- onnxruntime/core/providers/webgpu/webgpu_context.cc | 6 ++++-- onnxruntime/core/providers/webgpu/webgpu_context.h | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 955b54e873261..d9fe967f3047c 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -136,6 +136,8 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi LOGS_DEFAULT(VERBOSE) << "WebGPU EP Context is created for: Instance=" << instance_.Get() << ", Device=" << device_.Get() << "."; + // cache device queue + device_queue_ = device_.GetQueue(); // cache adapter info ORT_ENFORCE(Device().GetAdapterInfo(&adapter_info_)); // cache device limits @@ -404,7 +406,7 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { } uniform_buffer = buffer_mgr_->Create(uniform_buffer_total_size, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform); - device_.GetQueue().WriteBuffer(uniform_buffer, 0, uniform_data_buffer.data(), uniform_buffer_total_size); + device_queue_.WriteBuffer(uniform_buffer, 0, uniform_data_buffer.data(), uniform_buffer_total_size); } const auto& compute_pass_encoder = GetComputePassEncoder(); @@ -696,7 +698,7 @@ void WebGpuContext::Flush() { } auto command_buffer = current_command_encoder_.Finish(); - Device().GetQueue().Submit(1, &command_buffer); + device_queue_.Submit(1, &command_buffer); BufferManager().RefreshPendingBuffers(); current_command_encoder_ = nullptr; num_pending_dispatches_ = 0; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index 2f044400afee2..0a54b13e31bf7 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -207,6 +207,7 @@ class WebGpuContext final { webgpu::ValidationMode validation_mode_; + wgpu::Queue device_queue_; wgpu::AdapterInfo adapter_info_; wgpu::Limits device_limits_; std::unordered_set device_features_; From 55c1a3b073b50435a15441ffd856a7056be7c336 Mon Sep 17 00:00:00 2001 From: virajwad <84867530+virajwad@users.noreply.github.com> Date: Fri, 4 Apr 2025 18:14:55 -0700 Subject: [PATCH 11/18] Fix 'minimal_power' to 'minimum_power' for DirectML performance selection (perf test) (#24303) In Perf Test for DirectML EP, for the "performance_preference" runtime key we could not select the "minimum_power" value option due to a small bug. This PR fixes it so that "minimum_power" can be used and ran. I will also link the respective issue to this PR I made the change, built onnxruntime, and tested the perf_test.exe + DLLs on a system with Intel Integrated Graphics + Nvidia dGPU. Switching between 'minimum_power' and 'high_performance', I can see the options respectively choose Intel Integrated and Nvidia dGPU as device runtimes respectively (I checked task manager utilization for both devices). Both inferences complete with no problems. I am attaching a reproducer here with the built perf_test and the commands I tried to test it: [DLL_Build_DML_Reproducer.zip](https://github.com/user-attachments/files/19596463/DLL_Build_DML_Reproducer.zip) Issue #24182 @fdwr Hi, I fixed the issue, if you could please review, thank you --- onnxruntime/test/perftest/ort_test_session.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 86d04fb4bbc2b..1cc17ea03fa32 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -421,12 +421,12 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); "Select from 'gpu', or 'npu' \n"); } } else if (key == "performance_preference") { - std::set ov_supported_values = {"default", "high_performance", "minimal_power"}; + std::set ov_supported_values = {"default", "high_performance", "minimum_power"}; if (ov_supported_values.find(value) != ov_supported_values.end()) { } else { ORT_THROW( "[ERROR] [DML] You have selected a wrong configuration value for the key 'performance_preference'. " - "Select from 'default', 'high_performance' or 'minimal_power' \n"); + "Select from 'default', 'high_performance' or 'minimum_power' \n"); } } else if (key == "disable_metacommands") { std::set ov_supported_values = {"true", "True", "false", "False"}; From d6df4f29d32ab0632b7b1522aa4f117c36993d7d Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Fri, 4 Apr 2025 23:58:31 -0700 Subject: [PATCH 12/18] Add ConvTranspose cache key (#24317) ### Description ### Motivation and Context --- onnxruntime/core/providers/webgpu/nn/conv_backprop_webgpu.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxruntime/core/providers/webgpu/nn/conv_backprop_webgpu.cc b/onnxruntime/core/providers/webgpu/nn/conv_backprop_webgpu.cc index aa3ef5b96db54..74f3e0dcc85f5 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv_backprop_webgpu.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv_backprop_webgpu.cc @@ -181,6 +181,8 @@ ConvTranspose2DProgram CreateConvTranspose2DProgram(const std::vector(output_size)}, {strides}, {kernel_dims}, {dilations}, {effective_kernel_dims}, {local_pads}, {static_cast(input_channels_per_group_int)}, {static_cast(input_channels_per_group)}, {static_cast(output_channels_per_group)}}) .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); From a1186f634188ba17ab22572c031ee64d9178a675 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Mon, 7 Apr 2025 21:36:18 +0800 Subject: [PATCH 13/18] [webgpu] Use 1D dispatch groups for attention (#24228) This PR uses 1d disptach group size and uses workgroup_idx instead of workgroup.x|workgroup.y in case they are normalized. --- .../contrib_ops/webgpu/bert/attention.cc | 70 ++++++++++--------- .../contrib_ops/webgpu/bert/attention.h | 8 ++- 2 files changed, 44 insertions(+), 34 deletions(-) diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index abea94d2e0b50..6e7919f281fb6 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -99,23 +99,24 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << "var tileK: array;\n" << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; shader.MainFunctionBody() << "// x holds the N and y holds the M\n" - << "let m = workgroup_id.y * TILE_SIZE;\n" - << "let n = workgroup_id.x * TILE_SIZE;\n" - << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" - << "let qOffset = workgroup_id.z * uniforms.M * uniforms.K + m * uniforms.K;\n" + << "let m = u32(workgroup_idx / uniforms.num_total_seq_length_tile) % uniforms.num_seq_length_tile * TILE_SIZE;\n" + << "let n = (workgroup_idx % uniforms.num_total_seq_length_tile) * TILE_SIZE;\n" + << "let batch_head_idx = u32(workgroup_idx / (uniforms.num_total_seq_length_tile * uniforms.num_seq_length_tile));\n" + << "let batch_idx = batch_head_idx / uniforms.num_heads;\n" + << "let qOffset = batch_head_idx * uniforms.M * uniforms.K + m * uniforms.K;\n" << "let sequence_length = uniforms.M;\n" << "var total_sequence_length = uniforms.N;\n"; std::ostringstream oss; InitVarStub(oss, seqlen_k_); shader.MainFunctionBody() << oss.str(); - shader.MainFunctionBody() << "let kOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.kv_sequence_length * uniforms.K;\n"; + shader.MainFunctionBody() << "let kOffset = (batch_head_idx / uniforms.n_reps) * uniforms.kv_sequence_length * uniforms.K;\n"; if (has_present_key_) { - shader.MainFunctionBody() << "let presentKeyOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.present_sequence_length * uniforms.K;\n"; + shader.MainFunctionBody() << "let presentKeyOffset = (batch_head_idx / uniforms.n_reps) * uniforms.present_sequence_length * uniforms.K;\n"; } shader.MainFunctionBody() << "var value = f32_val_t(0);\n" "for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {\n" - " if (global_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n" + " if (m + local_id.y < uniforms.M && w + local_id.x < uniforms.K) {\n" " tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];\n" " }\n" " if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {\n" @@ -123,7 +124,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { if ((feed_past_key_ && has_present_key_) || (past_present_share_buffer_ && !is_first_prompt_)) { shader.MainFunctionBody() << " if (n + local_id.y < past_sequence_length) {\n" - << " let pastKeyOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.past_sequence_length * uniforms.K;\n" + << " let pastKeyOffset = (batch_head_idx / uniforms.n_reps) * uniforms.past_sequence_length * uniforms.K;\n" << " tileK[idx] = " << (past_present_share_buffer_ ? "present_key" : "past_key") << "[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];\n" << " } else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" << " tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];\n" @@ -152,9 +153,9 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { << " workgroupBarrier();\n" << "}\n"; - shader.MainFunctionBody() << "if (global_id.y < uniforms.M && global_id.x < total_sequence_length) {\n" - << " let headOffset = workgroup_id.z * uniforms.M * uniforms.N;\n" - << " let outputIdx = headOffset + global_id.y * uniforms.N + global_id.x;\n" + shader.MainFunctionBody() << "if (m + local_id.y < uniforms.M && n + local_id.x < total_sequence_length) {\n" + << " let headOffset = batch_head_idx * uniforms.M * uniforms.N;\n" + << " let outputIdx = headOffset + m + local_id.y * uniforms.N + n + local_id.x;\n" << " var sum: f32 = " << (components_ == 4 ? "value.x + value.y + value.z + value.w" : (components_ == 2 ? "value.x + value.y" : "value")) << ";\n"; shader.MainFunctionBody() << " output[outputIdx] = output_value_t(sum * uniforms.alpha)"; @@ -199,9 +200,9 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o } const uint32_t vectorized_head_size = (parameters.head_size_ + components - 1) / components; - program.SetDispatchGroupSize((total_sequence_length + tile_size - 1) / tile_size, - (parameters.sequence_length_ + tile_size - 1) / tile_size, - parameters.batch_size_ * parameters.num_heads_) + const uint32_t num_total_seq_length_tile = (total_sequence_length + tile_size - 1) / tile_size; + const uint32_t num_seq_length_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size; + program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_seq_length_tile * num_total_seq_length_tile) .SetWorkgroupSize(tile_size, tile_size) .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_key, has_present_key, has_attention_bias, seqlen_k != nullptr, components, parameters.is_first_prompt_) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, @@ -214,7 +215,9 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o {static_cast(parameters.kv_sequence_length_)}, {static_cast(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)}, {static_cast(parameters.n_reps)}, - {static_cast(parameters.is_first_prompt_ ? 1 : 0)}}) + {static_cast(parameters.is_first_prompt_ ? 1 : 0)}, + {num_total_seq_length_tile}, + {num_seq_length_tile}}) .SetOverridableConstants({{static_cast(tile_size)}}); return context.RunProgram(program); @@ -228,15 +231,15 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AdditionalImplementation() << "var thread_max: array;\n" << "var thread_sum: array;\n" << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; - shader.MainFunctionBody() << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" - << "let sequence_length = uniforms.sequence_length;\n" + shader.MainFunctionBody() << "let sequence_length = uniforms.sequence_length;\n" + << "let batch_idx = u32(workgroup_idx / sequence_length) / uniforms.num_heads;\n" << "var total_sequence_length = uniforms.total_sequence_length_comp * " << components_ << ";\n"; std::ostringstream oss; InitVarStub(oss, seqlen_k_); shader.MainFunctionBody() << oss.str() << "let local_offset = local_idx * uniforms.elements_per_thread;\n" - << "let offset = (global_idx / " << work_group_size_ << ") * uniforms.total_sequence_length_comp + local_offset;\n" - << "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_id.y + 1" : "uniforms.total_sequence_length_comp") << ";\n" + << "let offset = workgroup_idx * uniforms.total_sequence_length_comp + local_offset;\n" + << "let seq_causal_length = " << (seqlen_k_ ? "past_sequence_length + workgroup_idx % sequence_length + 1" : "uniforms.total_sequence_length_comp") << ";\n" << "var thread_max_vector = f32_val_t(-3.402823e+38f);\n" << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" << " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n" @@ -292,7 +295,7 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso } program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}}) .CacheHint(work_group_size) - .SetDispatchGroupSize(1, sequence_length, batch_size * num_heads) + .SetDispatchGroupSize(batch_size * num_heads * sequence_length) .SetWorkgroupSize(work_group_size) .AddUniformVariables({{static_cast(batch_size)}, {static_cast(num_heads)}, @@ -321,19 +324,20 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AdditionalImplementation() << "var tileQ: array;\n" << "var tileK: array;\n"; - shader.MainFunctionBody() << "let head_idx = workgroup_id.z % uniforms.num_heads;\n" - << "let batch_idx = workgroup_id.z / uniforms.num_heads;\n" - << "let m = global_id.y;\n" - << "let n = global_id.x;\n" - << "let offsetA = workgroup_id.z * (uniforms.M * uniforms.K) + m * uniforms.K;\n" + shader.MainFunctionBody() << "let batch_head_idx = u32(workgroup_idx / (uniforms.num_head_size_tile * uniforms.num_seq_length_tile));\n" + << "let head_idx = batch_head_idx % uniforms.num_heads;\n" + << "let batch_idx = batch_head_idx / uniforms.num_heads;\n" + << "let m = (u32(workgroup_idx / uniforms.num_head_size_tile) % uniforms.num_seq_length_tile) * TILE_SIZE + local_id.y;\n" + << "let n = (workgroup_idx % uniforms.num_head_size_tile) * TILE_SIZE + local_id.x;\n" + << "let offsetA = batch_head_idx * (uniforms.M * uniforms.K) + m * uniforms.K;\n" << "let sequence_length = uniforms.M;\n" << "var total_sequence_length = uniforms.K;\n"; std::ostringstream oss; InitVarStub(oss, seqlen_k_); shader.MainFunctionBody() << oss.str(); - shader.MainFunctionBody() << "let vOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.N * uniforms.kv_sequence_length + n;\n"; + shader.MainFunctionBody() << "let vOffset = (batch_head_idx / uniforms.n_reps) * uniforms.N * uniforms.kv_sequence_length + n;\n"; if (has_present_value_) { - shader.MainFunctionBody() << "let presentValueOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.N * uniforms.present_sequence_length + n;\n"; + shader.MainFunctionBody() << "let presentValueOffset = (batch_head_idx / uniforms.n_reps) * uniforms.N * uniforms.present_sequence_length + n;\n"; } shader.MainFunctionBody() << "var value = output_value_t(0);\n" @@ -346,7 +350,7 @@ Status VxAttentionScoreProgram::GenerateShaderCode(ShaderHelper& shader) const { if ((feed_past_value_ && has_present_value_) || (past_present_share_buffer_ && !is_first_prompt_)) { shader.MainFunctionBody() << " if (w + local_id.y < past_sequence_length) {\n" - << " let pastValueOffset = (workgroup_id.z / uniforms.n_reps) * uniforms.N * uniforms.past_sequence_length + n;\n" + << " let pastValueOffset = (batch_head_idx / uniforms.n_reps) * uniforms.N * uniforms.past_sequence_length + n;\n" << " tileK[idx] = " << (past_present_share_buffer_ ? "present_value" : "past_value") << "[pastValueOffset + (w + local_id.y) * uniforms.N];\n" << " } else if (w + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {\n" << " tileK[idx] = v[vOffset + (w + local_id.y - past_sequence_length) * uniforms.N];\n" @@ -414,9 +418,9 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int program.AddOutput({present_value, ProgramTensorMetadataDependency::TypeAndRank, components}); } - program.SetDispatchGroupSize((parameters.v_head_size_ + tile_n_size - 1) / tile_n_size, - (parameters.sequence_length_ + tile_size - 1) / tile_size, - parameters.batch_size_ * parameters.num_heads_) + const uint32_t num_head_size_tile = (parameters.v_head_size_ + tile_n_size - 1) / tile_n_size; + const uint32_t num_seq_length_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size; + program.SetDispatchGroupSize(parameters.batch_size_ * parameters.num_heads_ * num_head_size_tile * num_seq_length_tile) .CacheHint(std::to_string(tile_size), parameters.past_present_share_buffer_, feed_past_value, has_present_value, seqlen_k != nullptr, parameters.is_first_prompt_) .SetWorkgroupSize(tile_size, tile_size) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, @@ -429,7 +433,9 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int {static_cast(parameters.kv_sequence_length_)}, {static_cast(seqlen_k == nullptr ? total_sequence_length : parameters.seqlen_present_kv_cache_)}, {static_cast(parameters.n_reps)}, - {static_cast(parameters.is_first_prompt_)}}) + {static_cast(parameters.is_first_prompt_)}, + {num_head_size_tile}, + {num_seq_length_tile}}) .SetOverridableConstants({{static_cast(tile_size)}}); return context.RunProgram(program); diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index 6123d2c47add1..7c0cb40cc7f93 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -50,7 +50,9 @@ class AttentionProbsProgram final : public Program { {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, {"n_reps", ProgramUniformVariableDataType::Uint32}, - {"is_first_prompt", ProgramUniformVariableDataType::Uint32}); + {"is_first_prompt", ProgramUniformVariableDataType::Uint32}, + {"num_total_seq_length_tile", ProgramUniformVariableDataType::Uint32}, + {"num_seq_length_tile", ProgramUniformVariableDataType::Uint32}); WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); @@ -105,7 +107,9 @@ class VxAttentionScoreProgram final : public Program { {"kv_sequence_length", ProgramUniformVariableDataType::Uint32}, {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, {"n_reps", ProgramUniformVariableDataType::Uint32}, - {"is_first_prompt", ProgramUniformVariableDataType::Uint32}); + {"is_first_prompt", ProgramUniformVariableDataType::Uint32}, + {"num_head_size_tile", ProgramUniformVariableDataType::Uint32}, + {"num_seq_length_tile", ProgramUniformVariableDataType::Uint32}); WEBGPU_PROGRAM_DEFINE_OVERRIDABLE_CONSTANTS({"TILE_SIZE", ProgramConstantDataType::Uint32}); From 73676fc5eb1c5ddeec0c5be321936d34aeb7452a Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 7 Apr 2025 06:37:00 -0700 Subject: [PATCH 14/18] [webgpu][dawn API optimization] reduce number of calls to buffer APIs (#24315) ### Description This PR is one of a series of changes for optimization of Dawn API usage. See https://github.com/microsoft/onnxruntime/pull/24281 Reduce the calls to wgpuBufferAddRef and wgpuBufferRelease (part 1). --- .../core/providers/webgpu/allocator.cc | 14 +-- .../core/providers/webgpu/buffer_manager.cc | 102 ++++++++++++------ .../core/providers/webgpu/buffer_manager.h | 2 +- .../core/providers/webgpu/webgpu_context.cc | 4 - .../core/providers/webgpu/webgpu_context.h | 3 - 5 files changed, 80 insertions(+), 45 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/allocator.cc b/onnxruntime/core/providers/webgpu/allocator.cc index 91cae111a708a..315d0cd75e946 100644 --- a/onnxruntime/core/providers/webgpu/allocator.cc +++ b/onnxruntime/core/providers/webgpu/allocator.cc @@ -13,15 +13,15 @@ void* GpuBufferAllocator::Alloc(size_t size) { return nullptr; } - WGPUBuffer buffer; - if (!session_initialized_ && context_.SupportsBufferMapExtendedUsages()) { - buffer = context_.BufferManager().CreateUMA(size); - } else { - buffer = context_.BufferManager().Create(size); + stats_.num_allocs++; + +#if !defined(__wasm__) + if (!session_initialized_ && context_.DeviceHasFeature(wgpu::FeatureName::BufferMapExtendedUsages)) { + return context_.BufferManager().CreateUMA(size); } +#endif // !defined(__wasm__) - stats_.num_allocs++; - return buffer; + return context_.BufferManager().Create(size); } void GpuBufferAllocator::Free(void* p) { diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.cc b/onnxruntime/core/providers/webgpu/buffer_manager.cc index adb37f54f2e8f..1d8c689cbd909 100644 --- a/onnxruntime/core/providers/webgpu/buffer_manager.cc +++ b/onnxruntime/core/providers/webgpu/buffer_manager.cc @@ -56,14 +56,27 @@ class LazyReleaseCacheManager : public IBufferCacheManager { } void ReleaseBuffer(WGPUBuffer buffer) override { - pending_buffers_.emplace_back(wgpu::Buffer::Acquire(buffer)); + pending_buffers_.emplace_back(buffer); } void OnRefresh() override { + Release(); pending_buffers_.clear(); } - std::vector pending_buffers_; + public: + ~LazyReleaseCacheManager() { + Release(); + } + + protected: + void Release() { + for (auto& buffer : pending_buffers_) { + wgpuBufferRelease(buffer); + } + } + + std::vector pending_buffers_; }; class SimpleCacheManager : public IBufferCacheManager { @@ -74,7 +87,7 @@ class SimpleCacheManager : public IBufferCacheManager { WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) override { auto it = buffers_.find(buffer_size); if (it != buffers_.end() && !it->second.empty()) { - auto buffer = it->second.back().MoveToCHandle(); + auto buffer = it->second.back(); it->second.pop_back(); return buffer; } @@ -87,18 +100,31 @@ class SimpleCacheManager : public IBufferCacheManager { } void ReleaseBuffer(WGPUBuffer buffer) override { - pending_buffers_.emplace_back(wgpu::Buffer::Acquire(buffer)); + pending_buffers_.emplace_back(buffer); } void OnRefresh() override { for (auto& buffer : pending_buffers_) { - buffers_[static_cast(buffer.GetSize())].emplace_back(std::move(buffer)); + buffers_[static_cast(wgpuBufferGetSize(buffer))].emplace_back(buffer); } pending_buffers_.clear(); } - std::map> buffers_; - std::vector pending_buffers_; + public: + ~SimpleCacheManager() { + for (auto& buffer : pending_buffers_) { + wgpuBufferRelease(buffer); + } + for (auto& pair : buffers_) { + for (auto& buffer : pair.second) { + wgpuBufferRelease(buffer); + } + } + } + + protected: + std::map> buffers_; + std::vector pending_buffers_; }; // TODO: maybe use different bucket size for storage and uniform buffers? @@ -155,7 +181,7 @@ class BucketCacheManager : public IBufferCacheManager { WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size) override { auto it = buckets_.find(buffer_size); if (it != buckets_.end() && !it->second.empty()) { - auto buffer = it->second.back().MoveToCHandle(); + auto buffer = it->second.back(); it->second.pop_back(); return buffer; } @@ -167,31 +193,44 @@ class BucketCacheManager : public IBufferCacheManager { } void ReleaseBuffer(WGPUBuffer buffer) override { - pending_buffers_.emplace_back(wgpu::Buffer::Acquire(buffer)); + pending_buffers_.emplace_back(buffer); } void OnRefresh() override { // TODO: consider graph capture. currently not supported for (auto& buffer : pending_buffers_) { - auto buffer_size = static_cast(buffer.GetSize()); + auto buffer_size = static_cast(wgpuBufferGetSize(buffer)); auto it = buckets_.find(buffer_size); if (it != buckets_.end() && it->second.size() < buckets_limit_[buffer_size]) { - it->second.emplace_back(std::move(buffer)); + it->second.emplace_back(buffer); + } else { + wgpuBufferRelease(buffer); } } pending_buffers_.clear(); } + ~BucketCacheManager() { + for (auto& buffer : pending_buffers_) { + wgpuBufferRelease(buffer); + } + for (auto& pair : buckets_) { + for (auto& buffer : pair.second) { + wgpuBufferRelease(buffer); + } + } + } + protected: void Initialize() { buckets_keys_.reserve(buckets_limit_.size()); buckets_.reserve(buckets_limit_.size()); for (const auto& pair : buckets_limit_) { buckets_keys_.push_back(pair.first); - buckets_.emplace(pair.first, std::vector()); + buckets_.emplace(pair.first, std::vector()); } std::sort(buckets_keys_.begin(), buckets_keys_.end()); @@ -205,8 +244,8 @@ class BucketCacheManager : public IBufferCacheManager { #endif } std::unordered_map buckets_limit_; - std::unordered_map> buckets_; - std::vector pending_buffers_; + std::unordered_map> buckets_; + std::vector pending_buffers_; std::vector buckets_keys_; }; @@ -255,11 +294,10 @@ BufferManager::BufferManager(WebGpuContext& context, BufferCacheMode storage_buf void BufferManager::Upload(void* src, WGPUBuffer dst, size_t size) { // If the buffer is mapped, we can directly write to it. - wgpu::Buffer dst_buffer = dst; - auto mapped_data = dst_buffer.GetMappedRange(); + void* mapped_data = wgpuBufferGetMappedRange(dst, 0, WGPU_WHOLE_MAP_SIZE); // ensure the buffer is mapped if (mapped_data) { memcpy(mapped_data, src, size); - dst_buffer.Unmap(); + wgpuBufferUnmap(dst); return; } @@ -288,9 +326,11 @@ void BufferManager::MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) { EnforceBufferUnmapped(context_, dst); auto buffer_size = NormalizeBufferSize(size); - ORT_ENFORCE(buffer_size <= wgpuBufferGetSize(src) && buffer_size <= wgpuBufferGetSize(dst), + auto src_size = static_cast(wgpuBufferGetSize(src)); + auto dst_size = static_cast(wgpuBufferGetSize(dst)); + ORT_ENFORCE(buffer_size <= src_size && buffer_size <= dst_size, "Source and destination buffers must have enough space for the copy operation. src_size=", - wgpuBufferGetSize(src), ", dst_size=", wgpuBufferGetSize(dst), ", copy_size=", buffer_size, "."); + src_size, ", dst_size=", dst_size, ", copy_size=", buffer_size, "."); auto& command_encoder = context_.GetCommandEncoder(); context_.EndComputePass(); @@ -298,7 +338,7 @@ void BufferManager::MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) { } WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) { - auto& cache = GetCacheManager(static_cast(usage)); + auto& cache = GetCacheManager(usage); auto buffer_size = cache.CalculateBufferSize(size); auto buffer = cache.TryAcquireCachedBuffer(buffer_size); @@ -310,7 +350,6 @@ WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) { wgpu::BufferDescriptor desc{}; desc.size = buffer_size; desc.usage = usage; - // desc.label = std::to_string(xx++).c_str(); buffer = context_.Device().CreateBuffer(&desc).MoveToCHandle(); ORT_ENFORCE(buffer, "Failed to create GPU buffer: size=", buffer_size, ", usage=", uint64_t(usage), "."); @@ -320,14 +359,16 @@ WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) { } WGPUBuffer BufferManager::CreateUMA(size_t size, wgpu::BufferUsage usage) { - ORT_ENFORCE(usage & wgpu::BufferUsage::Storage, "UMA buffer must have storage usage."); - auto& cache = GetCacheManager(static_cast(usage)); + ORT_ENFORCE(usage & wgpu::BufferUsage::Storage, "UMA buffer must be a storage buffer."); + auto& cache = GetCacheManager(usage); auto buffer_size = cache.CalculateBufferSize(size); + // Ensure the buffer is mapped for writing at creation. + usage |= wgpu::BufferUsage::MapWrite; + wgpu::BufferDescriptor desc{}; desc.size = buffer_size; - // Ensure the buffer is mapped for writing at creation. - desc.usage = usage | wgpu::BufferUsage::MapWrite; + desc.usage = usage; desc.mappedAtCreation = true; auto buffer = context_.Device().CreateBuffer(&desc).MoveToCHandle(); @@ -373,12 +414,12 @@ void BufferManager::RefreshPendingBuffers() { default_cache_->OnRefresh(); } -IBufferCacheManager& BufferManager::GetCacheManager(WGPUBufferUsage usage) const { - if (usage & WGPUBufferUsage_Storage) { +IBufferCacheManager& BufferManager::GetCacheManager(wgpu::BufferUsage usage) const { + if (usage & wgpu::BufferUsage::Storage) { return *storage_cache_; - } else if (usage & WGPUBufferUsage_Uniform) { + } else if (usage & wgpu::BufferUsage::Uniform) { return *uniform_cache_; - } else if (usage & WGPUBufferUsage_QueryResolve) { + } else if (usage & wgpu::BufferUsage::QueryResolve) { return *query_resolve_cache_; } else { return *default_cache_; @@ -386,7 +427,8 @@ IBufferCacheManager& BufferManager::GetCacheManager(WGPUBufferUsage usage) const } IBufferCacheManager& BufferManager::GetCacheManager(WGPUBuffer buffer) const { - return GetCacheManager(wgpuBufferGetUsage(buffer)); + auto usage = static_cast(wgpuBufferGetUsage(buffer)); + return GetCacheManager(usage); } std::unique_ptr BufferManagerFactory::Create(WebGpuContext& context, BufferCacheMode storage_buffer_cache_mode, BufferCacheMode uniform_buffer_cache_mode, BufferCacheMode query_resolve_buffer_cache_mode) { diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.h b/onnxruntime/core/providers/webgpu/buffer_manager.h index 6a8ebdd60a1ec..b9028ad5de858 100644 --- a/onnxruntime/core/providers/webgpu/buffer_manager.h +++ b/onnxruntime/core/providers/webgpu/buffer_manager.h @@ -70,7 +70,7 @@ class BufferManager { void RefreshPendingBuffers(); private: - IBufferCacheManager& GetCacheManager(WGPUBufferUsage usage) const; + IBufferCacheManager& GetCacheManager(wgpu::BufferUsage usage) const; IBufferCacheManager& GetCacheManager(WGPUBuffer buffer) const; WebGpuContext& context_; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index d9fe967f3047c..2987d3905fe54 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -149,10 +149,6 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi device_features_.insert(supported_features.features[i]); } -#if !defined(__wasm__) - supports_buffer_map_extended_usages_ = device_.HasFeature(wgpu::FeatureName::BufferMapExtendedUsages); -#endif - // create buffer manager buffer_mgr_ = BufferManagerFactory::Create(*this, buffer_cache_config.storage.mode, diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index 0a54b13e31bf7..8ebb122103177 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -145,8 +145,6 @@ class WebGpuContext final { Status Run(ComputeContext& context, const ProgramBase& program); void OnRunEnd(); - bool SupportsBufferMapExtendedUsages() const { return supports_buffer_map_extended_usages_; } - private: enum class TimestampQueryType { None = 0, @@ -238,7 +236,6 @@ class WebGpuContext final { #if defined(ENABLE_PIX_FOR_WEBGPU_EP) std::unique_ptr pix_frame_generator_ = nullptr; #endif // ENABLE_PIX_FOR_WEBGPU_EP - bool supports_buffer_map_extended_usages_ = false; }; } // namespace webgpu From 350d140074d9c964b42d5718d47132c9889af885 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 7 Apr 2025 10:22:37 -0700 Subject: [PATCH 15/18] Implement load cancellation ability (#24257) ### Description SessionOptions now have a new property - load_cancelation_flag. This flag if set to true causes the model to abort load and initialization for huge models. ### Motivation and Context Some users request an ability to abandon model loading and initialization if that exceeds certain time limits. --- .../NativeMethods.shared.cs | 18 ++++++ .../SessionOptions.shared.cs | 10 ++++ include/onnxruntime/core/common/common.h | 32 +++++++++- include/onnxruntime/core/common/exceptions.h | 47 ++++++++++++++- include/onnxruntime/core/common/status.h | 7 ++- .../core/session/onnxruntime_c_api.h | 19 ++++++ .../core/session/onnxruntime_cxx_api.h | 2 + .../core/session/onnxruntime_cxx_inline.h | 6 ++ .../core/framework/error_code_helper.h | 23 ++++---- .../core/framework/graph_partitioner.cc | 59 ++++++++++++++++--- .../core/framework/graph_partitioner.h | 15 +++++ onnxruntime/core/framework/session_options.h | 15 +++++ onnxruntime/core/framework/session_state.cc | 9 +++ .../core/framework/session_state_utils.cc | 6 ++ onnxruntime/core/graph/graph.cc | 8 +++ onnxruntime/core/graph/model.cc | 19 +++++- onnxruntime/core/graph/model.h | 16 +++++ .../core/optimizer/graph_transformer_mgr.cc | 3 + .../core/optimizer/graph_transformer_mgr.h | 11 ++++ .../core/session/abi_session_options.cc | 8 +++ onnxruntime/core/session/inference_session.cc | 35 ++++++++--- onnxruntime/core/session/inference_session.h | 4 ++ onnxruntime/core/session/onnxruntime_c_api.cc | 40 +++---------- onnxruntime/core/session/ort_apis.h | 3 + .../python/onnxruntime_pybind_state.cc | 6 ++ .../test/framework/inference_session_test.cc | 25 ++++++++ .../providers/cpu/controlflow/loop_test.cc | 5 +- onnxruntime/test/shared_lib/test_inference.cc | 29 +++++++++ 28 files changed, 415 insertions(+), 65 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index b64a5c3e5a4a2..77c35aac65b92 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -3,6 +3,7 @@ using System; using System.Runtime.InteropServices; +using static Microsoft.ML.OnnxRuntime.NativeMethods; namespace Microsoft.ML.OnnxRuntime { @@ -325,6 +326,16 @@ public struct OrtApi public IntPtr CreateLoraAdapterFromArray; public IntPtr ReleaseLoraAdapter; public IntPtr RunOptionsAddActiveLoraAdapter; + public IntPtr SetEpDynamicOptions; + public IntPtr ReleaseValueInfo; + public IntPtr ReleaseNode; + public IntPtr ReleaseGraph; + public IntPtr ReleaseModel; + public IntPtr GetValueInfoName; + public IntPtr GetValueInfoTypeInfo; + public IntPtr GetModelEditorApi; + public IntPtr CreateTensorWithDataAndDeleterAsOrtValue; + public IntPtr SessionOptionsSetLoadCancellationFlag; } internal static class NativeMethods @@ -404,6 +415,7 @@ static NativeMethods() OrtReleaseSessionOptions = (DOrtReleaseSessionOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseSessionOptions, typeof(DOrtReleaseSessionOptions)); OrtCloneSessionOptions = (DOrtCloneSessionOptions)Marshal.GetDelegateForFunctionPointer(api_.CloneSessionOptions, typeof(DOrtCloneSessionOptions)); OrtSetSessionExecutionMode = (DOrtSetSessionExecutionMode)Marshal.GetDelegateForFunctionPointer(api_.SetSessionExecutionMode, typeof(DOrtSetSessionExecutionMode)); + OrtSessionOptionsSetLoadCancellationFlag = (DOrtSessionOptionsSetLoadCancellationFlag)Marshal.GetDelegateForFunctionPointer(api_.SessionOptionsSetLoadCancellationFlag, typeof(DOrtSessionOptionsSetLoadCancellationFlag)); OrtSetOptimizedModelFilePath = (DOrtSetOptimizedModelFilePath)Marshal.GetDelegateForFunctionPointer(api_.SetOptimizedModelFilePath, typeof(DOrtSetOptimizedModelFilePath)); OrtEnableProfiling = (DOrtEnableProfiling)Marshal.GetDelegateForFunctionPointer(api_.EnableProfiling, typeof(DOrtEnableProfiling)); OrtDisableProfiling = (DOrtDisableProfiling)Marshal.GetDelegateForFunctionPointer(api_.DisableProfiling, typeof(DOrtDisableProfiling)); @@ -1025,6 +1037,12 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca ExecutionMode execution_mode); public static DOrtSetSessionExecutionMode OrtSetSessionExecutionMode; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionOptionsSetLoadCancellationFlag(IntPtr /*(OrtSessionOptions*)*/ options, + bool value); + public static DOrtSessionOptionsSetLoadCancellationFlag OrtSessionOptionsSetLoadCancellationFlag; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtSetOptimizedModelFilePath(IntPtr /* OrtSessionOptions* */ options, byte[] optimizedModelFilepath); public static DOrtSetOptimizedModelFilePath OrtSetOptimizedModelFilePath; diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs index bd450451a1265..9b0f183f03681 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs @@ -802,6 +802,16 @@ public ExecutionMode ExecutionMode } private ExecutionMode _executionMode = ExecutionMode.ORT_SEQUENTIAL; + /// + /// Sets the load cancellation flag for the session. Default is set to false. + /// Provides an opportunity for the user to cancel model loading. + /// + /// true to request cancellation, false to proceed + public void SetLoadCancellationFlag(bool value) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsSetLoadCancellationFlag(handle, value)); + } + #endregion #region Private Methods diff --git a/include/onnxruntime/core/common/common.h b/include/onnxruntime/core/common/common.h index 0822eba950f50..10f658f52e0d9 100644 --- a/include/onnxruntime/core/common/common.h +++ b/include/onnxruntime/core/common/common.h @@ -148,6 +148,26 @@ void LogRuntimeError(uint32_t session_id, const common::Status& status, const ch abort(); \ } while (false) +#define ORT_THROW_FROM_STATUS(status) \ + do { \ + ::onnxruntime::PrintFinalMessage( \ + ::onnxruntime::OnnxRuntimeException( \ + ORT_WHERE_WITH_STACK, status.ToString()) \ + .what()); \ + abort(); \ + } while (false) + +#define ORT_THROW_WITH_CATEGORY_AND_CODE(category, code, ...) \ + do { \ + ::onnxruntime::PrintFinalMessage( \ + ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, \ + ::onnxruntime::MakeString(__VA_ARGS__), \ + ::onnxruntime::common::category, \ + ::onnxruntime::common::code) \ + .what()); \ + abort(); \ + } while (false) + #else #define ORT_TRY try @@ -180,6 +200,16 @@ void LogRuntimeError(uint32_t session_id, const common::Status& status, const ch #define ORT_THROW_EX(ex, ...) \ throw ex(__VA_ARGS__) +#define ORT_THROW_FROM_STATUS(status) \ + throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, status.ToString(), status.Category(), \ + static_cast<::onnxruntime::common::StatusCode>(status.Code())) + +#define ORT_THROW_WITH_CATEGORY_AND_CODE(category, code, ...) \ + throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, \ + ::onnxruntime::MakeString(__VA_ARGS__), \ + ::onnxruntime::common::category, \ + ::onnxruntime::common::code) + #endif #define ORT_MAKE_STATUS(category, code, ...) \ @@ -237,7 +267,7 @@ void LogRuntimeError(uint32_t session_id, const common::Status& status, const ch auto _status = (expr); \ if ((!_status.IsOK())) { \ ::onnxruntime::LogRuntimeError(0, _status, __FILE__, static_cast(__FUNCTION__), __LINE__); \ - ORT_THROW(_status); \ + ORT_THROW_FROM_STATUS(_status); \ } \ } while (0) diff --git a/include/onnxruntime/core/common/exceptions.h b/include/onnxruntime/core/common/exceptions.h index 494a770b8db98..6d0f6edd6e7c4 100644 --- a/include/onnxruntime/core/common/exceptions.h +++ b/include/onnxruntime/core/common/exceptions.h @@ -11,6 +11,7 @@ #include #include "core/common/common.h" +#include "core/common/status.h" #include "core/common/code_location.h" namespace onnxruntime { @@ -35,12 +36,44 @@ class OnnxRuntimeException : public std::exception { /** Create a new exception that captures the location it was thrown from. @param location Location in the source code the exception is being thrown from + @param msg Message containing additional information about the exception cause. + @param category Error category + @param code Error code + */ + + OnnxRuntimeException(const CodeLocation& location, + const std::string& message, + common::StatusCategory category, + common::StatusCode code) noexcept + : OnnxRuntimeException(location, nullptr, message, category, code) { + } + + /** + Create a new exception that captures the location it was thrown from. + The instance will be created with ONNXRUNTIME category and FAIL code. + @param location Location in the source code the exception is being thrown from @param failed_condition Optional string containing the condition that failed. e.g. "tensor.Size() == input.Size()". May be nullptr. @param msg Message containing additional information about the exception cause. */ - OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg) - : location_{location} { + OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg) noexcept + : OnnxRuntimeException(location, failed_condition, msg, + common::StatusCategory::ONNXRUNTIME, common::StatusCode::FAIL) { + } + + /** + Create a new exception that captures the location it was thrown from. + @param location Location in the source code the exception is being thrown from + @param failed_condition Optional string containing the condition that failed. + e.g. "tensor.Size() == input.Size()". May be nullptr. + @param msg Message containing additional information about the exception cause. + @param category Error category + @param code Error code + */ + OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg, + common::StatusCategory category, + common::StatusCode code) + : location_{location}, category_(category), code_(code) { std::ostringstream ss; ss << location.ToString(CodeLocation::kFilenameAndPath); // output full path in case just the filename is ambiguous @@ -58,6 +91,14 @@ class OnnxRuntimeException : public std::exception { what_ = ss.str(); } + common::StatusCategory Category() const noexcept { + return category_; + } + + common::StatusCode Code() const noexcept { + return code_; + } + const char* what() const noexcept override { return what_.c_str(); } @@ -66,6 +107,8 @@ class OnnxRuntimeException : public std::exception { const CodeLocation location_; const std::vector stacktrace_; std::string what_; + common::StatusCategory category_; + common::StatusCode code_; }; } // namespace onnxruntime diff --git a/include/onnxruntime/core/common/status.h b/include/onnxruntime/core/common/status.h index 8f171daabbb1e..b222e411d7804 100644 --- a/include/onnxruntime/core/common/status.h +++ b/include/onnxruntime/core/common/status.h @@ -43,7 +43,8 @@ enum StatusCode { MODEL_LOADED = 8, NOT_IMPLEMENTED = 9, INVALID_GRAPH = 10, - EP_FAIL = 11 + EP_FAIL = 11, + MODEL_LOAD_CANCELED = 12, }; constexpr const char* StatusCodeToString(StatusCode status) noexcept { @@ -72,6 +73,8 @@ constexpr const char* StatusCodeToString(StatusCode status) noexcept { return "INVALID_GRAPH"; case StatusCode::EP_FAIL: return "EP_FAIL"; + case StatusCode::MODEL_LOAD_CANCELED: + return "MODEL_LOAD_CANCELED"; default: return "GENERAL ERROR"; } @@ -104,6 +107,8 @@ constexpr HRESULT StatusCodeToHRESULT(StatusCode status) noexcept { return HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT); case StatusCode::EP_FAIL: return HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR); + case StatusCode::MODEL_LOAD_CANCELED: + return HRESULT_FROM_WIN32(ERROR_CANCELLED); default: return E_FAIL; } diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 6d4cc8a1f2fa9..3bf0d5e19c525 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -255,6 +255,7 @@ typedef enum OrtErrorCode { ORT_NOT_IMPLEMENTED, ORT_INVALID_GRAPH, ORT_EP_FAIL, + ORT_MODEL_LOAD_CANCELED, } OrtErrorCode; typedef enum OrtOpAttrType { @@ -4898,6 +4899,24 @@ struct OrtApi { _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out); + + /** \brief sets load cancellation flag to abort session loading process. + * + * \param[in] options instance that was passed to the session at creation time. + * \param[in] cancel setting this to true after model loading process was initiated will + * attempt to cancel the loading process. If cancellation is successful, CreateSession() + * CreateSessionFromArray() or any other session creation API that take session options as an + * argument will return an OrtStatus indicating that session loading was canceled at user request, + * error code ORT_MODEL_LOAD_CANCELED. + * The APIs above would not return any valid Session instance. This is the best case effort and the result + * is not guaranteed. The session may have already been created and initialized + * before the cancellation request was issued. + * + * \snippet{doc} snippets.dox OrtStatus + * + */ + ORT_API2_STATUS(SessionOptionsSetLoadCancellationFlag, _Inout_ OrtSessionOptions* options, + _In_ bool cancel); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 979b478e2fbb4..ce7dc1c45b05e 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -928,6 +928,8 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { SessionOptionsImpl& SetExecutionMode(ExecutionMode execution_mode); ///< Wraps OrtApi::SetSessionExecutionMode + SessionOptionsImpl& SetLoadCancellationFlag(bool value); ///< Wraps OrtApi::SessionOptionsSetLoadCancellationFlag + SessionOptionsImpl& SetLogId(const char* logid); ///< Wraps OrtApi::SetSessionLogId SessionOptionsImpl& SetLogSeverityLevel(int level); ///< Wraps OrtApi::SetSessionLogSeverityLevel diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 48c5e52e33c53..524e3ecc92936 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -747,6 +747,12 @@ inline SessionOptionsImpl& SessionOptionsImpl::SetExecutionMode(ExecutionM return *this; } +template +inline SessionOptionsImpl& SessionOptionsImpl::SetLoadCancellationFlag(bool value) { + ThrowOnError(GetApi().SessionOptionsSetLoadCancellationFlag(this->p_, value)); + return *this; +} + template inline SessionOptionsImpl& SessionOptionsImpl::SetLogId(const char* logid) { ThrowOnError(GetApi().SetSessionLogId(this->p_, logid)); diff --git a/onnxruntime/core/framework/error_code_helper.h b/onnxruntime/core/framework/error_code_helper.h index 703d183ea5c87..b42c6a9ba3e10 100644 --- a/onnxruntime/core/framework/error_code_helper.h +++ b/onnxruntime/core/framework/error_code_helper.h @@ -17,16 +17,19 @@ Status ToStatus(const OrtStatus* ort_status, common::StatusCategory category = c #ifndef ORT_NO_EXCEPTIONS #define API_IMPL_BEGIN try { -#define API_IMPL_END \ - } \ - catch (const onnxruntime::NotImplementedException& ex) { \ - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, ex.what()); \ - } \ - catch (const std::exception& ex) { \ - return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, ex.what()); \ - } \ - catch (...) { \ - return OrtApis::CreateStatus(ORT_FAIL, "Unknown Exception"); \ +#define API_IMPL_END \ + } \ + catch (const onnxruntime::OnnxRuntimeException& ex) { \ + return OrtApis::CreateStatus(static_cast(ex.Code()), ex.what()); \ + } \ + catch (const onnxruntime::NotImplementedException& ex) { \ + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, ex.what()); \ + } \ + catch (const std::exception& ex) { \ + return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, ex.what()); \ + } \ + catch (...) { \ + return OrtApis::CreateStatus(ORT_FAIL, "Unknown Exception"); \ } #else diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index ff4d300f665b1..50f14104cfd7a 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -56,6 +56,7 @@ namespace { // contains some common parameters used by the partitioning helper functions struct PartitionParams { std::reference_wrapper graph; + std::reference_wrapper check_load_cancellation_fn; #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) std::reference_wrapper func_mgr; std::reference_wrapper fused_kernel_registry; @@ -143,6 +144,7 @@ struct GetCapabilityForEPParams { #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) IResourceAccountant* resource_accountant; std::reference_wrapper graph_optimizer_registry; + std::reference_wrapper check_load_cancellation_fn; }; auto get_capabilities = [](const IExecutionProvider& ep, @@ -188,7 +190,12 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l { const GraphViewer graph_viewer(graph); - capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant, graph_optimizer_registry); + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant, + graph_optimizer_registry); + if (params.check_load_cancellation_fn()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, + "Graph partitioning was canceled by user request"); + } if (capabilities.empty()) { return Status::OK(); @@ -209,6 +216,10 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l // Perform layout transformation on the specific EP assigned graph bool modified = false; ORT_RETURN_IF_ERROR(params.transform_layout(graph, modified, current_ep, params.debug_graph_fn)); + if (params.check_load_cancellation_fn()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, + "GetCapabilities was canceled by user request"); + } // It is possible some new nodes are introduced during transformation. These nodes can be either existing nodes // which are reconstructed to update domain or completely new nodes which are necessary for layout transformation. @@ -226,7 +237,12 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l capabilities.clear(); const GraphViewer graph_viewer(graph); - capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant, graph_optimizer_registry); + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant, + graph_optimizer_registry); + if (params.check_load_cancellation_fn()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, + "GetCapabilities was canceled by user request"); + } // all nodes with an index >= first_new_node with domain of kMSInternalNHWCDomain should be in the capabilities InlinedHashSet new_nodes_in_capabilities; @@ -405,6 +421,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, int& fused_node_unique_id, const layout_transformation::TransformLayoutFunction& transform_layout_fn, const layout_transformation::DebugGraphFn& debug_graph_fn, + const CheckLoadCancellationFn& check_load_cancellation_fn, const logging::Logger& logger, IResourceAccountant* resource_accountant, const GraphOptimizerRegistry& graph_optimizer_registry) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. @@ -420,7 +437,10 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, // we pass through the FuncManager from the top level graph ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr, fused_kernel_registry, current_ep, mode, fused_node_unique_id, - transform_layout_fn, debug_graph_fn, logger, resource_accountant, graph_optimizer_registry)); + transform_layout_fn, debug_graph_fn, + check_load_cancellation_fn, + logger, resource_accountant, + graph_optimizer_registry)); } } @@ -445,7 +465,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, std::cref(transform_layout_fn), std::cref(debug_graph_fn), resource_accountant, - std::ref(graph_optimizer_registry)}; + std::ref(graph_optimizer_registry), + std::cref(check_load_cancellation_fn)}; ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger)); if (capabilities.empty()) { @@ -532,6 +553,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, } ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_and_viewers, node_compute_funcs)); + ORT_RETURN_IF(check_load_cancellation_fn(), + "Graph partitioning is canceled due to user request."); if (node_compute_funcs.size() != nodes_to_compile.size()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, type, " did not return correct number of compiled functions"); @@ -633,6 +656,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide Graph& graph, const GraphOptimizerRegistry& graph_optimizer_registry, const logging::Logger& logger, + const CheckLoadCancellationFn& check_load_cancellation_fn, InlinedHashSet& not_inlined, size_t& inlined_count) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. @@ -650,6 +674,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide *subgraph, graph_optimizer_registry, logger, + check_load_cancellation_fn, not_inlined, inlined_count)); } @@ -673,8 +698,13 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide InlinedHashSet claimed_by_ep; for (const auto& ep : execution_providers) { std::vector> capabilities; - ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, graph_optimizer_registry, logger, + ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, + graph_optimizer_registry, logger, capabilities)); + if (check_load_cancellation_fn()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, "AOT inlining is canceled due to user request."); + } + for (auto& capability : capabilities) { const auto& nodes = capability->sub_graph->nodes; if (nodes.size() == 1) { @@ -707,6 +737,9 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide ORT_IGNORE_RETURN_VALUE(not_inlined.insert(std::move(function_id))); } } + if (check_load_cancellation_fn()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, "AOT inlining is canceled due to user request."); + } } return Status::OK(); @@ -846,6 +879,7 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, auto& fused_kernel_registry = partition_params.fused_kernel_registry.get(); auto& fused_node_unique_id = partition_params.fused_node_unique_id.get(); const auto& transform_layout_function = partition_params.transform_layout_function; + const CheckLoadCancellationFn& check_load_cancellation_fn = partition_params.check_load_cancellation_fn; do { // process full graph with each EP @@ -861,6 +895,7 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, fused_kernel_registry, *ep, mode, fused_node_unique_id, transform_layout_function, partition_params.debug_graph_fn, + check_load_cancellation_fn, logger, resource_accountant, graph_optimizer_registry)); } @@ -915,7 +950,8 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param std::cref(partition_params.debug_graph_fn), #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) nullptr, - std::ref(graph_optimizer_registry) + std::ref(graph_optimizer_registry), + partition_params.check_load_cancellation_fn }; // clang-format on @@ -972,6 +1008,9 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param std::vector single_node_compute_func; ORT_RETURN_IF_ERROR(current_ep.Compile({IExecutionProvider::FusedNodeAndGraph{node, *compilation_entry.viewer}}, single_node_compute_func)); + if (partition_params.check_load_cancellation_fn()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, "Graph partitioning is canceled due to user request."); + } ORT_RETURN_IF(single_node_compute_func.empty(), "single_node_compute_func should have 1 element."); auto& func_mgr = partition_params.func_mgr.get(); @@ -1032,6 +1071,8 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model, return Status::OK(); } + auto check_load_cancellation_fn = [this]() -> bool { return IsLoadCancellationFlagSet(); }; + auto& graph = model.MainGraph(); InlinedHashSet not_inlined; do { @@ -1041,13 +1082,13 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model, graph, *graph_optimizer_registry_, logger, + check_load_cancellation_fn, not_inlined, inlined_count)); if (inlined_count == 0) { break; } - ORT_RETURN_IF_ERROR(graph.Resolve()); } while (true); @@ -1082,6 +1123,8 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "No provider specified."); } + CheckLoadCancellationFn check_load_cancellation_fn = [this]() -> bool { return IsLoadCancellationFlagSet(); }; + #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // fused_kernel_registry is preparing the kernels created on the fly for fused sub graph. // It is only visible for current session. @@ -1092,6 +1135,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, PartitionParams partition_params{ std::ref(graph), + std::cref(check_load_cancellation_fn), std::ref(func_mgr), std::ref(*fused_kernel_registry), std::ref(fused_node_unique_id), @@ -1105,6 +1149,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, ORT_UNUSED_PARAMETER(debug_graph_fn); PartitionParams partition_params{ std::ref(graph), + std::cref(check_load_cancellation_fn), }; #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index b9d4022cb5a14..87edc7a64c6b5 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -33,6 +33,16 @@ class GraphPartitioner { graph_optimizer_registry_(std::move(graph_optimizer_registry)) { } + GraphPartitioner(KernelRegistryManager& kernel_registry_mgr, + const ExecutionProviders& providers, + std::unique_ptr graph_optimizer_registry, + CheckLoadCancellationFn check_load_cancellation_fn) + : kernel_registry_mgr_(kernel_registry_mgr), + providers_(providers), + graph_optimizer_registry_(std::move(graph_optimizer_registry)), + check_load_cancellation_fn_(std::move(check_load_cancellation_fn)) { + } + // Run partitioning. Status Partition(Graph& graph, FuncManager& func_mgr, const layout_transformation::TransformLayoutFunction& transform_layout_function, @@ -41,6 +51,10 @@ class GraphPartitioner { Mode mode = Mode::kNormal, const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const; + bool IsLoadCancellationFlagSet() const { + return check_load_cancellation_fn_ && check_load_cancellation_fn_(); + } + #ifndef ORT_MINIMAL_BUILD /// // Ahead of Time Function inlining. The main purpose of the function is to inline as many @@ -69,6 +83,7 @@ class GraphPartitioner { KernelRegistryManager& kernel_registry_mgr_; const ExecutionProviders& providers_; std::unique_ptr graph_optimizer_registry_; + CheckLoadCancellationFn check_load_cancellation_fn_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 8d4db36106f28..ef323b99b006c 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -8,6 +8,7 @@ #include #include #include +#include #include #include "core/common/inlined_containers.h" #include "core/framework/config_options.h" @@ -66,6 +67,8 @@ struct FreeDimensionOverride { int64_t dim_value; }; +using CheckLoadCancellationFn = std::function; + /** * Configuration information for a session. */ @@ -184,6 +187,18 @@ struct SessionOptions { // User specified logging func and param OrtLoggingFunction user_logging_function = nullptr; void* user_logging_param = nullptr; + + void SetLoadCancellationFlag(bool value) noexcept { + *load_cancellation_flag = value; + } + + bool IsLoadCancellationFlagSet() const noexcept { + return *load_cancellation_flag; + } + + // Load cancellation flag is necessary to be within shared memory as session_options are + // copied internally and the flag needs to be accessible across all copies. + std::shared_ptr load_cancellation_flag = std::make_shared(false); }; inline std::ostream& operator<<(std::ostream& os, const SessionOptions& session_options) { diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index d174d6cc72ead..6362a3169f3a3 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -422,6 +422,10 @@ Status SessionState::PrepackConstantInitializedTensors( auto prepacked_constant_weights = [this, &constant_initializers_use_count, &initializers_to_share_map]( bool should_cache_prepacked_weights_for_shared_initializers) -> Status { for (auto& node : GetGraphViewer().Nodes()) { + if (sess_options_.IsLoadCancellationFlagSet()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, + "Weight pre-packing was canceled due to user request."); + } auto kernel = GetMutableKernel(node.Index()); int input_idx = 0; for (auto& input_def : node.InputDefs()) { @@ -1541,6 +1545,11 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_stringname(); diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 39ffc6a5b0cee..334ecb3887d14 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1268,6 +1268,10 @@ Graph::Graph(const Model& owning_model, #endif } + if (owning_model_.IsLoadCancellationFlagSet()) { + ORT_THROW_WITH_CATEGORY_AND_CODE(ONNXRUNTIME, MODEL_LOAD_CANCELED, "Graph loading canceled due to user request."); + } + // Remove constant nodes as they're replaced with initializers above. const gsl::not_null*> graph_mutable_nodes{graph_proto_->mutable_node()}; graph_mutable_nodes->erase( @@ -1365,6 +1369,10 @@ Graph::Graph(const Model& owning_model, } } + if (owning_model_.IsLoadCancellationFlagSet()) { + ORT_THROW_WITH_CATEGORY_AND_CODE(ONNXRUNTIME, MODEL_LOAD_CANCELED, "Graph loading canceled due to user request."); + } + for (auto& graph_output : graph_proto_->output()) { if (utils::HasName(graph_output) && utils::HasType(graph_output)) { auto& name = graph_output.name(); diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index 7629e40c1b5fe..436af7115eb1a 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -82,7 +82,7 @@ Model::Model(const std::string& graph_name, const std::vector& model_local_functions, const logging::Logger& logger, const ModelOptions& options) - : model_path_(model_path) { + : model_path_(model_path), check_load_cancellation_fn_(options.check_load_cancellation_fn) { model_proto_.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); model_proto_.mutable_graph()->set_name(graph_name); model_metadata_ = model_metadata; @@ -161,7 +161,7 @@ Model::Model(const ModelProto& model_proto, const PathString& model_path, Model::Model(ModelProto&& model_proto, const PathString& model_path, const IOnnxRuntimeOpSchemaRegistryList* local_registries, const logging::Logger& logger, const ModelOptions& options) - : model_path_(model_path) { + : model_path_(model_path), check_load_cancellation_fn_(options.check_load_cancellation_fn) { if (!utils::HasGraph(model_proto)) { ORT_THROW("ModelProto does not have a graph."); } @@ -435,6 +435,11 @@ Status Model::Load(const ModelProto& model_proto, ORT_TRY { model = std::make_unique(model_proto, model_path, local_registries, logger, options); } + ORT_CATCH(const OnnxRuntimeException& ex) { + ORT_HANDLE_EXCEPTION([&]() { + status = Status(ex.Category(), ex.Code(), ex.what()); + }); + } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { status = Status(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to load model with error: " + std::string(ex.what())); @@ -474,6 +479,11 @@ Status Model::Load(ModelProto&& model_proto, ORT_TRY { model = std::make_unique(std::move(model_proto), model_path, local_registries, logger, options); } + ORT_CATCH(const OnnxRuntimeException& ex) { + ORT_HANDLE_EXCEPTION([&]() { + status = Status(ex.Category(), ex.Code(), ex.what()); + }); + } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { status = Status(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to load model with error: " + std::string(ex.what())); @@ -509,6 +519,11 @@ static Status LoadModelHelper(const T& file_path, Loader loader) { ORT_TRY { status = loader(fd); } + ORT_CATCH(const OnnxRuntimeException& ex) { + ORT_HANDLE_EXCEPTION([&]() { + status = Status(ex.Category(), ex.Code(), ex.what()); + }); + } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { status = Status(ONNXRUNTIME, FAIL, ex.what()); diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 6fd94c60d6b99..70f82bcfb160b 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -11,6 +11,7 @@ #include "core/common/flatbuffers.h" +#include "core/framework/session_options.h" #include "core/graph/graph_viewer.h" #include "core/graph/ort_format_load_options.h" #include "core/session/onnxruntime_c_api.h" @@ -38,6 +39,14 @@ struct ModelOptions { // be returned. bool strict_shape_type_inference; + CheckLoadCancellationFn check_load_cancellation_fn; + + ModelOptions(bool allow_released_opsets_only, bool strict_shape_type_inference, + CheckLoadCancellationFn check_load_cancellation_fn) + : allow_released_opsets_only(allow_released_opsets_only), + strict_shape_type_inference(strict_shape_type_inference), + check_load_cancellation_fn(std::move(check_load_cancellation_fn)) {} + ModelOptions(bool allow_released_opsets_only, bool strict_shape_type_inference) : allow_released_opsets_only(allow_released_opsets_only), strict_shape_type_inference(strict_shape_type_inference) {} @@ -102,6 +111,11 @@ class Model { #endif // !defined(ORT_MINIMAL_BUILD) + // Check for load cancellation. + bool IsLoadCancellationFlagSet() const noexcept { + return check_load_cancellation_fn_ && check_load_cancellation_fn_(); + } + #if !defined(ORT_MINIMAL_BUILD) // Get model's IR version. // Return if not specified. @@ -343,5 +357,7 @@ class Model { // Main graph of the model. std::unique_ptr graph_; + + CheckLoadCancellationFn check_load_cancellation_fn_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_transformer_mgr.cc b/onnxruntime/core/optimizer/graph_transformer_mgr.cc index 039283bb2d4e1..83c3f70799987 100644 --- a/onnxruntime/core/optimizer/graph_transformer_mgr.cc +++ b/onnxruntime/core/optimizer/graph_transformer_mgr.cc @@ -27,6 +27,9 @@ common::Status GraphTransformerManager::ApplyTransformers(Graph& graph, Transfor } for (unsigned step = 0; step < steps_; ++step) { + if (IsLoadCancellationFlagSet()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, "Graph transformation canceled due to user request."); + } bool graph_changed = false; for (const auto& transformer : transformers->second) { if (step > 0 && transformer->ShouldOnlyApplyOnce()) diff --git a/onnxruntime/core/optimizer/graph_transformer_mgr.h b/onnxruntime/core/optimizer/graph_transformer_mgr.h index ed66302434ab2..eab57f12bfcbb 100644 --- a/onnxruntime/core/optimizer/graph_transformer_mgr.h +++ b/onnxruntime/core/optimizer/graph_transformer_mgr.h @@ -24,6 +24,16 @@ class GraphTransformerManager { // Get the maximum number of graph transformation steps common::Status GetSteps(unsigned& steps) const; + // Set the cancellation flag ptr from session_options + void SetLoadCancellationFn(CheckLoadCancellationFn check_load_cancellation_fn) { + check_load_cancellation_fn_ = std::move(check_load_cancellation_fn); + } + + // Get the cancellation flag ptr + bool IsLoadCancellationFlagSet() const noexcept { + return check_load_cancellation_fn_ && check_load_cancellation_fn_(); + } + // Register a transformer with a level. common::Status Register(std::unique_ptr transformer, TransformerLevel level); @@ -38,5 +48,6 @@ class GraphTransformerManager { InlinedHashMap>> level_to_transformer_map_; InlinedHashMap transformers_info_; + CheckLoadCancellationFn check_load_cancellation_fn_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index 2e733f67a888c..e50ee5738c30e 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -340,3 +340,11 @@ ORT_API_STATUS_IMPL(OrtApis::SetDeterministicCompute, _Inout_ OrtSessionOptions* return nullptr; API_IMPL_END } + +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetLoadCancellationFlag, _Inout_ OrtSessionOptions* options, + _In_ bool is_cancel) { + API_IMPL_BEGIN + options->value.SetLoadCancellationFlag(is_cancel); + return nullptr; + API_IMPL_END +} diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 92eaa68667f0e..0cb361bae563b 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -384,6 +384,7 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, #if !defined(ORT_MINIMAL_BUILD) // Update the number of steps for the graph transformer manager using the "finalized" session options ORT_THROW_IF_ERROR(graph_transformer_mgr_.SetSteps(session_options_.max_num_graph_transformation_steps)); + graph_transformer_mgr_.SetLoadCancellationFn(this->check_load_cancellation_fn_); #endif #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) @@ -1005,11 +1006,13 @@ common::Status InferenceSession::LoadOnnxModel(const PathString& model_uri) { std::copy(std::begin(interop_domains_), std::end(interop_domains_), std::back_inserter(domain_ptrs)); ORT_RETURN_IF_ERROR(AddCustomOpDomains(domain_ptrs)); #endif + const bool strict_shape_type_inference = session_options_.config_options.GetConfigOrDefault( kOrtSessionOptionsConfigStrictShapeTypeInference, "0") == "1"; return onnxruntime::Model::Load(model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_, - ModelOptions(true, strict_shape_type_inference)); + ModelOptions(true, strict_shape_type_inference, + check_load_cancellation_fn_)); }; common::Status st = LoadWithLoader(loader, "model_loading_uri"); @@ -1102,7 +1105,8 @@ common::Status InferenceSession::Load(const void* model_data, int model_data_len return onnxruntime::Model::Load(std::move(model_proto), model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_, - ModelOptions(true, strict_shape_type_inference)); + ModelOptions(true, strict_shape_type_inference, + check_load_cancellation_fn_)); }; return LoadWithLoader(loader, "model_loading_array"); @@ -1140,7 +1144,8 @@ common::Status InferenceSession::LoadOnnxModel(ModelProto model_proto) { // This call will move model_proto to the constructed model instance return onnxruntime::Model::Load(std::move(model_proto), model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_, - ModelOptions(true, strict_shape_type_inference)); + ModelOptions(true, strict_shape_type_inference, + check_load_cancellation_fn_)); }; return LoadWithLoader(loader, "model_loading_proto"); @@ -1173,7 +1178,8 @@ common::Status InferenceSession::Load(std::istream& model_istream, bool allow_re const bool strict_shape_type_inference = session_options_.config_options.GetConfigOrDefault( kOrtSessionOptionsConfigStrictShapeTypeInference, "0") == "1"; ModelOptions model_opts(allow_released_opsets_only, - strict_shape_type_inference); + strict_shape_type_inference, + check_load_cancellation_fn_); std::string external_data_folder_path = session_options_.config_options.GetConfigOrDefault( kOrtSessionOptionsModelExternalInitializersFileFolderPath, ""); @@ -1212,7 +1218,8 @@ common::Status InferenceSession::Load() { // Pass on ownership of the parsed ModelProto to the Model instance (its job here is done by this stage) return Model::Load(std::move(this->model_proto_), model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_, - ModelOptions(allow_released_opsets_only, strict_shape_type_inference)); + ModelOptions(allow_released_opsets_only, strict_shape_type_inference, + check_load_cancellation_fn_)); }; return LoadWithLoader(loader, "model_loading_from_saved_proto"); @@ -1240,7 +1247,8 @@ common::Status InferenceSession::Load(const OrtModel& model_editor_api_model) { std::unique_ptr tmp_model; ORT_RETURN_IF_ERROR(Model::LoadFromModelEditorApiModel(model_editor_api_model, HasLocalSchema() ? &custom_schema_registries_ : nullptr, - ModelOptions(true, strict_shape_type_inference), + ModelOptions(true, strict_shape_type_inference, + check_load_cancellation_fn_), *session_logger_, tmp_model)); model_ = std::move(tmp_model); @@ -1284,7 +1292,8 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool auto graph_optimizer_registry = std::make_unique(&session_options_, execution_providers_.Get(onnxruntime::kCpuExecutionProvider), session_logger_); - GraphPartitioner partitioner(kernel_registry_manager_, execution_providers_, std::move(graph_optimizer_registry)); + GraphPartitioner partitioner(kernel_registry_manager_, execution_providers_, std::move(graph_optimizer_registry), + check_load_cancellation_fn_); // Run Ahead Of time function inlining if (const bool disable_aot_function_inlining = @@ -1712,7 +1721,8 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, providers.Get(onnxruntime::kCpuExecutionProvider), &logger); - GraphPartitioner partitioner(kernel_registry_manager, providers, std::move(graph_optimizer_registry)); + GraphPartitioner partitioner(kernel_registry_manager, providers, std::move(graph_optimizer_registry), + [&sess_options]() -> bool { return sess_options.IsLoadCancellationFlagSet(); }); ORT_RETURN_IF_ERROR(partitioner.Partition(graph, session_state.GetMutableFuncMgr(), transform_layout_fn, @@ -1785,6 +1795,11 @@ common::Status InferenceSession::HasInvalidCombinationOfExecutionProviders() con #pragma warning(disable : 26117) #endif common::Status InferenceSession::Initialize() { + if (session_options_.IsLoadCancellationFlagSet()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, + "Session initialization canceled due to user request."); + } + Status status = Status::OK(); TimePoint tp; if (session_profiler_.IsEnabled()) { @@ -2010,6 +2025,10 @@ common::Status InferenceSession::Initialize() { // now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs. ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve()); + if (session_options_.IsLoadCancellationFlagSet()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, + "Session initialization canceled due to user request."); + } // Currently graph capture is only considered by CUDA EP, TRT EP, ROCM EP and JS EP. // diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 5b484103c9ecf..7b5d98c38a0fa 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -781,6 +781,10 @@ class InferenceSession { // the session options are released after the individual operators are destroyed. SessionOptions session_options_; + CheckLoadCancellationFn check_load_cancellation_fn_ = [this]() { + return session_options_.IsLoadCancellationFlagSet(); + }; + /// Logging manager if provided. logging::LoggingManager* logging_manager_; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 0e23d7a791bec..ac67a3ce5c1a2 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -720,22 +720,11 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSession, _In_ const OrtEnv* env, _In_ const O _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out) { API_IMPL_BEGIN std::unique_ptr sess; - OrtStatus* status = nullptr; *out = nullptr; - - ORT_TRY { - ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, sess)); - ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess)); - - *out = reinterpret_cast(sess.release()); - } - ORT_CATCH(const std::exception& e) { - ORT_HANDLE_EXCEPTION([&]() { - status = OrtApis::CreateStatus(ORT_FAIL, e.what()); - }); - } - - return status; + ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, sess)); + ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess)); + *out = reinterpret_cast(sess.release()); + return nullptr; API_IMPL_END } @@ -743,22 +732,10 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArray, _In_ const OrtEnv* env, _In size_t model_data_length, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out) { API_IMPL_BEGIN std::unique_ptr sess; - OrtStatus* status = nullptr; - *out = nullptr; - - ORT_TRY { - ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, nullptr, model_data, model_data_length, sess)); - ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess)); - - *out = reinterpret_cast(sess.release()); - } - ORT_CATCH(const std::exception& e) { - ORT_HANDLE_EXCEPTION([&]() { - status = OrtApis::CreateStatus(ORT_FAIL, e.what()); - }); - } - - return status; + ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, nullptr, model_data, model_data_length, sess)); + ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess)); + *out = reinterpret_cast(sess.release()); + return nullptr; API_IMPL_END } @@ -2810,6 +2787,7 @@ static constexpr OrtApi ort_api_1_to_22 = { &OrtApis::GetModelEditorApi, &OrtApis::CreateTensorWithDataAndDeleterAsOrtValue, + &OrtApis::SessionOptionsSetLoadCancellationFlag, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 9d8aeb18a782f..0a87036a0dd1d 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -549,4 +549,7 @@ ORT_API_STATUS_IMPL(CreateTensorWithDataAndDeleterAsOrtValue, _In_ OrtAllocator* ONNXTensorElementDataType type, _Outptr_ OrtValue** out); +ORT_API_STATUS_IMPL(SessionOptionsSetLoadCancellationFlag, _Inout_ OrtSessionOptions* options, + _In_ bool is_cancel); + } // namespace OrtApis diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 975502063ac2a..a069cfa0b4713 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1753,6 +1753,12 @@ Applies to session load, initialization, etc. Default is 0.)pbdoc") options->value.execution_mode = execution_mode; }, R"pbdoc(Sets the execution mode. Default is sequential.)pbdoc") + .def( + "set_load_cancellation_flag", + [](PySessionOptions* options, bool value) -> void { + options->value.SetLoadCancellationFlag(value); + }, + R"pbdoc(Request inference session load cancellation)pbdoc") .def_property( "execution_order", [](const PySessionOptions* options) -> ExecutionOrder { return options->value.execution_order; }, diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 95101c8075fc2..dc776f74d8758 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -498,6 +499,30 @@ TEST(InferenceSessionTests, TestModelSerialization) { ASSERT_TRUE(session_object_emptyValidation.Initialize().IsOK()); } +TEST(InferenceSessionTests, RequestLoadCancellation) { + { + // Explicit cancel during load, small model is fine + SessionOptions so; + so.session_logid = "InferenceSessionTests.TestLoadCancellation"; + + const PathString model_uri = ORT_TSTR("testdata/constant_floats.onnx"); + InferenceSession session_object{so, GetEnvironment()}; + so.SetLoadCancellationFlag(true); + ASSERT_FALSE(session_object.Load(model_uri).IsOK()); + } + { + // Explicit cancel during initialize, small model is fine + const PathString model_uri = ORT_TSTR("testdata/constant_floats.onnx"); + SessionOptions so; + so.session_logid = "InferenceSessionTests.TestLoadCancellation"; + so.SetLoadCancellationFlag(false); + InferenceSession session_object{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.Load(model_uri)); + so.SetLoadCancellationFlag(true); + ASSERT_FALSE(session_object.Initialize().IsOK()); + } +} + #ifdef ORT_RUN_EXTERNAL_ONNX_TESTS static bool Compare(const InputDefList& f_arg, const InputDefList& s_arg) { if (f_arg.size() != s_arg.size()) { diff --git a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc index 9c0b779870c70..a5fd37361a255 100644 --- a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc @@ -576,11 +576,10 @@ TEST(Loop, InfiniteLoopTermination) { test.Run(OpTester::ExpectResult::kExpectFailure, "Exiting due to terminate flag being set to true", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}, &session_run_options); // Disable TensorRT on unsupported data type BOOL - // call get to propagate any exception - terminator_result.get(); - // done with the thread terminator_thread.join(); + // call get to propagate any exception + terminator_result.get(); } // Add basic test to trigger types override logic in Graph::InferAndVerifySubgraphTypes as well as diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index b517ba7032886..e00606af1c086 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -4669,6 +4670,34 @@ TEST(CApiTest, RunBaseLoraModel) { } } +TEST(CApiTest, RequestLoadCancellation) { + constexpr const ORTCHAR_T* model_path = ORT_TSTR("testdata/transformers/tiny_gpt2_beamsearch.onnx"); + Ort::Env env(ORT_LOGGING_LEVEL_WARNING); + Ort::SessionOptions session_options; + + auto terminator = [&session_options]() { + session_options.SetLoadCancellationFlag(true); + return; + }; + + std::packaged_task task{terminator}; + std::future terminator_result = task.get_future(); + std::thread terminator_thread{std::move(task)}; + bool terminated = false; + try { + Ort::Session session(env, model_path, session_options); + } catch (const Ort::Exception& ex) { + terminated = OrtErrorCode::ORT_MODEL_LOAD_CANCELED == ex.GetOrtErrorCode(); + } + // done with the thread + terminator_thread.join(); + + // call get to propagate any exception + terminator_result.get(); + + ASSERT_TRUE(terminated); +} + struct MockGQA : public OrtCustomOp { MockGQA() { OrtCustomOp::GetMayInplace = [](int** input_index, int** output_index) { From ca1b32df6559854f5d3b6636988e470d1c5fd383 Mon Sep 17 00:00:00 2001 From: xhcao Date: Tue, 8 Apr 2025 04:13:21 +0800 Subject: [PATCH 16/18] [webgpu] Fix ROUND_PREFER_CEIL issue of Resize operator (#24229) ### Description ### Motivation and Context Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- onnxruntime/core/providers/webgpu/tensor/resize_impl.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webgpu/tensor/resize_impl.cc b/onnxruntime/core/providers/webgpu/tensor/resize_impl.cc index f68ace3c1d8a1..75a7f859c965f 100644 --- a/onnxruntime/core/providers/webgpu/tensor/resize_impl.cc +++ b/onnxruntime/core/providers/webgpu/tensor/resize_impl.cc @@ -122,7 +122,7 @@ void CalcNearestPixel(std::ostream& os, ResizeNearestMode mode) { body = "select(i32(round(x_original)), i32(floor(x_original)), x_original == f32(i32(x_original)) + 0.5)"; break; case ResizeNearestMode::ROUND_PREFER_CEIL: - body = "i32(round(x_original))"; + body = "select(i32(round(x_original)), i32(ceil(x_original)), x_original == f32(i32(x_original)) + 0.5)"; break; case ResizeNearestMode::FLOOR: body = "i32(floor(x_original))"; From b803429a50762e39d774751303b870eed72ac358 Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Mon, 7 Apr 2025 14:56:56 -0700 Subject: [PATCH 17/18] [Native WebGPU] Exclude WebGPU EP from ConvFp16 3D tests. (#24327) ### Description Exclude WebGPU from Conv3D tests ### Motivation and Context Fix failing tests in packaging pipelines. --- onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index d1350db8ec12e..1404071928e09 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -518,7 +518,7 @@ TEST(ConvFp16Test, Conv3D_1) { vector{1, 1, 1}, // kernel_shape vector{0, 0, 0, 0, 0, 0}, // pads vector{1, 1, 1}, // strides - {} // excluded EPs + {kWebGpuExecutionProvider} // excluded EPs }; vector X = { @@ -557,7 +557,7 @@ TEST(ConvFp16Test, Conv3D_2) { vector{1, 1, 1}, // kernel_shape vector{2, 2, 2, 2, 2, 2}, // pads vector{2, 2, 2}, // strides - {} // excluded EPs + {kWebGpuExecutionProvider} // excluded EPs }; vector X = { @@ -601,7 +601,7 @@ TEST(ConvFp16Test, Conv3D_Bias) { vector{2, 2, 2}, // kernel_shape vector{2, 2, 2, 2, 2, 2}, // pads vector{2, 2, 2}, // strides - {} // excluded EPs + {kWebGpuExecutionProvider} // excluded EPs }; vector X = { @@ -1082,7 +1082,7 @@ TEST(ConvFp16Test, Pointwise_3D) { vector{1, 1, 1}, // kernel_shape vector{0, 0, 0, 0, 0, 0}, // pads vector{1, 1, 1}, // strides - {} // excluded EPs + {kWebGpuExecutionProvider} // excluded EPs }; vector X = { From 554fb4ad1fcf808304d4758d73d93a8ecc362bf6 Mon Sep 17 00:00:00 2001 From: zz002 Date: Tue, 8 Apr 2025 08:15:40 +0800 Subject: [PATCH 18/18] [VitisAI EP] export InferShapes to VitisAIEP (#23881) ### Description [VitisAI EP] export InferShapes to VitisAIEP --------- Co-authored-by: Wang Chunye Co-authored-by: Zhenze --- .../providers/shared_library/provider_interfaces.h | 3 +++ .../providers/shared_library/provider_wrappedtypes.h | 1 + onnxruntime/core/providers/vitisai/imp/global_api.cc | 9 +++++++++ .../core/providers/vitisai/include/vaip/my_ort.h | 2 ++ .../core/providers/vitisai/include/vaip/vaip_ort_api.h | 10 ++++++++-- onnxruntime/core/session/provider_bridge_ort.cc | 8 ++++++++ 6 files changed, 31 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 9d5e16caa361d..bc8905c225822 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -611,6 +611,8 @@ struct ProviderHost { virtual int FunctionProto__metadata_props_size(const ONNX_NAMESPACE::FunctionProto* p) = 0; virtual ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__add_metadata_props(ONNX_NAMESPACE::FunctionProto* p) = 0; + virtual void InferShapes(const std::string& m, const std::string& save_path) = 0; + virtual void InferShapes(ONNX_NAMESPACE::ModelProto& m) = 0; virtual void RegisterSchema(const std::string& domain, const OrtCustomOp* op) = 0; virtual void DeregisterSchema(const std::string& domain, const std::string& op_type, int version) = 0; virtual const ONNX_NAMESPACE::OpSchema* GetSchema(const std::string& name, const int maxInclusiveVersion, const std::string& domain) = 0; @@ -1010,6 +1012,7 @@ struct ProviderHost { virtual const Graph* Graph__ParentGraph(const Graph* p) const = 0; virtual Graph* Graph__MutableParentGraph(Graph* p) = 0; virtual const std::string& Graph__Name(const Graph* p) const noexcept = 0; + virtual void Graph__SetName(Graph* p, const std::string& name) const noexcept = 0; virtual const std::filesystem::path& Graph__ModelPath(const Graph* p) const = 0; virtual const std::vector& Graph__GetInputsIncludingInitializers(const Graph* p) const noexcept = 0; virtual bool Graph__IsSubgraph(const Graph* p) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index e2af144f455e4..5f0f9ca4c8584 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -1050,6 +1050,7 @@ struct Graph final { const Graph* ParentGraph() const { return g_host->Graph__ParentGraph(this); } Graph* MutableParentGraph() { return g_host->Graph__MutableParentGraph(this); } const std::string& Name() const noexcept { return g_host->Graph__Name(this); } + void SetName(const std::string& name) noexcept { return g_host->Graph__SetName(this, name); } const std::filesystem::path& ModelPath() const { return g_host->Graph__ModelPath(this); } const std::vector& GetInputsIncludingInitializers() const noexcept { return g_host->Graph__GetInputsIncludingInitializers(this); } bool IsSubgraph() const { return g_host->Graph__IsSubgraph(this); } diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index 6547f00cd47c7..33aa8fa2b31b8 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -360,10 +360,19 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { }; the_global_api.graph_nodes_unsafe = [](const Graph& graph) -> auto { return vaip_core::DllSafe(graph.Nodes()); }; the_global_api.graph_get_name = [](const Graph& graph) -> const std::string& { return graph.Name(); }; + the_global_api.graph_set_name = [](Graph& graph, const char* name) -> void { return graph.SetName(std::string(name)); }; the_global_api.graph_reverse_dfs_from = [](const Graph& graph, gsl::span from, const auto& enter, const auto& leave, const auto& stop) { graph.ReverseDFSFrom(from, enter, leave, nullptr, stop); }; + + the_global_api.graph_infer_shapes_from_filepath = [](const std::string& m, const std::string& save_path) -> auto { return Provider_GetHost()->InferShapes(m, save_path); }; + the_global_api.graph_to_graph_proto = [](const Graph& graph) -> ONNX_NAMESPACE::GraphProto* { + return graph.ToGraphProto().release(); + }; + the_global_api.graph_proto_delete = [](ONNX_NAMESPACE::GraphProto* p) { delete p; }; + the_global_api.graph_infer_shapes = [](ONNX_NAMESPACE::ModelProto& m) -> auto { return Provider_GetHost()->InferShapes(m); }; + // node the_global_api.node_get_inputs_unsafe = vaip::node_get_inputs; the_global_api.node_get_output_node_args_unsafe = vaip::node_get_output_node_args; diff --git a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h index 85a1262d8489b..6c9c728d8ffad 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h @@ -20,6 +20,7 @@ struct NodeAttributes; namespace ONNX_NAMESPACE { struct AttributeProto; struct TensorProto; +struct GraphProto; struct ModelProto; #ifndef USE_VITISAI enum TensorProto_DataType : int { @@ -71,6 +72,7 @@ enum AttributeProto_AttributeType : int { namespace vaip_core { class GraphHolder; using ONNX_NAMESPACE::AttributeProto; +using ONNX_NAMESPACE::GraphProto; using ONNX_NAMESPACE::ModelProto; using ONNX_NAMESPACE::TensorProto; using onnxruntime::Graph; diff --git a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h index 0becc41d861f7..d40da70726b43 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h @@ -13,7 +13,7 @@ struct OrtApi; namespace vaip_core { -#define VAIP_ORT_API_MAJOR (14u) +#define VAIP_ORT_API_MAJOR (16u) #define VAIP_ORT_API_MINOR (0u) #define VAIP_ORT_API_PATCH (0u) struct OrtApiForVaip { @@ -249,7 +249,13 @@ struct OrtApiForVaip { const std::function& leave, const std::function& comp, const std::function& - stop); // [103] + stop); // [103] + void (*graph_set_name)(Graph& graph, const char* name); // [104] + void (*graph_infer_shapes_from_filepath)( + const std::string& m, const std::string& save_path); // [105] + GraphProto* (*graph_to_graph_proto)(const Graph& graph); // [106] + void (*graph_proto_delete)(GraphProto* p); // [107] + void (*graph_infer_shapes)(ModelProto& m); // [108] }; #ifndef USE_VITISAI diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index ba90d55a6b8e8..042598535e987 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -43,6 +43,7 @@ #include "core/session/onnxruntime_c_api.h" #include "core/common/string_helper.h" #include +#include "onnx/shape_inference/implementation.h" #ifdef ENABLE_TRAINING #ifdef ENABLE_TRAINING_TORCH_INTEROP @@ -771,6 +772,12 @@ struct ProviderHostImpl : ProviderHost { int FunctionProto__metadata_props_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->metadata_props_size(); } ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__add_metadata_props(ONNX_NAMESPACE::FunctionProto* p) override { return p->add_metadata_props(); } + void InferShapes(const std::string& m, const std::string& save_path) override { + return ONNX_NAMESPACE::shape_inference::InferShapes(m, save_path); + } + void InferShapes(ONNX_NAMESPACE::ModelProto& m) override { + return ONNX_NAMESPACE::shape_inference::InferShapes(m); + } void RegisterSchema(const std::string& domain, const OrtCustomOp* op) override { auto& domain_instance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance(); const auto& domain_to_version_map = domain_instance.Map(); @@ -1268,6 +1275,7 @@ struct ProviderHostImpl : ProviderHost { const Graph* Graph__ParentGraph(const Graph* p) const override { return p->ParentGraph(); } Graph* Graph__MutableParentGraph(Graph* p) override { return p->MutableParentGraph(); } const std::string& Graph__Name(const Graph* p) const noexcept override { return p->Name(); } + void Graph__SetName(Graph* p, const std::string& name) const noexcept override { return p->SetName(name); } const std::filesystem::path& Graph__ModelPath(const Graph* p) const override { return p->ModelPath(); } const std::vector& Graph__GetInputsIncludingInitializers(const Graph* p) const noexcept override { return p->GetInputsIncludingInitializers(); } bool Graph__IsSubgraph(const Graph* p) override { return p->IsSubgraph(); }