Skip to content

Commit f80e6f4

Browse files
Enable VTCM Back Up Buffer Sharing (microsoft#24962)
### Description Enabling the VTCM backup buffering feature on QNN EP, assuming all graphs are running sequentially where the input of the next graph is the output of the current graph. Under these assumptions, rather than allocate buffers for all inputs and outputs, only a single buffer can be shared between all graphs. ### Motivation and Context This will allow larger LLM models to be run
1 parent 6cffd1a commit f80e6f4

File tree

5 files changed

+424
-57
lines changed

5 files changed

+424
-57
lines changed

onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc

Lines changed: 211 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "core/providers/qnn/ort_api.h"
2525
#include "core/providers/qnn/qnn_allocator.h"
2626
#include "core/providers/qnn/qnn_telemetry.h"
27+
#include "core/providers/qnn/shared_context.h"
2728
#include "core/providers/qnn/builder/onnx_ctx_model_helper.h"
2829
#include "core/providers/qnn/builder/qnn_configs_helper.h"
2930
#include "core/providers/qnn/builder/qnn_utils.h"
@@ -709,6 +710,135 @@ Status SetQnnContextConfig(ContextPriority context_priority, QnnContext_Config_t
709710
return Status::OK();
710711
}
711712

713+
// callback required to add context handles to class list
714+
// when using contextCreateFromBinaryListAsync()
715+
void ContextCreateAsyncCallback(Qnn_ContextHandle_t context,
716+
Qnn_GraphHandle_t graph,
717+
const char* graphName,
718+
QnnContext_createFromBinaryAsyncNotifyType_t notifyType,
719+
void* notifyParam,
720+
Qnn_ErrorHandle_t status) {
721+
auto qnn_backend_manager = SharedContext::GetInstance().GetSharedQnnBackendManager();
722+
723+
if (context) {
724+
qnn_backend_manager->ProcessContextFromBinListAsync(context, notifyParam);
725+
}
726+
727+
if (nullptr == graphName || graph || notifyType || status) {
728+
// Avoid compilation unused var warning error
729+
}
730+
}
731+
732+
void QnnBackendManager::ProcessContextFromBinListAsync(Qnn_ContextHandle_t context, void* notifyParam) {
733+
std::lock_guard<std::mutex> guard(ep_context_handle_map_mutex_);
734+
if (!notifyParam) {
735+
LOGS(*logger_, WARNING) << "No known node names associated with context handle: " << context;
736+
return;
737+
}
738+
739+
std::vector<std::string>* ep_node_names = reinterpret_cast<std::vector<std::string>*>(notifyParam);
740+
for (const auto& node_name : *ep_node_names) {
741+
if (!(ep_context_handle_map_.emplace(node_name, context).second)) {
742+
LOGS(*logger_, VERBOSE) << "Unable to map " << context << " to " << node_name;
743+
}
744+
}
745+
746+
auto s = AddQnnContextHandle(context);
747+
if (s != Status::OK()) {
748+
LOGS(*logger_, WARNING) << "Unable to add context " << context;
749+
}
750+
}
751+
752+
Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unordered_map<std::string, std::unique_ptr<std::vector<std::string>>>& context_bin_map) {
753+
#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 26)
754+
QnnContext_Config_t context_config_resource_sharing = QNN_CONTEXT_CONFIG_INIT;
755+
QnnHtpContext_CustomConfig_t resource_sharing_custom_config;
756+
resource_sharing_custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_SHARE_RESOURCES;
757+
resource_sharing_custom_config.shareResources = true;
758+
context_config_resource_sharing.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM;
759+
context_config_resource_sharing.customConfig = &resource_sharing_custom_config;
760+
761+
QnnHtpContext_CustomConfig_t context_config_resource_sharing_opt_type;
762+
context_config_resource_sharing_opt_type.option = QNN_HTP_CONTEXT_CONFIG_OPTION_SHARE_RESOURCES_OPTIMIZATION_TYPE;
763+
context_config_resource_sharing_opt_type.shareResOptType = SEQUENTIAL_WITHOUT_VA_OPTIMIZATION;
764+
QnnContext_Config_t resource_sharing_opt_type_config;
765+
resource_sharing_opt_type_config.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM;
766+
resource_sharing_opt_type_config.customConfig = &context_config_resource_sharing_opt_type;
767+
768+
QnnContext_Config_t context_config_weight_sharing = QNN_CONTEXT_CONFIG_INIT;
769+
QnnHtpContext_CustomConfig_t custom_config;
770+
custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_WEIGHT_SHARING_ENABLED;
771+
custom_config.weightSharingEnabled = true;
772+
context_config_weight_sharing.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM;
773+
context_config_weight_sharing.customConfig = &custom_config;
774+
#else
775+
LOGS(*logger_, WARNING) << "Called CreateContextVtcmBackupBufferSharingEnabled() but QNN API version is older than 2.26!";
776+
#endif
777+
QnnContext_Config_t context_priority_config = QNN_CONTEXT_CONFIG_INIT;
778+
ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, context_priority_config));
779+
780+
const QnnContext_Config_t* configs[] = {&context_priority_config,
781+
#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 26)
782+
&context_config_resource_sharing,
783+
&resource_sharing_opt_type_config,
784+
&context_config_weight_sharing,
785+
#endif
786+
nullptr};
787+
788+
std::vector<QnnContext_Params_t> context_params_list;
789+
std::vector<QnnContext_ParamsV1_t> context_paramsv1_list;
790+
std::vector<const QnnContext_Params_t*> context_params_ptr_list(context_bin_map.size() + 1);
791+
std::vector<std::unique_ptr<char[]>> buffer_list;
792+
793+
size_t idx = 0;
794+
for (auto& it : context_bin_map) {
795+
auto context_bin_filepath = it.first;
796+
797+
std::ifstream cache_file(context_bin_filepath.c_str(), std::ifstream::binary);
798+
ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to retrieve context binary from: ", context_bin_filepath);
799+
800+
cache_file.seekg(0, cache_file.end);
801+
size_t buffer_size = static_cast<size_t>(cache_file.tellg());
802+
ORT_RETURN_IF(0 == buffer_size, "Empty cache file encountered.");
803+
804+
cache_file.seekg(0, cache_file.beg);
805+
std::unique_ptr<char[]> buffer = std::make_unique<char[]>(buffer_size);
806+
ORT_RETURN_IF(nullptr == buffer, "Failed to allocate memory for cache file.");
807+
const auto& read_result = cache_file.read(buffer.get(), buffer_size);
808+
ORT_RETURN_IF(!read_result, "Failed to read contents from cached context file.");
809+
810+
cache_file.close();
811+
QnnContext_ParamsV1_t context_params_v1 = {nullptr,
812+
buffer.get(),
813+
buffer_size,
814+
nullptr,
815+
ContextCreateAsyncCallback,
816+
it.second.get()};
817+
818+
QnnContext_Params_t context_params = {QnnContext_ParamsVersion_t::QNN_CONTEXT_PARAMS_VERSION_1,
819+
context_params_v1};
820+
821+
buffer_list.push_back(std::move(buffer));
822+
context_params_list.push_back(std::move(context_params));
823+
context_paramsv1_list.push_back(std::move(context_params_v1));
824+
context_params_ptr_list[idx++] = &context_params_list.back();
825+
}
826+
context_params_ptr_list[idx] = nullptr;
827+
auto result = qnn_interface_.contextCreateFromBinaryListAsync(backend_handle_,
828+
device_handle_,
829+
context_params_ptr_list.data(),
830+
configs,
831+
nullptr);
832+
833+
context_params_ptr_list.clear();
834+
context_paramsv1_list.clear();
835+
context_params_list.clear();
836+
buffer_list.clear();
837+
838+
ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to create context. Error: ", QnnErrorHandleToString(result), ", Code:", result);
839+
return Status::OK();
840+
}
841+
712842
Status QnnBackendManager::CreateContext(bool enable_htp_weight_sharing) {
713843
if (true == context_created_) {
714844
LOGS_DEFAULT(INFO) << "Context created already.";
@@ -728,6 +858,7 @@ Status QnnBackendManager::CreateContext(bool enable_htp_weight_sharing) {
728858
const QnnContext_Config_t* npu_context_configs[] = {&context_priority_config,
729859
&context_config_weight_sharing,
730860
nullptr};
861+
731862
const QnnContext_Config_t* empty_context_configs[] = {nullptr};
732863

733864
const QnnContext_Config_t** configs = nullptr;
@@ -751,12 +882,14 @@ Status QnnBackendManager::CreateContext(bool enable_htp_weight_sharing) {
751882
}
752883

753884
Qnn_ContextHandle_t context = nullptr;
754-
Qnn_ErrorHandle_t result = qnn_interface_.contextCreate(backend_handle_,
755-
device_handle_,
756-
configs,
757-
&context);
885+
Qnn_ErrorHandle_t result = 0;
758886

759-
ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to create context. Error: ", QnnErrorHandleToString(result));
887+
result = qnn_interface_.contextCreate(backend_handle_,
888+
device_handle_,
889+
configs,
890+
&context);
891+
892+
ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to create context. Error: ", QnnErrorHandleToString(result), ", Code:", result);
760893

761894
ORT_RETURN_IF_ERROR(AddQnnContextHandle(context));
762895

@@ -936,43 +1069,60 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
9361069
ORT_RETURN_IF(graph_count < 1 || graphs_info == nullptr, "Failed to get graph info from Qnn cached context.");
9371070
LOGS(*logger_, VERBOSE) << "Graph count from QNN context: " << graph_count;
9381071

939-
QnnContext_Config_t qnn_context_config = QNN_CONTEXT_CONFIG_INIT;
940-
ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, qnn_context_config));
1072+
Qnn_ContextHandle_t context = nullptr;
1073+
#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 26)
1074+
if (vtcm_backup_buffer_sharing_enabled_) {
1075+
if (ep_context_handle_map_.find(node_name) != ep_context_handle_map_.end()) {
1076+
context = ep_context_handle_map_.at(node_name);
1077+
}
1078+
ORT_RETURN_IF(nullptr == context, "Failed to retrieve context for ", node_name);
1079+
1080+
} else {
1081+
#endif
1082+
QnnContext_Config_t qnn_context_config = QNN_CONTEXT_CONFIG_INIT;
1083+
ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, qnn_context_config));
9411084

