44#include < sstream>
55#include < fstream>
66#include < utility>
7+ #include < string>
78
89#include < filesystem>
910#include < stdexcept>
@@ -20,22 +21,7 @@ using Exception = ov::Exception;
2021namespace onnxruntime {
2122namespace openvino_ep {
2223
23- SharedContext::SharedWeights::WeightsFile::WeightsFile (std::filesystem::path filename) : file_(filename, std::ios::in | std::ios::binary) {
24- try {
25- file_.exceptions (std::ifstream::failbit | std::ifstream::badbit);
26- weights_size_ = file_.seekg (0 , std::ios::end).tellg ();
27- } catch (std::ifstream::failure& e) {
28- ORT_THROW (" Error: Failed to open weight file at " , filename.string (), " " , e.what ());
29- }
30- }
31-
32- void SharedContext::SharedWeights::WeightsFile::load_weights (size_t file_offset, void * data, size_t size) {
33- ORT_ENFORCE (file_offset < weights_size_ && size <= weights_size_ && (file_offset <= weights_size_ - size), " Error: File offset is out of bounds." );
34- file_.seekg (file_offset);
35- file_.read (reinterpret_cast <char *>(data), size);
36- }
37-
38- std::ostream& operator <<(std::ostream& stream, const SharedContext::SharedWeights::Metadata::Map& metadata) {
24+ std::ostream& operator <<(std::ostream& stream, const Metadata::Map& metadata) {
3925 try {
4026 stream << metadata.size ();
4127
@@ -69,14 +55,14 @@ std::ostream& operator<<(std::ostream& stream, const SharedContext::SharedWeight
6955 return stream;
7056}
7157
72- std::istream& operator >>(std::istream& stream, SharedContext::SharedWeights:: Metadata::Map& metadata) {
58+ std::istream& operator >>(std::istream& stream, Metadata::Map& metadata) {
7359 size_t map_size{0 };
7460 try {
7561 stream >> map_size;
7662
7763 while (!stream.eof ()) {
78- SharedContext::SharedWeights:: Metadata::Key key;
79- SharedContext::SharedWeights:: Metadata::Value value;
64+ Metadata::Key key;
65+ Metadata::Value value;
8066 stream >> key.name ;
8167 stream >> value.location ;
8268 stream >> value.data_offset ;
@@ -399,8 +385,19 @@ ov::element::Type GetOpenVINOElementType(ONNX_NAMESPACE::TensorProto_DataType dt
399385
400386// Function to handle tensor creation from external data
401387void CreateOVTensors (const std::string& device_name,
402- SharedContext::SharedWeights::Metadata::Map& metadata_map,
403- SharedContext::SharedWeights::WeightsFile& weights) {
388+ Metadata::Map& metadata_map,
389+ std::filesystem::path& weights_filepath) {
390+ // File is guaranteed to exist at this point
391+ std::ifstream file (weights_filepath, std::ios::in | std::ios::binary);
392+ file.exceptions (std::ifstream::failbit | std::ifstream::badbit);
393+ size_t weights_size = std::filesystem::file_size (weights_filepath);
394+
395+ const auto load_weights = [&file, weights_size](size_t file_offset, void * data, size_t size) {
396+ ORT_ENFORCE (file_offset < weights_size && size <= weights_size && (file_offset <= weights_size - size), " Error: File offset is out of bounds." );
397+ file.seekg (file_offset);
398+ file.read (reinterpret_cast <char *>(data), size);
399+ };
400+
404401 for (auto & [key, value] : metadata_map) {
405402 if (value.tensor ) continue ;
406403
@@ -416,18 +413,18 @@ void CreateOVTensors(const std::string& device_name,
416413 auto && remote_tensor = npu_context.create_l0_host_tensor (ov_elementType, value.dimensions , ov::intel_npu::TensorType::INPUT);
417414
418415 // Copy data to remote tensor
419- weights. load_weights (value.data_offset , remote_tensor.get (), value.size );
416+ load_weights (value.data_offset , remote_tensor.get (), value.size );
420417 value.tensor = std::make_shared<ov::Tensor>(remote_tensor);
421418 } else {
422419 // Use vanilla tensors
423420 value.tensor = std::make_shared<ov::Tensor>(ov_elementType, value.dimensions );
424- weights. load_weights (value.data_offset , value.tensor ->data (), value.size );
421+ load_weights (value.data_offset , value.tensor ->data (), value.size );
425422 }
426423 ORT_ENFORCE (value.tensor ->get_byte_size () == value.size , " Unexpected tensor size mismatch" );
427424 }
428425}
429426
430- void DestroyOVTensors (SharedContext::SharedWeights:: Metadata::Map& metadata_map) {
427+ void DestroyOVTensors (Metadata::Map& metadata_map) {
431428 for (auto & [key, value] : metadata_map) {
432429 if (value.tensor ) {
433430 value.tensor .reset ();
@@ -436,6 +433,51 @@ void DestroyOVTensors(SharedContext::SharedWeights::Metadata::Map& metadata_map)
436433 metadata_map.clear ();
437434}
438435
436+ std::optional<std::string> GetExternalWeightFilename (const GraphViewer& graph) {
437+ auto get_external_location = [](const ONNX_NAMESPACE::TensorProto& proto) -> std::optional<std::string> {
438+ using mutable_proto_t = ONNX_NAMESPACE::TensorProto*;
439+ auto & mutable_proto = *const_cast <mutable_proto_t >(&proto);
440+ auto * entry_protos = mutable_proto.mutable_external_data ();
441+
442+ if (proto.has_data_location () && proto.data_location () == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) {
443+ for (int i = 0 ; i < entry_protos->size (); i++) {
444+ auto & string_entry_proto{entry_protos->at (i)};
445+ const auto & pb_key{*(string_entry_proto.mutable_key ())};
446+ const auto & pb_value{*(string_entry_proto.mutable_value ())};
447+ if (pb_key == " location" ) {
448+ return std::make_optional<std::string>(pb_value);
449+ }
450+ }
451+ }
452+
453+ return std::nullopt ;
454+ };
455+
456+ // Handle constant initializers
457+ auto & initializers = graph.GetAllInitializedTensors ();
458+ for (const auto & it : initializers) {
459+ if (auto result = get_external_location (*it.second )) {
460+ return result;
461+ }
462+ }
463+
464+ // Handle outer-scope constant initializers
465+ for (auto & node_idx : graph.GetNodesInTopologicalOrder ()) {
466+ const auto & node = graph.GetNode (node_idx);
467+ for (const auto & input : node->InputDefs ()) {
468+ if (graph.IsConstantInitializer (input->Name (), true )) {
469+ const auto & initializer_tensor = *graph.GetConstantInitializer (input->Name (), true );
470+
471+ if (auto result = get_external_location (initializer_tensor)) {
472+ return result;
473+ }
474+ }
475+ }
476+ }
477+
478+ return std::nullopt ;
479+ }
480+
439481} // namespace backend_utils
440482} // namespace openvino_ep
441483} // namespace onnxruntime
0 commit comments