44#include < sstream>
55#include < fstream>
66#include < utility>
7- #include < string>
87
98#include < filesystem>
109#include < stdexcept>
@@ -21,7 +20,22 @@ using Exception = ov::Exception;
2120namespace onnxruntime {
2221namespace openvino_ep {
2322
24- std::ostream& operator <<(std::ostream& stream, const Metadata::Map& metadata) {
23+ SharedContext::SharedWeights::WeightsFile::WeightsFile (std::filesystem::path filename) : file_(filename, std::ios::in | std::ios::binary) {
24+ try {
25+ file_.exceptions (std::ifstream::failbit | std::ifstream::badbit);
26+ weights_size_ = file_.seekg (0 , std::ios::end).tellg ();
27+ } catch (std::ifstream::failure& e) {
28+ ORT_THROW (" Error: Failed to open weight file at " , filename.string (), " " , e.what ());
29+ }
30+ }
31+
32+ void SharedContext::SharedWeights::WeightsFile::load_weights (size_t file_offset, void * data, size_t size) {
33+ ORT_ENFORCE (file_offset < weights_size_ && size <= weights_size_ && (file_offset <= weights_size_ - size), " Error: File offset is out of bounds." );
34+ file_.seekg (file_offset);
35+ file_.read (reinterpret_cast <char *>(data), size);
36+ }
37+
38+ std::ostream& operator <<(std::ostream& stream, const SharedContext::SharedWeights::Metadata::Map& metadata) {
2539 try {
2640 stream << metadata.size ();
2741
@@ -55,14 +69,14 @@ std::ostream& operator<<(std::ostream& stream, const Metadata::Map& metadata) {
5569 return stream;
5670}
5771
58- std::istream& operator >>(std::istream& stream, Metadata::Map& metadata) {
72+ std::istream& operator >>(std::istream& stream, SharedContext::SharedWeights:: Metadata::Map& metadata) {
5973 size_t map_size{0 };
6074 try {
6175 stream >> map_size;
6276
6377 while (!stream.eof ()) {
64- Metadata::Key key;
65- Metadata::Value value;
78+ SharedContext::SharedWeights:: Metadata::Key key;
79+ SharedContext::SharedWeights:: Metadata::Value value;
6680 stream >> key.name ;
6781 stream >> value.location ;
6882 stream >> value.data_offset ;
@@ -385,19 +399,8 @@ ov::element::Type GetOpenVINOElementType(ONNX_NAMESPACE::TensorProto_DataType dt
385399
386400// Function to handle tensor creation from external data
387401void CreateOVTensors (const std::string& device_name,
388- Metadata::Map& metadata_map,
389- std::filesystem::path& weights_filepath) {
390- // File is guaranteed to exist at this point
391- std::ifstream file (weights_filepath, std::ios::in | std::ios::binary);
392- file.exceptions (std::ifstream::failbit | std::ifstream::badbit);
393- size_t weights_size = std::filesystem::file_size (weights_filepath);
394-
395- const auto load_weights = [&file, weights_size](size_t file_offset, void * data, size_t size) {
396- ORT_ENFORCE (file_offset < weights_size && size <= weights_size && (file_offset <= weights_size - size), " Error: File offset is out of bounds." );
397- file.seekg (file_offset);
398- file.read (reinterpret_cast <char *>(data), size);
399- };
400-
402+ SharedContext::SharedWeights::Metadata::Map& metadata_map,
403+ SharedContext::SharedWeights::WeightsFile& weights) {
401404 for (auto & [key, value] : metadata_map) {
402405 if (value.tensor ) continue ;
403406
@@ -413,18 +416,18 @@ void CreateOVTensors(const std::string& device_name,
413416 auto && remote_tensor = npu_context.create_l0_host_tensor (ov_elementType, value.dimensions , ov::intel_npu::TensorType::INPUT);
414417
415418 // Copy data to remote tensor
416- load_weights (value.data_offset , remote_tensor.get (), value.size );
419+ weights. load_weights (value.data_offset , remote_tensor.get (), value.size );
417420 value.tensor = std::make_shared<ov::Tensor>(remote_tensor);
418421 } else {
419422 // Use vanilla tensors
420423 value.tensor = std::make_shared<ov::Tensor>(ov_elementType, value.dimensions );
421- load_weights (value.data_offset , value.tensor ->data (), value.size );
424+ weights. load_weights (value.data_offset , value.tensor ->data (), value.size );
422425 }
423426 ORT_ENFORCE (value.tensor ->get_byte_size () == value.size , " Unexpected tensor size mismatch" );
424427 }
425428}
426429
427- void DestroyOVTensors (Metadata::Map& metadata_map) {
430+ void DestroyOVTensors (SharedContext::SharedWeights:: Metadata::Map& metadata_map) {
428431 for (auto & [key, value] : metadata_map) {
429432 if (value.tensor ) {
430433 value.tensor .reset ();
@@ -433,51 +436,6 @@ void DestroyOVTensors(Metadata::Map& metadata_map) {
433436 metadata_map.clear ();
434437}
435438
436- std::optional<std::string> GetExternalWeightFilename (const GraphViewer& graph) {
437- auto get_external_location = [](const ONNX_NAMESPACE::TensorProto& proto) -> std::optional<std::string> {
438- using mutable_proto_t = ONNX_NAMESPACE::TensorProto*;
439- auto & mutable_proto = *const_cast <mutable_proto_t >(&proto);
440- auto * entry_protos = mutable_proto.mutable_external_data ();
441-
442- if (proto.has_data_location () && proto.data_location () == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) {
443- for (int i = 0 ; i < entry_protos->size (); i++) {
444- auto & string_entry_proto{entry_protos->at (i)};
445- const auto & pb_key{*(string_entry_proto.mutable_key ())};
446- const auto & pb_value{*(string_entry_proto.mutable_value ())};
447- if (pb_key == " location" ) {
448- return std::make_optional<std::string>(pb_value);
449- }
450- }
451- }
452-
453- return std::nullopt ;
454- };
455-
456- // Handle constant initializers
457- auto & initializers = graph.GetAllInitializedTensors ();
458- for (const auto & it : initializers) {
459- if (auto result = get_external_location (*it.second )) {
460- return result;
461- }
462- }
463-
464- // Handle outer-scope constant initializers
465- for (auto & node_idx : graph.GetNodesInTopologicalOrder ()) {
466- const auto & node = graph.GetNode (node_idx);
467- for (const auto & input : node->InputDefs ()) {
468- if (graph.IsConstantInitializer (input->Name (), true )) {
469- const auto & initializer_tensor = *graph.GetConstantInitializer (input->Name (), true );
470-
471- if (auto result = get_external_location (initializer_tensor)) {
472- return result;
473- }
474- }
475- }
476- }
477-
478- return std::nullopt ;
479- }
480-
481439} // namespace backend_utils
482440} // namespace openvino_ep
483441} // namespace onnxruntime
0 commit comments