diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index d8de7756bae22..ddf37cfded77d 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -258,7 +258,8 @@ Do not modify directly.* |||12|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[7, 11]|**T** = tensor(double), tensor(float)| |QLinearConv|*in* x:**T1**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T1**
*in* w:**T2**
*in* w_scale:**tensor(float)**
*in* w_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*in* B:**T4**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**T4** = tensor(int32)| -|QLinearMatMul|*in* a:**T1**
*in* a_scale:**TS**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**TS**
*in* b_zero_point:**T2**
*in* y_scale:**TS**
*in* y_zero_point:**T3**
*out* y:**T3**

or

*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| +|QLinearMatMul|*in* a:**T1**
*in* a_scale:**TS**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**TS**
*in* b_zero_point:**T2**
*in* y_scale:**TS**
*in* y_zero_point:**T3**
*out* y:**T3**

or

*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|21+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**TS** = tensor(float)| +|||[10, 20]|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|21+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)| |||[19, 20]|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int8), tensor(uint8)| |||[13, 18]|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 424bee63511ad..a8284e4d88693 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -379,8 +379,10 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn QuantizeLinear); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, int8_t, QuantizeLinear); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, QLinearMatMul); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, QLinearMatMul); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 20, uint8_t, + QLinearMatMul); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 20, int8_t, + QLinearMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, MatMulInteger); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, MatMulInteger); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ConvInteger); @@ -1108,6 +1110,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int16_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Int4x2, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, UInt4x2, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, uint8_t, QLinearMatMul); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, int8_t, QLinearMatMul); #if !defined(DISABLE_FLOAT8_TYPES) class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E4M3FN, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 21, Float8E4M3FNUZ, DequantizeLinear); @@ -1691,10 +1695,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { uint8_t, QuantizeLinear)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #if !defined(DISABLE_FLOAT8_TYPES) BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear_matmul.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear_matmul.cc index cb162ade44559..be448455194f6 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear_matmul.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear_matmul.cc @@ -14,10 +14,11 @@ namespace onnxruntime { // uint8_t kernel supports weight being either uint8_t or int8_t -ONNX_OPERATOR_TYPED_KERNEL_EX( +ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( QLinearMatMul, kOnnxDomain, 10, + 20, uint8_t, kCpuExecutionProvider, KernelDefBuilder() @@ -26,21 +27,45 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( .TypeConstraint("T3", DataTypeImpl::GetTensorType()), QLinearMatMul); +ONNX_OPERATOR_TYPED_KERNEL_EX( + QLinearMatMul, + kOnnxDomain, + 21, + uint8_t, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("TS", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + QLinearMatMul); + // int8_t kernel only supports weight being int8_t -#define REGISTER_QLINEARMATMUL_INT8_KERNEL() \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - QLinearMatMul, \ - kOnnxDomain, \ - 10, \ - int8_t, \ - kCpuExecutionProvider, \ - KernelDefBuilder() \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T3", DataTypeImpl::GetTensorType()), \ - QLinearMatMul); - -REGISTER_QLINEARMATMUL_INT8_KERNEL(); +ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( + QLinearMatMul, + kOnnxDomain, + 10, + 20, + int8_t, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + QLinearMatMul); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + QLinearMatMul, + kOnnxDomain, + 21, + int8_t, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("TS", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + QLinearMatMul); Status QLinearMatMul::Compute(OpKernelContext* ctx) const { const auto* a = ctx->Input(IN_A); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index eaffe1e2ac224..34dcbd1d77fca 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -302,13 +302,21 @@ QnnLog_Level_t QnnBackendManager::MapOrtSeverityToQNNLogLevel(logging::Severity } Status QnnBackendManager::ResetQnnLogLevel() { - auto ort_log_level = logger_->GetSeverity(); - LOGS(*logger_, INFO) << "Reset Qnn log level to ORT Logger level: " << (unsigned int)ort_log_level; - return UpdateQnnLogLevel(ort_log_level); + std::lock_guard lock(logger_mutex_); + + if (backend_setup_completed_ && logger_ != nullptr) { + auto ort_log_level = logger_->GetSeverity(); + LOGS(*logger_, INFO) << "Reset Qnn log level to ORT Logger level: " << (unsigned int)ort_log_level; + return UpdateQnnLogLevel(ort_log_level); + } + return Status::OK(); } Status QnnBackendManager::UpdateQnnLogLevel(logging::Severity ort_log_level) { ORT_RETURN_IF(nullptr == log_handle_, "Unable to update QNN Log Level. Invalid QNN log handle."); + ORT_RETURN_IF(false == backend_setup_completed_, "Unable to update QNN Log Level. Backend setup not completed."); + ORT_RETURN_IF(nullptr == logger_, "Unable to update QNN Log Level. Invalid logger."); + QnnLog_Level_t qnn_log_level = MapOrtSeverityToQNNLogLevel(ort_log_level); LOGS(*logger_, INFO) << "Updating Qnn log level to: " << qnn_log_level; @@ -686,6 +694,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t } Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_from_cached_context) { + std::lock_guard lock(logger_mutex_); if (backend_setup_completed_) { LOGS(logger, VERBOSE) << "Backend setup already!"; return Status::OK(); @@ -972,6 +981,7 @@ void QnnBackendManager::ReleaseResources() { ORT_THROW("Failed to ShutdownBackend."); } + std::lock_guard lock(logger_mutex_); result = TerminateQnnLog(); if (Status::OK() != result) { ORT_THROW("Failed to TerminateQnnLog."); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index b80f1374fcdc7..43007d4a5c244 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -12,9 +12,11 @@ #endif #include +#include #include #include #include + #include "HTP/QnnHtpDevice.h" #include "QnnLog.h" #include "QnnTypes.h" @@ -233,6 +235,7 @@ class QnnBackendManager { private: const std::string backend_path_; + std::mutex logger_mutex_; const logging::Logger* logger_ = nullptr; QNN_INTERFACE_VER_TYPE qnn_interface_ = QNN_INTERFACE_VER_TYPE_INIT; QNN_SYSTEM_INTERFACE_VER_TYPE qnn_sys_interface_ = QNN_SYSTEM_INTERFACE_VER_TYPE_INIT; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 4cd5d403e95b8..ed193904fe7a8 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -258,49 +258,6 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } } -#ifdef _WIN32 - auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); - // Register callback for ETW capture state (rundown) - callback_ETWSink_provider_ = onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback( - [&etwRegistrationManager, this]( - LPCGUID SourceId, - ULONG IsEnabled, - UCHAR Level, - ULONGLONG MatchAnyKeyword, - ULONGLONG MatchAllKeyword, - PEVENT_FILTER_DESCRIPTOR FilterData, - PVOID CallbackContext) { - ORT_UNUSED_PARAMETER(SourceId); - ORT_UNUSED_PARAMETER(MatchAnyKeyword); - ORT_UNUSED_PARAMETER(MatchAllKeyword); - ORT_UNUSED_PARAMETER(FilterData); - ORT_UNUSED_PARAMETER(CallbackContext); - - if (IsEnabled == EVENT_CONTROL_CODE_ENABLE_PROVIDER) { - if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Logs)) != 0) { - auto ortETWSeverity = etwRegistrationManager.MapLevelToSeverity(); - (void)qnn_backend_manager_->UpdateQnnLogLevel(ortETWSeverity); - } - if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Profiling)) != 0) { - if (Level != 0) { - // Commenting out Dynamic QNN Profiling for now - // There seems to be a crash in 3rd party QC QnnHtp.dll with this. - // Repro Scenario - start ETW tracing prior to session creation. - // Then disable/enable ETW Tracing with the code below uncommented a few times - // auto profiling_level_etw = GetProfilingLevelFromETWLevel(Level); - // (void)qnn_backend_manager_->SetProfilingLevelETW(profiling_level_etw); - } - } - } - - if (IsEnabled == EVENT_CONTROL_CODE_DISABLE_PROVIDER) { - // (void)qnn_backend_manager_->SetProfilingLevelETW(qnn::ProfilingLevel::INVALID); - (void)qnn_backend_manager_->ResetQnnLogLevel(); - } - }); - etwRegistrationManager.RegisterInternalCallback(callback_ETWSink_provider_); -#endif - // In case ETW gets disabled later auto profiling_level_pos = provider_options_map.find(PROFILING_LEVEL); if (profiling_level_pos != provider_options_map.end()) { @@ -440,6 +397,49 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio htp_arch, soc_model, enable_htp_weight_sharing_); + +#ifdef _WIN32 + auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); + // Register callback for ETW capture state (rundown) + callback_ETWSink_provider_ = onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback( + [&etwRegistrationManager, this]( + LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + ORT_UNUSED_PARAMETER(SourceId); + ORT_UNUSED_PARAMETER(MatchAnyKeyword); + ORT_UNUSED_PARAMETER(MatchAllKeyword); + ORT_UNUSED_PARAMETER(FilterData); + ORT_UNUSED_PARAMETER(CallbackContext); + + if (IsEnabled == EVENT_CONTROL_CODE_ENABLE_PROVIDER) { + if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Logs)) != 0) { + auto ortETWSeverity = etwRegistrationManager.MapLevelToSeverity(); + (void)qnn_backend_manager_->UpdateQnnLogLevel(ortETWSeverity); + } + if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Profiling)) != 0) { + if (Level != 0) { + // Commenting out Dynamic QNN Profiling for now + // There seems to be a crash in 3rd party QC QnnHtp.dll with this. + // Repro Scenario - start ETW tracing prior to session creation. + // Then disable/enable ETW Tracing with the code below uncommented a few times + // auto profiling_level_etw = GetProfilingLevelFromETWLevel(Level); + // (void)qnn_backend_manager_->SetProfilingLevelETW(profiling_level_etw); + } + } + } + + if (IsEnabled == EVENT_CONTROL_CODE_DISABLE_PROVIDER) { + // (void)qnn_backend_manager_->SetProfilingLevelETW(qnn::ProfilingLevel::INVALID); + (void)qnn_backend_manager_->ResetQnnLogLevel(); + } + }); + etwRegistrationManager.RegisterInternalCallback(callback_ETWSink_provider_); +#endif } QNNExecutionProvider::~QNNExecutionProvider() { @@ -453,7 +453,9 @@ QNNExecutionProvider::~QNNExecutionProvider() { // Unregister the ETW callback #ifdef _WIN32 - logging::EtwRegistrationManager::Instance().UnregisterInternalCallback(callback_ETWSink_provider_); + if (callback_ETWSink_provider_ != nullptr) { + logging::EtwRegistrationManager::Instance().UnregisterInternalCallback(callback_ETWSink_provider_); + } #endif } diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 246ab1d5a6608..9422e54bd0035 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -151,7 +151,7 @@ class QNNExecutionProvider : public IExecutionProvider { bool enable_HTP_FP16_precision_ = true; bool share_ep_contexts_ = false; #ifdef _WIN32 - onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_; + onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_ = nullptr; #endif qnn::ModelSettings model_settings_ = {}; diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index 45aaca1ceae56..6b9b20faf8697 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -1026,7 +1026,13 @@ std::unique_ptr> GetBrokenTests(const std::string& provider {"dequantizelinear_int4", "Bug with model input name 'zero_point' not matching node's input name", {}}, {"dequantizelinear_uint4", "Bug with model input name 'zero_point' not matching node's input name", {}}, {"quantizelinear_int4", "Bug with model input name 'zero_point' not matching node's input name", {}}, - {"quantizelinear_uint4", "Bug with model input name 'zero_point' not matching node's input name", {}}}); + {"quantizelinear_uint4", "Bug with model input name 'zero_point' not matching node's input name", {}}, + {"qlinearmatmul_2D_int8_float16", "fp16 type ont supported by CPU EP", {}}, + {"qlinearmatmul_2D_int8_float32", "result diff", {}}, + {"qlinearmatmul_2D_uint8_float16", "fp16 type ont supported by CPU EP", {}}, + {"qlinearmatmul_3D_int8_float16", "fp16 type ont supported by CPU EP", {}}, + {"qlinearmatmul_3D_int8_float32", "result diff", {}}, + {"qlinearmatmul_3D_uint8_float16", "fp16 type ont supported by CPU EP", {}}}); // Some EPs may fail to pass some specific testcases. // For example TenosrRT EP may fail on FLOAT16 related testcases if GPU doesn't support float16. diff --git a/onnxruntime/test/providers/cpu/math/quantize_linear_matmul_test.cc b/onnxruntime/test/providers/cpu/math/quantize_linear_matmul_test.cc index 8cdb837712e83..096263792727a 100644 --- a/onnxruntime/test/providers/cpu/math/quantize_linear_matmul_test.cc +++ b/onnxruntime/test/providers/cpu/math/quantize_linear_matmul_test.cc @@ -126,8 +126,8 @@ TEST(QuantizeLinearMatmulOpTest, QLinearMatMul3D_S8S8) { } TEST(QuantizeLinearMatmulOpTest, QLinearMatMul2D_U8U8) { - auto run_test = [](bool only_t1_not_initializer) { - OpTester test("QLinearMatMul", 10); + auto run_test = [](bool only_t1_not_initializer, int opset_version) { + OpTester test("QLinearMatMul", opset_version); test.AddInput("T1", {2, 4}, {208, 236, 0, 238, 3, 214, 255, 29}); @@ -155,10 +155,12 @@ TEST(QuantizeLinearMatmulOpTest, QLinearMatMul2D_U8U8) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); }; - run_test(false); + run_test(false, 10); + run_test(false, 21); // NNAPI will require all inputs except T1 to be initializers - run_test(true); + run_test(true, 10); + run_test(true, 21); } TEST(QuantizeLinearMatmulOpTest, QLinearMatMul2D_U8S8) { @@ -197,8 +199,8 @@ TEST(QuantizeLinearMatmulOpTest, QLinearMatMul2D_U8S8) { } TEST(QuantizeLinearMatmulOpTest, QLinearMatMul2D_S8S8) { - auto run_test = [](bool only_t1_not_initializer) { - OpTester test("QLinearMatMul", 10); + auto run_test = [](bool only_t1_not_initializer, int opset_version) { + OpTester test("QLinearMatMul", opset_version); test.AddInput("T1", {2, 4}, {80, -2, -128, 110, -125, 86, 127, -99}); @@ -225,10 +227,12 @@ TEST(QuantizeLinearMatmulOpTest, QLinearMatMul2D_S8S8) { test.Run(); }; - run_test(false); + run_test(false, 10); + run_test(false, 21); // NNAPI will require all inputs except T1 to be initializers - run_test(true); + run_test(true, 10); + run_test(true, 21); } static void QLinearMatMul2DTest(bool only_t1_not_initializer) {