diff --git a/.github/workflows/linux_cuda_ci.yml b/.github/workflows/linux_cuda_ci.yml index f4ee8a7c27cd0..b92c823ae808a 100644 --- a/.github/workflows/linux_cuda_ci.yml +++ b/.github/workflows/linux_cuda_ci.yml @@ -63,7 +63,7 @@ jobs: # --- Download Build Artifact to Runner Temp Directory --- - name: Download Build Artifact - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v5 with: name: build-output-x64-Release # Must match the upload name path: ${{ runner.temp }}/Release # Download contents into temp dir structure diff --git a/.github/workflows/linux_minimal_build.yml b/.github/workflows/linux_minimal_build.yml index 7532d363b19eb..7bb7241d5eed9 100644 --- a/.github/workflows/linux_minimal_build.yml +++ b/.github/workflows/linux_minimal_build.yml @@ -582,7 +582,7 @@ jobs: with: node-version: 20 - name: Download Test Data Artifact - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v5 with: name: test_data path: ${{ runner.temp }}/.test_data/ diff --git a/.github/workflows/linux_tensorrt_ci.yml b/.github/workflows/linux_tensorrt_ci.yml index a7d3f5ec0f5fd..2344f7803ac36 100644 --- a/.github/workflows/linux_tensorrt_ci.yml +++ b/.github/workflows/linux_tensorrt_ci.yml @@ -65,7 +65,7 @@ jobs: # --- Download Build Artifact to Runner Temp Directory --- - name: Download Build Artifact - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v5 with: name: build-output-x64-Release # Must match the upload name path: ${{ runner.temp }}/Release # Download contents into temp dir structure diff --git a/.github/workflows/mac.yml b/.github/workflows/mac.yml index a6e92326c9c6f..b1ccfbfbf591d 100644 --- a/.github/workflows/mac.yml +++ b/.github/workflows/mac.yml @@ -63,7 +63,7 @@ jobs: runs-on: macos-15 env: - xcode_version: 16 + xcode_version: 16.4 strategy: matrix: diff --git a/.github/workflows/windows-web-ci-workflow.yml b/.github/workflows/windows-web-ci-workflow.yml index fcbef760d4626..fcaaeada39e57 100644 --- a/.github/workflows/windows-web-ci-workflow.yml +++ b/.github/workflows/windows-web-ci-workflow.yml @@ -67,7 +67,7 @@ jobs: node-version: "20.x" - name: Download WebAssembly artifacts - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v5 with: name: ${{ inputs.build_config }}_wasm path: ${{ github.workspace }}/artifacts_wasm diff --git a/.github/workflows/windows_cuda.yml b/.github/workflows/windows_cuda.yml index 18ff55506d401..1553a8b54537b 100644 --- a/.github/workflows/windows_cuda.yml +++ b/.github/workflows/windows_cuda.yml @@ -158,7 +158,7 @@ jobs: submodules: 'none' - name: Download build artifacts - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v5 with: name: build-artifacts path: ${{ runner.temp }}\build diff --git a/.github/workflows/windows_tensorrt.yml b/.github/workflows/windows_tensorrt.yml index dbc138e57a3ec..f5fe930854f09 100644 --- a/.github/workflows/windows_tensorrt.yml +++ b/.github/workflows/windows_tensorrt.yml @@ -163,7 +163,7 @@ jobs: submodules: 'none' - name: Download build artifacts - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v5 with: name: build-artifacts path: ${{ runner.temp }}\build diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 9f8b40be1c145..13a8239195b36 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1791,3 +1791,12 @@ else() include("${CMAKE_CURRENT_SOURCE_DIR}/arm64x.cmake") endif() endif() + +if(onnxruntime_BUILD_UNIT_TESTS) + include("${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime_test_pch.cmake") +endif() + +# Include precompiled header configuration for providers +if(TARGET onnxruntime_providers) + include("${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime_providers_pch.cmake") +endif() diff --git a/cmake/onnxruntime_providers_pch.cmake b/cmake/onnxruntime_providers_pch.cmake new file mode 100644 index 0000000000000..dc05853bb1cf4 --- /dev/null +++ b/cmake/onnxruntime_providers_pch.cmake @@ -0,0 +1,8 @@ +# Precompiled header configuration for onnxruntime_providers + +if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") + # Visual Studio PCH + target_precompile_headers(onnxruntime_providers PRIVATE + "${CMAKE_CURRENT_SOURCE_DIR}/providers_pch.h" + ) +endif() diff --git a/cmake/onnxruntime_test_pch.cmake b/cmake/onnxruntime_test_pch.cmake new file mode 100644 index 0000000000000..7c58c2d787596 --- /dev/null +++ b/cmake/onnxruntime_test_pch.cmake @@ -0,0 +1,20 @@ +# Precompiled header configuration for onnxruntime_test_all + +if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") + # Visual Studio PCH + target_precompile_headers(onnxruntime_test_all PRIVATE + "${CMAKE_CURRENT_SOURCE_DIR}/test_pch.h" + ) + endif() + +# Exclude certain files that might conflict with PCH +set(PCH_EXCLUDE_FILES + # Add any problematic source files here + "${TEST_SRC_DIR}/framework/tensor_shape_test.cc" +) + +foreach(file ${PCH_EXCLUDE_FILES}) + set_source_files_properties(${file} PROPERTIES + SKIP_PRECOMPILE_HEADERS ON + ) +endforeach() diff --git a/cmake/providers_pch.h b/cmake/providers_pch.h new file mode 100644 index 0000000000000..6b21f9ad838c9 --- /dev/null +++ b/cmake/providers_pch.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +// Core framework headers (highest compilation time impact) +#include "core/framework/op_kernel.h" +#include "core/framework/op_kernel_info.h" +#include "core/framework/execution_provider.h" +#include "core/framework/op_node_proto_helper.h" +#include "core/framework/data_types.h" +#include "core/framework/tensor.h" + +// Graph-related headers +#include "core/graph/graph_viewer.h" +#include "core/graph/graph.h" +#include "core/graph/onnx_protobuf.h" + +// ONNX schema definitions +#include "onnx/defs/schema.h" + +// Windows-specific headers (if applicable) +#ifdef _WIN32 +#include +#endif diff --git a/cmake/test_pch.h b/cmake/test_pch.h new file mode 100644 index 0000000000000..d538976367791 --- /dev/null +++ b/cmake/test_pch.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +// Test framework headers (highest compilation time impact) +#include "gtest/gtest.h" +#include "gtest/gtest-assertion-result.h" +#include "gtest/gtest-message.h" +#include "gtest/internal/gtest-port.h" + +// Core test utilities (most frequently used in tests) +#include "test/providers/provider_test_utils.h" +#include "test/providers/checkers.h" + +// ONNX and Protocol Buffer headers +#include "core/graph/onnx_protobuf.h" +#include "onnx/defs/schema.h" + +// Data types and framework headers +#include "core/framework/data_types.h" + +// Windows-specific headers (if applicable) +#ifdef _WIN32 +#include +#endif diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h index 89467f5238fa9..59ca1a1df762e 100644 --- a/include/onnxruntime/core/session/environment.h +++ b/include/onnxruntime/core/session/environment.h @@ -199,23 +199,6 @@ class Environment { using OrtAllocatorUniquePtr = std::unique_ptr>; - // if the user calls CreateSharedAllocator and wraps the plugin EP's allocator with an arena we end up with - // OrtAllocator from EP -> wrapped in IAllocatorImplWrappingOrtAllocator -> inside a BFCArena IAllocator. - // we can put that in shared_allocators_ for sessions to use, but to have an OrtAllocator available in - // shared_ort_allocators_ that can be used outside of a session we need to additionally wrap that in an - // OrtAllocatorImplWrappingIAllocator. way too many levels of indirection but that is what it is currently. - // we need something to own that final OrtAllocator, so we add it to arena_ort_allocators_. - // - // TODO: we could split out the BFCArena implementation so it can be plugged into either an IAllocator - // or an OrtAllocator instance to reduce the indirection a little. - // with that we get an OrtAllocator from the EP, wrap it with an OrtAllocator based BFCArena, and wrap that with the - // IAllocatorImplWrappingOrtAllocator which takes ownership of the OrtAllocator and is in shared_allocators_. - // - // Alternatively we can disable wrapping an EP's allocator with a BFCArena and say the EP should provide the arena - // implementation directly. They're free to copy BFCArena as it came from TF originally. Or we could provide a - // cut-and-paste BFCArena implementation that works using the EP API that can be included in the EP source. - std::unordered_map> arena_ort_allocators_; - #if !defined(ORT_MINIMAL_BUILD) // register EPs that are built into the ORT binary so they can take part in AutoEP selection // added to ep_libraries diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index dbe2614099be1..9d78d52ef6f4b 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -134,432 +134,13 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { } shader.AddOutput("output", ShaderUsage::UseUniform); - shader.AdditionalImplementation() << "const qkv_head_size: u32 = " << qkv_head_size_ << ";\n" - << "const num_heads: u32 =" << qkv_num_heads_ << ";\n"; - - if (is_fp16_) { - shader.AdditionalImplementation() << "const min_value : q_element_t = q_element_t(-65504.0);\n"; - } else { - shader.AdditionalImplementation() << "const min_value = f32(-3.402823e+38f);\n"; - } - - shader.AdditionalImplementation() << R"HELPER_FN( - // For max performance max_k_step should be the same as sg_size, however we might run out of registers - // for qk_1, qk_2 .. qk_(sg_size). So we cap it at max_k_step (16). - const max_k_step: u32 = 16u; - const vec_factor: u32 = 4u; - const qkv_head_size_vec: u32 = qkv_head_size / vec_factor; - - // Default SHM usage limit is 16KB in Dawn. - // vec4 * qkv_head_size_vec * max_k_step = 8 * (128/4) * 16 = 4KB. 128 is head_size for phi4. - var k_tile : array, max_k_step>; - var v_tile : array, max_k_step>; - - // Private memory per lane. - var q_tile : array; - fn loadq(q_idx_global : u32, head_idx: u32, alpha: q_element_t) - { - // Stored as float16[batch_size,sequence_length,3072] the inputs as per onnx MHA - // This is the layout if TransferBSDToBNSH has not been run. - let offset = q_idx_global * (qkv_head_size_vec) * num_heads + qkv_head_size_vec * head_idx; - // Stored as BNSH - which is what webgpu uses after TransferBSDToBNSH has been run. - //let offset = head_idx * uniforms.new_sequence_length * qkv_head_size_vec + q_idx_global * qkv_head_size_vec; - for (var idx:u32 = 0; idx < qkv_head_size_vec; idx++) - { - q_tile[idx] = q[idx+offset] * alpha; - } - } - fn loadk(k_start : u32, head_idx: u32, local_idx: u32, k_step: u32) - { - // Stored as float16[batch_size,num_heads,present_sequence_length,96] - let offset = head_idx * uniforms.present_sequence_length * qkv_head_size_vec + k_start * qkv_head_size_vec; - for (var idx:u32 = local_idx; idx < qkv_head_size_vec*k_step; idx+=workgroup_size_x) - { - let slot = u32(idx/qkv_head_size_vec); - let val = select(q_value_t(0), present_key[offset+idx], k_start + slot < uniforms.total_sequence_length); - k_tile[slot][idx%qkv_head_size_vec] = val; - } - } - fn loadv(v_start : u32, head_idx: u32, local_idx: u32, k_step: u32) - { - // Stored as float16[batch_size,num_heads,present_sequence_length,96] - let offset = head_idx * uniforms.present_sequence_length * qkv_head_size_vec + v_start * qkv_head_size_vec; - for (var idx:u32 = local_idx; idx < qkv_head_size_vec*k_step; idx+=workgroup_size_x) - { - let slot = u32(idx/qkv_head_size_vec); - let val = select(q_value_t(0), present_value[offset+idx], v_start + slot < uniforms.total_sequence_length); - v_tile[slot][idx%qkv_head_size_vec] = val; - } - } -)HELPER_FN"; - - if (is_qualcomm_) { - shader.AdditionalImplementation() << R"HELPER_FN( - const half_qkv_head_size_vec = qkv_head_size_vec / 2u; - - // Move half of o_tile from private memory into workgroup memory to reduce register pressure. - // Note that register spill was observed on Qualcomm if whole o_tile is on private memory. - // vec4 * half_qkv_head_size_vec * workgroup_size_x = 8 * (128/4/2) * 64 = 8KB. - var o_tile_r : array, workgroup_size_x>; - - // Private memory per lane. - var o_tile : array; - fn writeo(o_idx_global: u32, head_idx: u32, local_idx: u32) - { - // Stored as float16[batch_size,sequence_length,3072] - let offset = o_idx_global * num_heads * qkv_head_size_vec + head_idx * qkv_head_size_vec; - for (var idx:u32 = 0; idx < half_qkv_head_size_vec; idx ++) - { - output[offset+idx] = o_tile[idx]; - output[offset+idx+half_qkv_head_size_vec] = o_tile_r[local_idx][idx]; - } - } - )HELPER_FN"; - } else { - shader.AdditionalImplementation() << R"HELPER_FN( - // Private memory per lane. - var o_tile : array; - fn writeo(o_idx_global: u32, head_idx: u32) - { - // Stored as float16[batch_size,sequence_length,3072] - let offset = o_idx_global * num_heads * qkv_head_size_vec + head_idx * qkv_head_size_vec; - for (var idx:u32 = 0; idx < qkv_head_size_vec; idx ++) - { - output[offset+idx] = o_tile[idx]; - } - } - )HELPER_FN"; - } - - if (has_attention_bias_) { - shader.AdditionalImplementation() << R"HELPER_FN( - fn loadAttentionBias(q_idx_global : u32, k_idx_global : u32, head_idx: u32) -> vec4 - { - // Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length] - if (q_idx_global >= uniforms.new_sequence_length || k_idx_global >= uniforms.total_sequence_length) { - return vec4(0); - } - let offset_base = head_idx * uniforms.new_sequence_length * uniforms.total_sequence_length + q_idx_global * uniforms.total_sequence_length; - let offset = offset_base + k_idx_global; - let offset_max = offset_base + uniforms.total_sequence_length; - let c1 = q_element_t(attention_bias[min(offset, offset_max)]); - let c2 = q_element_t(attention_bias[min(offset+1, offset_max)]); - let c3 = q_element_t(attention_bias[min(offset+2, offset_max)]); - let c4 = q_element_t(attention_bias[min(offset+3, offset_max)]); - return vec4(c1,c2,c3,c4); - } - )HELPER_FN"; - } else { - shader.AdditionalImplementation() << R"HELPER_FN( - fn loadAttentionBias(q_idx_global : u32, k_idx_global : u32, head_idx: u32) -> vec4 - { - return vec4(0); - } - )HELPER_FN"; - } - - // Shader is designed to be dispatched as Dispatch(num_heads, new_sequence_length / workgroup_size_x, 1) - // Each lane/thread is responsible for a single q. - shader.MainFunctionBody() << R"MAIN_FN( - let head_idx = u32(workgroup_idx / uniforms.num_seq_tile); - let capped_sg_id = min(sg_id, max_k_step - 1u); - let capped_sg_size = min(sg_size, max_k_step); - - // Load Q - let q_idx_global = (workgroup_idx % uniforms.num_seq_tile) * workgroup_size_x + local_idx; - let valid_q = q_idx_global < uniforms.new_sequence_length; - if (valid_q) - { - loadq(q_idx_global, head_idx, q_element_t(uniforms.alpha)); - } - - var previous_max : q_element_t = min_value; - var previous_denom : q_element_t = 0; -)MAIN_FN"; - - if (is_unidirectional_) { - // If attention is unidirectional, set the loop bound to enforce causal masking. - shader.MainFunctionBody() << R"MAIN_FN( - let max_causal_len_for_workgroup = uniforms.past_sequence_length + - (workgroup_idx % uniforms.num_seq_tile + 1) * workgroup_size_x; - let loop_bound = min(uniforms.total_sequence_length, max_causal_len_for_workgroup); -)MAIN_FN"; - } else { - shader.MainFunctionBody() << R"MAIN_FN( - let loop_bound = uniforms.total_sequence_length; -)MAIN_FN"; - } - - shader.MainFunctionBody() << R"MAIN_FN( - for(var k_start = 0u; k_start < loop_bound; k_start+=capped_sg_size) - { - workgroupBarrier(); - loadk(k_start, head_idx / uniforms.n_reps, local_idx, capped_sg_size); - loadv(k_start, head_idx / uniforms.n_reps, local_idx, capped_sg_size); - workgroupBarrier(); - - // Compute QKt - var qk_1:vec4; - var qk_2:vec4; - var qk_3:vec4; - var qk_4:vec4; - if (sg_size > 8) - { - for (var i:u32 = 0u; i < qkv_head_size_vec; i++) - { - var k_local = k_tile[capped_sg_id][i]; - var q_own = q_tile[i]; - qk_1[0] += dot(q_own, subgroupShuffle(k_local, 0)); - qk_1[1] += dot(q_own, subgroupShuffle(k_local, 1)); - qk_1[2] += dot(q_own, subgroupShuffle(k_local, 2)); - qk_1[3] += dot(q_own, subgroupShuffle(k_local, 3)); - qk_2[0] += dot(q_own, subgroupShuffle(k_local, 4)); - qk_2[1] += dot(q_own, subgroupShuffle(k_local, 5)); - qk_2[2] += dot(q_own, subgroupShuffle(k_local, 6)); - qk_2[3] += dot(q_own, subgroupShuffle(k_local, 7)); - qk_3[0] += dot(q_own, subgroupShuffle(k_local, 8)); - qk_3[1] += dot(q_own, subgroupShuffle(k_local, 9)); - qk_3[2] += dot(q_own, subgroupShuffle(k_local, 10)); - qk_3[3] += dot(q_own, subgroupShuffle(k_local, 11)); - qk_4[0] += dot(q_own, subgroupShuffle(k_local, 12)); - qk_4[1] += dot(q_own, subgroupShuffle(k_local, 13)); - qk_4[2] += dot(q_own, subgroupShuffle(k_local, 14)); - qk_4[3] += dot(q_own, subgroupShuffle(k_local, 15)); - } - } - else - { - for (var i:u32 = 0u; i < qkv_head_size_vec; i++) - { - var k_local = k_tile[capped_sg_id][i]; - var q_own = q_tile[i]; - qk_1[0] += dot(q_own, subgroupShuffle(k_local, 0)); - qk_1[1] += dot(q_own, subgroupShuffle(k_local, 1)); - qk_1[2] += dot(q_own, subgroupShuffle(k_local, 2)); - qk_1[3] += dot(q_own, subgroupShuffle(k_local, 3)); - qk_2[0] += dot(q_own, subgroupShuffle(k_local, 4)); - qk_2[1] += dot(q_own, subgroupShuffle(k_local, 5)); - qk_2[2] += dot(q_own, subgroupShuffle(k_local, 6)); - qk_2[3] += dot(q_own, subgroupShuffle(k_local, 7)); - } - } - - qk_1 = qk_1 + loadAttentionBias(q_idx_global, k_start, head_idx); - qk_2 = qk_2 + loadAttentionBias(q_idx_global, k_start+4, head_idx); - if (sg_size > 8) - { - qk_3 = qk_3 + loadAttentionBias(q_idx_global, k_start+8, head_idx); - qk_4 = qk_4 + loadAttentionBias(q_idx_global, k_start+12, head_idx); - } - - let seq_causal_length = select(uniforms.total_sequence_length, uniforms.past_sequence_length + q_idx_global + 1, uniforms.is_unidirectional > 0); - // Neuter qk values where K is out of bounds. - qk_1[0] = select(min_value, qk_1[0], k_start+0 < seq_causal_length); - qk_1[1] = select(min_value, qk_1[1], k_start+1 < seq_causal_length); - qk_1[2] = select(min_value, qk_1[2], k_start+2 < seq_causal_length); - qk_1[3] = select(min_value, qk_1[3], k_start+3 < seq_causal_length); - qk_2[0] = select(min_value, qk_2[0], k_start+4 < seq_causal_length); - qk_2[1] = select(min_value, qk_2[1], k_start+5 < seq_causal_length); - qk_2[2] = select(min_value, qk_2[2], k_start+6 < seq_causal_length); - qk_2[3] = select(min_value, qk_2[3], k_start+7 < seq_causal_length); - if (sg_size > 8) - { - qk_3[0] = select(min_value, qk_3[0], k_start+8 < seq_causal_length); - qk_3[1] = select(min_value, qk_3[1], k_start+9 < seq_causal_length); - qk_3[2] = select(min_value, qk_3[2], k_start+10 < seq_causal_length); - qk_3[3] = select(min_value, qk_3[3], k_start+11 < seq_causal_length); - qk_4[0] = select(min_value, qk_4[0], k_start+12 < seq_causal_length); - qk_4[1] = select(min_value, qk_4[1], k_start+13 < seq_causal_length); - qk_4[2] = select(min_value, qk_4[2], k_start+14 < seq_causal_length); - qk_4[3] = select(min_value, qk_4[3], k_start+15 < seq_causal_length); - } -)MAIN_FN"; - // - // Compute SoftMax as per Flash Attention technique. - // - // Crux of Flash Attention is here, that allows for partial softmax computation, - // direct update of output and merging with previous results. - // https://courses.cs.washington.edu/courses/cse599m/23sp/notes/flashattn.pdf - // Where b is the block size of the tile. Xi is storing QKtranspose for the ith tile. - // mi_local is the max of Xi. Note: _ in this notation means what follows is a - // subscript. max_j=1:b (Xi[j]) is the max of Xi[j] for j=1 to b. - // - // for i = 1, #tiles do - // Xi = Q[k,:] Kt[:, (i-1) b : i b] - // mi_local= max_j=1:b (Xi[j]) - // Mi = max(M_(i-1), mi_local) - // d'_i = d'_(i-1) * e^(M_(i-1)-M_i) + Σ_j=1:b e^(Xi[j]-Mi) - // o'_i = o'_(i-1) * d'_(i-1) * e^(M_(i-1)-M_i) / d'_i + Σ_j=1:b (e^(Xi[j]-Mi) / d'_i) V[j + (i - 1)b,:] - // end - // - // In the code below: - // dleft is the first term of d'_i expression above : d'_(i-1) * e^(M_(i-1)-M_i). - // sum is the second term of the same expression : Σ_j=1:b e^(Xi[j]-Mi) - // o_ratio is the part of the first term of o'_i expression above : d'_(i-1) * e^(M_(i-1)-M_i) / d'_i - // - - // TODO: support smooth softmax and head_sink - shader.MainFunctionBody() << R"MAIN_FN( - var local_max_temp = max(qk_1, qk_2); - if (sg_size > 8) - { - local_max_temp = max(local_max_temp, qk_3); - local_max_temp = max(local_max_temp, qk_4); - } - let local_max = max(max(local_max_temp.x, local_max_temp.y),max(local_max_temp.z, local_max_temp.w)); - let new_max = max(previous_max, local_max); - qk_1 = q_value_t(exp(vec4(qk_1) - f32(new_max))); - qk_2 = q_value_t(exp(vec4(qk_2) - f32(new_max))); - if (sg_size > 8) { - qk_3 = q_value_t(exp(vec4(qk_3) - f32(new_max))); - qk_4 = q_value_t(exp(vec4(qk_4) - f32(new_max))); - } - let sum_vec = qk_1 + qk_2 + qk_3 + qk_4; - let sum = sum_vec.x + sum_vec.y + sum_vec.z + sum_vec.w; - - // Compute lhs term of update di prime and the compute di prime. - let dleft = previous_denom * exp(previous_max-new_max); - var d = dleft + sum; - d = select(d,q_element_t(0.0000001),d==0); - qk_1 = qk_1 / d; - qk_2 = qk_2 / d; - if (sg_size > 8) { - qk_3 = qk_3 / d; - qk_4 = qk_4 / d; - } - previous_max = new_max; - previous_denom = d; - let o_ratio = dleft / d; - -)MAIN_FN"; - - if (is_qualcomm_) { - shader.MainFunctionBody() << R"MAIN_FN( - if (sg_size > 8) { - for (var i:u32 = 0; i < half_qkv_head_size_vec; i++) - { - var val = v_tile[capped_sg_id][i]; - var sum = subgroupShuffle(val, 0) * qk_1[0]; - sum += subgroupShuffle(val, 1) * qk_1[1]; - sum += subgroupShuffle(val, 2) * qk_1[2]; - sum += subgroupShuffle(val, 3) * qk_1[3]; - sum += subgroupShuffle(val, 4) * qk_2[0]; - sum += subgroupShuffle(val, 5) * qk_2[1]; - sum += subgroupShuffle(val, 6) * qk_2[2]; - sum += subgroupShuffle(val, 7) * qk_2[3]; - sum += subgroupShuffle(val, 8) * qk_3[0]; - sum += subgroupShuffle(val, 9) * qk_3[1]; - sum += subgroupShuffle(val, 10) * qk_3[2]; - sum += subgroupShuffle(val, 11) * qk_3[3]; - sum += subgroupShuffle(val, 12) * qk_4[0]; - sum += subgroupShuffle(val, 13) * qk_4[1]; - sum += subgroupShuffle(val, 14) * qk_4[2]; - sum += subgroupShuffle(val, 15) * qk_4[3]; - o_tile[i] = o_tile[i] * o_ratio + sum; - - val = v_tile[capped_sg_id][half_qkv_head_size_vec + i]; - sum = subgroupShuffle(val, 0) * qk_1[0]; - sum += subgroupShuffle(val, 1) * qk_1[1]; - sum += subgroupShuffle(val, 2) * qk_1[2]; - sum += subgroupShuffle(val, 3) * qk_1[3]; - sum += subgroupShuffle(val, 4) * qk_2[0]; - sum += subgroupShuffle(val, 5) * qk_2[1]; - sum += subgroupShuffle(val, 6) * qk_2[2]; - sum += subgroupShuffle(val, 7) * qk_2[3]; - sum += subgroupShuffle(val, 8) * qk_3[0]; - sum += subgroupShuffle(val, 9) * qk_3[1]; - sum += subgroupShuffle(val, 10) * qk_3[2]; - sum += subgroupShuffle(val, 11) * qk_3[3]; - sum += subgroupShuffle(val, 12) * qk_4[0]; - sum += subgroupShuffle(val, 13) * qk_4[1]; - sum += subgroupShuffle(val, 14) * qk_4[2]; - sum += subgroupShuffle(val, 15) * qk_4[3]; - o_tile_r[local_idx][i] = o_tile_r[local_idx][i] * o_ratio + sum; - } - } - else - { - for (var i:u32 = 0; i < half_qkv_head_size_vec; i++) - { - var val = v_tile[capped_sg_id][i]; - var sum = subgroupShuffle(val, 0) * qk_1[0]; - sum += subgroupShuffle(val, 1) * qk_1[1]; - sum += subgroupShuffle(val, 2) * qk_1[2]; - sum += subgroupShuffle(val, 3) * qk_1[3]; - sum += subgroupShuffle(val, 4) * qk_2[0]; - sum += subgroupShuffle(val, 5) * qk_2[1]; - sum += subgroupShuffle(val, 6) * qk_2[2]; - sum += subgroupShuffle(val, 7) * qk_2[3]; - o_tile[i] = o_tile[i] * o_ratio + sum; - - val = v_tile[capped_sg_id][half_qkv_head_size_vec + i]; - sum = subgroupShuffle(val, 0) * qk_1[0]; - sum += subgroupShuffle(val, 1) * qk_1[1]; - sum += subgroupShuffle(val, 2) * qk_1[2]; - sum += subgroupShuffle(val, 3) * qk_1[3]; - sum += subgroupShuffle(val, 4) * qk_2[0]; - sum += subgroupShuffle(val, 5) * qk_2[1]; - sum += subgroupShuffle(val, 6) * qk_2[2]; - sum += subgroupShuffle(val, 7) * qk_2[3]; - o_tile_r[local_idx][i] = o_tile_r[local_idx][i] * o_ratio + sum; - } - } - } - - if (valid_q) { - writeo(q_idx_global, head_idx, local_idx); - } -)MAIN_FN"; - } else { - shader.MainFunctionBody() << R"MAIN_FN( - if (sg_size > 8) { - for (var i:u32 = 0; i < qkv_head_size_vec; i++) - { - var val = v_tile[capped_sg_id][i]; - var sum = subgroupShuffle(val, 0) * qk_1[0]; - sum += subgroupShuffle(val, 1) * qk_1[1]; - sum += subgroupShuffle(val, 2) * qk_1[2]; - sum += subgroupShuffle(val, 3) * qk_1[3]; - sum += subgroupShuffle(val, 4) * qk_2[0]; - sum += subgroupShuffle(val, 5) * qk_2[1]; - sum += subgroupShuffle(val, 6) * qk_2[2]; - sum += subgroupShuffle(val, 7) * qk_2[3]; - sum += subgroupShuffle(val, 8) * qk_3[0]; - sum += subgroupShuffle(val, 9) * qk_3[1]; - sum += subgroupShuffle(val, 10) * qk_3[2]; - sum += subgroupShuffle(val, 11) * qk_3[3]; - sum += subgroupShuffle(val, 12) * qk_4[0]; - sum += subgroupShuffle(val, 13) * qk_4[1]; - sum += subgroupShuffle(val, 14) * qk_4[2]; - sum += subgroupShuffle(val, 15) * qk_4[3]; - o_tile[i] = o_tile[i] * o_ratio + sum; - } - } - else - { - for (var i:u32 = 0; i < qkv_head_size_vec; i++) - { - var val = v_tile[capped_sg_id][i]; - var sum = subgroupShuffle(val, 0) * qk_1[0]; - sum += subgroupShuffle(val, 1) * qk_1[1]; - sum += subgroupShuffle(val, 2) * qk_1[2]; - sum += subgroupShuffle(val, 3) * qk_1[3]; - sum += subgroupShuffle(val, 4) * qk_2[0]; - sum += subgroupShuffle(val, 5) * qk_2[1]; - sum += subgroupShuffle(val, 6) * qk_2[2]; - sum += subgroupShuffle(val, 7) * qk_2[3]; - o_tile[i] = o_tile[i] * o_ratio + sum; - } - } - } - - if (valid_q) { - writeo(q_idx_global, head_idx); - } -)MAIN_FN"; - } - - return Status::OK(); + return WGSL_TEMPLATE_APPLY(shader, "bert/flash_attention.wgsl.template", + WGSL_TEMPLATE_PARAMETER(has_attention_bias, has_attention_bias_), + WGSL_TEMPLATE_PARAMETER(is_fp16, is_fp16_), + WGSL_TEMPLATE_PARAMETER(is_qualcomm, is_qualcomm_), + WGSL_TEMPLATE_PARAMETER(is_unidirectional, is_unidirectional_), + WGSL_TEMPLATE_PARAMETER(qkv_head_size, qkv_head_size_), + WGSL_TEMPLATE_PARAMETER(qkv_num_heads, qkv_num_heads_)); } Status FlashAttentionDecodeQKTProgram::GenerateShaderCode(ShaderHelper& shader) const { @@ -942,7 +523,6 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co {static_cast(parameters.total_sequence_length_)}, {static_cast(present_sequence_length)}, {static_cast(parameters.total_sequence_length_ - parameters.kv_sequence_length_)}, - {static_cast(parameters.is_unidirectional_)}, {static_cast(parameters.n_reps)}, {alpha}, {num_seq_tile}}); diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 9908b33a38372..9839b43ee8a69 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -56,7 +56,6 @@ class FlashAttentionProgram final : public Program { {"total_sequence_length", ProgramUniformVariableDataType::Uint32}, {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"is_unidirectional", ProgramUniformVariableDataType::Uint32}, {"n_reps", ProgramUniformVariableDataType::Uint32}, {"alpha", ProgramUniformVariableDataType::Float32}, {"num_seq_tile", ProgramUniformVariableDataType::Uint32}); diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template new file mode 100644 index 0000000000000..a8c28388292f9 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.wgsl.template @@ -0,0 +1,368 @@ + +#param has_attention_bias +#param is_fp16 +#param is_qualcomm +#param is_unidirectional +#param qkv_head_size +#param qkv_num_heads + +const head_size : u32 = qkv_head_size; +const num_heads : u32 = qkv_num_heads; + +#if is_fp16 +const min_value = q_element_t(-65504.0); +#else +const min_value = q_element_t(-3.402823e+38f); +#endif + +// For max performance max_k_step should be the same as sg_size, however we might run out of registers +// for qk_1, qk_2 .. qk_(sg_size). So we cap it at max_k_step (16). +const max_k_step : u32 = 16u; +const vec_factor : u32 = 4u; +const head_size_vec : u32 = head_size / vec_factor; + +// Default SHM usage limit is 16KB in Dawn. +// vec4 * head_size_vec * max_k_step = 8 * (128/4) * 16 = 4KB. 128 is head_size for phi4. +var k_tile : array, max_k_step>; +var v_tile : array, max_k_step>; + +// Private memory per lane. +var q_tile : array; +fn loadq(q_idx_global : u32, head_idx : u32, alpha : q_element_t) { + // Stored as float16[batch_size,sequence_length,3072] the inputs as per onnx MHA + // This is the layout if TransferBSDToBNSH has not been run. + let offset = q_idx_global * (head_size_vec)*num_heads + head_size_vec * head_idx; + // Stored as BNSH - which is what webgpu uses after TransferBSDToBNSH has been run. + // let offset = head_idx * uniforms.new_sequence_length * head_size_vec + q_idx_global * head_size_vec; + for (var idx : u32 = 0; idx < head_size_vec; idx++) { + q_tile[idx] = q[idx + offset] * alpha; + } +} + +fn loadk(k_start : u32, head_idx : u32, local_idx : u32, k_step : u32) { + // Stored as float16[batch_size,num_heads,present_sequence_length,96] + let offset = head_idx * uniforms.present_sequence_length * head_size_vec + k_start * head_size_vec; + for (var idx : u32 = local_idx; idx < head_size_vec * k_step; idx += workgroup_size_x) { + let slot = u32(idx / head_size_vec); + let val = select(q_value_t(0), present_key[offset + idx], k_start + slot < uniforms.total_sequence_length); + k_tile[slot][idx % head_size_vec] = val; + } +} + +fn loadv(v_start : u32, head_idx : u32, local_idx : u32, k_step : u32) { + // Stored as float16[batch_size,num_heads,present_sequence_length,96] + let offset = head_idx * uniforms.present_sequence_length * head_size_vec + v_start * head_size_vec; + for (var idx : u32 = local_idx; idx < head_size_vec * k_step; idx += workgroup_size_x) { + let slot = u32(idx / head_size_vec); + let val = select(q_value_t(0), present_value[offset + idx], v_start + slot < uniforms.total_sequence_length); + v_tile[slot][idx % head_size_vec] = val; + } +} + +#if is_qualcomm +const half_head_size_vec = head_size_vec / 2u; + +// Move half of o_tile from private memory into workgroup memory to reduce register pressure. +// Note that register spill was observed on Qualcomm if whole o_tile is on private memory. +// vec4 * half_head_size_vec * workgroup_size_x = 8 * (128/4/2) * 64 = 8KB. +var o_tile_r : array, workgroup_size_x>; + +// Private memory per lane. +var o_tile : array; +fn writeo(o_idx_global : u32, head_idx : u32, local_idx : u32) { + // Stored as float16[batch_size,sequence_length,3072] + let offset = o_idx_global * num_heads * head_size_vec + head_idx * head_size_vec; + for (var idx : u32 = 0; idx < half_head_size_vec; idx++) { + output[offset + idx] = o_tile[idx]; + output[offset + idx + half_head_size_vec] = o_tile_r[local_idx][idx]; + } +} +#else +// Private memory per lane. +var o_tile : array; +fn writeo(o_idx_global : u32, head_idx : u32) { + // Stored as float16[batch_size,sequence_length,3072] + let offset = o_idx_global * num_heads * head_size_vec + head_idx * head_size_vec; + for (var idx : u32 = 0; idx < head_size_vec; idx++) { + output[offset + idx] = o_tile[idx]; + } +} +#endif + +#if has_attention_bias +fn loadAttentionBias(q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> vec4 { + // Stored as float16[batch_size,num_heads,new_seq_length,total_sequence_length] + if (q_idx_global >= uniforms.new_sequence_length || k_idx_global >= uniforms.total_sequence_length) { + return vec4(0); + } + let offset_base = head_idx * uniforms.new_sequence_length * uniforms.total_sequence_length + q_idx_global * uniforms.total_sequence_length; + let offset = offset_base + k_idx_global; + let offset_max = offset_base + uniforms.total_sequence_length; + let c1 = q_element_t(attention_bias[min(offset, offset_max)]); + let c2 = q_element_t(attention_bias[min(offset + 1, offset_max)]); + let c3 = q_element_t(attention_bias[min(offset + 2, offset_max)]); + let c4 = q_element_t(attention_bias[min(offset + 3, offset_max)]); + return vec4(c1, c2, c3, c4); +} +#else +fn loadAttentionBias(q_idx_global : u32, k_idx_global : u32, head_idx : u32) -> vec4 { + return vec4(0); +} +#endif + +$MAIN { + let head_idx = u32(workgroup_idx / uniforms.num_seq_tile); + let capped_sg_id = min(sg_id, max_k_step - 1u); + let capped_sg_size = min(sg_size, max_k_step); + + // Load Q + let q_idx_global = (workgroup_idx % uniforms.num_seq_tile) * workgroup_size_x + local_idx; + let valid_q = q_idx_global < uniforms.new_sequence_length; + if (valid_q) { + loadq(q_idx_global, head_idx, q_element_t(uniforms.alpha)); + } + + var previous_max : q_element_t = min_value; + var previous_denom : q_element_t = 0; + +#if is_unidirectional + // If attention is unidirectional, set the loop bound to enforce causal masking. + let max_causal_len_for_workgroup = uniforms.past_sequence_length + + (workgroup_idx % uniforms.num_seq_tile + 1) * workgroup_size_x; + let loop_bound = min(uniforms.total_sequence_length, max_causal_len_for_workgroup); + let seq_causal_length = uniforms.past_sequence_length + q_idx_global + 1; +#else + let loop_bound = uniforms.total_sequence_length; + let seq_causal_length = uniforms.total_sequence_length; +#endif + + for (var k_start = 0u; k_start < loop_bound; k_start += capped_sg_size) { + workgroupBarrier(); + loadk(k_start, head_idx / uniforms.n_reps, local_idx, capped_sg_size); + loadv(k_start, head_idx / uniforms.n_reps, local_idx, capped_sg_size); + workgroupBarrier(); + + // Compute QKt + var qk_1 : vec4; + var qk_2 : vec4; + var qk_3 : vec4; + var qk_4 : vec4; + if (sg_size > 8) { + for (var i : u32 = 0u; i < head_size_vec; i++) { +#if is_qualcomm + var k_local = q_value_t(0); + if (sg_id < max_k_step) { + k_local = k_tile[sg_id][i]; + } +#else + var k_local = k_tile[capped_sg_id][i]; +#endif + var q_own = q_tile[i]; + qk_1[0] += dot(q_own, subgroupShuffle(k_local, 0)); + qk_1[1] += dot(q_own, subgroupShuffle(k_local, 1)); + qk_1[2] += dot(q_own, subgroupShuffle(k_local, 2)); + qk_1[3] += dot(q_own, subgroupShuffle(k_local, 3)); + qk_2[0] += dot(q_own, subgroupShuffle(k_local, 4)); + qk_2[1] += dot(q_own, subgroupShuffle(k_local, 5)); + qk_2[2] += dot(q_own, subgroupShuffle(k_local, 6)); + qk_2[3] += dot(q_own, subgroupShuffle(k_local, 7)); + qk_3[0] += dot(q_own, subgroupShuffle(k_local, 8)); + qk_3[1] += dot(q_own, subgroupShuffle(k_local, 9)); + qk_3[2] += dot(q_own, subgroupShuffle(k_local, 10)); + qk_3[3] += dot(q_own, subgroupShuffle(k_local, 11)); + qk_4[0] += dot(q_own, subgroupShuffle(k_local, 12)); + qk_4[1] += dot(q_own, subgroupShuffle(k_local, 13)); + qk_4[2] += dot(q_own, subgroupShuffle(k_local, 14)); + qk_4[3] += dot(q_own, subgroupShuffle(k_local, 15)); + } + } else { + for (var i : u32 = 0u; i < head_size_vec; i++) { + var k_local = k_tile[capped_sg_id][i]; + var q_own = q_tile[i]; + qk_1[0] += dot(q_own, subgroupShuffle(k_local, 0)); + qk_1[1] += dot(q_own, subgroupShuffle(k_local, 1)); + qk_1[2] += dot(q_own, subgroupShuffle(k_local, 2)); + qk_1[3] += dot(q_own, subgroupShuffle(k_local, 3)); + qk_2[0] += dot(q_own, subgroupShuffle(k_local, 4)); + qk_2[1] += dot(q_own, subgroupShuffle(k_local, 5)); + qk_2[2] += dot(q_own, subgroupShuffle(k_local, 6)); + qk_2[3] += dot(q_own, subgroupShuffle(k_local, 7)); + } + } + + qk_1 = qk_1 + loadAttentionBias(q_idx_global, k_start, head_idx); + qk_2 = qk_2 + loadAttentionBias(q_idx_global, k_start + 4, head_idx); + if (sg_size > 8) { + qk_3 = qk_3 + loadAttentionBias(q_idx_global, k_start + 8, head_idx); + qk_4 = qk_4 + loadAttentionBias(q_idx_global, k_start + 12, head_idx); + } + + // Neuter qk values where K is out of bounds. + qk_1[0] = select(min_value, qk_1[0], k_start + 0 < seq_causal_length); + qk_1[1] = select(min_value, qk_1[1], k_start + 1 < seq_causal_length); + qk_1[2] = select(min_value, qk_1[2], k_start + 2 < seq_causal_length); + qk_1[3] = select(min_value, qk_1[3], k_start + 3 < seq_causal_length); + qk_2[0] = select(min_value, qk_2[0], k_start + 4 < seq_causal_length); + qk_2[1] = select(min_value, qk_2[1], k_start + 5 < seq_causal_length); + qk_2[2] = select(min_value, qk_2[2], k_start + 6 < seq_causal_length); + qk_2[3] = select(min_value, qk_2[3], k_start + 7 < seq_causal_length); + if (sg_size > 8) { + qk_3[0] = select(min_value, qk_3[0], k_start + 8 < seq_causal_length); + qk_3[1] = select(min_value, qk_3[1], k_start + 9 < seq_causal_length); + qk_3[2] = select(min_value, qk_3[2], k_start + 10 < seq_causal_length); + qk_3[3] = select(min_value, qk_3[3], k_start + 11 < seq_causal_length); + qk_4[0] = select(min_value, qk_4[0], k_start + 12 < seq_causal_length); + qk_4[1] = select(min_value, qk_4[1], k_start + 13 < seq_causal_length); + qk_4[2] = select(min_value, qk_4[2], k_start + 14 < seq_causal_length); + qk_4[3] = select(min_value, qk_4[3], k_start + 15 < seq_causal_length); + } + + var local_max_temp = max(qk_1, qk_2); + if (sg_size > 8) { + local_max_temp = max(local_max_temp, qk_3); + local_max_temp = max(local_max_temp, qk_4); + } + let local_max = max(max(local_max_temp.x, local_max_temp.y), max(local_max_temp.z, local_max_temp.w)); + let new_max = max(previous_max, local_max); + qk_1 = q_value_t(exp(vec4(qk_1) - f32(new_max))); + qk_2 = q_value_t(exp(vec4(qk_2) - f32(new_max))); + if (sg_size > 8) { + qk_3 = q_value_t(exp(vec4(qk_3) - f32(new_max))); + qk_4 = q_value_t(exp(vec4(qk_4) - f32(new_max))); + } + let sum_vec = qk_1 + qk_2 + qk_3 + qk_4; + let sum = sum_vec.x + sum_vec.y + sum_vec.z + sum_vec.w; + + // Compute lhs term of update di prime and the compute di prime. + let dleft = previous_denom * exp(previous_max - new_max); + var d = dleft + sum; + d = select(d, q_element_t(0.0000001), d == 0); + qk_1 = qk_1 / d; + qk_2 = qk_2 / d; + if (sg_size > 8) { + qk_3 = qk_3 / d; + qk_4 = qk_4 / d; + } + previous_max = new_max; + previous_denom = d; + let o_ratio = dleft / d; + +#if is_qualcomm + if (sg_size > 8) { + for (var i : u32 = 0; i < half_head_size_vec; i++) { + var val = q_value_t(0); + if (sg_id < max_k_step) { + val = v_tile[sg_id][i]; + } + var sum = subgroupShuffle(val, 0) * qk_1[0]; + sum += subgroupShuffle(val, 1) * qk_1[1]; + sum += subgroupShuffle(val, 2) * qk_1[2]; + sum += subgroupShuffle(val, 3) * qk_1[3]; + sum += subgroupShuffle(val, 4) * qk_2[0]; + sum += subgroupShuffle(val, 5) * qk_2[1]; + sum += subgroupShuffle(val, 6) * qk_2[2]; + sum += subgroupShuffle(val, 7) * qk_2[3]; + sum += subgroupShuffle(val, 8) * qk_3[0]; + sum += subgroupShuffle(val, 9) * qk_3[1]; + sum += subgroupShuffle(val, 10) * qk_3[2]; + sum += subgroupShuffle(val, 11) * qk_3[3]; + sum += subgroupShuffle(val, 12) * qk_4[0]; + sum += subgroupShuffle(val, 13) * qk_4[1]; + sum += subgroupShuffle(val, 14) * qk_4[2]; + sum += subgroupShuffle(val, 15) * qk_4[3]; + o_tile[i] = o_tile[i] * o_ratio + sum; + + if (sg_id < max_k_step) { + val = v_tile[sg_id][half_head_size_vec + i]; + } + sum = subgroupShuffle(val, 0) * qk_1[0]; + sum += subgroupShuffle(val, 1) * qk_1[1]; + sum += subgroupShuffle(val, 2) * qk_1[2]; + sum += subgroupShuffle(val, 3) * qk_1[3]; + sum += subgroupShuffle(val, 4) * qk_2[0]; + sum += subgroupShuffle(val, 5) * qk_2[1]; + sum += subgroupShuffle(val, 6) * qk_2[2]; + sum += subgroupShuffle(val, 7) * qk_2[3]; + sum += subgroupShuffle(val, 8) * qk_3[0]; + sum += subgroupShuffle(val, 9) * qk_3[1]; + sum += subgroupShuffle(val, 10) * qk_3[2]; + sum += subgroupShuffle(val, 11) * qk_3[3]; + sum += subgroupShuffle(val, 12) * qk_4[0]; + sum += subgroupShuffle(val, 13) * qk_4[1]; + sum += subgroupShuffle(val, 14) * qk_4[2]; + sum += subgroupShuffle(val, 15) * qk_4[3]; + o_tile_r[local_idx][i] = o_tile_r[local_idx][i] * o_ratio + sum; + } + } else { + for (var i : u32 = 0; i < half_head_size_vec; i++) { + var val = v_tile[capped_sg_id][i]; + var sum = subgroupShuffle(val, 0) * qk_1[0]; + sum += subgroupShuffle(val, 1) * qk_1[1]; + sum += subgroupShuffle(val, 2) * qk_1[2]; + sum += subgroupShuffle(val, 3) * qk_1[3]; + sum += subgroupShuffle(val, 4) * qk_2[0]; + sum += subgroupShuffle(val, 5) * qk_2[1]; + sum += subgroupShuffle(val, 6) * qk_2[2]; + sum += subgroupShuffle(val, 7) * qk_2[3]; + o_tile[i] = o_tile[i] * o_ratio + sum; + + val = v_tile[capped_sg_id][half_head_size_vec + i]; + sum = subgroupShuffle(val, 0) * qk_1[0]; + sum += subgroupShuffle(val, 1) * qk_1[1]; + sum += subgroupShuffle(val, 2) * qk_1[2]; + sum += subgroupShuffle(val, 3) * qk_1[3]; + sum += subgroupShuffle(val, 4) * qk_2[0]; + sum += subgroupShuffle(val, 5) * qk_2[1]; + sum += subgroupShuffle(val, 6) * qk_2[2]; + sum += subgroupShuffle(val, 7) * qk_2[3]; + o_tile_r[local_idx][i] = o_tile_r[local_idx][i] * o_ratio + sum; + } + } + } + + if (valid_q) { + writeo(q_idx_global, head_idx, local_idx); + } +#else + if (sg_size > 8) { + for (var i : u32 = 0; i < head_size_vec; i++) { + var val = v_tile[capped_sg_id][i]; + var sum = subgroupShuffle(val, 0) * qk_1[0]; + sum += subgroupShuffle(val, 1) * qk_1[1]; + sum += subgroupShuffle(val, 2) * qk_1[2]; + sum += subgroupShuffle(val, 3) * qk_1[3]; + sum += subgroupShuffle(val, 4) * qk_2[0]; + sum += subgroupShuffle(val, 5) * qk_2[1]; + sum += subgroupShuffle(val, 6) * qk_2[2]; + sum += subgroupShuffle(val, 7) * qk_2[3]; + sum += subgroupShuffle(val, 8) * qk_3[0]; + sum += subgroupShuffle(val, 9) * qk_3[1]; + sum += subgroupShuffle(val, 10) * qk_3[2]; + sum += subgroupShuffle(val, 11) * qk_3[3]; + sum += subgroupShuffle(val, 12) * qk_4[0]; + sum += subgroupShuffle(val, 13) * qk_4[1]; + sum += subgroupShuffle(val, 14) * qk_4[2]; + sum += subgroupShuffle(val, 15) * qk_4[3]; + o_tile[i] = o_tile[i] * o_ratio + sum; + } + } else { + for (var i : u32 = 0; i < head_size_vec; i++) { + var val = v_tile[capped_sg_id][i]; + var sum = subgroupShuffle(val, 0) * qk_1[0]; + sum += subgroupShuffle(val, 1) * qk_1[1]; + sum += subgroupShuffle(val, 2) * qk_1[2]; + sum += subgroupShuffle(val, 3) * qk_1[3]; + sum += subgroupShuffle(val, 4) * qk_2[0]; + sum += subgroupShuffle(val, 5) * qk_2[1]; + sum += subgroupShuffle(val, 6) * qk_2[2]; + sum += subgroupShuffle(val, 7) * qk_2[3]; + o_tile[i] = o_tile[i] * o_ratio + sum; + } + } + } + + if (valid_q) { + writeo(q_idx_global, head_idx); + } +#endif +} // MAIN diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template new file mode 100644 index 0000000000000..e4e3730eba808 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul.wgsl.template @@ -0,0 +1,276 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#param block_size +#param n_bits +#param has_zero_points + +#include "quantization/dp4a_matmul_common.wgsl.template" + +// This shader implements co-operative matrix multiply. The key idea here is to +// assume there is a primitive for medium size matrix multiply a subgroup can perform, +// using all its lanes and pooling all its registers to keep the values in registry. +// +// The entire workgroup which has N subgroups first loads a tile into shared memory, +// Then each subgroup loads a subtile from shared memory into registers and uses +// the medium size matrix multiply primitive to perform the math. +// The values for tile/subtile size are chosen to conform to the resource limits +// of an alderlake/tiger lake gpu. A tile is 64x64, workgroup is 256 threads - +// therefore there are 16 subgroups and 16 lanes in each subgroup. +// K the hidden dimension is paged in from RAM at k tile size which is 64. +// All this puts the shared memory requirement slightly above 16KB. +// WebGPU limit is 16KB, output is moved to registers instead of SHM to make +// everything fit in shared memory. +// +// Each subgroup performs a 16 x 64 x 16 multiply which is implemented with +// subgroup shuffle as a placeholder for the day the medium matrix mul primitive +// becomes available in WGSL. The registry requirements is ~2KB per subgroup, on +// Alderlake/Tigerlake subgroup has 8KB of registry space pooling the +// 512B of registry from each lane. +// +// The medium size matmul is implemented using dot4I8Packed, so the inputs for +// this shader require A to be int8 quantized with block size 64. B is regular +// matmulnbits input with block size 32. + +const tile_size = 64; +const subtile_size = 16; +const tile_size_k = 32; +const vec_factor = 4; +const u32_factor = 4; +const tile_size_k_vec = 2; + +// Shared memory +var tile_A : array, tile_size>, tile_size_k_vec>; // 64 x 32 +var scale_A : array; // 64 x 1 +var tile_B : array, tile_size>, tile_size_k_vec>; // 64 x 32 +var scale_B : array; // 64 x 1 + +#if has_zero_points + var zeroes : array; +#endif + +fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32) +{ + let a_global = a_global_base + row; + if (a_global >= uniforms.M) + { + return; + } + tile_A[col][row] = input_a[a_global*uniforms.K16+kidx_v+col]; + if (col == 0) + { + // kidx_v - covers 16 values of k + scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/8]; + } +} + +#if n_bits == 4 + fn loadSHMB(b_global_base:u32, kidx_v:u32, row: u32, col: u32) + { + let b_global = b_global_base + row; + if (b_global >= uniforms.N) + { + return; + } + + let b_value = input_b[b_global*uniforms.K16+kidx_v+col]; + let block_idx = kidx_v/(block_size/16); + let zero = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col); + tile_B[col][row] = DequantizedFrom4BitsTo8Bits(b_value, zero); + if (col == 0) + { + // kidx_v - each kidx_v covers 16 values of k + scale_B[row] = scales_b[b_global*(uniforms.K/block_size) + block_idx]; + } + } +#endif + +#if n_bits == 8 + fn loadSHMB(b_global_base:u32, kidx_v:u32, row: u32, col: u32) + { + let b_global = b_global_base + row; + if (b_global >= uniforms.N) + { + return; + } + + let b_value = input_b[b_global*uniforms.K16+kidx_v+col]; + tile_B[col][row] = AlignWithZeroPoint(b_value); + if (col == 0) + { + // kidx_v - each kidx_v covers 16 values of k + let block_idx = kidx_v/(block_size/16); + scale_B[row] = scales_b[b_global*(uniforms.K/block_size) + block_idx]; +#if has_zero_points + zeroes[row] = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col); +#endif + } + } +#endif + +$MAIN { + // During the load phase we use all 256 threads to load 64 rows of A/B. + // For each row we load tile_size_k_vec (2) vectorized elements, which are 32 elements of K. + let a_global_base = u32(workgroup_idx / uniforms.num_N_tile) * tile_size; + let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size; + let load_AorB = u32(local_idx/128); + let load_row = u32((local_idx%128)/2); + let load_col = u32(local_idx%2); + + // During the compute phase, we have the 64x64 tile split into + // subtiles of 16x16. We have a grid of 4x4 subtiles. + let subtile_id = u32(local_idx / subtile_size); + let subtile_idx = u32(subtile_id / 4); + let subtile_idy = u32(subtile_id % 4); + let base_A = subtile_idx * 16; + let base_B = subtile_idy * 16; + // For each subtile we have 16 threads assigned. + let a_idx = u32(local_idx % subtile_size); + + var lane_output1: vec4; + var lane_output2: vec4; + var lane_output3: vec4; + var lane_output4: vec4; + // K's vectrorization is 16 items per index. See input_a/input_b. + // tile_size_k_vec - is the k tile size in vectorized space (1/16). That is + // k tile size is 32. In vectorized space that is 32/16 = 2. + for (var kidx_v:u32 = 0; kidx_v < uniforms.K16; kidx_v+=tile_size_k_vec) + { + // Load Phase: Populate shared memory for the workgroup. + if (load_AorB == 0) + { + loadSHMA(a_global_base, kidx_v, load_row, load_col); + } + else + { + loadSHMB(b_global_base, kidx_v, load_row, load_col); + } + workgroupBarrier(); + + // Compute phase: Perform matmul for this subtile 16 x 32 x 16. + // Step 1: Load from shared memory into registers across entire subgroup. + var own_a0: vec4 = tile_A[0][base_A + a_idx]; + var own_a1: vec4 = tile_A[1][base_A + a_idx]; + var own_scale_a: output_element_t = scale_A[base_A + a_idx]; + +#if has_zero_points && n_bits == 8 + if (sg_size == 16) + { + var own_b0: vec4 = tile_B[0][base_B + sg_id]; + var own_b1: vec4 = tile_B[1][base_B + sg_id]; + var own_scale_b: output_element_t = scale_B[base_B + sg_id]; + var zero = zeroes[base_B + sg_id]; + // Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul. + lane_output1[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 0), own_a1, subgroupShuffle(own_b1, 0), subgroupShuffle(own_scale_b, 0) * own_scale_a, subgroupShuffle(zero, 0)); + lane_output1[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 1), own_a1, subgroupShuffle(own_b1, 1), subgroupShuffle(own_scale_b, 1) * own_scale_a, subgroupShuffle(zero, 1)); + lane_output1[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 2), own_a1, subgroupShuffle(own_b1, 2), subgroupShuffle(own_scale_b, 2) * own_scale_a, subgroupShuffle(zero, 2)); + lane_output1[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 3), own_a1, subgroupShuffle(own_b1, 3), subgroupShuffle(own_scale_b, 3) * own_scale_a, subgroupShuffle(zero, 3)); + + lane_output2[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 4), own_a1, subgroupShuffle(own_b1, 4), subgroupShuffle(own_scale_b, 4) * own_scale_a, subgroupShuffle(zero, 4)); + lane_output2[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 5), own_a1, subgroupShuffle(own_b1, 5), subgroupShuffle(own_scale_b, 5) * own_scale_a, subgroupShuffle(zero, 5)); + lane_output2[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 6), own_a1, subgroupShuffle(own_b1, 6), subgroupShuffle(own_scale_b, 6) * own_scale_a, subgroupShuffle(zero, 6)); + lane_output2[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 7), own_a1, subgroupShuffle(own_b1, 7), subgroupShuffle(own_scale_b, 7) * own_scale_a, subgroupShuffle(zero, 7)); + + lane_output3[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 8), own_a1, subgroupShuffle(own_b1, 8), subgroupShuffle(own_scale_b, 8) * own_scale_a, subgroupShuffle(zero, 8)); + lane_output3[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 9), own_a1, subgroupShuffle(own_b1, 9), subgroupShuffle(own_scale_b, 9) * own_scale_a, subgroupShuffle(zero, 9)); + lane_output3[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 10), own_a1, subgroupShuffle(own_b1, 10), subgroupShuffle(own_scale_b, 10) * own_scale_a, subgroupShuffle(zero, 10)); + lane_output3[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 11), own_a1, subgroupShuffle(own_b1, 11), subgroupShuffle(own_scale_b, 11) * own_scale_a, subgroupShuffle(zero, 11)); + + lane_output4[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 12), own_a1, subgroupShuffle(own_b1, 12), subgroupShuffle(own_scale_b, 12) * own_scale_a, subgroupShuffle(zero, 12)); + lane_output4[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 13), own_a1, subgroupShuffle(own_b1, 13), subgroupShuffle(own_scale_b, 13) * own_scale_a, subgroupShuffle(zero, 13)); + lane_output4[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 14), own_a1, subgroupShuffle(own_b1, 14), subgroupShuffle(own_scale_b, 14) * own_scale_a, subgroupShuffle(zero, 14)); + lane_output4[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 15), own_a1, subgroupShuffle(own_b1, 15), subgroupShuffle(own_scale_b, 15) * own_scale_a, subgroupShuffle(zero, 15)); + } + else + { + // Code for other subgroup sizes, simply doesnt use subgroups at all. + // Relies on reads from single location tile_B[][base_B + col] by all + // being optimized by the hardware. + lane_output1[0] += SDP8AI(own_a0, tile_B[0][base_B + 0], own_a1, tile_B[1][base_B + 0], own_scale_a * scale_B[base_B + 0], zeroes[base_B + 0]); + lane_output1[1] += SDP8AI(own_a0, tile_B[0][base_B + 1], own_a1, tile_B[1][base_B + 1], own_scale_a * scale_B[base_B + 1], zeroes[base_B + 1]); + lane_output1[2] += SDP8AI(own_a0, tile_B[0][base_B + 2], own_a1, tile_B[1][base_B + 2], own_scale_a * scale_B[base_B + 2], zeroes[base_B + 2]); + lane_output1[3] += SDP8AI(own_a0, tile_B[0][base_B + 3], own_a1, tile_B[1][base_B + 3], own_scale_a * scale_B[base_B + 3], zeroes[base_B + 3]); + + lane_output2[0] += SDP8AI(own_a0, tile_B[0][base_B + 4], own_a1, tile_B[1][base_B + 4], own_scale_a * scale_B[base_B + 4], zeroes[base_B + 4]); + lane_output2[1] += SDP8AI(own_a0, tile_B[0][base_B + 5], own_a1, tile_B[1][base_B + 5], own_scale_a * scale_B[base_B + 5], zeroes[base_B + 5]); + lane_output2[2] += SDP8AI(own_a0, tile_B[0][base_B + 6], own_a1, tile_B[1][base_B + 6], own_scale_a * scale_B[base_B + 6], zeroes[base_B + 6]); + lane_output2[3] += SDP8AI(own_a0, tile_B[0][base_B + 7], own_a1, tile_B[1][base_B + 7], own_scale_a * scale_B[base_B + 7], zeroes[base_B + 7]); + + lane_output3[0] += SDP8AI(own_a0, tile_B[0][base_B + 8], own_a1, tile_B[1][base_B + 8], own_scale_a * scale_B[base_B + 8], zeroes[base_B + 8]); + lane_output3[1] += SDP8AI(own_a0, tile_B[0][base_B + 9], own_a1, tile_B[1][base_B + 9], own_scale_a * scale_B[base_B + 9], zeroes[base_B + 9]); + lane_output3[2] += SDP8AI(own_a0, tile_B[0][base_B + 10], own_a1, tile_B[1][base_B + 10], own_scale_a * scale_B[base_B + 10], zeroes[base_B + 10]); + lane_output3[3] += SDP8AI(own_a0, tile_B[0][base_B + 11], own_a1, tile_B[1][base_B + 11], own_scale_a * scale_B[base_B + 11], zeroes[base_B + 11]); + + lane_output4[0] += SDP8AI(own_a0, tile_B[0][base_B + 12], own_a1, tile_B[1][base_B + 12], own_scale_a * scale_B[base_B + 12], zeroes[base_B + 12]); + lane_output4[1] += SDP8AI(own_a0, tile_B[0][base_B + 13], own_a1, tile_B[1][base_B + 13], own_scale_a * scale_B[base_B + 13], zeroes[base_B + 13]); + lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14], zeroes[base_B + 14]); + lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15], zeroes[base_B + 15]); + } +#else + if (sg_size == 16) + { + var own_b0: vec4 = tile_B[0][base_B + sg_id]; + var own_b1: vec4 = tile_B[1][base_B + sg_id]; + var own_scale_b: output_element_t = scale_B[base_B + sg_id]; + // Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul. + lane_output1[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 0), own_a1, subgroupShuffle(own_b1, 0), subgroupShuffle(own_scale_b, 0) * own_scale_a); + lane_output1[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 1), own_a1, subgroupShuffle(own_b1, 1), subgroupShuffle(own_scale_b, 1) * own_scale_a); + lane_output1[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 2), own_a1, subgroupShuffle(own_b1, 2), subgroupShuffle(own_scale_b, 2) * own_scale_a); + lane_output1[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 3), own_a1, subgroupShuffle(own_b1, 3), subgroupShuffle(own_scale_b, 3) * own_scale_a); + + lane_output2[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 4), own_a1, subgroupShuffle(own_b1, 4), subgroupShuffle(own_scale_b, 4) * own_scale_a); + lane_output2[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 5), own_a1, subgroupShuffle(own_b1, 5), subgroupShuffle(own_scale_b, 5) * own_scale_a); + lane_output2[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 6), own_a1, subgroupShuffle(own_b1, 6), subgroupShuffle(own_scale_b, 6) * own_scale_a); + lane_output2[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 7), own_a1, subgroupShuffle(own_b1, 7), subgroupShuffle(own_scale_b, 7) * own_scale_a); + + lane_output3[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 8), own_a1, subgroupShuffle(own_b1, 8), subgroupShuffle(own_scale_b, 8) * own_scale_a); + lane_output3[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 9), own_a1, subgroupShuffle(own_b1, 9), subgroupShuffle(own_scale_b, 9) * own_scale_a); + lane_output3[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 10), own_a1, subgroupShuffle(own_b1, 10), subgroupShuffle(own_scale_b, 10) * own_scale_a); + lane_output3[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 11), own_a1, subgroupShuffle(own_b1, 11), subgroupShuffle(own_scale_b, 11) * own_scale_a); + + lane_output4[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 12), own_a1, subgroupShuffle(own_b1, 12), subgroupShuffle(own_scale_b, 12) * own_scale_a); + lane_output4[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 13), own_a1, subgroupShuffle(own_b1, 13), subgroupShuffle(own_scale_b, 13) * own_scale_a); + lane_output4[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 14), own_a1, subgroupShuffle(own_b1, 14), subgroupShuffle(own_scale_b, 14) * own_scale_a); + lane_output4[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 15), own_a1, subgroupShuffle(own_b1, 15), subgroupShuffle(own_scale_b, 15) * own_scale_a); + } + else + { + // Code for other subgroup sizes, simply doesnt use subgroups at all. + // Relies on reads from single location tile_B[][base_B + col] by all + // being optimized by the hardware. + lane_output1[0] += SDP8AI(own_a0, tile_B[0][base_B + 0], own_a1, tile_B[1][base_B + 0], own_scale_a * scale_B[base_B + 0]); + lane_output1[1] += SDP8AI(own_a0, tile_B[0][base_B + 1], own_a1, tile_B[1][base_B + 1], own_scale_a * scale_B[base_B + 1]); + lane_output1[2] += SDP8AI(own_a0, tile_B[0][base_B + 2], own_a1, tile_B[1][base_B + 2], own_scale_a * scale_B[base_B + 2]); + lane_output1[3] += SDP8AI(own_a0, tile_B[0][base_B + 3], own_a1, tile_B[1][base_B + 3], own_scale_a * scale_B[base_B + 3]); + + lane_output2[0] += SDP8AI(own_a0, tile_B[0][base_B + 4], own_a1, tile_B[1][base_B + 4], own_scale_a * scale_B[base_B + 4]); + lane_output2[1] += SDP8AI(own_a0, tile_B[0][base_B + 5], own_a1, tile_B[1][base_B + 5], own_scale_a * scale_B[base_B + 5]); + lane_output2[2] += SDP8AI(own_a0, tile_B[0][base_B + 6], own_a1, tile_B[1][base_B + 6], own_scale_a * scale_B[base_B + 6]); + lane_output2[3] += SDP8AI(own_a0, tile_B[0][base_B + 7], own_a1, tile_B[1][base_B + 7], own_scale_a * scale_B[base_B + 7]); + + lane_output3[0] += SDP8AI(own_a0, tile_B[0][base_B + 8], own_a1, tile_B[1][base_B + 8], own_scale_a * scale_B[base_B + 8]); + lane_output3[1] += SDP8AI(own_a0, tile_B[0][base_B + 9], own_a1, tile_B[1][base_B + 9], own_scale_a * scale_B[base_B + 9]); + lane_output3[2] += SDP8AI(own_a0, tile_B[0][base_B + 10], own_a1, tile_B[1][base_B + 10], own_scale_a * scale_B[base_B + 10]); + lane_output3[3] += SDP8AI(own_a0, tile_B[0][base_B + 11], own_a1, tile_B[1][base_B + 11], own_scale_a * scale_B[base_B + 11]); + + lane_output4[0] += SDP8AI(own_a0, tile_B[0][base_B + 12], own_a1, tile_B[1][base_B + 12], own_scale_a * scale_B[base_B + 12]); + lane_output4[1] += SDP8AI(own_a0, tile_B[0][base_B + 13], own_a1, tile_B[1][base_B + 13], own_scale_a * scale_B[base_B + 13]); + lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14]); + lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15]); + } +#endif + workgroupBarrier(); + } + + let a_global = a_global_base + base_A + a_idx; + let b_global = b_global_base + base_B; + let output_idx = ((a_global) * uniforms.N + b_global)/4; + // This creates a shader requirement that uniforms.N % 16 == 0 + if (a_global < uniforms.M && b_global < uniforms.N) + { + output[output_idx] = lane_output1; + output[output_idx+1] = lane_output2; + output[output_idx+2] = lane_output3; + output[output_idx+3] = lane_output4; + } +} // MAIN diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_common.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_common.wgsl.template new file mode 100644 index 0000000000000..38fe0388c5954 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_common.wgsl.template @@ -0,0 +1,80 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#param n_bits +#param has_zero_points + +#include "quantization/matmul_nbits_zero_pt.wgsl.template" + +#if n_bits == 4 + fn DequantizedFrom4BitsTo8Bits(in: vec2, zero: i32) -> vec4 + { + var out = vec4(0); + var value_lower = vec4(unpack4xU8(in[0] & 0x0F0F0F0Fu)) - vec4(zero); + var value_upper = vec4(unpack4xU8((in[0] >> 4) & 0x0F0F0F0Fu)) - vec4(zero); + out[0] = pack4xI8(vec4(value_lower[0], value_upper[0], value_lower[1], value_upper[1])); + out[1] = pack4xI8(vec4(value_lower[2], value_upper[2], value_lower[3], value_upper[3])); + value_lower = vec4(unpack4xU8(in[1] & 0x0F0F0F0Fu)) - vec4(zero); + value_upper = vec4(unpack4xU8((in[1] >> 4) & 0x0F0F0F0Fu)) - vec4(zero); + out[2] = pack4xI8(vec4(value_lower[0], value_upper[0], value_lower[1], value_upper[1])); + out[3] = pack4xI8(vec4(value_lower[2], value_upper[2], value_lower[3], value_upper[3])); + return out; + } +#endif + +#if n_bits == 8 + fn AlignWithZeroPoint(in: vec4) -> vec4 + { + var out = vec4(0); + out[0] = pack4xI8(vec4(unpack4xU8(in[0])) - vec4(128)); + out[1] = pack4xI8(vec4(unpack4xU8(in[1])) - vec4(128)); + out[2] = pack4xI8(vec4(unpack4xU8(in[2])) - vec4(128)); + out[3] = pack4xI8(vec4(unpack4xU8(in[3])) - vec4(128)); + return out; + } +#endif + +// For 8bits, in case data overflow when converting from int32 (output of dot4I8Packed) to f16, we force it convert to f32. +// Then do the scale. Finally, convert to output element type. +#if has_zero_points && n_bits == 8 + // If has_zero_points is true, vec4(unpack4xU8(b_data)) - vec4(zero) may be out of the range [-128, 127] since zero can be any value between [0, 255]. + // To avoid the data overflow when use pack4xI8, we still use |pack4xI8(vec4(unpack4xU8(xxx)) - vec4(128))| to process the b data. In SDP8AI, we use the + // dp4a's result of a and b to subtract dot(vec4(unpack4xI8(a)), vec4(zero - 128)) to get the correct result. + // Scaled dot product of 8 packed unsigned integers. + fn SDP8AI(a1:vec4, b1:vec4, a2:vec4, b2:vec4, scale:output_element_t, zero: i32) -> output_element_t + { + let bias_zero = zero - 128; + var local_sum = dot4I8Packed(a1[0], b1[0]); + var dequantized_a_sum = vec4(unpack4xI8(a1[0])); + local_sum += dot4I8Packed(a1[1], b1[1]); + dequantized_a_sum += vec4(unpack4xI8(a1[1])); + local_sum += dot4I8Packed(a1[2], b1[2]); + dequantized_a_sum += vec4(unpack4xI8(a1[2])); + local_sum += dot4I8Packed(a1[3], b1[3]); + dequantized_a_sum += vec4(unpack4xI8(a1[3])); + local_sum += dot4I8Packed(a2[0], b2[0]); + dequantized_a_sum += vec4(unpack4xI8(a2[0])); + local_sum += dot4I8Packed(a2[1], b2[1]); + dequantized_a_sum += vec4(unpack4xI8(a2[1])); + local_sum += dot4I8Packed(a2[2], b2[2]); + dequantized_a_sum += vec4(unpack4xI8(a2[2])); + local_sum += dot4I8Packed(a2[3], b2[3]); + dequantized_a_sum += vec4(unpack4xI8(a2[3])); + local_sum -= dot(dequantized_a_sum, vec4(bias_zero)); + return output_element_t(f32(local_sum) * f32(scale)); + } +#else + // Scaled dot product of 8 packed unsigned integers. + fn SDP8AI(a1:vec4, b1:vec4, a2:vec4, b2:vec4, scale:output_element_t) -> output_element_t + { + var local_sum = dot4I8Packed(a1[0], b1[0]); + local_sum += dot4I8Packed(a1[1], b1[1]); + local_sum += dot4I8Packed(a1[2], b1[2]); + local_sum += dot4I8Packed(a1[3], b1[3]); + local_sum += dot4I8Packed(a2[0], b2[0]); + local_sum += dot4I8Packed(a2[1], b2[1]); + local_sum += dot4I8Packed(a2[2], b2[2]); + local_sum += dot4I8Packed(a2[3], b2[3]); + return output_element_t(f32(local_sum) * f32(scale)); + } +#endif diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc index 02d02e824b357..9a8185a778a42 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_nbits.cc @@ -8,131 +8,12 @@ namespace onnxruntime { namespace contrib { namespace webgpu { -namespace { - -std::string CommonFunctions(uint32_t nbits, bool has_zero_points) { - std::stringstream ss; - ss << GenerateZeroPointReadingCode(nbits, has_zero_points, "i32"); - - if (nbits == 4) { - ss << R"ADDNL_FN( - fn DequantizedFrom4BitsTo8Bits(in: vec2, zero: i32) -> vec4 - { - var out = vec4(0); - var value_lower = vec4(unpack4xU8(in[0] & 0x0F0F0F0Fu)) - vec4(zero); - var value_upper = vec4(unpack4xU8((in[0] >> 4) & 0x0F0F0F0Fu)) - vec4(zero); - out[0] = pack4xI8(vec4(value_lower[0], value_upper[0], value_lower[1], value_upper[1])); - out[1] = pack4xI8(vec4(value_lower[2], value_upper[2], value_lower[3], value_upper[3])); - value_lower = vec4(unpack4xU8(in[1] & 0x0F0F0F0Fu)) - vec4(zero); - value_upper = vec4(unpack4xU8((in[1] >> 4) & 0x0F0F0F0Fu)) - vec4(zero); - out[2] = pack4xI8(vec4(value_lower[0], value_upper[0], value_lower[1], value_upper[1])); - out[3] = pack4xI8(vec4(value_lower[2], value_upper[2], value_lower[3], value_upper[3])); - return out; - } - - // Scaled dot product of 8 packed unsigned integers. - fn SDP8AI(a1:vec4, b1:vec4, a2:vec4, b2:vec4, scale:output_element_t) -> output_element_t - { - var local_sum = dot4I8Packed(a1[0], b1[0]); - local_sum += dot4I8Packed(a1[1], b1[1]); - local_sum += dot4I8Packed(a1[2], b1[2]); - local_sum += dot4I8Packed(a1[3], b1[3]); - local_sum += dot4I8Packed(a2[0], b2[0]); - local_sum += dot4I8Packed(a2[1], b2[1]); - local_sum += dot4I8Packed(a2[2], b2[2]); - local_sum += dot4I8Packed(a2[3], b2[3]); - return output_element_t(local_sum) * scale; - } - )ADDNL_FN"; - } else { - ss << R"ADDNL_FN( - fn AlignWithZeroPoint(in: vec4) -> vec4 - { - var out = vec4(0); - out[0] = pack4xI8(vec4(unpack4xU8(in[0])) - vec4(128)); - out[1] = pack4xI8(vec4(unpack4xU8(in[1])) - vec4(128)); - out[2] = pack4xI8(vec4(unpack4xU8(in[2])) - vec4(128)); - out[3] = pack4xI8(vec4(unpack4xU8(in[3])) - vec4(128)); - return out; - } - )ADDNL_FN"; - // For 8bits, in case data overflow when converting from int32 (output of dot4I8Packed) to f16, we force it convert to f32. - // Then do the scale. Finally, convert to output element type. - if (has_zero_points) { - // If has_zero_points is true, vec4(unpack4xU8(b_data)) - vec4(zero) may be out of the range [-128, 127] since zero can be any value between [0, 255]. - // To avoid the data overflow when use pack4xI8, we still use |pack4xI8(vec4(unpack4xU8(xxx)) - vec4(128))| to process the b data. In SDP8AI, we use the - // dp4a's result of a and b to subtract dot(vec4(unpack4xI8(a)), vec4(zero - 128)) to get the correct result. - ss << R"ADDNL_FN( - // Scaled dot product of 8 packed unsigned integers. - fn SDP8AI(a1:vec4, b1:vec4, a2:vec4, b2:vec4, scale:output_element_t, zero: i32) -> output_element_t - { - let bias_zero = zero - 128; - var local_sum = dot4I8Packed(a1[0], b1[0]); - var dequantized_a_sum = vec4(unpack4xI8(a1[0])); - local_sum += dot4I8Packed(a1[1], b1[1]); - dequantized_a_sum += vec4(unpack4xI8(a1[1])); - local_sum += dot4I8Packed(a1[2], b1[2]); - dequantized_a_sum += vec4(unpack4xI8(a1[2])); - local_sum += dot4I8Packed(a1[3], b1[3]); - dequantized_a_sum += vec4(unpack4xI8(a1[3])); - local_sum += dot4I8Packed(a2[0], b2[0]); - dequantized_a_sum += vec4(unpack4xI8(a2[0])); - local_sum += dot4I8Packed(a2[1], b2[1]); - dequantized_a_sum += vec4(unpack4xI8(a2[1])); - local_sum += dot4I8Packed(a2[2], b2[2]); - dequantized_a_sum += vec4(unpack4xI8(a2[2])); - local_sum += dot4I8Packed(a2[3], b2[3]); - dequantized_a_sum += vec4(unpack4xI8(a2[3])); - local_sum -= dot(dequantized_a_sum, vec4(bias_zero)); - return output_element_t(f32(local_sum) * f32(scale)); - } - )ADDNL_FN"; - } else { - ss << R"ADDNL_FN( - // Scaled dot product of 8 packed unsigned integers. - fn SDP8AI(a1:vec4, b1:vec4, a2:vec4, b2:vec4, scale:output_element_t) -> output_element_t - { - var local_sum = dot4I8Packed(a1[0], b1[0]); - local_sum += dot4I8Packed(a1[1], b1[1]); - local_sum += dot4I8Packed(a1[2], b1[2]); - local_sum += dot4I8Packed(a1[3], b1[3]); - local_sum += dot4I8Packed(a2[0], b2[0]); - local_sum += dot4I8Packed(a2[1], b2[1]); - local_sum += dot4I8Packed(a2[2], b2[2]); - local_sum += dot4I8Packed(a2[3], b2[3]); - return output_element_t(f32(local_sum) * f32(scale)); - } - )ADDNL_FN"; - } - } - return ss.str(); -} - -} // namespace Status DP4AMatMulQuantizeProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); shader.AddOutput("output", ShaderUsage::UseUniform); shader.AddOutput("scales", ShaderUsage::UseUniform); - shader.MainFunctionBody() << R"MAIN_FN( - var local_a : array, 32>; - var max_value:vec4 = vec4(0); - for (var idx:u32=0;idx<32;idx+=1) - { - local_a[idx] = input_a[workgroup_idx*32 + idx]; - max_value = max(max_value, abs(local_a[idx])); - } - var scale = max(max_value.x, max_value.y); - scale = max(scale, max_value.z); - scale = max(scale, max_value.w); - for (var idx:u32=0;idx<32;idx+=1) - { - output[workgroup_idx*32+idx] = pack4x8snorm(vec4(local_a[idx]/scale)); - } - // 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f. - scales[workgroup_idx] = scale/127; - )MAIN_FN"; - return Status::OK(); + return WGSL_TEMPLATE_APPLY(shader, "quantization/dp4a_quantize.wgsl.template"); } Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { @@ -144,290 +25,11 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AddInput("zero_points", ShaderUsage::UseUniform); } shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseElementTypeAlias); - - // This shader implements co-operative matrix multiply. The key idea here is to - // assume there is a primitive for medium size matrix multiply a subgroup can perform, - // using all its lanes and pooling all its registers to keep the values in registry. - // - // The entire workgroup which has N subgroups first loads a tile into shared memory, - // Then each subgroup loads a subtile from shared memory into registers and uses - // the medium size matrix multiply primitive to perform the math. - // The values for tile/subtile size are chosen to conform to the resource limits - // of an alderlake/tiger lake gpu. A tile is 64x64, workgroup is 256 threads - - // therefore there are 16 subgroups and 16 lanes in each subgroup. - // K the hidden dimension is paged in from RAM at k tile size which is 64. - // All this puts the shared memory requirement slightly above 16KB. - // WebGPU limit is 16KB, output is moved to registers instead of SHM to make - // everything fit in shared memory. - // - // Each subgroup performs a 16 x 64 x 16 multiply which is implemented with - // subgroup shuffle as a placeholder for the day the medium matrix mul primitive - // becomes available in WGSL. The registry requirements is ~2KB per subgroup, on - // Alderlake/Tigerlake subgroup has 8KB of registry space pooling the - // 512B of registry from each lane. - // - // The medium size matmul is implemented using dot4I8Packed, so the inputs for - // this shader require A to be int8 quantized with block size 64. B is regular - // matmulnbits input with block size 32. - - shader.AdditionalImplementation() << CommonFunctions(nbits_, has_zero_points_) - << " const block_size = " << block_size_ << ";"; - - shader.AdditionalImplementation() << R"ADDNL_FN( - const tile_size = 64; - const subtile_size = 16; - const tile_size_k = 32; - const vec_factor = 4; - const u32_factor = 4; - const tile_size_k_vec = 2; - - // Shared memory - var tile_A : array, tile_size>, tile_size_k_vec>; // 64 x 32 - var scale_A : array; // 64 x 1 - var tile_B : array, tile_size>, tile_size_k_vec>; // 64 x 32 - var scale_B : array; // 64 x 1 - )ADDNL_FN"; - if (nbits_ == 8 && has_zero_points_) { - shader.AdditionalImplementation() << " var zeroes : array;"; - } - shader.AdditionalImplementation() << R"ADDNL_FN( - fn loadSHMA(a_global_base:u32, kidx_v:u32, row: u32, col: u32) - { - let a_global = a_global_base + row; - if (a_global >= uniforms.M) - { - return; - } - tile_A[col][row] = input_a[a_global*uniforms.K16+kidx_v+col]; - if (col == 0) - { - // kidx_v - covers 16 values of k - scale_A[row] = scales_a[a_global*(uniforms.K/128) + kidx_v/8]; - } - } - )ADDNL_FN"; - if (nbits_ == 4) { - shader.AdditionalImplementation() << R"ADDNL_FN( - fn loadSHMB(b_global_base:u32, kidx_v:u32, row: u32, col: u32) - { - let b_global = b_global_base + row; - if (b_global >= uniforms.N) - { - return; - } - - let b_value = input_b[b_global*uniforms.K16+kidx_v+col]; - let block_idx = kidx_v/(block_size/16); - let zero = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col); - tile_B[col][row] = DequantizedFrom4BitsTo8Bits(b_value, zero); - if (col == 0) - { - // kidx_v - each kidx_v covers 16 values of k - scale_B[row] = scales_b[b_global*(uniforms.K/block_size) + block_idx]; - } - } - )ADDNL_FN"; - } else { - ORT_ENFORCE(nbits_ == 8, "Only 4/8 bits are supported for webgpu matmulnbits"); - shader.AdditionalImplementation() << R"ADDNL_FN( - fn loadSHMB(b_global_base:u32, kidx_v:u32, row: u32, col: u32) - { - let b_global = b_global_base + row; - if (b_global >= uniforms.N) - { - return; - } - - let b_value = input_b[b_global*uniforms.K16+kidx_v+col]; - tile_B[col][row] = AlignWithZeroPoint(b_value); - if (col == 0) - { - // kidx_v - each kidx_v covers 16 values of k - let block_idx = kidx_v/(block_size/16); - scale_B[row] = scales_b[b_global*(uniforms.K/block_size) + block_idx]; - )ADDNL_FN"; - if (has_zero_points_) { - shader.AdditionalImplementation() << " zeroes[row] = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col);\n"; - } - shader.AdditionalImplementation() << R"ADDNL_FN( - } - } - )ADDNL_FN"; - } - - shader.MainFunctionBody() << R"MAIN_FN( - // During the load phase we use all 256 threads to load 64 rows of A/B. - // For each row we load tile_size_k_vec (2) vectorized elements, which are 32 elements of K. - let a_global_base = u32(workgroup_idx / uniforms.num_N_tile) * tile_size; - let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size; - let load_AorB = u32(local_idx/128); - let load_row = u32((local_idx%128)/2); - let load_col = u32(local_idx%2); - - // During the compute phase, we have the 64x64 tile split into - // subtiles of 16x16. We have a grid of 4x4 subtiles. - let subtile_id = u32(local_idx / subtile_size); - let subtile_idx = u32(subtile_id / 4); - let subtile_idy = u32(subtile_id % 4); - let base_A = subtile_idx * 16; - let base_B = subtile_idy * 16; - // For each subtile we have 16 threads assigned. - let a_idx = u32(local_idx % subtile_size); - - var lane_output1: vec4; - var lane_output2: vec4; - var lane_output3: vec4; - var lane_output4: vec4; - // K's vectrorization is 16 items per index. See input_a/input_b. - // tile_size_k_vec - is the k tile size in vectorized space (1/16). That is - // k tile size is 32. In vectorized space that is 32/16 = 2. - for (var kidx_v:u32 = 0; kidx_v < uniforms.K16; kidx_v+=tile_size_k_vec) - { - // Load Phase: Populate shared memory for the workgroup. - if (load_AorB == 0) - { - loadSHMA(a_global_base, kidx_v, load_row, load_col); - } - else - { - loadSHMB(b_global_base, kidx_v, load_row, load_col); - } - workgroupBarrier(); - - // Compute phase: Perform matmul for this subtile 16 x 32 x 16. - // Step 1: Load from shared memory into registers across entire subgroup. - var own_a0: vec4 = tile_A[0][base_A + a_idx]; - var own_a1: vec4 = tile_A[1][base_A + a_idx]; - var own_scale_a: output_element_t = scale_A[base_A + a_idx]; - )MAIN_FN"; - if (nbits_ == 8 && has_zero_points_) { - shader.MainFunctionBody() << R"MAIN_FN( - if (sg_size == 16) - { - var own_b0: vec4 = tile_B[0][base_B + sg_id]; - var own_b1: vec4 = tile_B[1][base_B + sg_id]; - var own_scale_b: output_element_t = scale_B[base_B + sg_id]; - var zero = zeroes[base_B + sg_id]; - // Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul. - lane_output1[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 0), own_a1, subgroupShuffle(own_b1, 0), subgroupShuffle(own_scale_b, 0) * own_scale_a, subgroupShuffle(zero, 0)); - lane_output1[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 1), own_a1, subgroupShuffle(own_b1, 1), subgroupShuffle(own_scale_b, 1) * own_scale_a, subgroupShuffle(zero, 1)); - lane_output1[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 2), own_a1, subgroupShuffle(own_b1, 2), subgroupShuffle(own_scale_b, 2) * own_scale_a, subgroupShuffle(zero, 2)); - lane_output1[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 3), own_a1, subgroupShuffle(own_b1, 3), subgroupShuffle(own_scale_b, 3) * own_scale_a, subgroupShuffle(zero, 3)); - - lane_output2[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 4), own_a1, subgroupShuffle(own_b1, 4), subgroupShuffle(own_scale_b, 4) * own_scale_a, subgroupShuffle(zero, 4)); - lane_output2[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 5), own_a1, subgroupShuffle(own_b1, 5), subgroupShuffle(own_scale_b, 5) * own_scale_a, subgroupShuffle(zero, 5)); - lane_output2[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 6), own_a1, subgroupShuffle(own_b1, 6), subgroupShuffle(own_scale_b, 6) * own_scale_a, subgroupShuffle(zero, 6)); - lane_output2[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 7), own_a1, subgroupShuffle(own_b1, 7), subgroupShuffle(own_scale_b, 7) * own_scale_a, subgroupShuffle(zero, 7)); - - lane_output3[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 8), own_a1, subgroupShuffle(own_b1, 8), subgroupShuffle(own_scale_b, 8) * own_scale_a, subgroupShuffle(zero, 8)); - lane_output3[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 9), own_a1, subgroupShuffle(own_b1, 9), subgroupShuffle(own_scale_b, 9) * own_scale_a, subgroupShuffle(zero, 9)); - lane_output3[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 10), own_a1, subgroupShuffle(own_b1, 10), subgroupShuffle(own_scale_b, 10) * own_scale_a, subgroupShuffle(zero, 10)); - lane_output3[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 11), own_a1, subgroupShuffle(own_b1, 11), subgroupShuffle(own_scale_b, 11) * own_scale_a, subgroupShuffle(zero, 11)); - - lane_output4[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 12), own_a1, subgroupShuffle(own_b1, 12), subgroupShuffle(own_scale_b, 12) * own_scale_a, subgroupShuffle(zero, 12)); - lane_output4[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 13), own_a1, subgroupShuffle(own_b1, 13), subgroupShuffle(own_scale_b, 13) * own_scale_a, subgroupShuffle(zero, 13)); - lane_output4[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 14), own_a1, subgroupShuffle(own_b1, 14), subgroupShuffle(own_scale_b, 14) * own_scale_a, subgroupShuffle(zero, 14)); - lane_output4[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 15), own_a1, subgroupShuffle(own_b1, 15), subgroupShuffle(own_scale_b, 15) * own_scale_a, subgroupShuffle(zero, 15)); - } - else - { - // Code for other subgroup sizes, simply doesnt use subgroups at all. - // Relies on reads from single location tile_B[][base_B + col] by all - // being optimized by the hardware. - lane_output1[0] += SDP8AI(own_a0, tile_B[0][base_B + 0], own_a1, tile_B[1][base_B + 0], own_scale_a * scale_B[base_B + 0], zeroes[base_B + 0]); - lane_output1[1] += SDP8AI(own_a0, tile_B[0][base_B + 1], own_a1, tile_B[1][base_B + 1], own_scale_a * scale_B[base_B + 1], zeroes[base_B + 1]); - lane_output1[2] += SDP8AI(own_a0, tile_B[0][base_B + 2], own_a1, tile_B[1][base_B + 2], own_scale_a * scale_B[base_B + 2], zeroes[base_B + 2]); - lane_output1[3] += SDP8AI(own_a0, tile_B[0][base_B + 3], own_a1, tile_B[1][base_B + 3], own_scale_a * scale_B[base_B + 3], zeroes[base_B + 3]); - - lane_output2[0] += SDP8AI(own_a0, tile_B[0][base_B + 4], own_a1, tile_B[1][base_B + 4], own_scale_a * scale_B[base_B + 4], zeroes[base_B + 4]); - lane_output2[1] += SDP8AI(own_a0, tile_B[0][base_B + 5], own_a1, tile_B[1][base_B + 5], own_scale_a * scale_B[base_B + 5], zeroes[base_B + 5]); - lane_output2[2] += SDP8AI(own_a0, tile_B[0][base_B + 6], own_a1, tile_B[1][base_B + 6], own_scale_a * scale_B[base_B + 6], zeroes[base_B + 6]); - lane_output2[3] += SDP8AI(own_a0, tile_B[0][base_B + 7], own_a1, tile_B[1][base_B + 7], own_scale_a * scale_B[base_B + 7], zeroes[base_B + 7]); - - lane_output3[0] += SDP8AI(own_a0, tile_B[0][base_B + 8], own_a1, tile_B[1][base_B + 8], own_scale_a * scale_B[base_B + 8], zeroes[base_B + 8]); - lane_output3[1] += SDP8AI(own_a0, tile_B[0][base_B + 9], own_a1, tile_B[1][base_B + 9], own_scale_a * scale_B[base_B + 9], zeroes[base_B + 9]); - lane_output3[2] += SDP8AI(own_a0, tile_B[0][base_B + 10], own_a1, tile_B[1][base_B + 10], own_scale_a * scale_B[base_B + 10], zeroes[base_B + 10]); - lane_output3[3] += SDP8AI(own_a0, tile_B[0][base_B + 11], own_a1, tile_B[1][base_B + 11], own_scale_a * scale_B[base_B + 11], zeroes[base_B + 11]); - - lane_output4[0] += SDP8AI(own_a0, tile_B[0][base_B + 12], own_a1, tile_B[1][base_B + 12], own_scale_a * scale_B[base_B + 12], zeroes[base_B + 12]); - lane_output4[1] += SDP8AI(own_a0, tile_B[0][base_B + 13], own_a1, tile_B[1][base_B + 13], own_scale_a * scale_B[base_B + 13], zeroes[base_B + 13]); - lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14], zeroes[base_B + 14]); - lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15], zeroes[base_B + 15]); - } - )MAIN_FN"; - } else { - shader.MainFunctionBody() << R"MAIN_FN( - if (sg_size == 16) - { - var own_b0: vec4 = tile_B[0][base_B + sg_id]; - var own_b1: vec4 = tile_B[1][base_B + sg_id]; - var own_scale_b: output_element_t = scale_B[base_B + sg_id]; - // Step 2: Access registers across the subgroup using subgroupShuffle and perform the matmul. - lane_output1[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 0), own_a1, subgroupShuffle(own_b1, 0), subgroupShuffle(own_scale_b, 0) * own_scale_a); - lane_output1[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 1), own_a1, subgroupShuffle(own_b1, 1), subgroupShuffle(own_scale_b, 1) * own_scale_a); - lane_output1[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 2), own_a1, subgroupShuffle(own_b1, 2), subgroupShuffle(own_scale_b, 2) * own_scale_a); - lane_output1[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 3), own_a1, subgroupShuffle(own_b1, 3), subgroupShuffle(own_scale_b, 3) * own_scale_a); - - lane_output2[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 4), own_a1, subgroupShuffle(own_b1, 4), subgroupShuffle(own_scale_b, 4) * own_scale_a); - lane_output2[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 5), own_a1, subgroupShuffle(own_b1, 5), subgroupShuffle(own_scale_b, 5) * own_scale_a); - lane_output2[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 6), own_a1, subgroupShuffle(own_b1, 6), subgroupShuffle(own_scale_b, 6) * own_scale_a); - lane_output2[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 7), own_a1, subgroupShuffle(own_b1, 7), subgroupShuffle(own_scale_b, 7) * own_scale_a); - - lane_output3[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 8), own_a1, subgroupShuffle(own_b1, 8), subgroupShuffle(own_scale_b, 8) * own_scale_a); - lane_output3[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 9), own_a1, subgroupShuffle(own_b1, 9), subgroupShuffle(own_scale_b, 9) * own_scale_a); - lane_output3[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 10), own_a1, subgroupShuffle(own_b1, 10), subgroupShuffle(own_scale_b, 10) * own_scale_a); - lane_output3[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 11), own_a1, subgroupShuffle(own_b1, 11), subgroupShuffle(own_scale_b, 11) * own_scale_a); - - lane_output4[0] += SDP8AI(own_a0, subgroupShuffle(own_b0, 12), own_a1, subgroupShuffle(own_b1, 12), subgroupShuffle(own_scale_b, 12) * own_scale_a); - lane_output4[1] += SDP8AI(own_a0, subgroupShuffle(own_b0, 13), own_a1, subgroupShuffle(own_b1, 13), subgroupShuffle(own_scale_b, 13) * own_scale_a); - lane_output4[2] += SDP8AI(own_a0, subgroupShuffle(own_b0, 14), own_a1, subgroupShuffle(own_b1, 14), subgroupShuffle(own_scale_b, 14) * own_scale_a); - lane_output4[3] += SDP8AI(own_a0, subgroupShuffle(own_b0, 15), own_a1, subgroupShuffle(own_b1, 15), subgroupShuffle(own_scale_b, 15) * own_scale_a); - } - else - { - // Code for other subgroup sizes, simply doesnt use subgroups at all. - // Relies on reads from single location tile_B[][base_B + col] by all - // being optimized by the hardware. - lane_output1[0] += SDP8AI(own_a0, tile_B[0][base_B + 0], own_a1, tile_B[1][base_B + 0], own_scale_a * scale_B[base_B + 0]); - lane_output1[1] += SDP8AI(own_a0, tile_B[0][base_B + 1], own_a1, tile_B[1][base_B + 1], own_scale_a * scale_B[base_B + 1]); - lane_output1[2] += SDP8AI(own_a0, tile_B[0][base_B + 2], own_a1, tile_B[1][base_B + 2], own_scale_a * scale_B[base_B + 2]); - lane_output1[3] += SDP8AI(own_a0, tile_B[0][base_B + 3], own_a1, tile_B[1][base_B + 3], own_scale_a * scale_B[base_B + 3]); - - lane_output2[0] += SDP8AI(own_a0, tile_B[0][base_B + 4], own_a1, tile_B[1][base_B + 4], own_scale_a * scale_B[base_B + 4]); - lane_output2[1] += SDP8AI(own_a0, tile_B[0][base_B + 5], own_a1, tile_B[1][base_B + 5], own_scale_a * scale_B[base_B + 5]); - lane_output2[2] += SDP8AI(own_a0, tile_B[0][base_B + 6], own_a1, tile_B[1][base_B + 6], own_scale_a * scale_B[base_B + 6]); - lane_output2[3] += SDP8AI(own_a0, tile_B[0][base_B + 7], own_a1, tile_B[1][base_B + 7], own_scale_a * scale_B[base_B + 7]); - - lane_output3[0] += SDP8AI(own_a0, tile_B[0][base_B + 8], own_a1, tile_B[1][base_B + 8], own_scale_a * scale_B[base_B + 8]); - lane_output3[1] += SDP8AI(own_a0, tile_B[0][base_B + 9], own_a1, tile_B[1][base_B + 9], own_scale_a * scale_B[base_B + 9]); - lane_output3[2] += SDP8AI(own_a0, tile_B[0][base_B + 10], own_a1, tile_B[1][base_B + 10], own_scale_a * scale_B[base_B + 10]); - lane_output3[3] += SDP8AI(own_a0, tile_B[0][base_B + 11], own_a1, tile_B[1][base_B + 11], own_scale_a * scale_B[base_B + 11]); - - lane_output4[0] += SDP8AI(own_a0, tile_B[0][base_B + 12], own_a1, tile_B[1][base_B + 12], own_scale_a * scale_B[base_B + 12]); - lane_output4[1] += SDP8AI(own_a0, tile_B[0][base_B + 13], own_a1, tile_B[1][base_B + 13], own_scale_a * scale_B[base_B + 13]); - lane_output4[2] += SDP8AI(own_a0, tile_B[0][base_B + 14], own_a1, tile_B[1][base_B + 14], own_scale_a * scale_B[base_B + 14]); - lane_output4[3] += SDP8AI(own_a0, tile_B[0][base_B + 15], own_a1, tile_B[1][base_B + 15], own_scale_a * scale_B[base_B + 15]); - } - )MAIN_FN"; - } - shader.MainFunctionBody() << R"MAIN_FN( - workgroupBarrier(); - } - - let a_global = a_global_base + base_A + a_idx; - let b_global = b_global_base + base_B; - let output_idx = ((a_global) * uniforms.N + b_global)/4; - // This creates a shader requirement that uniforms.N % 16 == 0 - if (a_global < uniforms.M && b_global < uniforms.N) - { - output[output_idx] = lane_output1; - output[output_idx+1] = lane_output2; - output[output_idx+2] = lane_output3; - output[output_idx+3] = lane_output4; - } - )MAIN_FN"; - - return Status::OK(); + return WGSL_TEMPLATE_APPLY(shader, "quantization/dp4a_matmul.wgsl.template", + WGSL_TEMPLATE_PARAMETER(block_size, block_size_), + WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points_), + WGSL_TEMPLATE_PARAMETER(n_bits, nbits_), + WGSL_TEMPLATE_PARAMETER(output_type_i32, true)); } // scale_A components = 1, b components = 4, output components = 1 @@ -445,123 +47,13 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co const uint32_t sub_tile_count = WorkgroupSizeX() / tile_size_k_vec_; ORT_ENFORCE(tile_size_ % sub_tile_count == 0, "tile_size_ must be divisible by sub_tile_count"); - // This algorithm works to compute dot product of k parallelly, by processing k at each step amongst tile_size_k_vec threads, - // and utilizing the remaining threads in the workgroup to process additional rows of b in parallel (such that the values in shared memory for A can be reused). - // For each load of k, the tile_size_k_vec threads also reload B tile_size/num_concurrent_b_rows times to compute partial dot products of other B rows - // in order to complete all tile_size b rows in this workgroup and also reusing the loaded in register values of a. - - // 1. Each workgroup handles tile_size_k_vec * k_vectorization_in_b (32) columns and num_concurrent_b_rows of matrix B at a time, - // iterating over the columns to compute a partial dot product. - // 2. Uses vec4 vectorization where each K represents 32 elements of matrix B - - // 1. Workgroup Responsibility: - // - Processes one row of matrix A - // - Handles tile_size rows of matrix B - // - // 2. Computation Process: - // - Reads [tile_size][tile_size_k_vec] block of B data at a time - // - Each thread within workgroup computes dot products of 32 A*B elements since each K represents 32 elements of matrix B - // - Stores intermediate results in shared memory (inter_results) - // - Iterates through columns accumulating results in inter_results - // - Performs final reduction sum in inter_results for output - shader.AdditionalImplementation() << " const tile_size = " << tile_size_ << "u;\n" - << " const tile_size_k_vec = " << tile_size_k_vec_ << "u;\n" - << " const double_tile_size_k_vec = " << 2 * tile_size_k_vec_ << "u;\n" - // sub_tile_count is the number of concurrent b rows processed by the workgroup. - << " const sub_tile_count = " << sub_tile_count << "u;\n"; - - shader.AdditionalImplementation() << CommonFunctions(nbits_, has_zero_points_) - << R"ADDNL_FN( - var inter_results: array, tile_size>; - // Need 2 * tile_size_k_vec to store a tile_A since b is quantized as 4 bits and a is quantized as 8 bits. - var tile_A : array, double_tile_size_k_vec>; - // double_tile_size_k_vec * 16 / 128 - const scale_a_size_in_tile_a = double_tile_size_k_vec / 8; - var scale_A : array; - fn loadSHMA(a_global: u32, kidx_v: u32, col: u32) - { - let k_offset = kidx_v + col; - if (k_offset >= uniforms.K16) { - return; - } - - tile_A[col] = input_a[a_global*uniforms.K16+k_offset]; - if (col < scale_a_size_in_tile_a) - { - // kidx_v - covers 16 values of k in input_a - scale_A[col] = scales_a[a_global*(uniforms.K/128) + kidx_v/8 + col]; - } - } - )ADDNL_FN"; - - shader.MainFunctionBody() << R"MAIN_FN( - let a_global = u32(workgroup_idx / uniforms.num_N_tile); - let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size; - // Handle each workgroup threads as a block of [sub_tile_count][tile_size_k_vec] - let local_col = local_idx % tile_size_k_vec; - let local_row = local_idx / tile_size_k_vec; - for (var kidx_v:u32 = 0; kidx_v < uniforms.K32; kidx_v += tile_size_k_vec) - { - // Load Phase: Populate shared memory for the workgroup. - if (local_idx < double_tile_size_k_vec) - { - loadSHMA(a_global, kidx_v * 2, local_idx); - } - workgroupBarrier(); - var own_a: vec4 = tile_A[local_col * 2]; - var own_a1: vec4 = tile_A[local_col * 2 + 1]; - var own_scale_a = scale_A[local_col / 4]; - let k_offset = kidx_v + local_col; - // k_offset - covers 32 values of k in input_b - let block_idx = k_offset * 32 / uniforms.block_size; - // calculate intermediate results into inter_results. - for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) { - let b_global = b_global_base + row_offset + local_row; - if (b_global < uniforms.N && k_offset < uniforms.K32) - { - let b_offset = b_global * uniforms.K32 + k_offset; - let zero = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col); - let own_scale_b = scales_b[b_global * uniforms.K / uniforms.block_size + block_idx]; - )MAIN_FN"; - if (nbits_ == 4) { - shader.MainFunctionBody() << R"MAIN_FN( - let b_value = input_b[b_offset]; - let own_b = DequantizedFrom4BitsTo8Bits(b_value.xy, zero); - let own_b1 = DequantizedFrom4BitsTo8Bits(b_value.zw, zero); - inter_results[row_offset + local_row][local_col] += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b); - )MAIN_FN"; - } else { - shader.MainFunctionBody() << R"MAIN_FN( - let own_b = AlignWithZeroPoint(input_b[b_offset * 2]); - let own_b1 = AlignWithZeroPoint(input_b[b_offset * 2 + 1]); - )MAIN_FN"; - if (has_zero_points_) { - shader.MainFunctionBody() << " inter_results[row_offset + local_row][local_col] += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b, zero);\n"; - } else { - shader.MainFunctionBody() << " inter_results[row_offset + local_row][local_col] += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b);\n"; - } - } - shader.MainFunctionBody() << R"MAIN_FN( - } - } - workgroupBarrier(); - } - - if (local_idx < tile_size) { - // Do reduce sum to get final output. - var output_value = output_element_t(0); - for (var b = 0u; b < tile_size_k_vec; b++) { - output_value += inter_results[local_idx][b]; - } - let b_global = b_global_base + local_idx; - let output_idx = a_global * uniforms.N + b_global; - if (b_global < uniforms.N) { - output[output_idx] = output_value; - } - } - )MAIN_FN"; - - return Status::OK(); + return WGSL_TEMPLATE_APPLY(shader, "quantization/dp4a_matmul_small_m.wgsl.template", + WGSL_TEMPLATE_PARAMETER(has_zero_points, has_zero_points_), + WGSL_TEMPLATE_PARAMETER(n_bits, nbits_), + WGSL_TEMPLATE_PARAMETER(output_type_i32, true), + WGSL_TEMPLATE_PARAMETER(sub_tile_count, sub_tile_count), + WGSL_TEMPLATE_PARAMETER(tile_size, tile_size_), + WGSL_TEMPLATE_PARAMETER(tile_size_k_vec, tile_size_k_vec_)); } Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor* scales, diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_small_m.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_small_m.wgsl.template new file mode 100644 index 0000000000000..87640aa5adb1b --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_matmul_small_m.wgsl.template @@ -0,0 +1,117 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#param tile_size +#param tile_size_k_vec +#param sub_tile_count +#param n_bits +#param has_zero_points + +#include "quantization/dp4a_matmul_common.wgsl.template" + +// This algorithm works to compute dot product of k in parallel, by processing k at each step amongst tile_size_k_vec threads, +// and utilizing the remaining threads in the workgroup to process additional rows of b in parallel (such that the values in shared memory for A can be reused). +// For each load of k, the tile_size_k_vec threads also reload B tile_size/num_concurrent_b_rows times to compute partial dot products of other B rows +// in order to complete all tile_size b rows in this workgroup and also reusing the loaded in register values of a. + +// 1. Each workgroup handles tile_size_k_vec * k_vectorization_in_b (32) columns and num_concurrent_b_rows of matrix B at a time, +// iterating over the columns to compute a partial dot product. +// 2. Uses vec4 vectorization where each K represents 32 elements of matrix B + +// 1. Workgroup Responsibility: +// - Processes one row of matrix A +// - Handles tile_size rows of matrix B +// +// 2. Computation Process: +// - Reads [tile_size][tile_size_k_vec] block of B data at a time +// - Each thread within workgroup computes dot products of 32 A*B elements since each K represents 32 elements of matrix B +// - Stores intermediate results in shared memory (inter_results) +// - Iterates through columns accumulating results in inter_results +// - Performs final reduction sum in inter_results for output +// sub_tile_count is the number of concurrent b rows processed by the workgroup. + +const double_tile_size_k_vec = 2 * tile_size_k_vec; + +var inter_results: array, tile_size>; +// Need 2 * tile_size_k_vec to store a tile_A since b is quantized as 4 bits and a is quantized as 8 bits. +var tile_A : array, double_tile_size_k_vec>; +// double_tile_size_k_vec * 16 / 128 +const scale_a_size_in_tile_a = double_tile_size_k_vec / 8; +var scale_A : array; + +fn loadSHMA(a_global: u32, kidx_v: u32, col: u32) +{ + let k_offset = kidx_v + col; + if (k_offset >= uniforms.K16) { + return; + } + + tile_A[col] = input_a[a_global*uniforms.K16+k_offset]; + if (col < scale_a_size_in_tile_a) + { + // kidx_v - covers 16 values of k in input_a + scale_A[col] = scales_a[a_global*(uniforms.K/128) + kidx_v/8 + col]; + } +} + +$MAIN { + let a_global = u32(workgroup_idx / uniforms.num_N_tile); + let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size; + // Handle each workgroup threads as a block of [sub_tile_count][tile_size_k_vec] + let local_col = local_idx % tile_size_k_vec; + let local_row = local_idx / tile_size_k_vec; + for (var kidx_v:u32 = 0; kidx_v < uniforms.K32; kidx_v += tile_size_k_vec) + { + // Load Phase: Populate shared memory for the workgroup. + if (local_idx < double_tile_size_k_vec) + { + loadSHMA(a_global, kidx_v * 2, local_idx); + } + workgroupBarrier(); + var own_a: vec4 = tile_A[local_col * 2]; + var own_a1: vec4 = tile_A[local_col * 2 + 1]; + var own_scale_a = scale_A[local_col / 4]; + let k_offset = kidx_v + local_col; + // k_offset - covers 32 values of k in input_b + let block_idx = k_offset * 32 / uniforms.block_size; + // calculate intermediate results into inter_results. + for (var row_offset = 0u; row_offset < tile_size; row_offset += sub_tile_count) { + let b_global = b_global_base + row_offset + local_row; + if (b_global < uniforms.N && k_offset < uniforms.K32) + { + let b_offset = b_global * uniforms.K32 + k_offset; + let zero = mm_read_zero(b_global, block_idx, uniforms.N, uniforms.zero_blocks_per_col); + let own_scale_b = scales_b[b_global * uniforms.K / uniforms.block_size + block_idx]; +#if n_bits == 4 + let b_value = input_b[b_offset]; + let own_b = DequantizedFrom4BitsTo8Bits(b_value.xy, zero); + let own_b1 = DequantizedFrom4BitsTo8Bits(b_value.zw, zero); + inter_results[row_offset + local_row][local_col] += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b); +#else + let own_b = AlignWithZeroPoint(input_b[b_offset * 2]); + let own_b1 = AlignWithZeroPoint(input_b[b_offset * 2 + 1]); +#if has_zero_points + inter_results[row_offset + local_row][local_col] += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b, zero); +#else + inter_results[row_offset + local_row][local_col] += SDP8AI(own_a, own_b, own_a1, own_b1, own_scale_a * own_scale_b); +#endif + +#endif + } + } + workgroupBarrier(); + } + + if (local_idx < tile_size) { + // Do reduce sum to get final output. + var output_value = output_element_t(0); + for (var b = 0u; b < tile_size_k_vec; b++) { + output_value += inter_results[local_idx][b]; + } + let b_global = b_global_base + local_idx; + let output_idx = a_global * uniforms.N + b_global; + if (b_global < uniforms.N) { + output[output_idx] = output_value; + } + } +} // MAIN diff --git a/onnxruntime/contrib_ops/webgpu/quantization/dp4a_quantize.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_quantize.wgsl.template new file mode 100644 index 0000000000000..3f764f8a602b4 --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/dp4a_quantize.wgsl.template @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +$MAIN { + var local_a : array, 32>; + var max_value:vec4 = vec4(0); + for (var idx:u32=0;idx<32;idx+=1) + { + local_a[idx] = input_a[workgroup_idx*32 + idx]; + max_value = max(max_value, abs(local_a[idx])); + } + var scale = max(max_value.x, max_value.y); + scale = max(scale, max_value.z); + scale = max(scale, max_value.w); + for (var idx:u32=0;idx<32;idx+=1) + { + output[workgroup_idx*32+idx] = pack4x8snorm(vec4(local_a[idx]/scale)); + } + // 127 is the max value of signed int8 [-127,127] used by pack4x8snorm for 1.0f. + scales[workgroup_idx] = scale/127; +} // MAIN diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_wide_tile.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_wide_tile.wgsl.template index e030f00c084e9..462f9a340c1b8 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_wide_tile.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_wide_tile.wgsl.template @@ -1,3 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. #param has_zero_points #param nbits diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_zero_pt.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_zero_pt.wgsl.template new file mode 100644 index 0000000000000..0da5bd09609af --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits_zero_pt.wgsl.template @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#param n_bits +#param has_zero_points +#param output_type_i32 + +#if output_type_i32 + alias output_type = i32; +#else + alias output_type = output_element_t; +#endif + +#if n_bits == 4 + const default_zero_point = 8; + const bit_mask = 0xFu; +#elif n_bits == 8 + const default_zero_point = 128; + const bit_mask = 0xFFu; +#endif + +#if has_zero_points + const elements_in_uint32:u32 = 32u / n_bits; + fn mm_read_zero(row : u32, col : u32, r_dim: u32, c_dim: u32) -> output_type + { + if (row < r_dim && col < c_dim) { + let offset = row * c_dim + col; + + // u32 holds elements_in_uint32 packed nbits. + let array_index = offset / elements_in_uint32; + let component_index = offset % elements_in_uint32; + let packed_value = zero_points[array_index]; + + // Extract the nbits component + let shift_amount = component_index * n_bits; + + let masked_value = (packed_value >> shift_amount) & bit_mask; + return output_type(masked_value); + } + return output_type(0); + } +#else + fn mm_read_zero(row : u32, col : u32, r_dim: u32, c_dim: u32) -> output_type { + return output_type(default_zero_point); + } +#endif diff --git a/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc b/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc index 8caa67f266266..4efaec325292a 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/weight_bias_quantization.cc @@ -13,6 +13,39 @@ namespace onnxruntime { +/** + * Checks whether or not the output path from a given node leads to a QuantizeLinear op, optionally, with no + * branching ReLU or Clip op in between. See also: NodeGroupSelector::GetQDQSelection() in qdq_selectors.cc. + * + * @param node The starting node to check the output path from. + * @param graph The graph containing the nodes. + * + * @return true if the path exist, false otherwise. + */ +static bool IsNoBranchPathToQuantizeLinear(const Node& node, const Graph& graph) { + const Node* current = &node; + while (true) { + // Conv / ConvTranspose / Gemm produces single output + if (current->OutputDefs().size() != 1) { + return false; + } + const std::vector& consumers = graph.GetConsumerNodes(current->OutputDefs()[0]->Name()); + // Branching or no consumer: not eligible + if (consumers.size() != 1) { + return false; + } + const Node* consumer = consumers[0]; + if (consumer->OpType() == QDQ::QOpName) { + return true; + } + // Allow ReLU or Clip, see also: NodeGroupSelector::GetQDQSelection() in qdq_selectors.cc. + if (consumer->OpType() != "Relu" && consumer->OpType() != "Clip") { + return false; + } + current = consumer; + } +} + Status WeightBiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { const GraphViewer graph_viewer{graph}; @@ -43,11 +76,8 @@ Status WeightBiasQuantization::ApplyImpl(Graph& graph, bool& modified, int graph continue; } - // Require that the node's output is consumed by a single QuantizeLinear node. - // Otherwise, if only the inputs are quantized, but not the output, then this node group would not - // be considered a QDQ node unit anyway. - std::vector children_nodes = graph.GetConsumerNodes(node.OutputDefs()[0]->Name()); - if (children_nodes.size() != 1 || children_nodes[0]->OpType() != QDQ::QOpName) { + // Check if the output path leads to QuantizeLinear with optionally ReLU or Clip op in between. + if (!IsNoBranchPathToQuantizeLinear(node, graph)) { continue; } diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index e036c7764d041..2c82ca1c31b08 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -321,9 +321,9 @@ DataLayout CUDAExecutionProvider::GetPreferredLayout() const { return this->IsNHWCPreferred() ? DataLayout::NHWC : DataLayout::NCHW; } -std::optional CUDAExecutionProvider::ShouldConvertDataLayoutForOp(std::string_view node_domain, - std::string_view node_op_type, - DataLayout target_data_layout) const { +std::optional CUDAExecutionProvider::ShouldConvertDataLayoutForOp([[maybe_unused]] std::string_view node_domain, + [[maybe_unused]] std::string_view node_op_type, + [[maybe_unused]] DataLayout target_data_layout) const { #if defined(ENABLE_CUDA_NHWC_OPS) if (target_data_layout != DataLayout::NHWC) { return std::nullopt; diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index 94ace606ac75a..f0c46279c566a 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -202,7 +202,7 @@ Status BaseOpBuilder::ProcessInt64Tensors(QnnModelWrapper& qnn_model_wrapper, // Insert cast to int32 if input dtype is int64 if (input_tensorwrapper.GetTensorDataType() == QNN_DATATYPE_INT_64) { const Qnn_TensorType_t tensor_type = QNN_TENSOR_TYPE_NATIVE; - const std::string cast_output_name = input_names[i] + "_cast_int32"; + const std::string cast_output_name = utils::GetUniqueName(input_names[i], "_cast_int32"); if (!qnn_model_wrapper.IsQnnTensorWrapperExist(cast_output_name)) { Qnn_DataType_t qnn_data_type = QNN_DATATYPE_INT_32; const auto& input_i = node_unit.Inputs()[i]; @@ -295,8 +295,8 @@ Status BaseOpBuilder::ProcessOutputs(QnnModelWrapper& qnn_model_wrapper, } if (needs_int64_cast) { - std::string cast_node_name = output_name + "_cast_int64"; - std::string cast_input_name = output_name + "_cast_int64_aux"; + const std::string cast_node_name = utils::GetUniqueName(node_unit, "_cast_int64"); + const std::string cast_input_name = utils::GetUniqueName(output_name, "_cast_int64"); QnnQuantParamsWrapper quant_params = output_info.quant_param.Copy(); std::vector cast_output_shape = output_info.shape; @@ -309,10 +309,10 @@ Status BaseOpBuilder::ProcessOutputs(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_input_tensorwrapper)), "Failed to add tensor."); output_names.push_back(cast_input_name); // Store the cast node information for later addition - cast_node_info_vec.push_back({cast_node_name, cast_input_name, output_name}); + cast_node_info_vec.emplace_back(CastNodeInfo{cast_node_name, cast_input_name, output_name}); } else if (supported_qnn_data_type != output_info.qnn_data_type && is_graph_output && !do_op_validation) { - std::string cast_node_name = output_name + "_ort_qnn_ep_cast"; - std::string cast_input_name = output_name + "_ort_qnn_ep_aux"; + const std::string cast_node_name = utils::GetUniqueName(node_unit, "_cast"); + const std::string cast_input_name = utils::GetUniqueName(output_name, "_cast"); std::vector cast_output_shape = output_info.shape; QnnTensorWrapper cast_input_tensorwrapper(cast_input_name, QNN_TENSOR_TYPE_NATIVE, @@ -322,7 +322,7 @@ Status BaseOpBuilder::ProcessOutputs(QnnModelWrapper& qnn_model_wrapper, mem_type); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_input_tensorwrapper)), "Failed to add tensor."); output_names.push_back(cast_input_name); - cast_node_info_vec.push_back({cast_node_name, cast_input_name, output_name}); + cast_node_info_vec.emplace_back(CastNodeInfo{cast_node_name, cast_input_name, output_name}); } else { output_info.qnn_data_type = supported_qnn_data_type; output_names.push_back(output_name); @@ -336,9 +336,9 @@ Status BaseOpBuilder::ProcessOutputs(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor."); } - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit), + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit), QNN_OP_PACKAGE_NAME_QTI_AISW, - qnn_op_type, // Typically GetQnnOpType(), but can be overridden. + qnn_op_type, std::move(input_names), std::move(output_names), std::move(param_tensor_names), diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h index a7d993bc54642..83c226115aa84 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h @@ -238,7 +238,7 @@ class BaseOpBuilder : public IOpBuilder { } // Onnx Pads is [x1_begin, x2_begin, x1_end, x2_end], QNN requires [x1_begin, x1_end, x2_begin, x2_end] - void ReArranagePads(std::vector& pads) const { + void ReArrangePads(std::vector& pads) const { auto pads_size = pads.size(); auto middle_pos = pads_size / 2; std::vector first_half(pads.begin(), pads.begin() + middle_pos); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/cast_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/cast_op_builder.cc index 5acfcb859b63a..c5c2a3c0150f9 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/cast_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/cast_op_builder.cc @@ -68,7 +68,7 @@ Status CastOpBuilder::ProcessExtraInputForNotEqual(QnnModelWrapper& qnn_model_wr } // Build additional static input with value 0. - const std::string& input_name = utils::GetNodeName(node_unit) + "_notequal_zero"; + const std::string& input_name = utils::GetUniqueName(node_unit, "_notequal_zero"); Qnn_DataType_t qnn_data_type = QNN_DATATYPE_UNDEFINED; const auto* type_proto = input.node_arg.TypeAsProto(); @@ -84,7 +84,7 @@ Status CastOpBuilder::ProcessExtraInputForNotEqual(QnnModelWrapper& qnn_model_wr "Failed to add additional input tensor for QNN Cast node that will be replaced by NotEqual."); input_names.push_back(input_name); - LOGS(logger, VERBOSE) << "FP-to-Bool Cast node " << utils::GetNodeName(node_unit) << " is replaced by NotEqual."; + LOGS(logger, VERBOSE) << "FP-to-Bool Cast node " << node_unit.Name() << " is replaced by NotEqual."; return Status::OK(); } @@ -177,7 +177,7 @@ Status CastOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra const std::string qnn_op_type = IsFpToBoolCast(node_unit) ? QNN_OP_ELEMENT_WISE_NOT_EQUAL : GetQnnOpType(node_unit.OpType()); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit), + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit), QNN_OP_PACKAGE_NAME_QTI_AISW, qnn_op_type, std::move(input_names), diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc index b80d9db5d3560..541ca5ca7ab14 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc @@ -24,7 +24,6 @@ static Status GetOnnxConvType(const std::string& onnx_op_type, OnnxConvType& con } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Unsupported ONNX convolution op type: ", onnx_op_type.c_str()); } - return Status::OK(); } @@ -171,7 +170,7 @@ Status ConvOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, return ProcessConv2D3DInputs(qnn_model_wrapper, node_unit, logger, input_names, do_op_validation); } - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN Conv only supports 3D(rank 5), 2D (rank 4) or 1D (rank 3) inputs."); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN Conv only supports 3D (rank 5), 2D (rank 4) or 1D (rank 3) inputs."); } Status ConvOpBuilder::ProcessConv2D3DInputs(QnnModelWrapper& qnn_model_wrapper, @@ -199,7 +198,7 @@ Status ConvOpBuilder::ProcessConv2D3DInputs(QnnModelWrapper& qnn_model_wrapper, TensorInfo input_info = {}; ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[1], input_info)); - std::string actual_name = input_info.is_initializer ? input1_name : input1_name + "_ort_qnn_ep_transpose"; + std::string actual_name = input_info.is_initializer ? input1_name : utils::GetUniqueName(input1_name, "_transpose"); input_names.push_back(actual_name); std::vector actual_shape; @@ -309,8 +308,7 @@ Status ConvOpBuilder::ProcessConv2D3DInputs(QnnModelWrapper& qnn_model_wrapper, // Pop Conv weight. Insert Convert op after Weight input_names.pop_back(); - const std::string& conv_output_name = node_unit.Outputs()[0].node_arg.Name(); - std::string convert_output_name = weight_input_name + "_convert_" + conv_output_name; + std::string convert_output_name = utils::GetUniqueName(weight_input_name, "_convert"); ORT_RETURN_IF_ERROR(utils::InsertConvertOp(qnn_model_wrapper, weight_input_name, @@ -345,7 +343,7 @@ Status ConvOpBuilder::ProcessConv2D3DInputs(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[1], input1_info)); if (input0_info.quant_param.IsPerTensor(/*include_bw*/ true) && input1_info.quant_param.IsQuantized()) { - const std::string bias_name = qnn::utils::GetNodeName(node_unit) + "_implicit_bias_ort_qnn_ep"; + const std::string bias_name = qnn::utils::GetUniqueName(node_unit, "_implicit_bias"); std::vector bias_shape = {input1_info.shape[0]}; ORT_RETURN_IF_ERROR(AddZeroBiasInput(qnn_model_wrapper, input0_info.quant_param, input1_info.quant_param, std::move(bias_shape), bias_name, logger, input_names)); @@ -378,7 +376,7 @@ Status ConvOpBuilder::ProcessConv1DInputs(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[0], input0_info)); const std::string conv_input0_name = input0_info.is_initializer ? input0_name - : input0_name + "_ort_qnn_ep_reshape"; + : utils::GetUniqueName(input0_name, "_reshape"); input_names.push_back(conv_input0_name); if (!qnn_model_wrapper.IsQnnTensorWrapperExist(conv_input0_name)) { @@ -435,7 +433,7 @@ Status ConvOpBuilder::ProcessConv1DInputs(QnnModelWrapper& qnn_model_wrapper, TensorInfo input_info = {}; ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[1], input_info)); - std::string conv_weight_input_name = input_info.is_initializer ? input1_name : input1_name + "_ort_qnn_ep_transpose"; + std::string conv_weight_input_name = input_info.is_initializer ? input1_name : utils::GetUniqueName(input1_name, "_transpose"); input_names.push_back(conv_weight_input_name); // Create the shape after reshaping. @@ -460,7 +458,7 @@ Status ConvOpBuilder::ProcessConv1DInputs(QnnModelWrapper& qnn_model_wrapper, return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Unexpected convolution op type: ", node_unit.OpType().c_str()); } - const std::string reshape_output = input1_name + "_ort_qnn_ep_reshape"; + const std::string reshape_output = utils::GetUniqueName(input1_name, "_reshape"); std::vector unpacked_tensor; if (input_info.is_initializer) { // @@ -713,7 +711,7 @@ Status ConvOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra } } - ReArranagePads(pads); + ReArrangePads(pads); uint32_t pad_size = narrow(pads.size() / 2); QnnParamWrapper pad_amount_paramwrapper(node_unit.Index(), node_unit.Name(), QNN_OP_CONV_2D_PARAM_PAD_AMOUNT, {pad_size, 2}, std::move(pads)); @@ -770,11 +768,11 @@ Status ConvOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra output_shape[1], // W output_shape[2], // C }; - const std::string conv_output_name = output_name + "_ort_qnn_ep_conv2d"; + const std::string conv_output_name = utils::GetUniqueName(output_name, "_conv"); QnnTensorWrapper output_tensorwrapper(conv_output_name, QNN_TENSOR_TYPE_NATIVE, qnn_data_type, output_quantize_param.Copy(), std::vector(output_shape_2d)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit), + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit), QNN_OP_PACKAGE_NAME_QTI_AISW, output_node_type, std::move(input_names), @@ -799,7 +797,7 @@ Status ConvOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra QnnTensorWrapper output_tensorwrapper(output_name, tensor_type, qnn_data_type, std::move(output_quantize_param), std::move(output_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit), + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit), QNN_OP_PACKAGE_NAME_QTI_AISW, output_node_type, std::move(input_names), diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc index 51c38b4483cb9..a22a331b2453f 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc @@ -287,8 +287,8 @@ Status CreateMatMulTransposeAll( std::vector input_shape1(input_info1.shape); std::swap(input_shape0[1], input_shape0[2]); std::swap(input_shape1[1], input_shape1[2]); - const std::string input_transpos0 = input_names[0] + "_t0"; - const std::string input_transpos1 = input_names[1] + "_t1"; + const std::string input_transpos0 = onnxruntime::qnn::utils::GetUniqueName(input_names[0], "_transpose"); + const std::string input_transpos1 = onnxruntime::qnn::utils::GetUniqueName(input_names[1], "_transpose"); const std::vector transpose_perm{0, 2, 1, 3}; ORT_RETURN_IF_ERROR(qnn_model_wrapper->AddTransposeNode( /*node_index=*/node_unit.Index(), @@ -315,7 +315,7 @@ Status CreateMatMulTransposeAll( onnxruntime::qnn::TensorInfo matmul_output_info{}; const auto& output = node_unit.Outputs()[0]; ORT_RETURN_IF_ERROR(qnn_model_wrapper->GetTensorInfo(output, matmul_output_info)); - const std::string matmul_output_name = onnxruntime::qnn::utils::GetNodeName(node_unit) + "_matmul"; + const std::string matmul_output_name = onnxruntime::qnn::utils::GetUniqueName(node_unit, "_matmul"); std::vector matmul_output_shape(matmul_output_info.shape); std::swap(matmul_output_shape[1], matmul_output_shape[2]); onnxruntime::qnn::QnnTensorWrapper matmul_output_wrapper( @@ -325,7 +325,7 @@ Status CreateMatMulTransposeAll( node_unit.OpType() + " failed to add tensor."); std::vector param_tensor_names = SetMatMulParamTensorNames( qnn_model_wrapper, node_unit, /*transpose_in0=*/false, /*transpose_in1=*/false); - ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode(/*qnn_node_name=*/onnxruntime::qnn::utils::GetNodeName(node_unit), + ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode(/*qnn_node_name=*/onnxruntime::qnn::utils::GetUniqueName(node_unit, QNN_OP_MAT_MUL), /*package_name=*/QNN_OP_PACKAGE_NAME_QTI_AISW, /*qnn_node_type=*/QNN_OP_MAT_MUL, /*input_names=*/{input_transpos1, input_transpos0}, @@ -373,7 +373,7 @@ Status CreateReduceSumMulBroadcastX( ORT_RETURN_IF_NOT(shape_in0.size() == 4, "CreateReduceSumMulBroadcastX expects input 0 to be rank 4"); ORT_RETURN_IF_NOT(shape_in1.size() == 3, "CreateReduceSumMulBroadcastX expects input 1 to be rank 3"); const std::vector new_shape_in0{shape_in0[0], shape_in0[1], shape_in0[2], 1, shape_in0[3]}; - const std::string reshape_out_name = input_names[0] + "_reshaped"; + const std::string reshape_out_name = onnxruntime::qnn::utils::GetUniqueName(input_names[0], "_reshape"); ORT_RETURN_IF_ERROR(qnn_model_wrapper->AddReshapeNode( /*input_name=*/input_names[0], /*output_name=*/reshape_out_name, @@ -387,7 +387,7 @@ Status CreateReduceSumMulBroadcastX( // Multiply: reshaped in0 * in1 // The output shape of the multiplication is determined by broadcasting the reshaped in0 of // (b, h, w, 1, c) and in1 (w, k, c) along the matching axes, resulting in (b, h, w, k, c). - const std::string mul_out_name = onnxruntime::qnn::utils::GetNodeName(node_unit) + "_mul"; + const std::string mul_out_name = onnxruntime::qnn::utils::GetUniqueName(node_unit, "_mul"); std::vector shape_out_mul{new_shape_in0[0], new_shape_in0[1], new_shape_in0[2], shape_in1[1], new_shape_in0[4]}; onnxruntime::qnn::QnnTensorWrapper tensor_wrapper_mul(mul_out_name, QNN_TENSOR_TYPE_NATIVE, @@ -397,7 +397,7 @@ Status CreateReduceSumMulBroadcastX( ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(tensor_wrapper_mul)), "CreateReduceSumMulBroadcastX: failed to AddTensorWrapper"); ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode( - /*qnn_node_name=*/mul_out_name, + /*qnn_node_name=*/onnxruntime::qnn::utils::GetUniqueName(node_unit, QNN_OP_ELEMENT_WISE_MULTIPLY), /*package_name=*/QNN_OP_PACKAGE_NAME_QTI_AISW, /*qnn_node_type=*/QNN_OP_ELEMENT_WISE_MULTIPLY, /*input_names=*/{reshape_out_name, input_names[1]}, @@ -444,7 +444,7 @@ Status CreateReduceSumMulBroadcastX( "CreateReduceSumMulBroadcastX: failed to AddTensorWrapper"); ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode( - /*qnn_node_name=*/out_name, + /*qnn_node_name=*/onnxruntime::qnn::utils::GetUniqueName(node_unit, QNN_OP_REDUCE_SUM), /*package_name=*/QNN_OP_PACKAGE_NAME_QTI_AISW, /*qnn_node_type=*/QNN_OP_REDUCE_SUM, /*input_names=*/{mul_out_name}, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc index 477a2445d9369..69980b8f86dab 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/expand_op_builder.cc @@ -138,7 +138,7 @@ Status ExpandOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, } // if-else const std::string& output_name = node_unit.Outputs()[0].node_arg.Name(); - std::string shape_input_name(input_name + "_" + output_name); + std::string shape_input_name = utils::GetUniqueName(input_name, output_name); QnnTensorWrapper input_tensorwrapper(shape_input_name, QNN_TENSOR_TYPE_STATIC, qnn_data_type, std::move(quantize_param), std::move(input_shape), std::move(shape_data)); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/gather_nd_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/gather_nd_op_builder.cc index 3125e4e13aa15..0a1e6d68010a2 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/gather_nd_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/gather_nd_op_builder.cc @@ -144,7 +144,7 @@ Status GatherNDOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, std::move(cast_output_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(indices_cast_tensor)), "Failed to add gather indices cast tensor."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(indices_casted_name, + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(indices_tensor_name, QNN_OP_CAST), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_CAST, {indices_tensor_name}, @@ -254,8 +254,8 @@ Status GatherNDOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model // If a cast to int64 is needed, add the cast node if (needs_int64_cast) { - std::string cast_node_name = output_name + "_cast_int64"; - std::string cast_input_name = output_name + "_cast_int64_aux"; + std::string cast_node_name = utils::GetUniqueName(node_unit, "_cast_int64"); + std::string cast_input_name = utils::GetUniqueName(output_name, "_cast_int64"); std::string cast_output_name = output_name; // Create the cast input tensor wrapper - use qnn_output_shape for the intermediate tensor @@ -275,9 +275,9 @@ Status GatherNDOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model std::string gather_output_name = output_name; if (reshape_required) { - gather_output_name += "_ort_qnn_ep_reshape"; + gather_output_name = utils::GetUniqueName(output_name, "_reshape"); } else if (needs_int64_cast) { - gather_output_name += "_cast_int64_aux"; + gather_output_name = utils::GetUniqueName(output_name, "_cast_int64"); } Qnn_TensorType_t tensor_type = (!reshape_required && is_graph_output) @@ -289,7 +289,7 @@ Status GatherNDOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(gather_output_tensor)), "Failed to add GatherND output tensor."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit), + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_GATHER_ND, std::move(input_names), @@ -307,10 +307,10 @@ Status GatherNDOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model std::string node_output_name = output_name; if (needs_int64_cast) { // If needs_int64 is true, the output name should be the input name of the cast node - node_output_name = output_name + "_cast_int64_aux"; + node_output_name = utils::GetUniqueName(output_name, "_cast_int64"); } - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(output_name, + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit, QNN_OP_RESHAPE), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_RESHAPE, {gather_output_name}, @@ -341,4 +341,4 @@ void CreateGatherNDOpBuilder(const std::string& op_type, OpBuilderRegistrations& } } // namespace qnn -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc index e39f38fb020dc..8b3089f63723c 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc @@ -174,7 +174,7 @@ static Status ProcessIndicesInput(QnnModelWrapper& qnn_model_wrapper, std::move(cast_output_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(indices_cast_tensor)), "Failed to add gather indices cast tensor."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(indices_casted_name, + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(indices_tensor_name, QNN_OP_CAST), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_CAST, {indices_tensor_name}, @@ -298,9 +298,9 @@ Status GatherOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w // If a cast to int64 is needed, add the cast node if (needs_int64_cast) { - std::string cast_node_name = output_name + "_cast_int64"; - std::string cast_input_name = output_name + "_cast_int64_aux"; - std::string cast_output_name = output_name; + const std::string cast_node_name = utils::GetUniqueName(node_unit, "_cast_int64"); + const std::string cast_input_name = utils::GetUniqueName(output_name, "_cast_int64"); + const std::string cast_output_name = output_name; // Create the cast input tensor wrapper QnnTensorWrapper cast_input_tensorwrapper(cast_input_name, @@ -310,7 +310,7 @@ Status GatherOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w std::move(qnn_output_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_input_tensorwrapper)), "Failed to add tensor."); - cast_node_info_vec.push_back({cast_node_name, cast_input_name, cast_output_name}); + cast_node_info_vec.emplace_back(CastNodeInfo{cast_node_name, cast_input_name, cast_output_name}); Qnn_TensorType_t cast_tensor_type = is_graph_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; QnnTensorWrapper cast_output(output_name, cast_tensor_type, qnn_data_type, std::move(quantize_param), std::move(target_output_shape)); @@ -319,16 +319,16 @@ Status GatherOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w std::string gather_output_name = output_name; if (reshape_required) { - gather_output_name += "_ort_qnn_ep_reshape"; + gather_output_name = utils::GetUniqueName(output_name, "_reshape"); } else if (needs_int64_cast) { - gather_output_name += "_cast_int64_aux"; + gather_output_name = utils::GetUniqueName(output_name, "_cast_int64"); } Qnn_TensorType_t tensor_type = (!reshape_required && is_graph_output) ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; QnnTensorWrapper gather_output_wrapper(gather_output_name, tensor_type, qnn_data_type, quantize_param.Copy(), std::move(qnn_output_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(gather_output_wrapper)), "Failed to add tensor."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit), + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit), QNN_OP_PACKAGE_NAME_QTI_AISW, GetQnnOpType(node_unit.OpType()), std::move(input_names), @@ -347,9 +347,9 @@ Status GatherOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w if (needs_int64_cast) { // If needs_int64 is true, the output name should be the input name of the cast node - node_output_name = output_name + "_cast_int64_aux"; + node_output_name = utils::GetUniqueName(output_name, "_cast_int64"); } - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(output_name, + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit, QNN_OP_RESHAPE), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_RESHAPE, {gather_output_name}, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc index 03dba4e98a901..4f5cb5ff86c6a 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc @@ -129,7 +129,7 @@ Status GemmOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, input_shape[0] = old_input_shape[1]; input_shape[1] = old_input_shape[0]; const std::string& node_input_name(input_name); - input_tensor_name = input_tensor_name + "_ort_qnn_ep_transpose"; + input_tensor_name = utils::GetUniqueName(input_tensor_name, "_transpose"); std::vector perm{1, 0}; ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddTransposeNode(node_unit.Index(), node_input_name, input_tensor_name, old_input_shape, perm, input_shape, @@ -178,8 +178,7 @@ Status GemmOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, // Pop FC weight. Insert Convert op after Weight input_names.pop_back(); - const std::string& fc_output_name = node_unit.Outputs()[0].node_arg.Name(); - std::string convert_output_name = weight_input_name + "_convert_" + fc_output_name; + std::string convert_output_name = utils::GetUniqueName(weight_input_name, "_convert"); ORT_RETURN_IF_ERROR(utils::InsertConvertOp(qnn_model_wrapper, weight_input_name, @@ -231,35 +230,33 @@ Status GemmOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra std::vector gemm_input_0_1; gemm_input_0_1.push_back(input_names[0]); gemm_input_0_1.push_back(input_names[1]); - std::string split_fully_connected_name = onnxruntime::qnn::utils::GetNodeName(node_unit) + "_split_FullyConnected"; - std::string split_fully_connected_output_name = onnxruntime::qnn::utils::GetNodeName(node_unit) + "_split_FullyConnected_output"; - QnnTensorWrapper fully_connected_output(split_fully_connected_output_name, QNN_TENSOR_TYPE_NATIVE, input_info.qnn_data_type, + const std::string fc_output_name = onnxruntime::qnn::utils::GetUniqueName(org_output_name, "_fc"); + QnnTensorWrapper fully_connected_output(fc_output_name, QNN_TENSOR_TYPE_NATIVE, input_info.qnn_data_type, QnnQuantParamsWrapper(), std::vector(output_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(fully_connected_output)), "Failed to add FullyConnected output tensor."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(split_fully_connected_name, + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit, QNN_OP_FULLY_CONNECTED), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_FULLY_CONNECTED, std::move(gemm_input_0_1), - {split_fully_connected_output_name}, + {fc_output_name}, {}, do_op_validation), "Failed to add FullyConnected node."); // Create Add Node Qnn_TensorType_t op_output_tensor_type = is_graph_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; - std::string split_add_name = onnxruntime::qnn::utils::GetNodeName(node_unit) + "_split_add"; QnnTensorWrapper op_output_tensor_wrapper(org_output_name, op_output_tensor_type, output_info.qnn_data_type, op_output_quant_param.Copy(), std::vector(output_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(op_output_tensor_wrapper)), "Failed to add ElementWiseAdd output tensor."); std::string bias_name = input_names[2]; - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(split_add_name, + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit, QNN_OP_ELEMENT_WISE_ADD), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_ELEMENT_WISE_ADD, - {split_fully_connected_output_name, bias_name}, // FullyConnected output as input - {org_output_name}, // Original output as output + {fc_output_name, bias_name}, + {org_output_name}, {}, do_op_validation), "Failed to add ElementWiseAdd node."); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/instance_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/instance_norm_op_builder.cc index e370501871e81..5fc2b710c73d0 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/instance_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/instance_norm_op_builder.cc @@ -103,7 +103,7 @@ Status InstanceNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, input0_info.shape.size() == 3 && input0_info.shape[0] != 1) { const std::string& orig_input0_name = inputs[0].node_arg.Name(); const std::string op_input0_name = input0_info.is_initializer ? orig_input0_name - : orig_input0_name + "_ort_qnn_ep_reshape"; + : utils::GetUniqueName(orig_input0_name, "_reshape"); input_names.push_back(op_input0_name); std::vector initializer_data; @@ -170,7 +170,7 @@ Status InstanceNormOpBuilder::ProcessScale(QnnModelWrapper& qnn_model_wrapper, const Qnn_QuantizeParams_t& quant_param = tensor_info.quant_param.Get(); if (tensor_info.qnn_data_type == QNN_DATATYPE_SFIXED_POINT_8) { std::string convert_input_name = input_names.back(); - std::string convert_output_name = convert_input_name + "_convert_s8_to_u8"; + std::string convert_output_name = utils::GetUniqueName(convert_input_name, "_convert_s8_to_u8"); Status status = utils::InsertConvertOp( qnn_model_wrapper, convert_input_name, @@ -231,7 +231,7 @@ Status InstanceNormOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_m // const std::string& orig_output_name = outputs[0].node_arg.Name(); - std::string op_output_name = orig_output_name + "_ort_qnn_ep_reshape"; + std::string op_output_name = utils::GetUniqueName(orig_output_name, "_reshape"); std::vector op_output_shape = { output_info.shape[0], // N @@ -243,7 +243,7 @@ Status InstanceNormOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_m QnnTensorWrapper output_tensorwrapper(op_output_name, QNN_TENSOR_TYPE_NATIVE, output_info.qnn_data_type, output_info.quant_param.Copy(), std::vector(op_output_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit), + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit), QNN_OP_PACKAGE_NAME_QTI_AISW, GetQnnOpType(node_unit.OpType()), std::move(input_names), diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc index fc92f42b376bc..b3deeb6b25db8 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc @@ -92,7 +92,7 @@ Status LayerNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[SCALE_IDX], scale_input_info)); if (x_input_info.quant_param.IsPerTensor(/*include_bw*/ true) && scale_input_info.quant_param.IsQuantized()) { - const std::string bias_name = qnn::utils::GetNodeName(node_unit) + "_implicit_bias_ort_qnn_ep"; + const std::string bias_name = qnn::utils::GetUniqueName(node_unit, "_implicit_bias"); std::vector bias_shape = scale_input_info.shape; ORT_RETURN_IF_ERROR(AddZeroBiasInput(qnn_model_wrapper, x_input_info.quant_param, scale_input_info.quant_param, std::move(bias_shape), bias_name, logger, input_names)); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/lstm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/lstm_op_builder.cc index f131d58277038..7efe0f1279a4b 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/lstm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/lstm_op_builder.cc @@ -197,7 +197,7 @@ Status LSTMOpBuilder::AddStridedSliceOrReshape(QnnModelWrapper& qnn_model_wrappe "Failed to add input tensor for inserted StridedSlice or Reshape."); // params - const std::string& node_name = output_name; + const std::string node_name = utils::GetUniqueName(node_unit, QNN_OP_STRIDED_SLICE); // ranges std::vector ranges_data; @@ -314,7 +314,7 @@ Status LSTMOpBuilder::AddUnidirectionLSTM(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(onnx_inputs[i], input_tensor_infos[i])); } } - // becuase QNN LSTM three outputs are mandatory, we should provide them tensor info + // because QNN LSTM three outputs are mandatory, we should provide them tensor info std::vector output_tensor_infos(3); for (size_t i = 0; i < 3; i++) { if (onnx_outputs.size() > i && onnx_outputs[i].node_arg.Exists()) { @@ -389,10 +389,10 @@ Status LSTMOpBuilder::AddUnidirectionLSTM(QnnModelWrapper& qnn_model_wrapper, std::vector qnn_input_indices = {1, 2, 3, 16}; std::vector begins = {2, 3, 1, 0}; std::vector qnn_lstm_weight_name = { - input_names[1] + "_input_to_forget_gate_weight_" + direction, - input_names[1] + "_input_to_cell_gate_weight_" + direction, - input_names[1] + "_input_to_output_gate_weight_" + direction, - input_names[1] + "_input_to_input_gate_weight_" + direction, + utils::GetUniqueName(input_names[1], "_input_to_forget_gate_weight_" + direction), + utils::GetUniqueName(input_names[1], "_input_to_cell_gate_weight_" + direction), + utils::GetUniqueName(input_names[1], "_input_to_output_gate_weight_" + direction), + utils::GetUniqueName(input_names[1], "_input_to_input_gate_weight_" + direction), }; for (size_t i = 0; i < 4; i++) { std::vector> ranges = {{direction_idx, direction_idx + 1, 1}, @@ -432,10 +432,10 @@ Status LSTMOpBuilder::AddUnidirectionLSTM(QnnModelWrapper& qnn_model_wrapper, std::vector qnn_input_indices = {4, 5, 6, 17}; std::vector begins = {2, 3, 1, 0}; std::vector qnn_lstm_weight_name = { - input_names[2] + "_recurrent_to_forget_gate_weight_" + direction, - input_names[2] + "_recurrent_to_cell_gate_weight_" + direction, - input_names[2] + "_recurrent_to_output_gate_weight_" + direction, - input_names[2] + "_recurrent_to_input_gate_weight_" + direction}; + utils::GetUniqueName(input_names[2], "_recurrent_to_forget_gate_weight_" + direction), + utils::GetUniqueName(input_names[2], "_recurrent_to_cell_gate_weight_" + direction), + utils::GetUniqueName(input_names[2], "_recurrent_to_output_gate_weight_" + direction), + utils::GetUniqueName(input_names[2], "_recurrent_to_input_gate_weight_" + direction)}; for (size_t i = 0; i < 4; i++) { std::vector> ranges = {{direction_idx, direction_idx + 1, 1}, {begins[i] * hidden_size_sign, (begins[i] + 1) * hidden_size_sign, 1}, @@ -473,22 +473,22 @@ Status LSTMOpBuilder::AddUnidirectionLSTM(QnnModelWrapper& qnn_model_wrapper, uint32_t new_axes_mask = 0b00U; std::vector output_shape = {hidden_size}; std::vector qnn_lstm_bias_name = { - node_name + "_forget_gate_bias_" + direction, - node_name + "_cell_gate_bias_" + direction, - node_name + "_output_gate_bias_" + direction, - node_name + "_input_gate_bias_" + direction}; + utils::GetUniqueName(node_unit, "_forget_gate_bias_" + direction), + utils::GetUniqueName(node_unit, "_cell_gate_bias_" + direction), + utils::GetUniqueName(node_unit, "_output_gate_bias_" + direction), + utils::GetUniqueName(node_unit, "_input_gate_bias_" + direction)}; std::vector qnn_input_indices = {7, 8, 9, 21}; if (onnx_inputs.size() > 3 && onnx_inputs[3].node_arg.Exists()) { std::vector begins = {2, 3, 1, 0, 6, 7, 5, 4}; std::vector onnx_lstm_bias_name = { - input_names[3] + "_input_to_forget_gate_bias_" + direction, - input_names[3] + "_input_to_cell_gate_bias_" + direction, - input_names[3] + "_input_to_output_gate_bias_" + direction, - input_names[3] + "_input_to_input_gate_bias_" + direction, - input_names[3] + "_recurrent_to_forget_gate_bias_" + direction, - input_names[3] + "_recurrent_to_cell_gate_bias_" + direction, - input_names[3] + "_recurrent_to_output_gate_bias_" + direction, - input_names[3] + "_recurrent_to_input_gate_bias_" + direction}; + utils::GetUniqueName(input_names[3], "_input_to_forget_gate_bias_" + direction), + utils::GetUniqueName(input_names[3], "_input_to_cell_gate_bias_" + direction), + utils::GetUniqueName(input_names[3], "_input_to_output_gate_bias_" + direction), + utils::GetUniqueName(input_names[3], "_input_to_input_gate_bias_" + direction), + utils::GetUniqueName(input_names[3], "_recurrent_to_forget_gate_bias_" + direction), + utils::GetUniqueName(input_names[3], "_recurrent_to_cell_gate_bias_" + direction), + utils::GetUniqueName(input_names[3], "_recurrent_to_output_gate_bias_" + direction), + utils::GetUniqueName(input_names[3], "_recurrent_to_input_gate_bias_" + direction)}; for (size_t i = 0; i < 8; i++) { std::vector> ranges = {{direction_idx, direction_idx + 1, 1}, {begins[i] * hidden_size_sign, (begins[i] + 1) * hidden_size_sign, 1}}; @@ -516,14 +516,14 @@ Status LSTMOpBuilder::AddUnidirectionLSTM(QnnModelWrapper& qnn_model_wrapper, input_tensor_infos[3].quant_param.Copy(), std::vector(output_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(add_output_tensorwrapper)), "QNN EP: Failed to add output tensor for inserted ElementWiseAdd node."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name, QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_ELEMENT_WISE_ADD, + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit, QNN_OP_ELEMENT_WISE_ADD), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_ELEMENT_WISE_ADD, std::move(add_input_names), {qnn_lstm_bias_name[i]}, {}, do_op_validation), "Failed to create manually inserted ElementWiseAdd node."); qnn_lstm_input_names[qnn_input_indices[i]] = qnn_lstm_bias_name[i]; } } else { // prepare zero bias - std::string zero_bias_name = node_name + "_zero_bias"; + std::string zero_bias_name = utils::GetUniqueName(node_unit, "_zero_bias"); QnnTensorWrapper zero_bias_tensor_wrapper(zero_bias_name, QNN_TENSOR_TYPE_STATIC, input_tensor_infos[0].qnn_data_type, @@ -551,9 +551,9 @@ Status LSTMOpBuilder::AddUnidirectionLSTM(QnnModelWrapper& qnn_model_wrapper, std::vector qnn_input_indices = {18, 19, 20}; std::vector begins = {0, 2, 1}; std::vector qnn_lstm_weight_name = { - input_names[7] + "_cell_to_input_gate_weight_" + direction, - input_names[7] + "_cell_to_forget_gate_weight_" + direction, - input_names[7] + "_cell_to_output_gate_weight_" + direction}; + utils::GetUniqueName(input_names[7], "_cell_to_input_gate_weight_" + direction), + utils::GetUniqueName(input_names[7], "_cell_to_forget_gate_weight_" + direction), + utils::GetUniqueName(input_names[7], "_cell_to_output_gate_weight_" + direction)}; for (size_t i = 0; i < 3; i++) { std::vector> ranges = { {direction_idx, direction_idx + 1, 1}, @@ -595,7 +595,7 @@ Status LSTMOpBuilder::AddUnidirectionLSTM(QnnModelWrapper& qnn_model_wrapper, std::vector output_shape = {batch_size, hidden_size}; for (size_t i = 0; i < 2; i++) { if (onnx_inputs.size() > src_indices[i] && onnx_inputs[src_indices[i]].node_arg.Exists()) { - std::string qnn_lstm_input_name = input_names[src_indices[i]] + "_" + direction; + const std::string qnn_lstm_input_name = utils::GetUniqueName(input_names[src_indices[i]], direction); ORT_RETURN_IF_ERROR(AddStridedSliceOrReshape(/*qnn_model_wrapper=*/qnn_model_wrapper, /*node_unit=*/node_unit, /*input_name=*/input_names[src_indices[i]], @@ -615,7 +615,7 @@ Status LSTMOpBuilder::AddUnidirectionLSTM(QnnModelWrapper& qnn_model_wrapper, qnn_lstm_input_names[qnn_input_indices[i]] = qnn_lstm_input_name; } else { // prepare zero initial values - std::string zero_initial_values_name = node_name + "_LSTM_initial_values_" + (i == 0 ? "h" : "c"); + std::string zero_initial_values_name = utils::GetUniqueName(node_name, std::string("_LSTM_initial_values_") + (i == 0 ? "h" : "c")); QnnTensorWrapper zero_bias_tensor_wrapper(zero_initial_values_name, QNN_TENSOR_TYPE_STATIC, input_tensor_infos[0].qnn_data_type, @@ -648,7 +648,7 @@ Status LSTMOpBuilder::AddUnidirectionLSTM(QnnModelWrapper& qnn_model_wrapper, std::vector> ranges = {{SafeInt(sequence_idx), SafeInt(sequence_idx + 1), 1}, {0, SafeInt(batch_size), 1}, {0, SafeInt(input_size), 1}}; - std::string qnn_lstm_input_name = input_names[0] + "_cell_" + std::to_string(sequence_idx) + "_input"; + std::string qnn_lstm_input_name = utils::GetUniqueName(input_names[0], "_cell_" + std::to_string(sequence_idx) + "_input"); std::vector output_shape = {batch_size, input_size}; ORT_RETURN_IF_ERROR(AddStridedSliceOrReshape(/*qnn_model_wrapper=*/qnn_model_wrapper, /*node_unit=*/node_unit, @@ -673,9 +673,9 @@ Status LSTMOpBuilder::AddUnidirectionLSTM(QnnModelWrapper& qnn_model_wrapper, std::vector qnn_lstm_output_shape = {batch_size, hidden_size}; std::vector qnn_lstm_output_names = { - node_name + "_QNN_LSTM_output_all_hidden_state_" + std::to_string(sequence_idx) + "_" + direction, - node_name + "_QNN_LSTM_output_cell_state_" + std::to_string(sequence_idx) + "_" + direction, - node_name + "_QNN_LSTM_output_hidden_state_" + std::to_string(sequence_idx) + "_" + direction}; + utils::GetUniqueName(node_unit, "_QNN_LSTM_output_all_hidden_state_" + std::to_string(sequence_idx) + "_" + direction), + utils::GetUniqueName(node_unit, "_QNN_LSTM_output_cell_state_" + std::to_string(sequence_idx) + "_" + direction), + utils::GetUniqueName(node_unit, "_QNN_LSTM_output_hidden_state_" + std::to_string(sequence_idx) + "_" + direction)}; qnn_lstm_input_names[10] = qnn_lstm_output_names[2]; // update initial_h qnn_lstm_input_names[11] = qnn_lstm_output_names[1]; // update initial_c qnn_all_hidden_state_names[sequence_idx] = qnn_lstm_output_names[2]; @@ -689,7 +689,7 @@ Status LSTMOpBuilder::AddUnidirectionLSTM(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "QNN EP: Failed to add %ldth output tensor for QNN LSTM.", j); } - std::string lstm_node_name = node_name + "_cell_" + std::to_string(sequence_idx) + "_" + direction; + const std::string lstm_node_name = utils::GetUniqueName(node_unit, "_cell" + std::to_string(sequence_idx) + "_" + direction); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(lstm_node_name, QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_LSTM, std::move(qnn_lstm_input_names_i), std::move(qnn_lstm_output_names), std::vector(param_names), do_op_validation), @@ -697,7 +697,7 @@ Status LSTMOpBuilder::AddUnidirectionLSTM(QnnModelWrapper& qnn_model_wrapper, } // pack all timestamp outputs together for onnx output[0] - std::string qnn_pack_output_name = node_name + "_QNN_LSTM_output_hidden_state_all_" + direction; + const std::string qnn_pack_output_name = utils::GetUniqueName(node_unit, "_QNN_LSTM_output_hidden_state_all_" + direction); // add pack for output[0] std::vector pack_param_names; @@ -731,7 +731,7 @@ Status LSTMOpBuilder::AddUnidirectionLSTM(QnnModelWrapper& qnn_model_wrapper, {1, batch_size, hidden_size}}; for (size_t i = 0; i < 3; i++) { if (onnx_outputs.size() > i && onnx_outputs[i].node_arg.Exists()) { - const std::string reshape_output_name = is_bidirection ? qnn_reshape_input_names[i] + "_unsqueeze_" + direction : onnx_outputs[i].node_arg.Name(); + const std::string reshape_output_name = is_bidirection ? utils::GetUniqueName(qnn_reshape_input_names[i], "_unsqueeze_" + direction) : onnx_outputs[i].node_arg.Name(); ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode(/*input_name=*/qnn_reshape_input_names[i], /*output_name=*/reshape_output_name, /*input_shape=*/qnn_lstm_output_shapes[i], @@ -786,7 +786,7 @@ Status LSTMOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra std::vector(output_info.shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(concat_output_tensorwrapper)), "QNN EP: Failed to add output tensor for QNN Concat."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_unit.Name(), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_CONCAT, + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit, QNN_OP_CONCAT), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_CONCAT, {uni_lstm_output_names_forward[i], uni_lstm_output_names_reverse[i]}, {onnx_output_name}, std::move(concat_param_names), do_op_validation), "QNN EP: Failed to create Qnn Concat node."); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/matmul_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/matmul_op_builder.cc index fa5e95727e651..e3e5dde74b642 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/matmul_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/matmul_op_builder.cc @@ -92,7 +92,7 @@ Status ProcessInput0(QnnModelWrapper& qnn_model_wrapper, std::string actual_input_0_name = original_input_0_name; if (reshape_input_0) { - actual_input_0_name = original_input_0_name + "_ort_qnn_ep_reshape"; + actual_input_0_name = utils::GetUniqueName(original_input_0_name, "_reshape"); std::vector shape_2d{1, input_0_info.shape[0]}; QnnQuantParamsWrapper quant_param_2d = input_0_info.quant_param.Copy(); ORT_RETURN_IF_ERROR(quant_param_2d.HandleUnsqueeze(input_0_info.shape, shape_2d)); @@ -178,7 +178,7 @@ Status MatMulOpBuilder::ProcessInputsForQnnMatMul(QnnModelWrapper& qnn_model_wra // Input[1] is a rank 1 tensor that needs to be reshaped. std::vector shape_2d; QnnQuantParamsWrapper quant_param_2d = input_info_1.quant_param.Copy(); - input_1_name = org_input_1_name + "_ort_qnn_ep_reshape"; + input_1_name = utils::GetUniqueName(org_input_1_name, "_reshape"); shape_2d = {input_info_1.shape[0], 1}; ORT_RETURN_IF_ERROR(quant_param_2d.HandleUnsqueeze(input_info_1.shape, shape_2d)); @@ -239,8 +239,7 @@ Status MatMulOpBuilder::ProcessInputsForQnnMatMul(QnnModelWrapper& qnn_model_wra // insert Convert op after input1 std::string convert_input_name = input_names.back(); input_names.pop_back(); - const std::string& matmul_output_name = node_unit.Outputs()[0].node_arg.Name(); - std::string convert_output_name = convert_input_name + "_convert_" + matmul_output_name; + const std::string convert_output_name = utils::GetUniqueName(convert_input_name, "_convert"); std::vector input_1_shape = input_info_1.shape; if (reshape_input_1) { input_1_shape = {input_info_1.shape[0], 1}; @@ -294,14 +293,14 @@ Status MatMulOpBuilder::ProcessInputsForQnnFullyConnected(QnnModelWrapper& qnn_m QnnQuantParamsWrapper quant_param_2d = input_info_1.quant_param.Copy(); if (reshape_input_1) { // Input[1] is a rank 1 tensor that needs to be reshaped. - input_1_name = org_input_1_name + "_ort_qnn_ep_reshape"; + input_1_name = utils::GetUniqueName(org_input_1_name, "_reshape"); // FullyConnected requires input_1's shape to be [n, k]. shape_2d = {1, input_info_1.shape[0]}; ORT_RETURN_IF_ERROR(quant_param_2d.HandleUnsqueeze(input_info_1.shape, shape_2d)); } else { assert(input_info_1.shape.size() == 2); - input_1_name = org_input_1_name + "_ort_qnn_ep_transpose"; + input_1_name = utils::GetUniqueName(org_input_1_name, "_transpose"); shape_2d = {input_info_1.shape[1], input_info_1.shape[0]}; ORT_RETURN_IF_ERROR(quant_param_2d.HandleTranspose(std::vector({1, 0}))); } @@ -361,8 +360,7 @@ Status MatMulOpBuilder::ProcessInputsForQnnFullyConnected(QnnModelWrapper& qnn_m // Pop Conv weight. Insert Convert op after Weight input_names.pop_back(); - const std::string& conv_output_name = node_unit.Outputs()[0].node_arg.Name(); - std::string convert_output_name = weight_input_name + "_convert_" + conv_output_name; + std::string convert_output_name = utils::GetUniqueName(weight_input_name, "_convert"); ORT_RETURN_IF_ERROR(utils::InsertConvertOp(qnn_model_wrapper, weight_input_name, @@ -417,7 +415,7 @@ Status MatMulOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w std::vector op_output_shape = output_info.shape; QnnQuantParamsWrapper op_output_quant_param = output_info.quant_param.Copy(); if (reshape_output) { - op_output_name = org_output_name + "_ort_qnn_ep_reshape"; + op_output_name = utils::GetUniqueName(org_output_name, "_reshape"); if (use_fully_connected && input_info_0.shape.size() > 2) { op_output_shape = {std::accumulate(input_info_0.shape.begin(), input_info_0.shape.end() - 1, static_cast(1), std::multiplies()), @@ -443,7 +441,7 @@ Status MatMulOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w op_output_quant_param.Copy(), std::vector(op_output_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(op_output_tensor_wrapper)), "Failed to add output tensor."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit), QNN_OP_PACKAGE_NAME_QTI_AISW, + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit), QNN_OP_PACKAGE_NAME_QTI_AISW, use_fully_connected ? QNN_OP_FULLY_CONNECTED : QNN_OP_MAT_MUL, std::move(input_names), {op_output_name}, std::move(param_tensor_names), do_op_validation), diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/mean_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/mean_op_builder.cc index 07e73350fde5f..e8aaf800b0092 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/mean_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/mean_op_builder.cc @@ -50,13 +50,12 @@ Status MeanOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(output.node_arg, output_shape), "Failed to get output shape."); std::vector unpackage_data(sizeof(float)); - const std::string add_output = sum_output + "_ort_qnn_ep_add_" + std::to_string(i); + const std::string add_output = utils::GetUniqueName(sum_output, "_add" + std::to_string(i)); QnnTensorWrapper add_tensor(add_output, QNN_TENSOR_TYPE_NATIVE, input_info.qnn_data_type, QnnQuantParamsWrapper(), std::move(output_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(add_tensor)), "Failed to add Add tensor wrapper."); - const std::string add_op_name = "Mean_Add_" + std::to_string(i); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(add_op_name, + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit, QNN_OP_ELEMENT_WISE_ADD), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_ELEMENT_WISE_ADD, {sum_output, input_names[i]}, @@ -74,7 +73,7 @@ Status MeanOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra std::vector divisor_data(sizeof(float)); memcpy(divisor_data.data(), &divisor, sizeof(float)); - const std::string divisor_name = sum_output + "_ort_qnn_ep_divisor"; + const std::string divisor_name = utils::GetUniqueName(sum_output, "_divisor"); QnnTensorWrapper divisor_tensor(divisor_name, QNN_TENSOR_TYPE_STATIC, input_info.qnn_data_type, QnnQuantParamsWrapper(), std::move(scalar_shape), std::move(divisor_data)); @@ -94,8 +93,7 @@ Status MeanOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output tensor wrapper."); std::vector div_inputs = {sum_output, divisor_name}; - const std::string div_node_name = output_name + "_div"; - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(div_node_name, + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit, QNN_OP_ELEMENT_WISE_DIVIDE), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_ELEMENT_WISE_DIVIDE, {sum_output, divisor_name}, @@ -112,4 +110,4 @@ void CreateMeanOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_ } } // namespace qnn -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc index 404d3c402c21e..d2b1434c1c896 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pad_op_builder.cc @@ -193,7 +193,7 @@ Status PadOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrap [](int64_t item) { return SafeInt(item); }); // Onnx format is begin_0, begin_1, ..., end_0, end_1, ... // Qnn format is begin_0, end_0, begin_1, end_1, ... - ReArranagePads(pad_amount); + ReArrangePads(pad_amount); std::vector input_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[0].node_arg, input_shape), "Cannot get shape of input 0."); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc index 21947a22e2b92..851c65aa1c1a3 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/pool_op_builder.cc @@ -195,7 +195,7 @@ Status PoolOpBuilder::SetCommonPoolParams(const NodeAttrHelper& node_helper, } } } - ReArranagePads(pad_amount); + ReArrangePads(pad_amount); // Param: rounding_mode. rounding_mode = node_helper.Get("ceil_mode", rounding_mode); @@ -235,7 +235,7 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(reshape_input, reshape_input_info)); bool needs_reshape = false; - const std::string reshape_prior_out = input_names[0] + "_prior_reshape"; + const std::string reshape_prior_out = utils::GetUniqueName(input_names[0], "_reshape"); if (input_shape.size() == 3) { needs_reshape = true; // build new_shape = {N, 1, C, L} @@ -254,7 +254,7 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_prior_tensor)), "Failed to add reshape prior tensor."); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - utils::GetNodeName(node_unit) + "_reshape_prior", + utils::GetUniqueName(node_unit, QNN_OP_RESHAPE), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_RESHAPE, {input_names[0]}, @@ -445,8 +445,7 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra } const auto& outputs = node_unit.Outputs(); const std::string real_out = outputs[0].node_arg.Name(); - const std::string pool_out = real_out + "_reshape_after"; - const std::string qnn_op = GetQnnOpType(op_type); + const std::string pool_out = utils::GetUniqueName(real_out, "_reshape_after"); TensorInfo output_info{}; ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Outputs()[0], output_info)); bool is_graph_output = qnn_model_wrapper.IsGraphOutput(real_out); @@ -463,9 +462,9 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra "Failed to add tensor for pool_out"); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - utils::GetNodeName(node_unit) + "_pool2d", + utils::GetUniqueName(node_unit, op_type), QNN_OP_PACKAGE_NAME_QTI_AISW, - qnn_op, + GetQnnOpType(op_type), {reshape_prior_out}, {pool_out}, std::move(param_tensor_names), @@ -483,7 +482,7 @@ Status PoolOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reshape_after_tensor)), "Failed to add reshape after tensor."); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - utils::GetNodeName(node_unit) + "_reshape_after", + utils::GetUniqueName(node_unit, QNN_OP_RESHAPE), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_RESHAPE, {pool_out}, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/reciprocal_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/reciprocal_op_builder.cc index bd55df8650b97..484d50c63a814 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/reciprocal_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/reciprocal_op_builder.cc @@ -51,7 +51,7 @@ Status ReciprocalOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_mod ORT_UNUSED_PARAMETER(logger); // Create a constant tensor for the divisor (1.0) - std::string divisor_name = node_unit.Name() + "_divisor"; + std::string divisor_name = utils::GetUniqueName(node_unit, "_divisor"); std::vector divisor_shape{1}; std::vector divisor_data; @@ -94,7 +94,7 @@ Status ReciprocalOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_mod ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add output tensor."); ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( - utils::GetNodeName(node_unit), + utils::GetUniqueName(node_unit), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_ELEMENT_WISE_DIVIDE, {divisor_name, input_names[0]}, @@ -111,4 +111,4 @@ void CreateReciprocalOpBuilder(const std::string& op_type, OpBuilderRegistration } } // namespace qnn -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/reduce_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/reduce_op_builder.cc index 25763f23d442c..e1d3a202f58f1 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/reduce_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/reduce_op_builder.cc @@ -256,23 +256,33 @@ Status ReduceOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w // Step 1: y_pow2 = x * x, using ElementWiseMultiply instead of ElementWisePower so we don't need to add a new // initializer tensor for the power value. The performance difference is negligible. - const std::string pow2_name = input_name + "_ort_qnn_ep_pow2"; - QnnTensorWrapper pow2_tensorwrapper(pow2_name, QNN_TENSOR_TYPE_NATIVE, qnn_data_type, QnnQuantParamsWrapper(), + const std::string pow2_output_name = utils::GetUniqueName(input_name, "_pow2"); + QnnTensorWrapper pow2_tensorwrapper(pow2_output_name, QNN_TENSOR_TYPE_NATIVE, qnn_data_type, QnnQuantParamsWrapper(), std::move(input_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(pow2_tensorwrapper)), "AddTensorWrapper failed"); ORT_RETURN_IF_NOT( - qnn_model_wrapper.CreateQnnNode(pow2_name, QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_ELEMENT_WISE_MULTIPLY, - {input_name, input_name}, {pow2_name}, {}, do_op_validation), + qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit, QNN_OP_ELEMENT_WISE_MULTIPLY), + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_ELEMENT_WISE_MULTIPLY, + {input_name, input_name}, + {pow2_output_name}, + {}, + do_op_validation), "CreateQnnNode failed"); // Step 2: y_pow2_sum = ReduceSum(y_pow2) - const std::string reduce_name = input_name + "_ort_qnn_ep_pow2_sum"; - QnnTensorWrapper reduce_tensorwrapper(reduce_name, QNN_TENSOR_TYPE_NATIVE, qnn_data_type, QnnQuantParamsWrapper(), + const std::string reduce_output_name = utils::GetUniqueName(input_name, "_sum"); + QnnTensorWrapper reduce_tensorwrapper(reduce_output_name, QNN_TENSOR_TYPE_NATIVE, qnn_data_type, QnnQuantParamsWrapper(), std::vector(output_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(reduce_tensorwrapper)), "AddTensorWrapper failed"); ORT_RETURN_IF_NOT( - qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_REDUCE_SUM, - {pow2_name}, {reduce_name}, std::move(param_tensor_names), do_op_validation), + qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit, QNN_OP_REDUCE_SUM), + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_REDUCE_SUM, + {pow2_output_name}, + {reduce_output_name}, + std::move(param_tensor_names), + do_op_validation), "CreateQnnNode failed"); // Step 3: y = Sqrt(y_pow2_sum) @@ -281,9 +291,13 @@ Status ReduceOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w QnnTensorWrapper sqrt_tensorwrapper(output.node_arg.Name(), output_tensor_type, qnn_data_type, QnnQuantParamsWrapper(), std::move(output_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(sqrt_tensorwrapper)), "AddTensorWrapper failed"); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(input_name + "_ort_qnn_ep_pow2_sum_sqrt", - QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_ELEMENT_WISE_SQUARE_ROOT, - {reduce_name}, {output.node_arg.Name()}, {}, do_op_validation), + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit, QNN_OP_ELEMENT_WISE_SQUARE_ROOT), + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_ELEMENT_WISE_SQUARE_ROOT, + {reduce_output_name}, + {output.node_arg.Name()}, + {}, + do_op_validation), "CreateQnnNode failed"); } else { ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc index ff3f476cbc4dc..c2211ce35ff59 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/softmax_op_builder.cc @@ -92,7 +92,7 @@ Status SoftmaxOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, Given an input with shape=(3, 4, 5) and axis=0. Its behavior is to reshape the input to (1, 60), perform softmax, and then reshape back to (3, 4, 5). */ - std::string reshape_output_name = input_name + "_ort_qnn_ep_reshape"; + std::string reshape_output_name = utils::GetUniqueName(input_name, "_reshape"); std::vector reshape_output_shape = FlattenShapeFromAxis(input_info.shape, axis); // Input is dynamic, so add reshape node before input. @@ -114,7 +114,7 @@ Status SoftmaxOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, input dimension. QNN EP is able to support arbitrary axis attribute by wrapping transposes around the operator. */ - std::string transpose_output_name = input_name + "_ort_qnn_ep_transpose"; + std::string transpose_output_name = utils::GetUniqueName(input_name, "_transpose"); std::vector transpose_perm; ORT_RETURN_IF_ERROR(utils::GetPermToLastAxis(static_cast(axis), static_cast(input_rank), @@ -168,7 +168,7 @@ Status SoftmaxOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_ size_t output_rank = output_info.shape.size(); if (opset_version < 13) { - std::string reshape_input_name = orig_output_name + "_ort_qnn_ep_reshape"; + std::string reshape_input_name = utils::GetUniqueName(orig_output_name, "_reshape"); std::vector reshape_input_shape = FlattenShapeFromAxis(output_info.shape, axis); if (axis == 0) { @@ -184,7 +184,7 @@ Status SoftmaxOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_ QnnTensorWrapper output_tensorwrapper(reshape_input_name, QNN_TENSOR_TYPE_NATIVE, output_info.qnn_data_type, output_info.quant_param.Copy(), std::vector(reshape_input_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit), + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit), QNN_OP_PACKAGE_NAME_QTI_AISW, GetQnnOpType(node_unit.OpType()), std::move(input_names), @@ -203,7 +203,7 @@ Status SoftmaxOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_ false, is_graph_output)); } else if (is_npu_backend && axis != static_cast(output_rank) - 1) { - std::string transpose_input_name = orig_output_name + "_ort_qnn_ep_transpose"; + std::string transpose_input_name = utils::GetUniqueName(orig_output_name, "_transpose"); std::vector transpose_input_shape = output_info.shape; transpose_input_shape[output_rank - 1] = output_info.shape[axis]; @@ -220,7 +220,7 @@ Status SoftmaxOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_ QnnTensorWrapper output_tensorwrapper(transpose_input_name, QNN_TENSOR_TYPE_NATIVE, output_info.qnn_data_type, output_info.quant_param.Copy(), std::vector(transpose_input_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit), + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit), QNN_OP_PACKAGE_NAME_QTI_AISW, GetQnnOpType(node_unit.OpType()), std::move(input_names), diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc index b87cdd4e25f08..78bd4ab41155e 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc @@ -96,7 +96,7 @@ Status TopKOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, } // Add Transpose to permute axis to the last. - std::string transpose_output_name = input_names[0] + "_ort_qnn_ep_transpose"; + const std::string transpose_output_name = utils::GetUniqueName(input_names[0], "_transpose"); std::vector transpose_perm; ORT_RETURN_IF_ERROR(utils::GetPermToLastAxis(static_cast(axis), static_cast(input_rank), @@ -175,7 +175,7 @@ Status TopKOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra // Since user may not be aware of the additional Transpose, the original output name of TopK node must be used by // the additional Transpose node which has the same output as original TopK node. const std::string& output_name = output.node_arg.Name(); - std::string transpose_input_name = output_name + "_ort_qnn_ep_transpose"; + const std::string transpose_input_name = utils::GetUniqueName(output_name, "_transpose"); transpose_input_names.push_back(std::move(transpose_input_name)); // Since the input of TopK node is permuted, its output shape must be manually calculated. @@ -197,7 +197,7 @@ Status TopKOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra } // Add TopK node. - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit), + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit), QNN_OP_PACKAGE_NAME_QTI_AISW, GetQnnOpType(node_unit.OpType()), std::move(input_names), @@ -228,7 +228,7 @@ Status TopKOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra bool is_cast_required = output_idx == 1 && output_info.qnn_data_type == QNN_DATATYPE_INT_64 && is_graph_output; std::string cast_input_name = ""; if (is_cast_required) { - cast_input_name = transpose_output_name + "_ort_qnn_ep_cast"; + cast_input_name = utils::GetUniqueName(transpose_output_name, "_cast"); // For the same reason described above, the original output name is now used by this Cast. transpose_output_name = cast_input_name; // Since additional Cast is added, below Transpose is no longer graph output. @@ -255,7 +255,7 @@ Status TopKOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra std::vector(output_info.shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_output_tensorwrapper)), "Failed to add tensor."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(cast_input_name, + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit, QNN_OP_CAST), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_CAST, {cast_input_name}, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/transpose_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/transpose_op_builder.cc index 3498aa92032f3..e27cf3ad11530 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/transpose_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/transpose_op_builder.cc @@ -111,8 +111,8 @@ Status TransposeOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_mode // If a cast to int64 is needed, add the cast node if (needs_int64_cast) { - std::string cast_node_name = output_name + "_cast_int64"; - std::string cast_input_name = output_name + "_cast_int64_aux"; + std::string cast_node_name = utils::GetUniqueName(node_unit, "_cast_int64"); + std::string cast_input_name = utils::GetUniqueName(output_name, "_cast_int64"); std::string cast_output_name = output_name; // Create the cast input tensor wrapper @@ -123,7 +123,7 @@ Status TransposeOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_mode std::move(output_shape)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(cast_input_tensorwrapper)), "Failed to add tensor."); - cast_node_info_vec.push_back({cast_node_name, cast_input_name, cast_output_name}); + cast_node_info_vec.emplace_back(CastNodeInfo{cast_node_name, cast_input_name, cast_output_name}); } // Transpose output uses same data type and quantization parameter with input @@ -140,7 +140,7 @@ Status TransposeOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_mode ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor."); output_names.push_back(output_name); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit), + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_TRANSPOSE, std::move(input_names), diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index e1a74b9e35370..1e4ba6afe6f0b 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -158,7 +158,7 @@ bool QnnModelWrapper::CreateQnnInputOutputTensors(const std::string& qnn_node_na return false; } - // During graph patitioning, we only need to do op validation, it's not required to create Qnn graph tensor + // During graph partitioning, we only need to do op validation, it's not required to create Qnn graph tensor // We only need to create the Qnn graph tensor during Compile to create Qnn graph if (!do_op_validation) { std::string error_string; @@ -514,7 +514,7 @@ Status QnnModelWrapper::AddReshapeNode(const std::string& input_name, const std: ORT_RETURN_IF_NOT(AddTensorWrapper(std::move(output_tensorwrapper)), "QNN EP: Failed to add output tensor for inserted Reshape."); - ORT_RETURN_IF_NOT(CreateQnnNode(output_name, QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_RESHAPE, {input_name}, + ORT_RETURN_IF_NOT(CreateQnnNode(utils::GetUniqueName(output_name, QNN_OP_RESHAPE), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_RESHAPE, {input_name}, {output_name}, {}, do_op_validation), "QNN EP: Failed to create manually inserted Qnn Reshape node."); @@ -564,8 +564,7 @@ Status QnnModelWrapper::AddTransposeNode(NodeIndex node_index, uint32_t perm_size = static_cast(transpose_perm.size()); std::vector perm_dim{perm_size}; std::vector transpose_perm_copy = transpose_perm; - const std::string& node_name = output_name; - QnnParamWrapper transpose_param(node_index, node_name, QNN_OP_TRANSPOSE_PARAM_PERM, std::move(perm_dim), std::move(transpose_perm_copy)); + QnnParamWrapper transpose_param(node_index, output_name, QNN_OP_TRANSPOSE_PARAM_PERM, std::move(perm_dim), std::move(transpose_perm_copy)); std::string param_tensor_name(transpose_param.GetParamTensorName()); ORT_RETURN_IF_NOT(AddParamWrapper(std::move(transpose_param)), "Failed to add tensor."); Qnn_TensorType_t tensor_type = (false == is_for_output) ? QNN_TENSOR_TYPE_NATIVE : QNN_TENSOR_TYPE_APP_READ; @@ -576,11 +575,9 @@ Status QnnModelWrapper::AddTransposeNode(NodeIndex node_index, quantize_param.Copy(), std::move(output_shape_copy)); ORT_RETURN_IF_NOT(AddTensorWrapper(std::move(output_tensorwrapper)), "Failed to add tensor."); - const static std::string qnn_node_type = "Transpose"; - - ORT_RETURN_IF_NOT(CreateQnnNode(output_name, + ORT_RETURN_IF_NOT(CreateQnnNode(utils::GetUniqueName(output_name, QNN_OP_TRANSPOSE), QNN_OP_PACKAGE_NAME_QTI_AISW, - qnn_node_type, + QNN_OP_TRANSPOSE, {input_name}, {output_name}, {param_tensor_name}, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h index a44163563b430..f0d145c2938c8 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h @@ -186,7 +186,7 @@ class QnnModelWrapper { bool is_for_input = true, bool is_for_output = false); - // Tranpose NCHW->HWCN for QNN weight + // Transpose NCHW->HWCN for QNN weight Status AddNchwToHwcnTranspose(NodeIndex node_index, const std::string& input_name, const std::string& output_name, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/channel_shuffle_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/channel_shuffle_fusion.cc index bad5a3c038cf9..a92df85eb69d3 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/channel_shuffle_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/channel_shuffle_fusion.cc @@ -181,7 +181,7 @@ Status CreateOrValidateOnQnn( if (!validate) { ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(channel_shuffle_input)), "Failed to add input"); ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(channel_shuffle_output)), "Failed to add output"); - ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode(transpose_tail->Name(), + ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode(onnxruntime::qnn::utils::GetUniqueName(*transpose_tail), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_CHANNEL_SHUFFLE, {cs_input_def.node_arg.Name()}, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc index 3af2fdd1f0276..9c981ba916418 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/dq_q_fusion.cc @@ -89,7 +89,7 @@ static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& q_node_unit, bool validate) { assert(dq_node_unit.OpType() == DEQUANTIZE_LINEAR && q_node_unit.OpType() == QUANTIZE_LINEAR); - const auto& node_name = utils::GetNodeName(dq_node_unit); + const auto& node_name = utils::GetUniqueName(dq_node_unit); const NodeUnitIODef& input_def = dq_node_unit.Inputs()[0]; const NodeUnitIODef& output_def = q_node_unit.Outputs()[0]; @@ -109,7 +109,7 @@ static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, } else { ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(q_node_unit), + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name, QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_CONVERT, {input_def.node_arg.Name()}, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc index 5094ad96724f5..af00a7bfb4439 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/hardsigmoid_mul_fusion.cc @@ -106,7 +106,7 @@ static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& mul_node_unit, bool validate) { assert(hardsigmoid_node_unit.OpType() == "HardSigmoid" && mul_node_unit.OpType() == "Mul"); - const auto& node_name = utils::GetNodeName(hardsigmoid_node_unit); + const auto& node_name = utils::GetUniqueName(hardsigmoid_node_unit); const NodeUnitIODef& input_def = hardsigmoid_node_unit.Inputs()[0]; const NodeUnitIODef& output_def = mul_node_unit.Outputs()[0]; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.cc index 99ea79e028b0c..41c06b0f63663 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.cc @@ -191,7 +191,7 @@ Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, act_dql_node_unit.OpType() == "DequantizeLinear" && gemm_node_unit.OpType() == "Gemm" && output_ql_node_unit.OpType() == "QuantizeLinear"); - const auto& node_name = utils::GetNodeName(gemm_node_unit); + const auto& node_name = utils::GetUniqueName(gemm_node_unit); const NodeUnitIODef& act_dql_input_1_def = act_dql_node_unit.Inputs()[0]; const NodeUnitIODef& w_dql_input_1_def = w_dql_node_unit.Inputs()[0]; const NodeUnitIODef& w_ql_input_1_def = w_ql_node_unit.Inputs()[0]; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.cc index 92e0f28b0307c..96394492299c6 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.cc @@ -127,7 +127,7 @@ Status ProcessInput0(QnnModelWrapper& qnn_model_wrapper, std::string actual_input_0_name = original_input_0_name; if (reshape_input_0) { - actual_input_0_name = original_input_0_name + "_ort_qnn_ep_reshape"; + actual_input_0_name = utils::GetUniqueName(original_input_0_name, "_reshape"); std::vector shape_2d{1, input_0_info.shape[0]}; QnnQuantParamsWrapper quant_param_2d = input_0_info.quant_param.Copy(); ORT_RETURN_IF_ERROR(quant_param_2d.HandleUnsqueeze(input_0_info.shape, shape_2d)); @@ -329,7 +329,7 @@ Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, w_ql_node_unit.OpType() == "QuantizeLinear" && matmul_node_unit.OpType() == "MatMul"); - const auto& node_name = utils::GetNodeName(matmul_node_unit); + const auto& node_name = utils::GetUniqueName(matmul_node_unit); std::vector input_names; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.cc index b851ac9ea9c47..34b680866c7a9 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/reshape_gemm_fusion.cc @@ -68,7 +68,7 @@ bool CheckShape(const Node& reshape_node) { Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& reshape_node_unit, const NodeUnit& gemm_node_unit, bool validate) { assert(reshape_node_unit.OpType() == "Reshape" && gemm_node_unit.OpType() == "Gemm"); - const auto& node_name = utils::GetNodeName(gemm_node_unit); + const auto& node_name = utils::GetUniqueName(gemm_node_unit); const NodeUnitIODef& input_def = reshape_node_unit.Inputs()[0]; const NodeUnitIODef& weight_def = gemm_node_unit.Inputs()[1]; const NodeUnitIODef* bias_def_ptr = nullptr; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.cc index 5c7091b3be3cc..2022fb08d905f 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/scale_softmax_fusion.cc @@ -152,8 +152,9 @@ Status CreateOrValidateOnQnn( ORT_RETURN_IF_ERROR(qnn_model_wrapper->MakeTensorWrapper(mul_input_other, fused_softmax_input)); ORT_RETURN_IF_ERROR(qnn_model_wrapper->MakeTensorWrapper(softmax_output, fused_softmax_output)); + const std::string node_name = onnxruntime::qnn::utils::GetUniqueName(*softmax); if (validate) { - ORT_RETURN_IF_ERROR(qnn_model_wrapper->ValidateQnnNode(softmax->Name(), + ORT_RETURN_IF_ERROR(qnn_model_wrapper->ValidateQnnNode(node_name, QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_SOFTMAX, {fused_softmax_input.GetQnnTensor()}, @@ -162,7 +163,7 @@ Status CreateOrValidateOnQnn( } else { ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(fused_softmax_input)), "Failed to add input"); ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(fused_softmax_output)), "Failed to add output"); - ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode(softmax->Name(), + ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode(node_name, QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_SOFTMAX, {mul_input_other.node_arg.Name()}, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/udo_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/udo_fusion.cc index 9c5b82bcfd68f..2be7433fceca4 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/udo_fusion.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/udo_fusion.cc @@ -99,7 +99,7 @@ static Status CreateOrValidateOnQnn( const logging::Logger& logger) { ORT_UNUSED_PARAMETER(do_op_validation); - std::string node_name = utils::GetNodeName(node_unit); + const std::string node_name = utils::GetUniqueName(node_unit); // get qnn inputs const auto& inputs = node_unit.Inputs(); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc index 1e1a183c5c444..afa5e3bdbb6d1 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc @@ -790,13 +790,33 @@ Status GetQnnDataType(const bool is_quantized_tensor, const ONNX_NAMESPACE::Type return Status::OK(); } -const std::string& GetNodeName(const NodeUnit& node_unit) { - const std::string& node_name = node_unit.Name(); - if (node_name.empty()) { - return node_unit.Outputs()[0].node_arg.Name(); +std::string GetUniqueName(const std::string& base, std::string_view suffix) { + std::string name = base; + if (!suffix.empty()) { + name += suffix; } + { + static std::unordered_map counter; + static std::mutex counter_mutex; + std::lock_guard lock(counter_mutex); + + int& count = counter[name]; + if (count++ > 0) { + return name + "_" + std::to_string(count); + } + } + return name; +} - return node_name; +std::string GetUniqueName(const NodeUnit& node_unit, std::string_view suffix) { + // Preserve node name when exist. Otherwise, use op type with index + std::string base; + if (!node_unit.Name().empty()) { + base = node_unit.Name(); + } else { + base = node_unit.OpType() + std::to_string(node_unit.Index()); + } + return GetUniqueName(base, suffix); } bool OnnxDataTypeToQnnDataType(const int32_t onnx_data_type, Qnn_DataType_t& qnn_data_type, bool is_quantized) { @@ -1380,10 +1400,9 @@ Status InsertConvertOp(QnnModelWrapper& qnn_model_wrapper, QnnQuantParamsWrapper(scale, offset), std::move(output_shape_copy)); ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(convert_output_tensorwrapper)), "Failed to add tensor."); - - ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(convert_output_name, + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(convert_output_name, QNN_OP_CONVERT), QNN_OP_PACKAGE_NAME_QTI_AISW, - "Convert", + QNN_OP_CONVERT, {convert_input_name}, {convert_output_name}, {}, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.h b/onnxruntime/core/providers/qnn/builder/qnn_utils.h index f8e8f9c8409ee..b234f7df375e9 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.h @@ -90,7 +90,11 @@ std::ostream& operator<<(std::ostream& out, const QnnOpConfigWrapper& op_conf_wr Status GetQnnDataType(const bool is_quantized_tensor, const ONNX_NAMESPACE::TypeProto* type_proto, Qnn_DataType_t& tensor_data_type); -const std::string& GetNodeName(const NodeUnit& node_unit); +// Returns an unique name string based on a base string and an optional suffix. +std::string GetUniqueName(const std::string& base, std::string_view suffix = {}); + +// Returns an unique name string from its name or op type and index, plus an optional suffix. +std::string GetUniqueName(const NodeUnit& node_unit, std::string_view suffix = {}); bool OnnxDataTypeToQnnDataType(const int32_t data_type, Qnn_DataType_t& qnn_data_type, bool is_quantized = false); diff --git a/onnxruntime/core/providers/webgpu/math/gemm.cc b/onnxruntime/core/providers/webgpu/math/gemm.cc index c833938f9ad30..4fb512001381a 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm.cc @@ -72,9 +72,7 @@ Status GemmNaiveProgram::GenerateShaderCode(ShaderHelper& shader) const { } // Calculate Alpha - if (alpha_) { - shader.MainFunctionBody() << " value = value * output_value_t(uniforms.alpha);\n"; - } + shader.MainFunctionBody() << " value = value * output_value_t(uniforms.alpha);\n"; // Calculate Bias if (need_handle_bias_) { @@ -126,7 +124,7 @@ Status Gemm::ComputeInternal(ComputeContext& context) const { if (M <= 8 && N <= 8 && K <= 8) { // Use naive implementation for small matrices - GemmNaiveProgram program{transA_, transB_, alpha_, need_handle_bias, need_handle_matmul}; + GemmNaiveProgram program{transA_, transB_, need_handle_bias, need_handle_matmul}; if (need_handle_matmul) { program.AddInputs({{A, ProgramTensorMetadataDependency::Type}, {B, ProgramTensorMetadataDependency::Type}}); @@ -136,7 +134,7 @@ Status Gemm::ComputeInternal(ComputeContext& context) const { program.AddInput({C, ProgramTensorMetadataDependency::Rank}); } - program.CacheHint(alpha_, transA_, transB_) + program.CacheHint(transA_, transB_) .AddOutputs({{Y, ProgramTensorMetadataDependency::Type}}) .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) .SetWorkgroupSize(WORKGROUP_SIZE) diff --git a/onnxruntime/core/providers/webgpu/math/gemm.h b/onnxruntime/core/providers/webgpu/math/gemm.h index 06e6587050604..581b36e379ac1 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm.h +++ b/onnxruntime/core/providers/webgpu/math/gemm.h @@ -12,11 +12,10 @@ namespace webgpu { class GemmNaiveProgram final : public Program { public: - GemmNaiveProgram(bool transA, bool transB, float alpha, bool need_handle_bias, bool need_handle_matmul) + GemmNaiveProgram(bool transA, bool transB, bool need_handle_bias, bool need_handle_matmul) : Program{"GemmNaive"}, transA_{transA}, transB_{transB}, - alpha_{alpha}, need_handle_bias_{need_handle_bias}, need_handle_matmul_{need_handle_matmul} {} @@ -33,7 +32,6 @@ class GemmNaiveProgram final : public Program { private: bool transA_; bool transB_; - float alpha_; bool need_handle_bias_; bool need_handle_matmul_; }; diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index a1adb1a016b73..2993bfcebb1da 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -72,7 +72,7 @@ void MatMulReadFnSource(ShaderHelper& shader, ? ", batch_indices: batch_dims_indices_t" : "") << ") -> " << type_string << " {\n " - << " var value = " << type_string << "(0.0);\n" + << " var value = " << type_string << "(0);\n" << " let col = colIn * " << components << ";\n"; if (transA) { shader.AdditionalImplementation() << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_a_outer)) {\n"; @@ -98,7 +98,7 @@ void MatMulReadFnSource(ShaderHelper& shader, ? ", batch_indices: batch_dims_indices_t" : "") << ") -> " << type_string << " {\n " - << " var value = " << type_string << "(0.0);\n" + << " var value = " << type_string << "(0);\n" << " let col = colIn * " << components << ";\n"; if (transB) { diff --git a/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json b/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json index df1940ed6416b..1c64eaf2596e2 100644 --- a/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json +++ b/onnxruntime/core/providers/webgpu/wgsl_templates/package-lock.json @@ -9,13 +9,13 @@ "version": "1.0.0", "license": "MIT", "dependencies": { - "@fs-eire/wgsl-template": "^0.1.13" + "@fs-eire/wgsl-template": "^0.1.14" } }, "node_modules/@fs-eire/wgsl-template": { - "version": "0.1.13", - "resolved": "https://registry.npmjs.org/@fs-eire/wgsl-template/-/wgsl-template-0.1.13.tgz", - "integrity": "sha512-SOQjVCQCUmXb9qYr2E3CKNs88/FzINuhFJiobBEkSAsyKtJby9oFWGZnrEO+hIl/oDTLA01LbjiDxuf6TGHE/w==", + "version": "0.1.14", + "resolved": "https://registry.npmjs.org/@fs-eire/wgsl-template/-/wgsl-template-0.1.14.tgz", + "integrity": "sha512-tEECRz7gAWH+NWGRWGNk9IIs4LRpYN9wsy9jzJklopvDcHpg3If8XY3tphggnpr9Otq4jUiqCLZCfQIBg515Mg==", "license": "MIT", "dependencies": { "minimist": "^1.2.8" diff --git a/onnxruntime/core/providers/webgpu/wgsl_templates/package.json b/onnxruntime/core/providers/webgpu/wgsl_templates/package.json index 246e7365531e0..29112c5c66f6c 100644 --- a/onnxruntime/core/providers/webgpu/wgsl_templates/package.json +++ b/onnxruntime/core/providers/webgpu/wgsl_templates/package.json @@ -10,6 +10,6 @@ "author": "", "license": "MIT", "dependencies": { - "@fs-eire/wgsl-template": "^0.1.13" + "@fs-eire/wgsl-template": "^0.1.14" } } diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index dfb2e33f8cb32..39b785c327d56 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -182,11 +182,6 @@ Status Environment::UnregisterAllocatorImpl(const OrtMemoryInfo& mem_info, bool shared_ort_allocators_.erase(it2); } - // also remove an arena wrapped allocator from an EP if the user called CreateSharedAllocator to create one - if (auto it3 = arena_ort_allocators_.find(&mem_info); it3 != arena_ort_allocators_.end()) { - arena_ort_allocators_.erase(it3); - } - if (found_shared_allocator) { shared_allocators_.erase(it); } @@ -436,6 +431,10 @@ Environment::~Environment() { // instance and will call Release on it. If the plugin EP has been freed the Release will fail. shared_allocators_.clear(); + // and as any OrtAllocator instances in shared_ort_allocators_ were owned by values in shared_allocators_ and have + // now been released we need to clear that too before calling UnregisterExecutionProviderLibrary(). + shared_ort_allocators_.clear(); + #if !defined(ORT_MINIMAL_BUILD) // unregister any remaining EP libraries so they're cleaned up in a determistic way. while (!ep_libraries_.empty()) { @@ -673,11 +672,6 @@ Status Environment::CreateSharedAllocatorImpl(const OrtEpDevice& ep_device, shared_ort_allocators_.erase(it); } - // if a previous call created an arena wrapped allocator for the EP's memory_info we also need to remove that - if (auto it = arena_ort_allocators_.find(&memory_info); it != arena_ort_allocators_.end()) { - arena_ort_allocators_.erase(it); - } - // we only want one shared allocator for an OrtDevice in the shared_allocators_ so that it's deterministic which // one will be used for an inference session. ignore the name so that is the case. if (auto it = FindExistingAllocator(shared_allocators_, memory_info, /*match_name*/ false); diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 1baa6e529cbde..4196ed280a993 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -5444,8 +5444,59 @@ TEST(QDQTransformerTests, WeightBiasQuantization_Conv_Weight_Bias) { #endif } +// Tests that the WeightBiasQuantization optimizer still processes nodes that contain a type-preserving no +// branch ReLU op to QuantizeLinear e.g., Q -> DQ -> Conv (w/ float weight initializer) -> ReLU -> Q -> DQ +TEST(QDQTransformerTests, WeightBiasQuantization_ConvWithReLU) { + auto test_case = [](bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + NodeArg* input_fp32 = builder.MakeInput({1, 1, 4, 4}, -1.0f, 1.0f); + NodeArg* weight_fp32 = builder.MakeInitializer({2, 1, 3, 3}, -1.0f, 1.0f); + NodeArg* input_q = builder.MakeIntermediate(); + NodeArg* input_dq = builder.MakeIntermediate(); + NodeArg* conv_fp32 = builder.MakeIntermediate(); + NodeArg* relu_fp32 = builder.MakeIntermediate(); + NodeArg* relu_q = builder.MakeIntermediate(); + NodeArg* relu_dq = builder.MakeOutput(); + builder.AddQuantizeLinearNode(input_fp32, 0.18f, static_cast(127), input_q, use_contrib_qdq); + builder.AddDequantizeLinearNode(input_q, 0.18f, static_cast(127), input_dq, use_contrib_qdq); + auto& conv_node = builder.AddNode("Conv", {input_dq, weight_fp32}, {conv_fp32}); + conv_node.AddAttribute("dilations", std::vector{1, 1}); + conv_node.AddAttribute("kernel_shape", std::vector{3, 3}); + conv_node.AddAttribute("strides", std::vector{1, 1}); + conv_node.AddAttribute("group", static_cast(1)); + conv_node.AddAttribute("pads", std::vector{0, 0, 0, 0}); + builder.AddNode("Relu", {conv_fp32}, {relu_fp32}); + builder.AddQuantizeLinearNode(relu_fp32, 0.69f, static_cast(127), relu_q, use_contrib_qdq); + builder.AddDequantizeLinearNode(relu_q, 0.69f, static_cast(127), relu_dq, use_contrib_qdq); + }; + + // Conv's weights should be quantized and folded, one additional Q/DQ pair inserted for weight + auto check_transformed_graph = [](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_to_count["QuantizeLinear"] + op_to_count["com.microsoft.QuantizeLinear"], 2 + 1); + EXPECT_EQ(op_to_count["DequantizeLinear"] + op_to_count["com.microsoft.DequantizeLinear"], 2 + 1); + EXPECT_EQ(op_to_count["Conv"], 1); + EXPECT_EQ(op_to_count["Relu"], 1); + }; + + TransformerTester(build_test_case, + check_transformed_graph, + TransformerLevel::Default, + TransformerLevel::Level1, + /*opset_version=*/20, + /*per_sample_tolerance=*/0.01, + /*relative_per_sample_tolerance=*/0.01, + /*transformer=*/std::make_unique()); + }; + + test_case(false); +#if !defined(DISABLE_CONTRIB_OPS) + test_case(true); +#endif +} + // Tests that the WeightBiasQuantization optimizer does not process nodes that do not -// already have an output that is consumed by a single QuantizeLinear node. +// already have an output that is consumed by a valid path to QuantizeLinear node. TEST(QDQTransformerTests, WeightBiasQuantization_SkipIfOutputNotQuantized) { auto test_case = [](bool add_final_reshape) { auto build_test_case = [&](ModelTestBuilder& builder) { diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index f8f38e5b4e76c..d709e71815839 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -28,7 +28,7 @@ auto initialize_matrix = [](int64_t rows, int64_t cols) { std::vector data; data.reserve(rows * cols); for (int64_t i = 0; i < rows * cols; ++i) { - data.push_back(static_cast((i % 7) + 1)); + data.push_back(((i % 7) + 1)); } return data; }; @@ -40,7 +40,6 @@ enum class BiasType { MNBias, // C shape is {M,N} NBias // C shape is {N} }; - // Helper function to initialize bias data for Gemm tests auto initialize_bias = [](BiasType bias_type, int64_t M, int64_t N) { std::pair, std::vector> result; @@ -52,7 +51,7 @@ auto initialize_bias = [](BiasType bias_type, int64_t M, int64_t N) { case BiasType::MBias: shape = {M, 1}; for (int64_t i = 0; i < M; ++i) { - data.push_back(static_cast((i % 7) + 1)); + data.push_back(((i % 7) + 1)); } break; case BiasType::ScalarBias: @@ -62,13 +61,13 @@ auto initialize_bias = [](BiasType bias_type, int64_t M, int64_t N) { case BiasType::MNBias: shape = {M, N}; for (int64_t i = 0; i < M * N; ++i) { - data.push_back(static_cast((i % 7) + 1)); + data.push_back(((i % 7) + 1)); } break; case BiasType::NBias: shape = {N}; for (int64_t i = 0; i < N; ++i) { - data.push_back(static_cast((i % 7) + 1)); + data.push_back((i % 7) + 1); } break; } @@ -703,56 +702,406 @@ TYPED_TEST(GemmOpTypedTests, TestGemmTransB_1) { } TYPED_TEST(GemmOpTypedTests, TestGemmAlpha) { - OpTester test("Gemm"); + // Test case 1: 2x4 * 4x3 + { + OpTester test("Gemm"); - test.AddAttribute("transA", (int64_t)0); - test.AddAttribute("transB", (int64_t)0); - test.AddAttribute("alpha", 0.5f); - test.AddAttribute("beta", 1.0f); + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 0.5f); + test.AddAttribute("beta", 1.0f); - test.AddInput("A", {2, 4}, - {static_cast(1.0f), static_cast(2.0f), static_cast(3.0f), static_cast(4.0f), - static_cast(-1.0f), static_cast(-2.0f), static_cast(-3.0f), static_cast(-4.0f)}); - test.AddInput("B", {4, 3}, std::vector(12, static_cast(1.0f))); - test.AddInput("C", {3}, std::vector(3, static_cast(1.0f))); - test.AddOutput("Y", {2, 3}, - {static_cast(6.0f), static_cast(6.0f), static_cast(6.0f), - static_cast(-4.0f), static_cast(-4.0f), static_cast(-4.0f)}); - // test.AddOutput("Y", {2, 3}, - // {5.0f, 5.0f, 5.0f, - // -5.0f, -5.0f, -5.0f}); + test.AddInput("A", {2, 4}, + {static_cast(1.0f), static_cast(2.0f), static_cast(3.0f), static_cast(4.0f), + static_cast(-1.0f), static_cast(-2.0f), static_cast(-3.0f), static_cast(-4.0f)}); + test.AddInput("B", {4, 3}, std::vector(12, static_cast(1.0f))); + test.AddInput("C", {3}, std::vector(3, static_cast(1.0f))); + test.AddOutput("Y", {2, 3}, + {static_cast(6.0f), static_cast(6.0f), static_cast(6.0f), + static_cast(-4.0f), static_cast(-4.0f), static_cast(-4.0f)}); #if defined(OPENVINO_CONFIG_GPU) - test.ConfigExcludeEps({kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues + test.ConfigExcludeEps({kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues #else - test.ConfigExcludeEps({kTensorrtExecutionProvider}); // TensorRT: Seg fault in parser + test.ConfigExcludeEps({kTensorrtExecutionProvider}); // TensorRT: Seg fault in parser #endif - test.Config(run_with_tunable_op) - .RunWithConfig(); + test.Config(run_with_tunable_op) + .RunWithConfig(); + } + + // Test case 2: 64x64 * 64x64 + { + OpTester test("Gemm"); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 0.5f); + test.AddAttribute("beta", 1.0f); + + // Create 64x64 matrices with simple pattern + std::vector A_data(64 * 64); + std::vector B_data(64 * 64); + std::vector C_data(64 * 64); + std::vector Y_data(64 * 64); + + // Fill A matrix with pattern + for (int i = 0; i < 64 * 64; ++i) { + A_data[i] = static_cast((i % 7) + 1); + } + + // Fill B matrix with ones + for (int i = 0; i < 64 * 64; ++i) { + B_data[i] = static_cast(1.0f); + } + + // Fill C matrix with pattern + for (int i = 0; i < 64 * 64; ++i) { + C_data[i] = static_cast((i % 3) + 1); + } + + // Calculate expected output: Y = alpha * A * B + beta * C + // Since B is all ones, A * B results in row sums of A + for (int i = 0; i < 64; ++i) { + TypeParam row_sum = static_cast(0.0f); + for (int k = 0; k < 64; ++k) { + row_sum += A_data[i * 64 + k]; + } + for (int j = 0; j < 64; ++j) { + Y_data[i * 64 + j] = static_cast(0.5f) * row_sum + static_cast(1.0f) * C_data[i * 64 + j]; + } + } + + test.AddInput("A", {64, 64}, A_data); + test.AddInput("B", {64, 64}, B_data); + test.AddInput("C", {64, 64}, C_data); + test.AddOutput("Y", {64, 64}, Y_data); + +#if defined(OPENVINO_CONFIG_GPU) + test.ConfigExcludeEps({kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues +#else + test.ConfigExcludeEps({kTensorrtExecutionProvider}); // TensorRT: Seg fault in parser +#endif + test.Config(run_with_tunable_op) + .RunWithConfig(); + } } TYPED_TEST(GemmOpTypedTests, TestGemmBeta) { - OpTester test("Gemm"); + // Test case 1: 2x4 * 4x3 + { + OpTester test("Gemm"); - test.AddAttribute("transA", (int64_t)0); - test.AddAttribute("transB", (int64_t)0); - test.AddAttribute("alpha", 1.0f); - test.AddAttribute("beta", 2.0f); + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 2.0f); - test.AddInput("A", {2, 4}, - {static_cast(1.0f), static_cast(2.0f), static_cast(3.0f), static_cast(4.0f), - static_cast(-1.0f), static_cast(-2.0f), static_cast(-3.0f), static_cast(-4.0f)}); - test.AddInput("B", {4, 3}, std::vector(12, static_cast(1.0f))); - test.AddInput("C", {3}, std::vector(3, static_cast(1.0f))); - test.AddOutput("Y", {2, 3}, - {static_cast(12.0f), static_cast(12.0f), static_cast(12.0f), - static_cast(-8.0f), static_cast(-8.0f), static_cast(-8.0f)}); + test.AddInput("A", {2, 4}, + {static_cast(1.0f), static_cast(2.0f), static_cast(3.0f), static_cast(4.0f), + static_cast(-1.0f), static_cast(-2.0f), static_cast(-3.0f), static_cast(-4.0f)}); + test.AddInput("B", {4, 3}, std::vector(12, static_cast(1.0f))); + test.AddInput("C", {3}, std::vector(3, static_cast(1.0f))); + test.AddOutput("Y", {2, 3}, + {static_cast(12.0f), static_cast(12.0f), static_cast(12.0f), + static_cast(-8.0f), static_cast(-8.0f), static_cast(-8.0f)}); #if defined(OPENVINO_CONFIG_GPU) - test.ConfigExcludeEps({kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues + test.ConfigExcludeEps({kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues #else - test.ConfigExcludeEps({kTensorrtExecutionProvider}); // TensorRT: Seg fault in parser + test.ConfigExcludeEps({kTensorrtExecutionProvider}); // TensorRT: Seg fault in parser #endif - test.Config(run_with_tunable_op) - .RunWithConfig(); + test.Config(run_with_tunable_op) + .RunWithConfig(); + } + + // Test case 2: 64x64 * 64x64 + { + OpTester test("Gemm"); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 2.0f); + + // Create 64x64 matrices with simple pattern + std::vector A_data(64 * 64); + std::vector B_data(64 * 64); + std::vector C_data(64 * 64); + std::vector Y_data(64 * 64); + + // Fill A matrix with pattern + for (int i = 0; i < 64 * 64; ++i) { + A_data[i] = static_cast((i % 7) + 1); + } + + // Fill B matrix with ones + for (int i = 0; i < 64 * 64; ++i) { + B_data[i] = static_cast(1.0f); + } + + // Fill C matrix with pattern + for (int i = 0; i < 64 * 64; ++i) { + C_data[i] = static_cast((i % 3) + 1); + } + + // Calculate expected output: Y = alpha * A * B + beta * C + // Since B is all ones, A * B results in row sums of A + for (int i = 0; i < 64; ++i) { + TypeParam row_sum = static_cast(0.0f); + for (int k = 0; k < 64; ++k) { + row_sum += A_data[i * 64 + k]; + } + for (int j = 0; j < 64; ++j) { + Y_data[i * 64 + j] = static_cast(1.0f) * row_sum + static_cast(2.0f) * C_data[i * 64 + j]; + } + } + + test.AddInput("A", {64, 64}, A_data); + test.AddInput("B", {64, 64}, B_data); + test.AddInput("C", {64, 64}, C_data); + test.AddOutput("Y", {64, 64}, Y_data); + +#if defined(OPENVINO_CONFIG_GPU) + test.ConfigExcludeEps({kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues +#else + test.ConfigExcludeEps({kTensorrtExecutionProvider}); // TensorRT: Seg fault in parser +#endif + test.Config(run_with_tunable_op) + .RunWithConfig(); + } +} + +TYPED_TEST(GemmOpTypedTests, TestGemmZeroAlpha) { + // Test case 1: 2x4 * 4x3, alpha=0, beta=2.0 + { + OpTester test("Gemm"); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 0.0f); + test.AddAttribute("beta", 2.0f); + + test.AddInput("A", {2, 4}, + {static_cast(1.0f), static_cast(2.0f), static_cast(3.0f), static_cast(4.0f), + static_cast(-1.0f), static_cast(-2.0f), static_cast(-3.0f), static_cast(-4.0f)}); + test.AddInput("B", {4, 3}, std::vector(12, static_cast(1.0f))); + test.AddInput("C", {3}, std::vector(3, static_cast(1.0f))); + test.AddOutput("Y", {2, 3}, + {static_cast(2.0f), static_cast(2.0f), static_cast(2.0f), + static_cast(2.0f), static_cast(2.0f), static_cast(2.0f)}); +#if defined(OPENVINO_CONFIG_GPU) + test.ConfigExcludeEps({kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues +#else + test.ConfigExcludeEps({kTensorrtExecutionProvider}); // TensorRT: Seg fault in parser +#endif + test.Config(run_with_tunable_op) + .RunWithConfig(); + } + + // Test case 2: 64x64 * 64x64, alpha=0, beta=2.0 + { + OpTester test("Gemm"); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 0.0f); + test.AddAttribute("beta", 2.0f); + + // Create 64x64 matrices with simple pattern + std::vector A_data(64 * 64); + std::vector B_data(64 * 64); + std::vector C_data(64 * 64); + std::vector Y_data(64 * 64); + + // Fill A matrix with pattern + for (int i = 0; i < 64 * 64; ++i) { + A_data[i] = static_cast((i % 7) + 1); + } + + // Fill B matrix with ones + for (int i = 0; i < 64 * 64; ++i) { + B_data[i] = static_cast(1.0f); + } + + // Fill C matrix with pattern + for (int i = 0; i < 64 * 64; ++i) { + C_data[i] = static_cast((i % 3) + 1); + } + + // Calculate expected output: Y = alpha * A * B + beta * C + // Since alpha=0, Y = beta * C = 2.0 * C + for (int i = 0; i < 64 * 64; ++i) { + Y_data[i] = static_cast(2.0f) * C_data[i]; + } + + test.AddInput("A", {64, 64}, A_data); + test.AddInput("B", {64, 64}, B_data); + test.AddInput("C", {64, 64}, C_data); + test.AddOutput("Y", {64, 64}, Y_data); + +#if defined(OPENVINO_CONFIG_GPU) + test.ConfigExcludeEps({kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues +#else + test.ConfigExcludeEps({kTensorrtExecutionProvider}); // TensorRT: Seg fault in parser +#endif + test.Config(run_with_tunable_op) + .RunWithConfig(); + } +} + +TYPED_TEST(GemmOpTypedTests, TestGemmZeroBeta) { + // Test case 1: 2x4 * 4x3, alpha=2.0, beta=0 + { + OpTester test("Gemm"); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 2.0f); + test.AddAttribute("beta", 0.0f); + + test.AddInput("A", {2, 4}, + {static_cast(1.0f), static_cast(2.0f), static_cast(3.0f), static_cast(4.0f), + static_cast(-1.0f), static_cast(-2.0f), static_cast(-3.0f), static_cast(-4.0f)}); + test.AddInput("B", {4, 3}, std::vector(12, static_cast(1.0f))); + test.AddInput("C", {3}, std::vector(3, static_cast(1.0f))); + test.AddOutput("Y", {2, 3}, + {static_cast(20.0f), static_cast(20.0f), static_cast(20.0f), + static_cast(-20.0f), static_cast(-20.0f), static_cast(-20.0f)}); +#if defined(OPENVINO_CONFIG_GPU) + test.ConfigExcludeEps({kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues +#else + test.ConfigExcludeEps({kTensorrtExecutionProvider}); // TensorRT: Seg fault in parser +#endif + test.Config(run_with_tunable_op) + .RunWithConfig(); + } + + // Test case 2: 64x64 * 64x64, alpha=2.0, beta=0 + { + OpTester test("Gemm"); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 2.0f); + test.AddAttribute("beta", 0.0f); + + // Create 64x64 matrices with simple pattern + std::vector A_data(64 * 64); + std::vector B_data(64 * 64); + std::vector C_data(64 * 64); + std::vector Y_data(64 * 64); + + // Fill A matrix with pattern + for (int i = 0; i < 64 * 64; ++i) { + A_data[i] = static_cast((i % 7) + 1); + } + + // Fill B matrix with ones + for (int i = 0; i < 64 * 64; ++i) { + B_data[i] = static_cast(1.0f); + } + + // Fill C matrix with pattern + for (int i = 0; i < 64 * 64; ++i) { + C_data[i] = static_cast((i % 3) + 1); + } + + // Calculate expected output: Y = alpha * A * B + beta * C + // Since beta=0, Y = alpha * A * B = 2.0 * A * B + // Since B is all ones, A * B results in row sums of A + for (int i = 0; i < 64; ++i) { + TypeParam row_sum = static_cast(0.0f); + for (int k = 0; k < 64; ++k) { + row_sum += A_data[i * 64 + k]; + } + for (int j = 0; j < 64; ++j) { + Y_data[i * 64 + j] = static_cast(2.0f) * row_sum; + } + } + + test.AddInput("A", {64, 64}, A_data); + test.AddInput("B", {64, 64}, B_data); + test.AddInput("C", {64, 64}, C_data); + test.AddOutput("Y", {64, 64}, Y_data); + +#if defined(OPENVINO_CONFIG_GPU) + test.ConfigExcludeEps({kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues +#else + test.ConfigExcludeEps({kTensorrtExecutionProvider}); // TensorRT: Seg fault in parser +#endif + test.Config(run_with_tunable_op) + .RunWithConfig(); + } +} + +TYPED_TEST(GemmOpTypedTests, TestGemmZeroAlphaBeta) { + // Test case 1: 2x4 * 4x3, alpha=0, beta=0 + { + OpTester test("Gemm"); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 0.0f); + test.AddAttribute("beta", 0.0f); + + test.AddInput("A", {2, 4}, + {static_cast(1.0f), static_cast(2.0f), static_cast(3.0f), static_cast(4.0f), + static_cast(-1.0f), static_cast(-2.0f), static_cast(-3.0f), static_cast(-4.0f)}); + test.AddInput("B", {4, 3}, std::vector(12, static_cast(1.0f))); + test.AddInput("C", {3}, std::vector(3, static_cast(1.0f))); + test.AddOutput("Y", {2, 3}, std::vector(6, static_cast(0.0f))); +#if defined(OPENVINO_CONFIG_GPU) + test.ConfigExcludeEps({kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues +#else + test.ConfigExcludeEps({kTensorrtExecutionProvider}); // TensorRT: Seg fault in parser +#endif + test.Config(run_with_tunable_op) + .RunWithConfig(); + } + + // Test case 2: 64x64 * 64x64, alpha=0, beta=0 + { + OpTester test("Gemm"); + + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 0.0f); + test.AddAttribute("beta", 0.0f); + + // Create 64x64 matrices with simple pattern + std::vector A_data(64 * 64); + std::vector B_data(64 * 64); + std::vector C_data(64 * 64); + std::vector Y_data(64 * 64, static_cast(0.0f)); // All zeros + + // Fill A matrix with pattern + for (int i = 0; i < 64 * 64; ++i) { + A_data[i] = static_cast((i % 7) + 1); + } + + // Fill B matrix with ones + for (int i = 0; i < 64 * 64; ++i) { + B_data[i] = static_cast(1.0f); + } + + // Fill C matrix with pattern + for (int i = 0; i < 64 * 64; ++i) { + C_data[i] = static_cast((i % 3) + 1); + } + + // Expected output: Y = alpha * A * B + beta * C = 0 * A * B + 0 * C = 0 + + test.AddInput("A", {64, 64}, A_data); + test.AddInput("B", {64, 64}, B_data); + test.AddInput("C", {64, 64}, C_data); + test.AddOutput("Y", {64, 64}, Y_data); + +#if defined(OPENVINO_CONFIG_GPU) + test.ConfigExcludeEps({kOpenVINOExecutionProvider}); // OpenVINO: Temporarily disabled due to accuracy issues +#else + test.ConfigExcludeEps({kTensorrtExecutionProvider}); // TensorRT: Seg fault in parser +#endif + test.Config(run_with_tunable_op) + .RunWithConfig(); + } } TYPED_TEST(GemmOpTypedTests, TestGemmNaN) { @@ -893,22 +1242,45 @@ TYPED_TEST(GemmOpTypedTests, ZeroKWithBias) { } TYPED_TEST(GemmOpTypedTests, ZeroKWithNoBias) { - OpTester test("Gemm", 13); + // Test case 1: 4x4 + { + OpTester test("Gemm", 13); - test.AddAttribute("transA", static_cast(0)); - test.AddAttribute("transB", static_cast(0)); - test.AddAttribute("alpha", 1.0f); - test.AddAttribute("beta", .0f); + test.AddAttribute("transA", static_cast(0)); + test.AddAttribute("transB", static_cast(0)); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", .0f); - test.AddInput("A", {4, 0}, {}); - test.AddInput("B", {0, 4}, {}); - test.AddOutput("Y", {4, 4}, std::vector(16, static_cast(0.0f))); + test.AddInput("A", {4, 0}, {}); + test.AddInput("B", {0, 4}, {}); + test.AddOutput("Y", {4, 4}, std::vector(16, static_cast(0.0f))); - test.ConfigExcludeEps({kCoreMLExecutionProvider, kNnapiExecutionProvider, - kDmlExecutionProvider, kDnnlExecutionProvider, kQnnExecutionProvider, - kOpenVINOExecutionProvider}) - .Config(run_with_tunable_op) - .RunWithConfig(); + test.ConfigExcludeEps({kCoreMLExecutionProvider, kNnapiExecutionProvider, + kDmlExecutionProvider, kDnnlExecutionProvider, kQnnExecutionProvider, + kOpenVINOExecutionProvider}) + .Config(run_with_tunable_op) + .RunWithConfig(); + } + + // Test case 2: 64x64 with K=0 + { + OpTester test("Gemm", 13); + + test.AddAttribute("transA", static_cast(0)); + test.AddAttribute("transB", static_cast(0)); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", .0f); + + test.AddInput("A", {64, 0}, {}); + test.AddInput("B", {0, 64}, {}); + test.AddOutput("Y", {64, 64}, std::vector(64 * 64, static_cast(0.0f))); + + test.ConfigExcludeEps({kCoreMLExecutionProvider, kNnapiExecutionProvider, + kDmlExecutionProvider, kDnnlExecutionProvider, kQnnExecutionProvider, + kOpenVINOExecutionProvider}) + .Config(run_with_tunable_op) + .RunWithConfig(); + } } TYPED_TEST(GemmOpTypedTests, MissingBias) { @@ -1038,224 +1410,273 @@ TEST(GemmOpTest, SharedPrepackedWeights) { } #endif -TEST(GemmOpTest, GemmOptimizePacked) { - auto run_test = [](int64_t M, int64_t K, int64_t N, BiasType bias_type) { - OpTester test("Gemm", 13); +// Common helper function for GEMM optimize packed tests +auto run_gemm_optimize_packed_test = [](int64_t M, int64_t K, int64_t N, BiasType bias_type, bool transA, bool transB) { + OpTester test("Gemm", 13); - test.AddAttribute("transA", (int64_t)0); - test.AddAttribute("transB", (int64_t)0); - test.AddAttribute("alpha", 1.0f); - test.AddAttribute("beta", 1.0f); + test.AddAttribute("transA", static_cast(transA ? 1 : 0)); + test.AddAttribute("transB", static_cast(transB ? 1 : 0)); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + + // Initialize matrices based on transpose settings + std::vector a_data, b_data; + std::vector a_shape, b_shape; - std::vector a_data = initialize_matrix(M, K); - std::vector b_data = initialize_matrix(K, N); + if (transA) { + a_data = initialize_matrix(K, M); + a_shape = {K, M}; + } else { + a_data = initialize_matrix(M, K); + a_shape = {M, K}; + } - auto [c_data, c_shape] = initialize_bias(bias_type, M, N); - bool has_bias = !c_data.empty(); + if (transB) { + b_data = initialize_matrix(N, K); + b_shape = {N, K}; + } else { + b_data = initialize_matrix(K, N); + b_shape = {K, N}; + } - test.AddInput("A", {M, K}, a_data); - test.AddInput("B", {K, N}, b_data); - if (has_bias) { - test.AddInput("C", c_shape, c_data); - } + // Initialize bias with appropriate shape + auto [c_data, c_shape] = initialize_bias(bias_type, M, N); + bool has_bias = !c_data.empty(); + + test.AddInput("A", a_shape, a_data); + test.AddInput("B", b_shape, b_data); + if (has_bias) { + test.AddInput("C", c_shape, c_data); + } + + // Calculate expected output based on transpose settings + std::vector expected_data(M * N, 0.0f); + for (int64_t i = 0; i < M; ++i) { + for (int64_t j = 0; j < N; ++j) { + float sum = 0.0f; + for (int64_t k = 0; k < K; ++k) { + float a_val, b_val; + + if (transA) { + a_val = a_data[k * M + i]; // A^T[i][k] = A[k][i] + } else { + a_val = a_data[i * K + k]; // A[i][k] + } - // Calculate expected output - std::vector expected_data(M * N, 0.0f); - for (int64_t i = 0; i < M; ++i) { - for (int64_t j = 0; j < N; ++j) { - float sum = 0.0f; - for (int64_t k = 0; k < K; ++k) { - sum += a_data[i * K + k] * b_data[k * N + j]; + if (transB) { + b_val = b_data[j * K + k]; // B^T[k][j] = B[j][k] + } else { + b_val = b_data[k * N + j]; // B[k][j] } - expected_data[i * N + j] = sum + get_bias_value(c_data, bias_type, i, j, N); + + sum += a_val * b_val; } + float matmul_result = sum; + float bias_value = get_bias_value(c_data, bias_type, i, j, N); + expected_data[i * N + j] = matmul_result + bias_value; } + } - test.AddOutput("Y", {M, N}, expected_data); - test.ConfigExcludeEps({kQnnExecutionProvider}) - .Config(run_with_tunable_op) - .RunWithConfig(); - }; - - // Test different matrix sizes with all bias types - std::vector> test_sizes = { - {32, 32, 32}, {64, 64, 64}, {60, 16, 92}, {8, 8, 8}, {128, 128, 128}, {128, 32, 64}, {96, 24, 48}, {48, 48, 120}, {72, 80, 84}, {33, 67, 99}, {1, 1, 1}, {31, 31, 31}}; + test.AddOutput("Y", {M, N}, expected_data); - std::vector bias_types = { - BiasType::noBias, BiasType::MBias, BiasType::ScalarBias, - BiasType::MNBias, BiasType::NBias}; + test.ConfigExcludeEps({kQnnExecutionProvider}) + .Config(run_with_tunable_op) + .RunWithConfig(); +}; - // Run tests with different combinations of matrix sizes and bias types - for (const auto& size : test_sizes) { - for (const auto& bias_type : bias_types) { - run_test(std::get<0>(size), std::get<1>(size), std::get<2>(size), bias_type); - } - } -} +// Parameterized test for GEMM optimize packed variants +struct GemmOptimizePackedParams { + int64_t M, K, N; + BiasType bias_type; + bool transA, transB; -TEST(GemmOpTest, GemmOptimizePackedTransA) { - auto run_test = [](int64_t M, int64_t K, int64_t N, BiasType bias_type) { - OpTester test("Gemm", 13); + // Helper for readable test names + std::string ToString() const { + std::string name = std::to_string(M) + "x" + std::to_string(K) + "x" + std::to_string(N); - test.AddAttribute("transA", (int64_t)1); // A is transposed - test.AddAttribute("transB", (int64_t)0); - test.AddAttribute("alpha", 1.0f); - test.AddAttribute("beta", 1.0f); + // Bias type names + const char* bias_names[] = {"noBias", "MBias", "ScalarBias", "MNBias", "NBias"}; + name += "_" + std::string(bias_names[static_cast(bias_type)]); - std::vector a_data = initialize_matrix(K, M); - std::vector b_data = initialize_matrix(K, N); + name += (transA ? "_transA" : ""); + name += (transB ? "_transB" : ""); + return name; + } +}; - // Initialize bias with appropriate shape - auto [c_data, c_shape] = initialize_bias(bias_type, M, N); - bool has_bias = !c_data.empty(); +class GemmOptimizePackedTest : public ::testing::TestWithParam {}; - test.AddInput("A", {K, M}, a_data); - test.AddInput("B", {K, N}, b_data); - if (has_bias) { - test.AddInput("C", c_shape, c_data); - } +TEST_P(GemmOptimizePackedTest, TestVariants) { + const auto& params = GetParam(); + run_gemm_optimize_packed_test(params.M, params.K, params.N, params.bias_type, + params.transA, params.transB); +} - // Calculate expected output for transposed A - std::vector expected_data(M * N, 0.0f); - for (int64_t i = 0; i < M; ++i) { - for (int64_t j = 0; j < N; ++j) { - float sum = 0.0f; - for (int64_t k = 0; k < K; ++k) { - sum += a_data[k * M + i] * b_data[k * N + j]; - } - expected_data[i * N + j] = sum + get_bias_value(c_data, bias_type, i, j, N); - } - } +// Test parameter generation +std::vector GenerateGemmParams() { + std::vector params; - test.AddOutput("Y", {M, N}, expected_data); - test.ConfigExcludeEps({kQnnExecutionProvider}) - .Config(run_with_tunable_op) - .RunWithConfig(); - }; + std::vector> test_sizes = {{1, 1, 1}, {1, 64, 448}, {2, 3, 4}, {8, 8, 8}, {31, 31, 31}, {32, 32, 32}, {33, 67, 99}, {37, 64, 256}, {48, 48, 120}, {60, 16, 92}, {63, 64, 65}, {64, 64, 64}, {64, 64, 65}, {72, 80, 84}, {96, 24, 48}, {128, 32, 64}, {128, 128, 128}, {129, 129, 129}, {256, 64, 1024}}; - std::vector> test_sizes = { - {32, 32, 32}, {64, 64, 64}, {60, 16, 92}, {8, 8, 8}, {128, 128, 128}, {128, 32, 64}, {96, 24, 48}, {48, 48, 120}, {72, 80, 84}, {33, 67, 99}, {1, 1, 1}, {31, 31, 31}, {2, 3, 4}, {63, 64, 65}, {129, 129, 129}}; + std::vector + bias_types = {BiasType::noBias, BiasType::MBias, BiasType::ScalarBias, BiasType::MNBias, BiasType::NBias}; - std::vector bias_types = { - BiasType::noBias, BiasType::MBias, BiasType::ScalarBias, - BiasType::MNBias, BiasType::NBias}; + // Test all four transpose combinations: (transA, transB) + std::vector> transpose_combinations = { + {false, false}, // No transpose + {true, false}, // Transpose A + {false, true}, // Transpose B + {true, true} // Transpose A and B + }; - // Run tests with different combinations - for (const auto& size : test_sizes) { - for (const auto& bias_type : bias_types) { - run_test(std::get<0>(size), std::get<1>(size), std::get<2>(size), bias_type); + // Generate all combinations + for (const auto& [transA, transB] : transpose_combinations) { + for (const auto& size : test_sizes) { + for (const auto& bias_type : bias_types) { + params.push_back({std::get<0>(size), std::get<1>(size), std::get<2>(size), + bias_type, transA, transB}); + } } } + return params; } -TEST(GemmOpTest, GemmOptimizePackedTransB) { - auto run_test = [](int64_t M, int64_t K, int64_t N, BiasType bias_type) { - OpTester test("Gemm", 13); +INSTANTIATE_TEST_SUITE_P( + GemmOptimizePackedVariants, + GemmOptimizePackedTest, + ::testing::ValuesIn(GenerateGemmParams()), + [](const ::testing::TestParamInfo& info) { + return info.param.ToString(); + }); + +#if defined(USE_WEBGPU) +// Test int32 with M=128, K=128, N=128, transA=True +TEST(GemmOpTest, GemmTransA_int32_128x128x128) { + OpTester test("Gemm", 13); - test.AddAttribute("transA", (int64_t)0); - test.AddAttribute("transB", (int64_t)1); - test.AddAttribute("alpha", 1.0f); - test.AddAttribute("beta", 1.0f); + test.AddAttribute("transA", (int64_t)1); // transposeA = 1 + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); - std::vector a_data = initialize_matrix(M, K); - std::vector b_data = initialize_matrix(N, K); + const int64_t M = 128, K = 128, N = 128; - // Initialize bias with appropriate shape - auto [c_data, c_shape] = initialize_bias(bias_type, M, N); - bool has_bias = !c_data.empty(); + // Initialize input matrices with int values + std::vector A_data(K * M); // A shape is {K, M} because transposeA=1 + std::vector B_data(K * N); + std::vector C_data(M * N); - test.AddInput("A", {M, K}, a_data); - test.AddInput("B", {N, K}, b_data); - if (has_bias) { - test.AddInput("C", c_shape, c_data); - } + // Fill A matrix with pattern (will be transposed) + for (int64_t i = 0; i < K * M; ++i) { + A_data[i] = static_cast((i % 7) + 1); + } - // Calculate expected output - std::vector expected_data(M * N, 0.0f); - for (int64_t i = 0; i < M; ++i) { - for (int64_t j = 0; j < N; ++j) { - float sum = 0.0f; - for (int64_t k = 0; k < K; ++k) { - sum += a_data[i * K + k] * b_data[j * K + k]; - } - expected_data[i * N + j] = sum + get_bias_value(c_data, bias_type, i, j, N); + // Fill B matrix with pattern + for (int64_t i = 0; i < K * N; ++i) { + B_data[i] = static_cast((i % 5) + 1); + } + + // Fill C matrix (bias) with small values + for (int64_t i = 0; i < M * N; ++i) { + C_data[i] = static_cast((i % 3) + 1); + } + + // Calculate expected output: Y = alpha * A^T * B + beta * C + std::vector Y_data(M * N, 0); + for (int64_t i = 0; i < M; ++i) { + for (int64_t j = 0; j < N; ++j) { + int64_t sum = 0; + for (int64_t k = 0; k < K; ++k) { + // A is transposed, so A^T[i][k] = A[k][i] + sum += static_cast(A_data[k * M + i]) * static_cast(B_data[k * N + j]); } + Y_data[i * N + j] = static_cast(sum + C_data[i * N + j]); // alpha=1.0, beta=1.0 } + } - test.AddOutput("Y", {M, N}, expected_data); - test.ConfigExcludeEps({kQnnExecutionProvider}) - .Config(run_with_tunable_op) - .RunWithConfig(); - }; - - std::vector> test_sizes = { - {32, 32, 32}, {64, 64, 64}, {60, 16, 92}, {8, 8, 8}, {128, 128, 128}, {128, 32, 64}, {96, 24, 48}, {48, 48, 120}, {72, 80, 84}, {33, 67, 99}, {1, 1, 1}, {31, 31, 31}, {2, 3, 4}, {63, 64, 65}, {129, 129, 129}}; + test.AddInput("A", {K, M}, A_data); // A shape is {K, M} because transA=True + test.AddInput("B", {K, N}, B_data); + test.AddInput("C", {M, N}, C_data); + test.AddOutput("Y", {M, N}, Y_data); - std::vector bias_types = { - BiasType::noBias, BiasType::MBias, BiasType::ScalarBias, - BiasType::MNBias, BiasType::NBias}; + test.ConfigExcludeEps({kQnnExecutionProvider, kCpuExecutionProvider}) + .Config(run_with_tunable_op) + .RunWithConfig(); +} +#endif // defined(USE_WEBGPU) - // Run tests with different combinations - for (const auto& size : test_sizes) { - for (const auto& bias_type : bias_types) { - run_test(std::get<0>(size), std::get<1>(size), std::get<2>(size), bias_type); - } +// Test f16 with M=32, K=32, N=128 +TEST(GemmOpTest, GemmTransB_f16_32x32x128) { +#ifdef USE_CUDA + int min_cuda_architecture = 530; + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware NOT support FP16"; + return; } -} +#endif -TEST(GemmOpTest, GemmOptimizePackedTransAB) { - auto run_test = [](int64_t M, int64_t K, int64_t N, BiasType bias_type) { - OpTester test("Gemm", 13); + // 32x32, 32x128 matrix multiplication test with transB=True, alpha=1.0, beta=1.0 + const int64_t M = 32, K = 32, N = 128; - test.AddAttribute("transA", (int64_t)1); - test.AddAttribute("transB", (int64_t)1); - test.AddAttribute("alpha", 1.0f); - test.AddAttribute("beta", 1.0f); + // Initialize input matrices with simple pattern + std::vector A_f32(M * K); + std::vector B_f32(N * K); // Note: B is N×K because transB=True + std::vector C_f32(M * N); - std::vector a_data = initialize_matrix(M, K); - std::vector b_data = initialize_matrix(N, K); + // Fill A matrix with pattern + for (int64_t i = 0; i < M * K; ++i) { + A_f32[i] = ((i % 7) + 1) * 0.1f; + } - // Initialize bias with appropriate shape - auto [c_data, c_shape] = initialize_bias(bias_type, M, N); - bool has_bias = !c_data.empty(); + // Fill B matrix with pattern (will be transposed) + for (int64_t i = 0; i < N * K; ++i) { + B_f32[i] = ((i % 5) + 1) * 0.1f; + } - test.AddInput("A", {K, M}, a_data); - test.AddInput("B", {N, K}, b_data); - if (has_bias) { - test.AddInput("C", c_shape, c_data); - } + // Fill C matrix (bias) with small values + for (int64_t i = 0; i < M * N; ++i) { + C_f32[i] = ((i % 3) + 1) * 0.01f; + } - // Calculate expected output for both matrices transposed - std::vector expected_data(M * N, 0.0f); - for (int64_t i = 0; i < M; ++i) { - for (int64_t j = 0; j < N; ++j) { - float sum = 0.0f; - for (int64_t k = 0; k < K; ++k) { - sum += a_data[k * M + i] * b_data[j * K + k]; - } - expected_data[i * N + j] = sum + get_bias_value(c_data, bias_type, i, j, N); + // Convert to MLFloat16 + std::vector f_A(M * K); + std::vector f_B(N * K); + std::vector f_C(M * N); + + ConvertFloatToMLFloat16(A_f32.data(), f_A.data(), M * K); + ConvertFloatToMLFloat16(B_f32.data(), f_B.data(), N * K); + ConvertFloatToMLFloat16(C_f32.data(), f_C.data(), M * N); + + // Calculate expected output: Y = alpha * A * B^T + beta * C + std::vector Y_f32(M * N, 0.0f); + for (int64_t i = 0; i < M; ++i) { + for (int64_t j = 0; j < N; ++j) { + float sum = 0.0f; + for (int64_t k = 0; k < K; ++k) { + // B is transposed, so B^T[k][j] = B[j][k] + sum += A_f32[i * K + k] * B_f32[j * K + k]; } + Y_f32[i * N + j] = 1.0f * sum + 1.0f * C_f32[i * N + j]; // alpha=1.0, beta=1.0 } + } - test.AddOutput("Y", {M, N}, expected_data); - test.ConfigExcludeEps({kQnnExecutionProvider}) - .Config(run_with_tunable_op) - .RunWithConfig(); - }; - - std::vector> test_sizes = { - {32, 32, 32}, {64, 64, 64}, {60, 16, 92}, {8, 8, 8}, {128, 128, 128}, {128, 32, 64}, {96, 24, 48}, {48, 48, 120}, {72, 80, 84}, {33, 67, 99}, {1, 1, 1}, {31, 31, 31}, {2, 3, 4}, {63, 64, 65}, {64, 64, 65}, {129, 129, 129}}; - - std::vector bias_types = { - BiasType::noBias, BiasType::MBias, BiasType::ScalarBias, - BiasType::MNBias, BiasType::NBias}; + // Convert expected output to MLFloat16 + std::vector f_Y(M * N); + ConvertFloatToMLFloat16(Y_f32.data(), f_Y.data(), M * N); - // Run tests with different combinations - for (const auto& size : test_sizes) { - for (const auto& bias_type : bias_types) { - run_test(std::get<0>(size), std::get<1>(size), std::get<2>(size), bias_type); - } - } + OpTester test("Gemm", 13); + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)1); // transB = True + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); + test.AddInput("A", {M, K}, f_A); + test.AddInput("B", {N, K}, f_B); // B shape is {N, K} because transB=True + test.AddInput("C", {M, N}, f_C); + test.AddOutput("Y", {M, N}, f_Y); + test.SetOutputTolerance(0.01f); + test.ConfigExcludeEps({kTensorrtExecutionProvider}) // TensorRT: fp16 is not supported + .Config(run_with_tunable_op) + .RunWithConfig(); } } // namespace test diff --git a/onnxruntime/test/providers/qnn/qnn_node_group/lpbqgemm_fusion_test.cc b/onnxruntime/test/providers/qnn/qnn_node_group/lpbqgemm_fusion_test.cc index 773e6146f102f..b349e0c40882f 100644 --- a/onnxruntime/test/providers/qnn/qnn_node_group/lpbqgemm_fusion_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_node_group/lpbqgemm_fusion_test.cc @@ -107,7 +107,12 @@ ProviderOptions GetProviderOptions() { } // namespace +#if defined(_WIN32) +// Graph fails to compose on ARM64 Windows since QNN 2.37.0 +TEST_F(QnnHTPBackendTests, DISABLED_LPBQGemmFusion) { +#else TEST_F(QnnHTPBackendTests, LPBQGemmFusion) { +#endif ProviderOptions provider_options = GetProviderOptions(); RunQnnModelTest(BuildLPBQGemmTestCase(), provider_options, diff --git a/onnxruntime/test/providers/qnn/qnn_node_group/lpbqmatmul_fusion_test.cc b/onnxruntime/test/providers/qnn/qnn_node_group/lpbqmatmul_fusion_test.cc index 92c0895cd691e..8f63ccd5f2cd1 100644 --- a/onnxruntime/test/providers/qnn/qnn_node_group/lpbqmatmul_fusion_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_node_group/lpbqmatmul_fusion_test.cc @@ -106,7 +106,12 @@ ProviderOptions GetProviderOptions() { } // namespace +#if defined(_WIN32) +// Graph fails to compose on ARM64 Windows since QNN 2.37.0 +TEST_F(QnnHTPBackendTests, DISABLED_LPBQMatMulFusion) { +#else TEST_F(QnnHTPBackendTests, LPBQMatMulFusion) { +#endif ProviderOptions provider_options = GetProviderOptions(); RunQnnModelTest(BuildLPBQMatMulTestCase(), provider_options, diff --git a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc index 706bd3c0fce62..95c5a0ab97728 100644 --- a/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc +++ b/onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc @@ -685,7 +685,7 @@ TEST(TensorrtExecutionProviderTest, TRTPluginsCustomOpTest) { auto cuda_provider = DefaultCudaExecutionProvider(); auto cpu_allocator = cuda_provider->CreatePreferredAllocators()[1]; std::vector dims_op_x = {12, 256, 256}; - std::vector values_op_x(1.0f, 786432); // 786432=12*256*256 + std::vector values_op_x(786432, 1.0f); // 786432=12*256*256 OrtValue ml_value_x; CreateMLValue(cpu_allocator, dims_op_x, values_op_x, &ml_value_x); OrtValue ml_value_y;