Skip to content

Commit 269f6fe

Browse files
authored
Add support for session option ep.stop_context_sharing (#655)
* Add function to query external initializer file name * Decouple external weight processing from shared context and add support for stop context sharing
1 parent 7401335 commit 269f6fe

File tree

10 files changed

+152
-99
lines changed

10 files changed

+152
-99
lines changed

onnxruntime/core/providers/openvino/backend_manager.cc

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -83,22 +83,23 @@ BackendManager::BackendManager(SessionContext& session_context,
8383
}
8484
std::string device_type = session_context_.device_type;
8585

86-
auto& sw = shared_context_.shared_weights;
87-
if (session_context_.so_share_ep_contexts) {
88-
std::filesystem::path weight_filename = session_context_.onnx_model_path_name.parent_path();
89-
if (sw.external_weight_filename.empty() && !sw.metadata.empty()) {
90-
// Reasonable assumption that all metadata entries have the same external file location
91-
sw.external_weight_filename = sw.metadata.begin()->second.location;
92-
}
93-
weight_filename /= sw.external_weight_filename;
94-
std::ifstream weight_file(weight_filename);
86+
// Check if model is using external weights
87+
if (auto filename = backend_utils::GetExternalWeightFilename(subgraph)) {
88+
std::filesystem::path weights_filepath = session_context_.onnx_model_path_name.parent_path() / filename.value();
9589

96-
if (weight_file) {
97-
if (!sw.mapped_weights) {
98-
sw.mapped_weights = std::make_unique<SharedContext::SharedWeights::WeightsFile>(weight_filename);
99-
}
100-
backend_utils::CreateOVTensors(session_context_.device_type, sw.metadata, *sw.mapped_weights);
90+
// Initialize external weights with fully qualified path
91+
if (!std::filesystem::exists(weights_filepath)) {
92+
ORT_THROW("Error: Failed to locate weight file at ", weights_filepath.string());
10193
}
94+
95+
external_weights_.emplace(weights_filepath);
96+
}
97+
98+
if (session_context_.so_share_ep_contexts) {
99+
ORT_ENFORCE(external_weights_.has_value(), "Expected external weight object to be valid");
100+
backend_utils::CreateOVTensors(session_context_.device_type,
101+
shared_context_.shared_weights.metadata,
102+
external_weights_.value());
102103
}
103104

104105
if (ModelHasSymbolicInputDims(subgraph)) {
@@ -324,7 +325,7 @@ static bool IsQDQGraph(const onnxruntime::GraphViewer& graph_viewer) {
324325
static void DumpOpenVINOEPModel([[maybe_unused]] const std::filesystem::path& onnx_model_path_name,
325326
[[maybe_unused]] ONNX_NAMESPACE::ModelProto* model_proto,
326327
[[maybe_unused]] const onnxruntime::Node& fused_node) {
327-
#ifndef RELEASE
328+
#ifdef NOT_RELEASE
328329
if (openvino_ep::backend_utils::IsDebugEnabled()) {
329330
auto model_name = onnx_model_path_name.empty() ? "unknown.onnx" : onnx_model_path_name.filename();
330331

@@ -384,7 +385,12 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node,
384385
if (session_context_.device_type.find("NPU") != std::string::npos &&
385386
(enable_ovep_qdq_optimizer || session_context_.so_share_ep_contexts)) {
386387
std::unique_ptr<onnxruntime::Model> model;
387-
Status status = CreateModelWithStrippedQDQNodes(subgraph, logger, session_context_.so_share_ep_contexts, enable_ovep_qdq_optimizer, model, shared_context_.shared_weights);
388+
Status status = CreateModelWithStrippedQDQNodes(subgraph,
389+
logger,
390+
session_context_.so_share_ep_contexts,
391+
enable_ovep_qdq_optimizer,
392+
model,
393+
shared_context_.shared_weights.metadata);
388394
auto model_proto = model->ToProto();
389395
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
390396
print_model_proto_duration();

onnxruntime/core/providers/openvino/backend_manager.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class BackendManager {
5454
EPCtxHandler& ep_ctx_handle_;
5555
SessionContext& session_context_;
5656
SharedContext& shared_context_;
57+
std::optional<fs::path> external_weights_;
5758
};
5859

5960
} // namespace openvino_ep

onnxruntime/core/providers/openvino/backend_utils.cc

Lines changed: 66 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <sstream>
55
#include <fstream>
66
#include <utility>
7+
#include <string>
78

89
#include <filesystem>
910
#include <stdexcept>
@@ -20,22 +21,7 @@ using Exception = ov::Exception;
2021
namespace onnxruntime {
2122
namespace openvino_ep {
2223

23-
SharedContext::SharedWeights::WeightsFile::WeightsFile(std::filesystem::path filename) : file_(filename, std::ios::in | std::ios::binary) {
24-
try {
25-
file_.exceptions(std::ifstream::failbit | std::ifstream::badbit);
26-
weights_size_ = file_.seekg(0, std::ios::end).tellg();
27-
} catch (std::ifstream::failure& e) {
28-
ORT_THROW("Error: Failed to open weight file at ", filename.string(), " ", e.what());
29-
}
30-
}
31-
32-
void SharedContext::SharedWeights::WeightsFile::load_weights(size_t file_offset, void* data, size_t size) {
33-
ORT_ENFORCE(file_offset < weights_size_ && size <= weights_size_ && (file_offset <= weights_size_ - size), "Error: File offset is out of bounds.");
34-
file_.seekg(file_offset);
35-
file_.read(reinterpret_cast<char*>(data), size);
36-
}
37-
38-
std::ostream& operator<<(std::ostream& stream, const SharedContext::SharedWeights::Metadata::Map& metadata) {
24+
std::ostream& operator<<(std::ostream& stream, const Metadata::Map& metadata) {
3925
try {
4026
stream << metadata.size();
4127

@@ -69,14 +55,14 @@ std::ostream& operator<<(std::ostream& stream, const SharedContext::SharedWeight
6955
return stream;
7056
}
7157

72-
std::istream& operator>>(std::istream& stream, SharedContext::SharedWeights::Metadata::Map& metadata) {
58+
std::istream& operator>>(std::istream& stream, Metadata::Map& metadata) {
7359
size_t map_size{0};
7460
try {
7561
stream >> map_size;
7662

7763
while (!stream.eof()) {
78-
SharedContext::SharedWeights::Metadata::Key key;
79-
SharedContext::SharedWeights::Metadata::Value value;
64+
Metadata::Key key;
65+
Metadata::Value value;
8066
stream >> key.name;
8167
stream >> value.location;
8268
stream >> value.data_offset;
@@ -399,8 +385,19 @@ ov::element::Type GetOpenVINOElementType(ONNX_NAMESPACE::TensorProto_DataType dt
399385

400386
// Function to handle tensor creation from external data
401387
void CreateOVTensors(const std::string& device_name,
402-
SharedContext::SharedWeights::Metadata::Map& metadata_map,
403-
SharedContext::SharedWeights::WeightsFile& weights) {
388+
Metadata::Map& metadata_map,
389+
std::filesystem::path& weights_filepath) {
390+
// File is guaranteed to exist at this point
391+
std::ifstream file(weights_filepath, std::ios::in | std::ios::binary);
392+
file.exceptions(std::ifstream::failbit | std::ifstream::badbit);
393+
size_t weights_size = std::filesystem::file_size(weights_filepath);
394+
395+
const auto load_weights = [&file, weights_size](size_t file_offset, void* data, size_t size) {
396+
ORT_ENFORCE(file_offset < weights_size && size <= weights_size && (file_offset <= weights_size - size), "Error: File offset is out of bounds.");
397+
file.seekg(file_offset);
398+
file.read(reinterpret_cast<char*>(data), size);
399+
};
400+
404401
for (auto& [key, value] : metadata_map) {
405402
if (value.tensor) continue;
406403

@@ -416,18 +413,18 @@ void CreateOVTensors(const std::string& device_name,
416413
auto&& remote_tensor = npu_context.create_l0_host_tensor(ov_elementType, value.dimensions, ov::intel_npu::TensorType::INPUT);
417414

418415
// Copy data to remote tensor
419-
weights.load_weights(value.data_offset, remote_tensor.get(), value.size);
416+
load_weights(value.data_offset, remote_tensor.get(), value.size);
420417
value.tensor = std::make_shared<ov::Tensor>(remote_tensor);
421418
} else {
422419
// Use vanilla tensors
423420
value.tensor = std::make_shared<ov::Tensor>(ov_elementType, value.dimensions);
424-
weights.load_weights(value.data_offset, value.tensor->data(), value.size);
421+
load_weights(value.data_offset, value.tensor->data(), value.size);
425422
}
426423
ORT_ENFORCE(value.tensor->get_byte_size() == value.size, "Unexpected tensor size mismatch");
427424
}
428425
}
429426

430-
void DestroyOVTensors(SharedContext::SharedWeights::Metadata::Map& metadata_map) {
427+
void DestroyOVTensors(Metadata::Map& metadata_map) {
431428
for (auto& [key, value] : metadata_map) {
432429
if (value.tensor) {
433430
value.tensor.reset();
@@ -436,6 +433,51 @@ void DestroyOVTensors(SharedContext::SharedWeights::Metadata::Map& metadata_map)
436433
metadata_map.clear();
437434
}
438435

436+
std::optional<std::string> GetExternalWeightFilename(const GraphViewer& graph) {
437+
auto get_external_location = [](const ONNX_NAMESPACE::TensorProto& proto) -> std::optional<std::string> {
438+
using mutable_proto_t = ONNX_NAMESPACE::TensorProto*;
439+
auto& mutable_proto = *const_cast<mutable_proto_t>(&proto);
440+
auto* entry_protos = mutable_proto.mutable_external_data();
441+
442+
if (proto.has_data_location() && proto.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) {
443+
for (int i = 0; i < entry_protos->size(); i++) {
444+
auto& string_entry_proto{entry_protos->at(i)};
445+
const auto& pb_key{*(string_entry_proto.mutable_key())};
446+
const auto& pb_value{*(string_entry_proto.mutable_value())};
447+
if (pb_key == "location") {
448+
return std::make_optional<std::string>(pb_value);
449+
}
450+
}
451+
}
452+
453+
return std::nullopt;
454+
};
455+
456+
// Handle constant initializers
457+
auto& initializers = graph.GetAllInitializedTensors();
458+
for (const auto& it : initializers) {
459+
if (auto result = get_external_location(*it.second)) {
460+
return result;
461+
}
462+
}
463+
464+
// Handle outer-scope constant initializers
465+
for (auto& node_idx : graph.GetNodesInTopologicalOrder()) {
466+
const auto& node = graph.GetNode(node_idx);
467+
for (const auto& input : node->InputDefs()) {
468+
if (graph.IsConstantInitializer(input->Name(), true)) {
469+
const auto& initializer_tensor = *graph.GetConstantInitializer(input->Name(), true);
470+
471+
if (auto result = get_external_location(initializer_tensor)) {
472+
return result;
473+
}
474+
}
475+
}
476+
}
477+
478+
return std::nullopt;
479+
}
480+
439481
} // namespace backend_utils
440482
} // namespace openvino_ep
441483
} // namespace onnxruntime

onnxruntime/core/providers/openvino/backend_utils.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,18 @@ CreateOVModel(std::string&& model,
6767
std::map<std::string, std::shared_ptr<ov::Node>>& const_outputs_map);
6868

6969
void CreateOVTensors(const std::string& device_name,
70-
SharedContext::SharedWeights::Metadata::Map& metadata_map,
71-
SharedContext::SharedWeights::WeightsFile& weights);
72-
void DestroyOVTensors(SharedContext::SharedWeights::Metadata::Map& metadata_map);
70+
Metadata::Map& metadata_map,
71+
std::filesystem::path& weights_filepath);
72+
void DestroyOVTensors(Metadata::Map& metadata_map);
7373

7474
void printPerformanceCounts(const std::vector<OVProfilingInfo>& performanceMap,
7575
std::ostream& stream, std::string deviceName);
7676

7777
void printPerformanceCounts(OVInferRequestPtr request, std::ostream& stream, std::string deviceName);
7878

79+
// Returns the location string from the first external initializer nodes found or nullopt if none found
80+
std::optional<std::string> GetExternalWeightFilename(const GraphViewer& graph);
81+
7982
} // namespace backend_utils
8083
} // namespace openvino_ep
8184
} // namespace onnxruntime

onnxruntime/core/providers/openvino/backends/basic_backend.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,12 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
125125
std::function<void(OVInferRequestPtr)> initializer = [](OVInferRequestPtr) {};
126126
auto metadata = shared_context_.shared_weights.metadata;
127127
if (session_context_.so_share_ep_contexts) {
128+
// When shared ep contexts is set external weight references are transformed to model inputs. This
129+
// creates an initializer to populate/bind input weight tensors to each inference request
128130
initializer = [&metadata](OVInferRequestPtr ir_ptr) {
129131
const auto input_count = ir_ptr->GetNumInputs();
130132
for (auto i = 0u; i < input_count; i++) {
131-
using Key = SharedContext::SharedWeights::Metadata::Key;
133+
using Key = Metadata::Key;
132134
const auto tensor_key = Key{ir_ptr->GetInputTensorName(i)};
133135
if (metadata.contains(tensor_key)) {
134136
auto& value = metadata.at(tensor_key);
@@ -137,6 +139,8 @@ BasicBackend::BasicBackend(std::unique_ptr<ONNX_NAMESPACE::ModelProto>& model_pr
137139
}
138140
};
139141
}
142+
143+
// Create inference request queue and initialize according to passed function
140144
inferRequestsQueue_ = std::unique_ptr<InferRequestsQueue>(new InferRequestsQueue(exe_network_, num_infer_req, std::move(initializer)));
141145
}
142146

onnxruntime/core/providers/openvino/contexts.h

Lines changed: 28 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -18,52 +18,42 @@ namespace openvino_ep {
1818

1919
namespace fs = std::filesystem;
2020

21+
struct Metadata {
22+
struct Key {
23+
std::string name;
24+
bool operator==(const Key&) const = default;
25+
};
26+
struct Hash {
27+
std::size_t operator()(const Key& key) const noexcept {
28+
return std::hash<std::string>()(key.name);
29+
}
30+
};
31+
struct Value {
32+
std::string location;
33+
unsigned int data_offset;
34+
unsigned int size;
35+
std::vector<size_t> dimensions;
36+
std::int32_t element_type;
37+
std::shared_ptr<ov::Tensor> tensor;
38+
};
39+
using Map = std::unordered_map<Key, Value, Hash>;
40+
friend std::ostream& operator<<(std::ostream& right, const Metadata::Map& metadata);
41+
friend std::istream& operator>>(std::istream& right, Metadata::Map& metadata);
42+
};
43+
2144
class SharedContext : public WeakSingleton<SharedContext> {
2245
// Keep the core alive as long as the shared SharedContext are alive.
2346
std::shared_ptr<OVCore> OVCore_;
2447

2548
public:
2649
SharedContext() : OVCore_(OVCore::Get()) {}
2750
struct SharedWeights {
28-
struct Metadata {
29-
struct Key {
30-
std::string name;
31-
bool operator==(const Key&) const = default;
32-
};
33-
struct Hash {
34-
std::size_t operator()(const Key& key) const noexcept {
35-
return std::hash<std::string>()(key.name);
36-
}
37-
};
38-
struct Value {
39-
std::string location;
40-
unsigned int data_offset;
41-
unsigned int size;
42-
std::vector<size_t> dimensions;
43-
std::int32_t element_type;
44-
std::shared_ptr<ov::Tensor> tensor;
45-
};
46-
using Map = std::unordered_map<Key, Value, Hash>;
47-
friend std::ostream& operator<<(std::ostream& right, const Metadata::Map& metadata);
48-
friend std::istream& operator>>(std::istream& right, Metadata::Map& metadata);
49-
};
50-
51-
struct WeightsFile {
52-
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WeightsFile);
53-
WeightsFile() = delete;
54-
explicit WeightsFile(std::filesystem::path filename);
55-
56-
void load_weights(size_t file_offset, void* data, size_t size);
57-
58-
private:
59-
std::ifstream file_;
60-
size_t weights_size_;
61-
};
62-
63-
fs::path external_weight_filename;
64-
std::unique_ptr<WeightsFile> mapped_weights;
6551
Metadata::Map metadata;
6652
} shared_weights;
53+
54+
void clear() { // Deletes the data stored in the SharedContext
55+
shared_weights.metadata.clear();
56+
}
6757
};
6858

6959
using config_t = std::map<std::string, ov::AnyMap>;
@@ -102,6 +92,7 @@ struct ProviderInfo {
10292
bool so_context_embed_mode{false}; // ORT session option
10393
bool so_share_ep_contexts{false}; // ORT session option
10494
fs::path so_context_file_path{}; // ORT session option
95+
bool so_stop_share_ep_contexts{false}; // ORT session option
10596
const ConfigOptions* config_options{NULL};
10697
const std::unordered_set<std::string> valid_provider_keys = {"device_type", "device_id", "device_luid", "cache_dir", "precision",
10798
"load_config", "context", "num_of_threads", "model_priority", "num_streams", "enable_opencl_throttling", "enable_qdq_optimizer",

onnxruntime/core/providers/openvino/openvino_execution_provider.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ OpenVINOExecutionProvider::~OpenVINOExecutionProvider() {
6565
backend_manager.ShutdownBackendManager();
6666
}
6767
backend_managers_.clear();
68+
shared_context_.reset();
6869
}
6970

7071
std::vector<std::unique_ptr<ComputeCapability>>
@@ -106,7 +107,12 @@ common::Status OpenVINOExecutionProvider::Compile(
106107
auto& metadata = shared_context_->shared_weights.metadata;
107108
if (session_context_.so_share_ep_contexts && metadata.empty()) {
108109
// Metadata is always read from model location, this could be a source or epctx model
109-
fs::path metadata_filename = session_context_.onnx_model_path_name.parent_path() / "metadata.bin";
110+
fs::path metadata_filename;
111+
if (session_context_.so_context_file_path.empty()) {
112+
metadata_filename = session_context_.onnx_model_path_name.parent_path() / "metadata.bin";
113+
} else {
114+
metadata_filename = session_context_.so_context_file_path.parent_path() / "metadata.bin";
115+
}
110116
std::ifstream file(metadata_filename, std::ios::binary);
111117
if (file) {
112118
file >> metadata;
@@ -191,6 +197,10 @@ common::Status OpenVINOExecutionProvider::Compile(
191197
}
192198
}
193199

200+
if (session_context_.so_stop_share_ep_contexts) {
201+
shared_context_->clear();
202+
}
203+
194204
return status;
195205
}
196206

onnxruntime/core/providers/openvino/openvino_provider_factory.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ void ParseConfigOptions(ProviderInfo& pi) {
2626
pi.so_context_embed_mode = pi.config_options->GetConfigOrDefault(kOrtSessionOptionEpContextEmbedMode, "0") == "1";
2727
pi.so_share_ep_contexts = pi.config_options->GetConfigOrDefault(kOrtSessionOptionShareEpContexts, "0") == "1";
2828
pi.so_context_file_path = pi.config_options->GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");
29+
pi.so_stop_share_ep_contexts = pi.config_options->GetConfigOrDefault(kOrtSessionOptionStopShareEpContexts, "0") == "1";
2930

3031
if (pi.so_share_ep_contexts) {
3132
ov::AnyMap map;

0 commit comments

Comments
 (0)