diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 57ae8c354abb7..049dcac970396 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -87,7 +87,8 @@ Status CreateNodeArgs(const std::vector& names, Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, - QnnModelLookupTable& qnn_models) { + QnnModelLookupTable& qnn_models, + uint32_t total_context_size) { ORT_RETURN_IF_NOT(EPCONTEXT_OP == main_context_node.OpType(), "Should only filter in the EPContext node."); NodeAttrHelper node_helper(main_context_node); bool is_embed_mode = node_helper.Get(EMBED_MODE, true); @@ -96,7 +97,8 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast(context_binary.c_str()), static_cast(context_binary.length()), main_context_node.Name(), - qnn_models); + qnn_models, + total_context_size); } std::filesystem::path folder_path = std::filesystem::path(ctx_onnx_model_path).parent_path(); @@ -145,17 +147,19 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(), static_cast(buffer_size), main_context_node.Name(), - qnn_models); + qnn_models, + total_context_size); } Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, QnnModelLookupTable& qnn_models, - const logging::Logger& logger) { + const logging::Logger& logger, + uint32_t total_context_size) { ORT_RETURN_IF(graph_viewer.NumberOfNodes() != 1, "One filtered graph should has only one EPContext node!"); Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager, - qnn_models); + qnn_models, total_context_size); // This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model if (!status.IsOK()) { diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index f308a7456d46c..a7b25592c0802 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -49,13 +49,15 @@ bool ValidateContextCacheFilePath(bool is_qnn_ctx_model, Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, - QnnModelLookupTable& qnn_models); + QnnModelLookupTable& qnn_models, + uint32_t total_context_size); Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, QnnModelLookupTable& qnn_models, - const logging::Logger& logger); + const logging::Logger& logger, + uint32_t total_context_size); Status CreateEPContextNodes(Model* model, unsigned char* buffer, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index bfc2102bdaac2..bc505aeb05ff5 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -8,6 +8,7 @@ #include #include "QnnOpDef.h" #include "HTP/QnnHtpPerfInfrastructure.h" +#include "HTP/QnnHtpSystemContext.h" #include "CPU/QnnCpuCommon.h" // TODO: not exist for Windows yet // #include "GPU/QnnGpuCommon.h" @@ -531,11 +532,11 @@ Status QnnBackendManager::CreateContext() { } QnnContext_Config_t context_config_weight_sharing = QNN_CONTEXT_CONFIG_INIT; - QnnHtpContext_CustomConfig_t customConfig; - customConfig.option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED; - customConfig.weightSharingEnabled = enable_htp_weight_sharing_; + QnnHtpContext_CustomConfig_t custom_config; + custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED; + custom_config.weightSharingEnabled = enable_htp_weight_sharing_; context_config_weight_sharing.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM; - context_config_weight_sharing.customConfig = &customConfig; + context_config_weight_sharing.customConfig = &custom_config; QnnContext_Config_t context_priority_config = QNN_CONTEXT_CONFIG_INIT; ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, context_priority_config)); @@ -616,7 +617,8 @@ std::unique_ptr QnnBackendManager::GetContextBinaryBuffer(uint6 Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, std::string node_name, - QnnModelLookupTable& qnn_models) { + QnnModelLookupTable& qnn_models, + uint32_t total_context_size) { bool result = nullptr == qnn_sys_interface_.systemContextCreate || nullptr == qnn_sys_interface_.systemContextGetBinaryInfo || nullptr == qnn_sys_interface_.systemContextFree; @@ -657,13 +659,48 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t ORT_RETURN_IF(graph_count < 1 || graphs_info == nullptr, "Failed to get graph info from Qnn cached context."); LOGS(*logger_, VERBOSE) << "Graph count from QNN context: " << graph_count; - ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary, - "Invalid function pointer for contextCreateFromBinary."); + // HTP spill fill buffer only works for multiple QNN contexts generated after QNN v2.28 + if (total_context_size > 1 && max_spill_fill_buffer_ == 0) { + for (uint32_t i = 0; i < graph_count; ++i) { + if (graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_3) { + auto htp_graph_info = reinterpret_cast(graphs_info[i].graphInfoV3.graphBlobInfo); + if (htp_graph_info->version == QNN_SYSTEM_CONTEXT_HTP_GRAPH_INFO_BLOB_VERSION_V1) { + auto spill_fill_buffer_size = htp_graph_info->contextBinaryGraphBlobInfoV1.spillFillBufferSize; + max_spill_fill_buffer_ = spill_fill_buffer_size > max_spill_fill_buffer_ ? spill_fill_buffer_size : max_spill_fill_buffer_; + } else { + LOGS(*logger_, VERBOSE) << "Unknown context binary graph info blob version."; + } + } else if (graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_2 || + graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_1) { + LOGS(*logger_, VERBOSE) << "Skip retrieve spill file buffer size, it is not supported with graph info v1 & v2."; + } else { + LOGS(*logger_, VERBOSE) << "Unknown context binary graph info version."; + } + } + } QnnContext_Config_t qnn_context_config = QNN_CONTEXT_CONFIG_INIT; ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, qnn_context_config)); - const QnnContext_Config_t* context_configs[] = {&qnn_context_config, nullptr}; + // Register spill fill buffer for multi context + QnnContext_Config_t spill_fill_config = QNN_CONTEXT_CONFIG_INIT; + QnnHtpContext_CustomConfig_t custom_config; + custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_REGISTER_MULTI_CONTEXTS; + QnnHtpContext_GroupRegistration_t group_info; + size_t current_contexts_size = GetQnnContextSize(); + // set to 0x0 (new group) if this is the first context, otherwise point to the first context handle + group_info.firstGroupHandle = current_contexts_size > 0 ? GetQnnContext(0) : 0x0; + group_info.maxSpillFillBuffer = max_spill_fill_buffer_; // Max spill-fill buffer across contexts. Must be >0 + custom_config.groupRegistration = group_info; + spill_fill_config.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM; + spill_fill_config.customConfig = &custom_config; + QnnContext_Config_t* spill_fill_config_pointer = + (total_context_size > 1 && max_spill_fill_buffer_ > 0) ? &spill_fill_config : nullptr; + + const QnnContext_Config_t* context_configs[] = {&qnn_context_config, spill_fill_config_pointer, nullptr}; + + ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary, + "Invalid function pointer for contextCreateFromBinary."); Qnn_ContextHandle_t context = nullptr; rt = qnn_interface_.contextCreateFromBinary(backend_handle_, device_handle_, @@ -672,7 +709,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, &context, profile_backend_handle_); - ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary."); + ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary. Error code: ", rt); contexts_.push_back(context); if (1 == graph_count) { // in case the EPContext node is generated from script @@ -932,20 +969,6 @@ Status QnnBackendManager::SetRpcControlLatency(uint32_t htp_power_config_client_ return Status::OK(); } -void QnnBackendManager::Split(std::vector& split_string, - const std::string& tokenized_string, - const char separator) { - split_string.clear(); - std::istringstream tokenized_string_stream(tokenized_string); - while (!tokenized_string_stream.eof()) { - std::string value; - getline(tokenized_string_stream, value, separator); - if (!value.empty()) { - split_string.push_back(value); - } - } -} - Status QnnBackendManager::DestroyHTPPowerConfigID(uint32_t htp_power_config_id) { QnnDevice_Infrastructure_t qnn_device_infra = nullptr; auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 43007d4a5c244..f2e14e4c86970 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -93,7 +93,8 @@ class QnnBackendManager { Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, std::string node_name, - std::unordered_map>& qnn_models); + std::unordered_map>& qnn_models, + uint32_t total_context_size); Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context); @@ -112,6 +113,10 @@ class QnnBackendManager { return contexts_[index]; } + size_t GetQnnContextSize() { + return contexts_.size(); + } + const Qnn_BackendHandle_t& GetQnnBackendHandle() { return backend_handle_; } const Qnn_ProfileHandle_t& GetQnnProfileHandle() { return profile_backend_handle_; } @@ -145,8 +150,6 @@ class QnnBackendManager { void ReleaseResources(); - void Split(std::vector& split_string, const std::string& tokenized_string, const char separator); - Status ExtractBackendProfilingInfo(); Status ExtractProfilingSubEvents(QnnProfile_EventId_t profile_event_id, std::ofstream& outfile, bool backendSupportsExtendedEventData, bool tracelogging_provider_ep_enabled); @@ -268,6 +271,7 @@ class QnnBackendManager { QnnHtpDevice_Arch_t htp_arch_ = QNN_HTP_DEVICE_ARCH_NONE; uint32_t soc_model_ = QNN_SOC_MODEL_UNKNOWN; bool enable_htp_weight_sharing_ = false; + uint64_t max_spill_fill_buffer_ = 0; }; } // namespace qnn diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 6735528bebbf9..b17741594e164 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -363,18 +363,19 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_fp16_precision: " << enable_HTP_FP16_precision_; } + bool enable_htp_weight_sharing = false; static const std::string QNN_HTP_WEIGHT_SHARING_ENABLED = "enable_htp_weight_sharing"; auto htp_weight_sharing_enabled_pos = provider_options_map.find(QNN_HTP_WEIGHT_SHARING_ENABLED); if (htp_weight_sharing_enabled_pos != provider_options_map.end()) { if ("1" == htp_weight_sharing_enabled_pos->second) { - enable_htp_weight_sharing_ = true; + enable_htp_weight_sharing = true; } else if ("0" == htp_weight_sharing_enabled_pos->second) { - enable_htp_weight_sharing_ = false; + enable_htp_weight_sharing = false; } else { - LOGS_DEFAULT(VERBOSE) << "Invalid enable_htp_weight_sharing: " << enable_htp_weight_sharing_ + LOGS_DEFAULT(VERBOSE) << "Invalid enable_htp_weight_sharing: " << enable_htp_weight_sharing << " only 0 or 1 allowed. Set to 0."; } - LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_weight_sharing: " << enable_htp_weight_sharing_; + LOGS_DEFAULT(VERBOSE) << "User specified enable_htp_weight_sharing: " << enable_htp_weight_sharing; } model_settings_.offload_graph_io_quantization = ParseBoolOption("offload_graph_io_quantization", false, @@ -396,7 +397,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio device_id_, htp_arch, soc_model, - enable_htp_weight_sharing_); + enable_htp_weight_sharing); #ifdef _WIN32 auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); @@ -934,6 +935,7 @@ Status QNNExecutionProvider::Compile(const std::vector& fused std::vector main_context_pos_list; ORT_RETURN_IF_ERROR(qnn::GetMainContextNode(fused_nodes_and_graphs, main_context_pos_list)); + uint32_t total_context_size = SafeInt(main_context_pos_list.size()); for (auto main_context_pos : main_context_pos_list) { const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[main_context_pos].filtered_graph); @@ -942,7 +944,8 @@ Status QNNExecutionProvider::Compile(const std::vector& fused context_cache_path, qnn_backend_manager_.get(), qnn_models, - logger)); + logger, + total_context_size)); } for (auto fused_node_and_graph : fused_nodes_and_graphs) { diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 35c061de6132c..40d03135c4bbe 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -141,7 +141,6 @@ class QNNExecutionProvider : public IExecutionProvider { std::string context_node_name_prefix_ = ""; bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session. bool qnn_context_embed_mode_ = true; - bool enable_htp_weight_sharing_ = false; int32_t vtcm_size_in_mb_ = 0; std::unique_ptr qnn_ep_context_model_; ModelMetadefIdGenerator metadef_id_generator_;