942-
// Register spill fill buffer for multi context
943-
QnnContext_Config_t spill_fill_config = QNN_CONTEXT_CONFIG_INIT;
1085+
// Register spill fill buffer for multi context
1086+
QnnContext_Config_t spill_fill_config = QNN_CONTEXT_CONFIG_INIT;
9441087

945-
// The spill fill buffer is available since 2.28, API version starts from 2.21
1088+
// The spill fill buffer is available since 2.28, API version starts from 2.21
9461089
#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 21)
947-
QnnHtpContext_CustomConfig_t custom_config;
948-
custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_REGISTER_MULTI_CONTEXTS;
949-
QnnHtpContext_GroupRegistration_t group_info;
950-
size_t current_contexts_size = GetQnnContextSize();
951-
// set to 0x0 (new group) if this is the first context, otherwise point to the first context handle
952-
// note that we already move the context with max spill fill size to the beginning of the list
953-
group_info.firstGroupHandle = (max_spill_fill_size > 0 && current_contexts_size > 0) ? GetQnnContext(0) : 0x0;
954-
group_info.maxSpillFillBuffer = max_spill_fill_size; // Max spill-fill buffer across contexts. Must be >0
955-
custom_config.groupRegistration = group_info;
956-
spill_fill_config.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM;
957-
spill_fill_config.customConfig = &custom_config;
1090+
QnnHtpContext_CustomConfig_t custom_config;
1091+
custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_REGISTER_MULTI_CONTEXTS;
1092+
QnnHtpContext_GroupRegistration_t group_info;
1093+
size_t current_contexts_size = GetQnnContextSize();
1094+
// set to 0x0 (new group) if this is the first context, otherwise point to the first context handle
1095+
// note that we already move the context with max spill fill size to the beginning of the list
1096+
group_info.firstGroupHandle = (max_spill_fill_size > 0 && current_contexts_size > 0) ? GetQnnContext(0) : 0x0;
1097+
group_info.maxSpillFillBuffer = max_spill_fill_size; // Max spill-fill buffer across contexts. Must be >0
1098+
custom_config.groupRegistration = group_info;
1099+
spill_fill_config.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM;
1100+
spill_fill_config.customConfig = &custom_config;
1101+
9581102
#endif
959-
QnnContext_Config_t* spill_fill_config_pointer = max_spill_fill_size > 0 ? &spill_fill_config : nullptr;
960-
LOGS(*logger_, VERBOSE) << "Max spill fill buffer size:" << max_spill_fill_size;
9611103

