diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc index 3e70f848675cb..d5b8961cf8c5a 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc @@ -96,79 +96,6 @@ void ComputeJob( } } -void ComputeJob( - const MLFloat16* input_data, - const MLFloat16* skip_data, - const float* prepacked_skip_fp32_data, - const float* gamma_float_ptr, - const float* beta_float_ptr, - const float* bias_float_ptr, - float* output_float_ptr, - ptrdiff_t task_idx, - int hidden_size, - int64_t skip_size, - float epsilon, - bool simplified, - MLFloat16* output_data, - MLFloat16* skip_input_bias_add_output_data, - AllocatorPtr alloc) { - auto offset = task_idx * hidden_size; - const MLFloat16* p_input = input_data + offset; - MLFloat16* p_output = output_data + offset; - MLFloat16* p_skip_input_bias_add_output = skip_input_bias_add_output_data == nullptr ? nullptr : skip_input_bias_add_output_data + offset; - - float mean(0.0f); - float mean_square(0.0f); - const size_t num_elems = static_cast(hidden_size); - - IAllocatorUniquePtr input_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); - MlasConvertHalfToFloatBuffer(p_input, input_float_uptr.get(), num_elems); - - IAllocatorUniquePtr skip_float_uptr = nullptr; - if (prepacked_skip_fp32_data == nullptr && skip_data) { - const MLFloat16* p_skip = skip_data + (offset % skip_size); - skip_float_uptr = IAllocator::MakeUniquePtr(alloc, num_elems); - MlasConvertHalfToFloatBuffer(p_skip, skip_float_uptr.get(), num_elems); - } - - const float* input_float_ptr = input_float_uptr.get(); - const float* skip_float_ptr = prepacked_skip_fp32_data ? prepacked_skip_fp32_data : skip_float_uptr.get(); - for (size_t h = 0; h < num_elems; h++) { - float val = input_float_ptr[h] + skip_float_ptr[h]; - - if (bias_float_ptr) { - val += bias_float_ptr[h]; - } - - output_float_ptr[h] = val; - mean += val; - mean_square += val * val; - } - - if (nullptr != p_skip_input_bias_add_output) { - MlasConvertFloatToHalfBuffer(output_float_ptr, p_skip_input_bias_add_output, num_elems); - } - - mean = mean / hidden_size; - if (simplified) { - mean_square = sqrt(mean_square / hidden_size + epsilon); - } else { - mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon); - } - - for (size_t h = 0; h < num_elems; h++) { - if (simplified) { - output_float_ptr[h] = output_float_ptr[h] / mean_square * gamma_float_ptr[h]; - } else if (nullptr == beta_float_ptr) { - output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h]; - } else { - output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * gamma_float_ptr[h] + beta_float_ptr[h]; - } - } - - MlasConvertFloatToHalfBuffer(output_float_ptr, p_output, num_elems); -} - void ConvertMLFloat16ToFloatIfNeeded(const Tensor& tensor, AllocatorPtr alloc, IAllocatorUniquePtr& dest, bool& is_packed) { if (tensor.GetElementType() == utils::ToTensorProtoElementType()) { auto tensor_data_ptr = tensor.Data(); @@ -200,8 +127,8 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { const Tensor* input = p_ctx->Input(0); const Tensor* skip = prepacked_skip_fp32_data_ ? nullptr : p_ctx->Input(1); const Tensor* gamma = prepacked_gamma_fp32_data_ ? nullptr : p_ctx->Input(2); - const Tensor* beta = prepacked_beta_fp32_data_ ? nullptr : p_ctx->Input(3); - const Tensor* bias = prepacked_bias_fp32_data_ ? nullptr : p_ctx->Input(4); + const Tensor* beta = simplified ? nullptr : (prepacked_beta_fp32_data_ ? nullptr : p_ctx->Input(3)); + const Tensor* bias = prepacked_bias_fp32_data_ ? nullptr : p_ctx->Input(simplified ? 3 : 4); Tensor* output = p_ctx->Output(0, input->Shape()); // For inferencing, we support one more optional output which is the sum of the input and skip tensors Tensor* skip_input_bias_add_output = p_ctx->Output(3, input->Shape()); @@ -232,56 +159,93 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { // For inferencing, we support one more optional output which is the sum of the input and skip tensors T* skip_input_bias_add_output_data = skip_input_bias_add_output == nullptr ? nullptr : skip_input_bias_add_output->MutableData(); - const int64_t skip_size = skip ? skip->Shape().Size() : prepacked_skip_fp32_size_; - AllocatorPtr alloc; - ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc)); - - IAllocatorUniquePtr output_fp32; - IAllocatorUniquePtr gamma_fp32; - IAllocatorUniquePtr beta_fp32; - IAllocatorUniquePtr bias_fp32; - if constexpr (std::is_same_v) { + const size_t total_data_size = static_cast(input->Shape().Size()); + + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc)); + + IAllocatorUniquePtr input_fp32; + IAllocatorUniquePtr output_fp32; + IAllocatorUniquePtr skip_input_bias_add_output_fp32; + IAllocatorUniquePtr skip_fp32; + IAllocatorUniquePtr gamma_fp32; + IAllocatorUniquePtr beta_fp32; + IAllocatorUniquePtr bias_fp32; + + const float* input_data_f = nullptr; + const float* skip_data_f = nullptr; + const float* gamma_data_f = nullptr; + const float* beta_data_f = nullptr; + const float* bias_data_f = nullptr; + float* output_data_f = nullptr; + float* skip_input_bias_add_output_data_f = nullptr; + const size_t num_elems = static_cast(hidden_size); - output_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); + input_fp32 = IAllocator::MakeUniquePtr(alloc, total_data_size); + MlasConvertHalfToFloatBuffer(input_data, input_fp32.get(), total_data_size); + input_data_f = input_fp32.get(); + + output_fp32 = IAllocator::MakeUniquePtr(alloc, total_data_size); + output_data_f = output_fp32.get(); + + skip_input_bias_add_output_fp32 = IAllocator::MakeUniquePtr(alloc, total_data_size); + skip_input_bias_add_output_data_f = skip_input_bias_add_output_fp32.get(); - if (prepacked_gamma_fp32_data_ == nullptr && gamma_data) { + if (skip_data) { + skip_fp32 = IAllocator::MakeUniquePtr(alloc, static_cast(skip_size)); + MlasConvertHalfToFloatBuffer(skip_data, skip_fp32.get(), static_cast(skip_size)); + skip_data_f = skip_fp32.get(); + } else if (prepacked_skip_fp32_data_) { + skip_data_f = prepacked_skip_fp32_data_.get(); + } + + if (gamma_data) { gamma_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); MlasConvertHalfToFloatBuffer(gamma_data, gamma_fp32.get(), num_elems); + gamma_data_f = gamma_fp32.get(); + } else if (prepacked_gamma_fp32_data_) { + gamma_data_f = prepacked_gamma_fp32_data_.get(); } - if (prepacked_beta_fp32_data_ == nullptr && beta_data) { + if (beta_data) { beta_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); MlasConvertHalfToFloatBuffer(beta_data, beta_fp32.get(), num_elems); + beta_data_f = beta_fp32.get(); + } else if (prepacked_beta_fp32_data_) { + beta_data_f = prepacked_beta_fp32_data_.get(); } - if (prepacked_bias_fp32_data_ == nullptr && bias_data) { + if (bias_data) { bias_fp32 = IAllocator::MakeUniquePtr(alloc, num_elems); MlasConvertHalfToFloatBuffer(bias_data, bias_fp32.get(), num_elems); + bias_data_f = bias_fp32.get(); + } else if (prepacked_bias_fp32_data_) { + bias_data_f = prepacked_bias_fp32_data_.get(); } - } - concurrency::ThreadPool::TryBatchParallelFor( - p_ctx->GetOperatorThreadPool(), static_cast(task_count), - [&](ptrdiff_t task_idx) { - if constexpr (std::is_same_v) { - ComputeJob(input_data, skip_data, - prepacked_skip_fp32_data_.get(), - prepacked_gamma_fp32_data_ ? prepacked_gamma_fp32_data_.get() : gamma_fp32.get(), - prepacked_beta_fp32_data_ ? prepacked_beta_fp32_data_.get() : beta_fp32.get(), - prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get() : bias_fp32.get(), - output_fp32.get(), - task_idx, hidden_size, skip_size, epsilon_, simplified, output_data, - skip_input_bias_add_output_data, alloc); - } else { + concurrency::ThreadPool::TryBatchParallelFor( + p_ctx->GetOperatorThreadPool(), static_cast(task_count), + [&](ptrdiff_t task_idx) { + ComputeJob(input_data_f, skip_data_f, gamma_data_f, beta_data_f, bias_data_f, task_idx, hidden_size, skip_size, + epsilon_, simplified, output_data_f, skip_input_bias_add_output_data_f); + }, + 0); + MlasConvertFloatToHalfBuffer(output_data_f, output_data, total_data_size); + if (skip_input_bias_add_output_data != nullptr) + MlasConvertFloatToHalfBuffer(skip_input_bias_add_output_data_f, skip_input_bias_add_output_data, total_data_size); + } else { + concurrency::ThreadPool::TryBatchParallelFor( + p_ctx->GetOperatorThreadPool(), static_cast(task_count), + [&](ptrdiff_t task_idx) { ComputeJob(input_data, skip_data, gamma_data, beta_data, bias_data, task_idx, hidden_size, skip_size, epsilon_, simplified, output_data, skip_input_bias_add_output_data); - } - }, - 0); + }, + 0); + } return Status::OK(); } @@ -290,16 +254,22 @@ template Status SkipLayerNorm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, PrePackedWeights* prepacked_weights) { ORT_UNUSED_PARAMETER(prepacked_weights); - is_packed = false; if (input_idx == 1) { // skip prepacked_skip_fp32_size_ = tensor.Shape().Size(); ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_skip_fp32_data_, is_packed); } else if (input_idx == 2) { // gamma ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_gamma_fp32_data_, is_packed); - } else if (input_idx == 3) { // beta - ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_beta_fp32_data_, is_packed); + } else if (input_idx == 3) { + if constexpr (simplified) { + // bias + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed); + } else { + // beta + ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_beta_fp32_data_, is_packed); + } } else if (input_idx == 4) { // bias + ORT_ENFORCE(!simplified, "SkipSimplifiedLayerNormalization should only has 4 inputs (input, skip, gamma, and beta). Got 5."); ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed); } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc index d089235ceaa02..d1a0e88686f39 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/layer_norm_op_builder.cc @@ -87,10 +87,10 @@ Status LayerNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[BIAS_IDX], logger, input_names)); } -#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 17) +#if QNN_API_VERSION_MAJOR == 2 && QNN_API_VERSION_MINOR >= 17 && QNN_API_VERSION_MINOR <= 20 if (!has_bias_input && IsNpuBackend(qnn_model_wrapper.GetQnnBackendType())) { - // Bias is implicit. QNN SDK 2.24+ (QNN API version 2.17+) has a validation bug for implicit bias inputs, - // so provide an explicit bias of all 0 (quantized int32). + // Bias is implicit. QNN SDK 2.24 to 2.27 (QNN API version 2.17 to 2.20) has a validation bug for + // implicit bias inputs, so provide an explicit bias of all 0 (quantized int32). TensorInfo x_input_info = {}; ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[X_IDX], x_input_info)); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index bfc2102bdaac2..f37c91aa0413b 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -14,6 +14,7 @@ #include "DSP/QnnDspCommon.h" #include "HTP/QnnHtpCommon.h" #include "HTP/QnnHtpContext.h" +#include "Saver/QnnSaver.h" #include #include "core/framework/endian_utils.h" #include "core/common/logging/capture.h" @@ -1040,7 +1041,14 @@ Status QnnBackendManager::ExtractBackendProfilingInfo() { const QnnProfile_EventId_t* profile_events{nullptr}; uint32_t num_events{0}; Qnn_ErrorHandle_t result = qnn_interface_.profileGetEvents(profile_backend_handle_, &profile_events, &num_events); - ORT_RETURN_IF(QNN_PROFILE_NO_ERROR != result, "Failed to get profile events. Error: ", QnnErrorHandleToString(result)); + if (!qnn_saver_path_.empty()) { // Using QNN Saver backend + // QNN SDK 2.28.2 returns QNN_SAVER_ERROR_DUMMY_RETVALUE, but previous QNN versions return QNN_PROFILE_NO_ERROR. + // We accept both values. + ORT_RETURN_IF(QNN_PROFILE_NO_ERROR != result && QNN_SAVER_ERROR_DUMMY_RETVALUE != result, + "Failed to get profile events. Error: ", QnnErrorHandleToString(result)); + } else { + ORT_RETURN_IF(QNN_PROFILE_NO_ERROR != result, "Failed to get profile events. Error: ", QnnErrorHandleToString(result)); + } if (num_events > 0) { LOGS(*logger_, VERBOSE) << "profile_events: " << profile_events << " num_events: " << num_events; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 97d88786e4bcd..66e8e175cb37d 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1725,6 +1725,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv runtime_ = std::unique_ptr(nvinfer1::createInferRuntime(GetTensorrtLogger(detailed_build_log_))); } + trt_version_ = getInferLibVersion(); + + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] TensorRT version is " << trt_version_; + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] TensorRT provider options: " << "device_id: " << device_id_ << ", trt_max_partition_iterations: " << max_partition_iterations_ @@ -2462,10 +2466,30 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, std::vector nodes_vector(number_of_ort_nodes); std::iota(std::begin(nodes_vector), std::end(nodes_vector), 0); - std::vector filtered_nodes_vector; + std::set exclude_ops_set; + + /* + * There is a known performance issue with the DDS ops (NonMaxSuppression, NonZero and RoiAlign) in TRT 10. + * TRT EP automatically excludes DDS ops from running on TRT. + */ + if (trt_version_ >= 100000 && trt_version_ < 110000) { + exclude_ops_set.insert("NonMaxSuppression"); + exclude_ops_set.insert("NonZero"); + exclude_ops_set.insert("RoiAlign"); + LOGS_DEFAULT(VERBOSE) << "There is a known performance issue with the DDS ops (NonMaxSuppression, NonZero and RoiAlign) in TRT 10. TRT EP automatically excludes DDS ops from running on TRT, if applicable"; + } + + SubGraphCollection_t parser_nodes_vector, supported_nodes_vector; const std::vector& node_index = graph.GetNodesInTopologicalOrder(1 /*priority-based topological sort*/); + bool new_subgraph = true; + + /* Iterate all the nodes and exclude the node if: + * 1. It's a control flow op and its subgraph(s) is not fully TRT eligible. + * 2. It's a DDS op. + */ for (const auto& index : nodes_vector) { const auto& node = graph.GetNode(node_index[index]); + bool supported_node = true; /* If current node is control flow op, we take different approach based on following four cases: * @@ -2477,29 +2501,43 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, * For cases 2, 3, 4, even though the control flow op is not assigned to TRT, any portion of its subgraphs that can run in TRT will be still fused and assigned to TRT EP. */ if (control_flow_op_set_.find(node->OpType()) != control_flow_op_set_.end()) { - auto sub_graphs = node->GetSubgraphs(); - if (sub_graphs.size() != 0) { - bool all_subgraphs_are_supported = true; - for (auto sub_graph : sub_graphs) { - // TRT EP should consider the empty subgraph is fully supported by TRT. - if (sub_graph->CreateGraphViewer()->NumberOfNodes() == 0) { - continue; - } - if (!AllNodesAssignedToSpecificEP(*(sub_graph->CreateGraphViewer()), kTensorrtExecutionProvider)) { - all_subgraphs_are_supported = false; - break; + auto supported_control_flow_op = [&](const Node* node) { + auto sub_graphs = node->GetSubgraphs(); + if (sub_graphs.size() != 0) { + for (auto sub_graph : sub_graphs) { + // TRT EP should consider the empty subgraph is fully supported by TRT. + if (sub_graph->CreateGraphViewer()->NumberOfNodes() == 0) { + continue; + } + if (!AllNodesAssignedToSpecificEP(*(sub_graph->CreateGraphViewer()), kTensorrtExecutionProvider)) { + // if not all its subgraphs are supported, we need to exclude this control flow op + return false; + } } } - if (!all_subgraphs_are_supported) { - // if not all its subgraphs are supported, we need to exclude this control flow op - continue; - } + return true; + }; + supported_node = supported_control_flow_op(node); + } + + // Exclude any ops, if applicable + if (exclude_ops_set.find(node->OpType()) != exclude_ops_set.end()) { + supported_node = false; + } + + if (supported_node) { + if (new_subgraph) { + parser_nodes_vector.emplace_back(); + // Mark all new graphs as "UnKnown" which will later be parsed by TRT parser + parser_nodes_vector.back().second = false; + new_subgraph = false; } + parser_nodes_vector.back().first.emplace_back(index); + } else { + new_subgraph = true; } - filtered_nodes_vector.push_back(index); } - SubGraphCollection_t supported_nodes_vector, parser_nodes_vector = {{filtered_nodes_vector, false}}; bool early_termination = false; supported_nodes_vector = GetSupportedList(parser_nodes_vector, 0, max_partition_iterations_, graph, &early_termination); if (early_termination) { diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 97c9367b0bb61..0e9c11f7de968 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -329,6 +329,10 @@ class TensorrtExecutionProvider : public IExecutionProvider { bool cuda_graph_enable_ = false; std::string cache_prefix_; bool engine_hw_compatible_ = false; + std::string op_types_to_exclude_; + + // The format is as for TENSORRT_VERSION: (MAJOR * 100 + MINOR) * 100 + PATCH + int32_t trt_version_; // The OrtAllocator object will be get during ep compute time // and should be kept for the lifetime of TRT EP object. diff --git a/onnxruntime/test/providers/qnn/gather_op_htp_test.cc b/onnxruntime/test/providers/qnn/gather_op_htp_test.cc index 019d619f9be49..55177cc7ed131 100644 --- a/onnxruntime/test/providers/qnn/gather_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/gather_op_htp_test.cc @@ -132,6 +132,7 @@ TEST_F(QnnHTPBackendTests, GatherOp_IndicesDynamicInt32_Axis0) { } // disabled for QNN 2.28.0.241029 failed for accuracy validation +// Also fails on QNN 2.28.2. // qdq@QNN_EP val: 3.6094117164611816 (err: 1.3094117641448975, err/output_range: 22.19342041015625%) // qdq@CPU_EP val: 2.2905881404876709 (err: 0.0094118118286132812, err/output_range: 0.15952222049236298%) // abs(qdq@QNN_EP - qdq@CPU_EP) / output_range = 22.033897399902344% diff --git a/onnxruntime/test/providers/qnn/layer_norm_test.cc b/onnxruntime/test/providers/qnn/layer_norm_test.cc index 2773568dde717..947ac19be40a8 100644 --- a/onnxruntime/test/providers/qnn/layer_norm_test.cc +++ b/onnxruntime/test/providers/qnn/layer_norm_test.cc @@ -188,15 +188,11 @@ TEST_F(QnnHTPBackendTests, LayerNorm1D_LastAxis_StaticScale_StaticBias_AU8_WU8_B ExpectedEPNodeAssignment::All); } -// QNN 2.27 accuracy issue -// Inaccuracy detected for output 'output_0', element 0 -// output_range=1.2245157957077026, tolerance=0.40000000596046448%. -// Expected val (f32@CPU_EP): -0 -// qdq@QNN_EP val: 0.19133351743221283 (err: 0.19133351743221283, err/output_range: 15.625238418579102%) -// qdq@CPU_EP val: 0 (err: 0, err/output_range: 0%) -TEST_F(QnnHTPBackendTests, DISABLED_LayerNorm1D_QNN2_24_ImplicitBias_ValidationBug) { - // QNN 2.24 LayerNorm fails validation (intermittent) if the bias input is not provided. QNN EP will provide an - // explicit bias of all zeros to get around this bug. +TEST_F(QnnHTPBackendTests, LayerNorm1D_QNN2_24_ImplicitBias_ValidationBug) { + // QNN 2.24 to 2.27: LayerNorm fails validation (intermittent) if the bias input is not provided. QNN EP will provide + // an explicit bias of all zeros to get around this bug. + // QNN 2.28.0: Validation bug is fixed, but get accuracy errors. + // QNN 2.28.2: All fixed. for (size_t i = 0; i < 15; i++) { // Run it multiple times since this is an intermittent bug. RunLayerNormQDQTest(TestInputDef({1, 2, 3}, false, GetFloatDataInRange(0.0f, 1.0f, 6)), TestInputDef({3}, true, GetFloatDataInRange(0.0f, 1.0f, 3)), @@ -207,14 +203,9 @@ TEST_F(QnnHTPBackendTests, DISABLED_LayerNorm1D_QNN2_24_ImplicitBias_ValidationB } } -// Test accuracy of 16-bit QDQ LayerNorm with a static scale input. -// QNN 2.27 accuracy issue -// Inaccuracy detected for output 'output_0', element 0 -// output_range=1.224743127822876, tolerance=0.40000000596046448%. -// Expected val (f32@CPU_EP): -0 -// qdq@QNN_EP val: 0.19136904180049896 (err: 0.19136904180049896, err/output_range: 15.625238418579102%) -// qdq@CPU_EP val: 0 (err: 0, err/output_range: 0%) -TEST_F(QnnHTPBackendTests, DISABLED_LayerNorm1D_LastAxis_StaticScale_AU16_WU8) { +TEST_F(QnnHTPBackendTests, LayerNorm1D_LastAxis_StaticScale_AU16_WU8) { + // QNN 2.28.0: Get accuracy errors. + // QNN 2.28.2: All fixed. RunLayerNormQDQTest(TestInputDef({1, 2, 3}, false, GetFloatDataInRange(0.0f, 10.0f, 6)), TestInputDef({3}, true, GetFloatDataInRange(0.0f, 1.0f, 3)), // Static TestInputDef(), @@ -225,7 +216,7 @@ TEST_F(QnnHTPBackendTests, DISABLED_LayerNorm1D_LastAxis_StaticScale_AU16_WU8) { // Test accuracy of 8-bit QDQ LayerNorm with a dynamic scale input. // -// TODO(adrianlizarraga): Fails to finalize with QNN SDK 2.22. +// TODO(adrianlizarraga): Fails to finalize with QNN SDK 2.22. Still fails on QNN SDK 2.28.2. // Verbose logs: // Starting stage: Graph Transformations and Optimizations // C:\...\QNN\HTP\HTP\src\hexagon\prepare\graph_prepare.cc:203:ERROR:could not create op: q::flat_to_vtcm diff --git a/onnxruntime/test/providers/qnn/matmul_test.cpp b/onnxruntime/test/providers/qnn/matmul_test.cpp index 800457d906940..5c6967761b1db 100644 --- a/onnxruntime/test/providers/qnn/matmul_test.cpp +++ b/onnxruntime/test/providers/qnn/matmul_test.cpp @@ -273,7 +273,7 @@ TEST_F(QnnHTPBackendTests, MatMulOp_PerChannel_A16_WeightUInt4) { } // Test QDQ per-channel MatMul with int8 act, int4 weights (static) -// QNN 2.27 regression +// QNN 2.27 regression. Also fails on QNN 2.28.2. // Failed to finalize QNN graph. Error code: 1002 TEST_F(QnnHTPBackendTests, DISABLED_MatMulOp_PerChannel_AS8_WeightInt4) { std::vector input0_data = GetFloatDataInRange(-5.0f, 5.0f, 6); diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 05731976c453f..7541d94bac0c6 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -230,6 +230,7 @@ TEST_F(QnnHTPBackendTests, UnaryOp_Tanh) { } // disabled for QNN 2.28.0.241029 backendValidateOpConfig failed +// still fails on QNN 2.28.2. // QnnDsp [4294967295] has incorrect Value -32768, expected equal to 0. // QnnDsp validateNativeOps node_token_6:qti.aisw:Tanh htp op validator failed 3110 // QnnDsp registered validator failed => 3110 diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index 20252220da8f9..c3dbee336b69d 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.28.0.241029 + default: 2.28.2.241116 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index c1e469509b9bd..00d622fc23c40 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -62,7 +62,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.28.0.241029 + default: 2.28.2.241116 resources: repositories: diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index 03859b1548fd2..d3826d90f9073 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.28.0.241029 + default: 2.28.2.241116 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index 0a18343eee33d..74d13dec582a8 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -69,7 +69,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.28.0.241029 + default: 2.28.2.241116 trigger: none diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index f2c0561368a9e..d54b8018c232a 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -2,7 +2,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.28.0.241029 + default: 2.28.2.241116 - name: build_config displayName: Build Configuration diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml index 97ca94e7ab516..3c64ee04eb913 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.28.0.241029' + default: '2.28.2.241116' steps: - script: | diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml index 6b318664d1b12..9df8b249f681e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.28.0.241029' + default: '2.28.2.241116' steps: - powershell: | diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml index d2ce7c84aa40d..b1cec2284df65 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-qnn.yml @@ -26,7 +26,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.28.0.241029 + default: 2.28.2.241116 jobs: - job: Linux_py_qnn_Wheels_x64 diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml index 2a59e9de9908f..1aba5437c0618 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml @@ -73,7 +73,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.28.0.241029 + default: 2.28.2.241116 stages: - ${{ if eq(parameters.enable_windows_cpu, true) }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index 6adc35568b034..44bb554c20e79 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.28.0.241029 + default: 2.28.2.241116 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index 0a58874d1d478..a67828dd18bb3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.28.0.241029 + default: 2.28.2.241116 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml index 1114477c84454..bf9b791de0c22 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.28.0.241029 + default: 2.28.2.241116 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index 24abf7f6d0872..51122bc1653a6 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -1,5 +1,5 @@ parameters: - QnnSdk: '2.28.0.241029' + QnnSdk: '2.28.2.241116' build_config: 'RelWithDebInfo' IsReleaseBuild: false DoEsrp: false diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 59a8dac9b1988..5c013fae6be0b 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.28.0.241029 + default: 2.28.2.241116 jobs: - job: 'build' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index 6645c9b1f78f3..f442f706fae0b 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -33,7 +33,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.28.0.241029 + default: 2.28.2.241116 jobs: - job: 'build'