@@ -20,22 +20,7 @@ using Exception = ov::Exception;
20
20
namespace onnxruntime {
21
21
namespace openvino_ep {
22
22
23
- SharedContext::SharedWeights::WeightsFile::WeightsFile (std::filesystem::path filename) : file_(filename, std::ios::in | std::ios::binary) {
24
- try {
25
- file_.exceptions (std::ifstream::failbit | std::ifstream::badbit);
26
- weights_size_ = file_.seekg (0 , std::ios::end).tellg ();
27
- } catch (std::ifstream::failure& e) {
28
- ORT_THROW (" Error: Failed to open weight file at " , filename.string (), " " , e.what ());
29
- }
30
- }
31
-
32
- void SharedContext::SharedWeights::WeightsFile::load_weights (size_t file_offset, void * data, size_t size) {
33
- ORT_ENFORCE (file_offset < weights_size_ && size <= weights_size_ && (file_offset <= weights_size_ - size), " Error: File offset is out of bounds." );
34
- file_.seekg (file_offset);
35
- file_.read (reinterpret_cast <char *>(data), size);
36
- }
37
-
38
- std::ostream& operator <<(std::ostream& stream, const SharedContext::SharedWeights::Metadata::Map& metadata) {
23
+ std::ostream& operator <<(std::ostream& stream, const Metadata::Map& metadata) {
39
24
try {
40
25
stream << metadata.size ();
41
26
@@ -69,14 +54,14 @@ std::ostream& operator<<(std::ostream& stream, const SharedContext::SharedWeight
69
54
return stream;
70
55
}
71
56
72
- std::istream& operator >>(std::istream& stream, SharedContext::SharedWeights:: Metadata::Map& metadata) {
57
+ std::istream& operator >>(std::istream& stream, Metadata::Map& metadata) {
73
58
size_t map_size{0 };
74
59
try {
75
60
stream >> map_size;
76
61
77
62
while (!stream.eof ()) {
78
- SharedContext::SharedWeights:: Metadata::Key key;
79
- SharedContext::SharedWeights:: Metadata::Value value;
63
+ Metadata::Key key;
64
+ Metadata::Value value;
80
65
stream >> key.name ;
81
66
stream >> value.location ;
82
67
stream >> value.data_offset ;
@@ -399,8 +384,19 @@ ov::element::Type GetOpenVINOElementType(ONNX_NAMESPACE::TensorProto_DataType dt
399
384
400
385
// Function to handle tensor creation from external data
401
386
void CreateOVTensors (const std::string& device_name,
402
- SharedContext::SharedWeights::Metadata::Map& metadata_map,
403
- SharedContext::SharedWeights::WeightsFile& weights) {
387
+ Metadata::Map& metadata_map,
388
+ std::filesystem::path& weights_filepath) {
389
+ // File is guaranteed to exist at this point
390
+ std::ifstream file (weights_filepath, std::ios::in | std::ios::binary);
391
+ file.exceptions (std::ifstream::failbit | std::ifstream::badbit);
392
+ size_t weights_size = std::filesystem::file_size (weights_filepath);
393
+
394
+ const auto load_weights = [&file, weights_size](size_t file_offset, void * data, size_t size) {
395
+ ORT_ENFORCE (file_offset < weights_size && size <= weights_size && (file_offset <= weights_size - size), " Error: File offset is out of bounds." );
396
+ file.seekg (file_offset);
397
+ file.read (reinterpret_cast <char *>(data), size);
398
+ };
399
+
404
400
for (auto & [key, value] : metadata_map) {
405
401
if (value.tensor ) continue ;
406
402
@@ -416,18 +412,18 @@ void CreateOVTensors(const std::string& device_name,
416
412
auto && remote_tensor = npu_context.create_l0_host_tensor (ov_elementType, value.dimensions , ov::intel_npu::TensorType::INPUT);
417
413
418
414
// Copy data to remote tensor
419
- weights. load_weights (value.data_offset , remote_tensor.get (), value.size );
415
+ load_weights (value.data_offset , remote_tensor.get (), value.size );
420
416
value.tensor = std::make_shared<ov::Tensor>(remote_tensor);
421
417
} else {
422
418
// Use vanilla tensors
423
419
value.tensor = std::make_shared<ov::Tensor>(ov_elementType, value.dimensions );
424
- weights. load_weights (value.data_offset , value.tensor ->data (), value.size );
420
+ load_weights (value.data_offset , value.tensor ->data (), value.size );
425
421
}
426
422
ORT_ENFORCE (value.tensor ->get_byte_size () == value.size , " Unexpected tensor size mismatch" );
427
423
}
428
424
}
429
425
430
- void DestroyOVTensors (SharedContext::SharedWeights:: Metadata::Map& metadata_map) {
426
+ void DestroyOVTensors (Metadata::Map& metadata_map) {
431
427
for (auto & [key, value] : metadata_map) {
432
428
if (value.tensor ) {
433
429
value.tensor .reset ();
0 commit comments