962-
const QnnContext_Config_t* context_configs[] = {&qnn_context_config, spill_fill_config_pointer, nullptr};
1104+
QnnContext_Config_t* spill_fill_config_pointer = max_spill_fill_size > 0 ? &spill_fill_config : nullptr;
1105+
LOGS(*logger_, VERBOSE) << "Max spill fill buffer size:" << max_spill_fill_size;
1106+
1107+
const QnnContext_Config_t* context_configs[] = {&qnn_context_config, spill_fill_config_pointer, nullptr};
1108+
1109+
ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary,
1110+
"Invalid function pointer for contextCreateFromBinary.");
1111+
1112+
rt = qnn_interface_.contextCreateFromBinary(backend_handle_,
1113+
device_handle_,
1114+
context_configs,
1115+
static_cast<void*>(buffer),
1116+
buffer_length,
1117+
&context,
1118+
profile_backend_handle_);
1119+
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary. Error code: ", rt);
1120+
ORT_RETURN_IF_ERROR(AddQnnContextHandle(context));
1121+
1122+
#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 26)
1123+
}
1124+
#endif
9631125

964-
ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary,
965-
"Invalid function pointer for contextCreateFromBinary.");
966-
Qnn_ContextHandle_t context = nullptr;
967-
rt = qnn_interface_.contextCreateFromBinary(backend_handle_,
968-
device_handle_,
969-
context_configs,
970-
static_cast<void*>(buffer),
971-
buffer_length,
972-
&context,
973-
profile_backend_handle_);
974-
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary. Error code: ", rt);
975-
ORT_RETURN_IF_ERROR(AddQnnContextHandle(context));
9761126
if (1 == graph_count) {
9771127
// in case the EPContext node is generated from script
9781128
// the graph name from the context binary may not match the EPContext node name
@@ -1002,13 +1152,33 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t
10021152
Status QnnBackendManager::SetupBackend(const logging::Logger& logger,
10031153
bool load_from_cached_context,
10041154
bool need_load_system_lib,
1005-
bool share_ep_contexts) {
1155+
bool share_ep_contexts,
1156+
bool enable_vtcm_backup_buffer_sharing,
1157+
std::unordered_map<std::string, std::unique_ptr<std::vector<std::string>>>& context_bin_map) {
10061158
std::lock_guard<std::recursive_mutex> lock(logger_recursive_mutex_);
10071159
if (backend_setup_completed_) {
10081160
LOGS(logger, VERBOSE) << "Backend setup already!";
1161+
1162+
#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 26)
1163+
if (vtcm_backup_buffer_sharing_enabled_) {
1164+
LOGS(logger, VERBOSE) << "Mapping contexts to new EP main context nodes";
1165+
1166+
for (auto& it : context_bin_map) {
1167+
auto context_bin_filepath = it.first;
1168+
auto ep_node_names = *(it.second);
1169+
1170+
auto context = ep_context_handle_map_.at(context_bin_filepath);
1171+
for (auto node_name : ep_node_names) {
1172+
ep_context_handle_map_.emplace(node_name, context);
1173+
}
1174+
}
1175+
}
1176+
#endif
10091177
return Status::OK();
10101178
}
10111179

1180+
vtcm_backup_buffer_sharing_enabled_ = enable_vtcm_backup_buffer_sharing;
1181+
10121182
Status status = Status::OK();
10131183
if (!qnn_serializer_config_) {
10141184
status = LoadBackend();
@@ -1071,10 +1241,10 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger,
10711241
#endif
10721242
}
10731243

1074-
if (!load_from_cached_context) {
1075-
if (status.IsOK()) {
1076-
status = CreateContext(enable_htp_weight_sharing);
1077-
}
1244+
if (status.IsOK() && (vtcm_backup_buffer_sharing_enabled_ || !load_from_cached_context)) {
1245+
status = vtcm_backup_buffer_sharing_enabled_ ? CreateContextVtcmBackupBufferSharingEnabled(context_bin_map)
1246+
: CreateContext(enable_htp_weight_sharing);
1247+
10781248
if (status.IsOK()) {
10791249
LOGS(logger, VERBOSE) << "CreateContext succeed.";
10801250
}

onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,9 @@ class QnnBackendManager : public std::enable_shared_from_this<QnnBackendManager>
149149
// Initializes handles to QNN resources (device, logger, etc.).
150150
// NOTE: This function locks the internal `logger_recursive_mutex_`.
151151
Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context,
152-
bool need_load_system_lib, bool share_ep_contexts);
152+
bool need_load_system_lib, bool share_ep_contexts,
153+
bool enable_vtcm_backup_buffer_sharing,
154+
std::unordered_map<std::string, std::unique_ptr<std::vector<std::string>>>& context_bin_map);
153155

154156
Status CreateHtpPowerCfgId(uint32_t deviceId, uint32_t coreId, uint32_t& htp_power_config_id);
155157

@@ -209,6 +211,13 @@ class QnnBackendManager : public std::enable_shared_from_this<QnnBackendManager>
209211

210212
QnnSerializerConfig* GetQnnSerializerConfig();
211213

214+
// Handler to be called upon successful context creation via contextCreateFromBinaryListAsync()
215+
// This handler is expected to be called in the callback ContextCreateAsyncCallback() in the .cc file
216+
// Takes in the context and the notifyParam objects received by the callback function
217+
// notifyParam is expected to be a pointer to a vector of node names associated with that context handle
218+
// For each node name, a mapping to the context handle will be created
219+
void ProcessContextFromBinListAsync(Qnn_ContextHandle_t handle, void* notifyParam);
220+
212221
private:
213222
Status LoadBackend();
214223

@@ -226,6 +235,9 @@ class QnnBackendManager : public std::enable_shared_from_this<QnnBackendManager>
226235

227236
Status CreateContext(bool enable_htp_weight_sharing);
228237

238+
Status CreateContextVtcmBackupBufferSharingEnabled(std::unordered_map<std::string,
239+
std::unique_ptr<std::vector<std::string>>>& context_bin_map);
240+
229241
Status ReleaseContext();
230242

231243
// Sets the ORT logger and creates a corresponding QNN logger with the same log level.
@@ -310,7 +322,7 @@ class QnnBackendManager : public std::enable_shared_from_this<QnnBackendManager>
310322
#endif
311323

312324
// Adds a new QNN context.
313-
// Transfers ownership of `context_handle` (i.e., responsibility of freeing it) to this instance.
325+
// Transfers ownership of `context_handle` (i.e., responsibility of freeing it) to this instance
314326
Status AddQnnContextHandle(Qnn_ContextHandle_t context_handle);
315327

316328
private:
@@ -407,6 +419,10 @@ class QnnBackendManager : public std::enable_shared_from_this<QnnBackendManager>
407419
// HtpSharedMemoryAllocator allocation cleanup callback.
408420
std::unordered_map<Qnn_ContextHandle_t, std::shared_ptr<QnnContextHandleRecord>> context_map_;
409421

422+
// Map of EP Main Context Node names to Qnn_ContextHandle_t
423+
std::mutex ep_context_handle_map_mutex_;
424+
std::unordered_map<std::string, Qnn_ContextHandle_t> ep_context_handle_map_;
425+
410426
// Vector of Qnn_ContextHandle_t. The context handles are owned by context_map_.
411427
std::vector<Qnn_ContextHandle_t> contexts_;
412428

@@ -418,6 +434,7 @@ class QnnBackendManager : public std::enable_shared_from_this<QnnBackendManager>
418434
bool device_created_ = false;
419435
bool context_created_ = false;
420436
bool backend_setup_completed_ = false;
437+
bool vtcm_backup_buffer_sharing_enabled_ = false;
421438
// NPU backend requires quantized model
422439
QnnBackendType qnn_backend_type_ = QnnBackendType::CPU;
423440
Qnn_ProfileHandle_t profile_backend_handle_ = nullptr;

0 commit comments

Comments
 (0)