1
- import ast
2
1
import copy
3
2
import logging
4
3
import os
4
+ import pickle
5
5
import shutil
6
+ import sys
6
7
from abc import ABC , abstractmethod
7
8
from typing import Any , Dict , List , Optional , Tuple , cast
8
9
@@ -50,6 +51,7 @@ def save(
50
51
serialized_engine : bytes ,
51
52
input_names : List [str ],
52
53
output_names : List [str ],
54
+ weight_name_map : Optional [Dict [str , Any ]] = None ,
53
55
) -> bool :
54
56
"""Save the serialized engine to hard disk
55
57
@@ -58,21 +60,24 @@ def save(
58
60
serialized_engine (bytes): serialized TRT engine
59
61
input_names (List[str]): input names of TRT engine
60
62
output_names (List[str]): output names of TRT engine
63
+ weight_name_map (Optional[Dict[str, Any]]): weight name map for refitting
61
64
62
65
Returns:
63
66
bool: whether the serialized engine is saved successfully
64
67
"""
65
68
pass
66
69
67
70
@abstractmethod
68
- def load (self , hash : str ) -> Tuple [Optional [bytes ], List [str ], List [str ]]:
71
+ def load (
72
+ self , hash : str
73
+ ) -> Tuple [Optional [bytes ], List [str ], List [str ], Optional [Dict [str , Any ]]]:
69
74
"""Load the serialized engine from hard disk
70
75
71
76
Args:
72
77
hash (str): hash value of the GraphModule
73
78
74
79
Returns:
75
- Sequence[Optional[bytes], List[str], List[str]] : serialized TRT engine, input names of TRT Engine , output names of TRT Engine
80
+ Sequence[Optional[bytes], List[str], List[str], Optional[Dict[str, Any]]] : serialized engine, input names, output names, weight name map
76
81
"""
77
82
pass
78
83
@@ -89,16 +94,16 @@ def __init__(
89
94
self .engine_cache_dir = engine_cache_dir
90
95
self .hash2size_map : Dict [str , int ] = {}
91
96
92
- def has_available_cache_size (self , serialized_engine : bytes ) -> bool :
97
+ def has_available_cache_size (self , needed_size : int ) -> bool :
93
98
"""Check if the cache has available space for saving the serialized engine
94
99
95
100
Args:
96
- serialized_engine (bytes ): serialized TRT engine
101
+ needed_size (int ): needed size for erialized TRT engine and/or weight_name_map
97
102
98
103
Returns:
99
104
bool: whether the cache has available size for the serialized engine
100
105
"""
101
- return int ( serialized_engine . nbytes ) <= self .available_engine_cache_size
106
+ return needed_size <= self .available_engine_cache_size
102
107
103
108
def clear_cache (self , needed_min_size : int ) -> bool :
104
109
"""Clear the cache to make sure at least `needed_min_size` bytes are available, if possible
@@ -154,36 +159,75 @@ def save(
154
159
serialized_engine : bytes ,
155
160
input_names : List [str ],
156
161
output_names : List [str ],
162
+ weight_name_map : Optional [Dict [str , Any ]] = None ,
157
163
) -> bool :
158
164
serialized_engine_size = int (serialized_engine .nbytes )
165
+ if weight_name_map is not None :
166
+ serialized_engine_size += sum (
167
+ sys .getsizeof (v ) for v in weight_name_map .values ()
168
+ )
159
169
if serialized_engine_size > self .total_engine_cache_size :
160
170
_LOGGER .warning (
161
171
f"The serialized engine cannot be saved because the size of the engine { serialized_engine_size } is larger than the total cache size { self .total_engine_cache_size } ."
162
172
)
163
173
return False
164
174
165
- # Check if there is enough available cache size for the serialized engine
166
- if not self .has_available_cache_size (serialized_engine ):
175
+ # Check if there is enough available cache size for the serialized engine and/or weight_name_map
176
+ if not self .has_available_cache_size (serialized_engine_size ):
167
177
self .clear_cache (serialized_engine_size )
168
178
169
179
# Save the serialized engine to the cache directory
170
- if self .has_available_cache_size (serialized_engine ):
171
- path = os .path .join (
172
- self .engine_cache_dir ,
173
- f"{ hash } /engine--{ input_names } --{ output_names } .trt" ,
180
+ if self .has_available_cache_size (serialized_engine_size ):
181
+ self .hash2size_map [hash ] = serialized_engine_size
182
+ self .available_engine_cache_size -= serialized_engine_size
183
+ directory = os .path .join (self .engine_cache_dir , hash )
184
+
185
+ engine_path = os .path .join (
186
+ directory ,
187
+ "engine.trt" ,
188
+ )
189
+ io_names_path = os .path .join (
190
+ directory ,
191
+ "io_names.pkl" ,
174
192
)
175
193
try :
176
- os .makedirs (os .path .dirname (path ), exist_ok = True )
177
- with open (path , "wb" ) as f :
194
+ os .makedirs (os .path .dirname (engine_path ), exist_ok = True )
195
+ with open (engine_path , "wb" ) as f :
178
196
f .write (serialized_engine )
179
- self .hash2size_map [hash ] = serialized_engine_size
180
- self .available_engine_cache_size -= serialized_engine_size
181
- _LOGGER .info (f"A TRT engine was cached to { path } " )
182
-
197
+ os .makedirs (os .path .dirname (io_names_path ), exist_ok = True )
198
+ with open (io_names_path , "wb" ) as f :
199
+ pickle .dump (
200
+ {"input_names" : input_names , "output_names" : output_names }, f
201
+ )
202
+ _LOGGER .info (f"The TRT engine was saved to { engine_path } " )
183
203
except Exception as e :
184
- _LOGGER .warning (f"Failed to save the TRT engine to { path } : { e } " )
204
+ del self .hash2size_map [hash ]
205
+ self .available_engine_cache_size += serialized_engine_size
206
+ shutil .rmtree (directory )
207
+ _LOGGER .warning (f"Failed to save the TRT engine to { engine_path } : { e } " )
185
208
return False
186
209
210
+ if weight_name_map is not None :
211
+ weight_name_map_path = os .path .join (
212
+ directory ,
213
+ "weight_name_map.pkl" ,
214
+ )
215
+ try :
216
+ os .makedirs (os .path .dirname (weight_name_map_path ), exist_ok = True )
217
+ with open (weight_name_map_path , "wb" ) as f :
218
+ pickle .dump (weight_name_map , f )
219
+ _LOGGER .info (
220
+ f"The weight_name_map was saved to { weight_name_map_path } "
221
+ )
222
+ except Exception as e :
223
+ del self .hash2size_map [hash ]
224
+ self .available_engine_cache_size += serialized_engine_size
225
+ shutil .rmtree (directory )
226
+ _LOGGER .warning (
227
+ f"Failed to save the weight_name_map to { weight_name_map_path } : { e } "
228
+ )
229
+ return False
230
+
187
231
return True
188
232
189
233
else :
@@ -192,21 +236,33 @@ def save(
192
236
)
193
237
return False
194
238
195
- def load (self , hash : str ) -> Tuple [Optional [bytes ], List [str ], List [str ]]:
239
+ def load (
240
+ self , hash : str
241
+ ) -> Tuple [Optional [bytes ], List [str ], List [str ], Optional [Dict [str , Any ]]]:
196
242
directory = os .path .join (self .engine_cache_dir , hash )
197
243
if os .path .exists (directory ):
198
- engine_list = os .listdir (directory )
199
- assert (
200
- len (engine_list ) == 1
201
- ), f"There are more than one engine { engine_list } under { directory } ."
202
- path = os .path .join (directory , engine_list [0 ])
203
- input_names_str , output_names_str = (
204
- engine_list [0 ].split (".trt" )[0 ].split ("--" )[1 :]
205
- )
206
- input_names = ast .literal_eval (input_names_str )
207
- output_names = ast .literal_eval (output_names_str )
208
- with open (path , "rb" ) as f :
209
- serialized_engine = f .read ()
210
- return serialized_engine , input_names , output_names
244
+ # load engine
245
+ serialized_engine = None
246
+ engine_path = os .path .join (directory , "engine.trt" )
247
+ if os .path .exists (engine_path ):
248
+ with open (engine_path , "rb" ) as f :
249
+ serialized_engine = f .read ()
250
+
251
+ input_names = []
252
+ output_names = []
253
+ io_names_path = os .path .join (directory , "io_names.pkl" )
254
+ if os .path .exists (io_names_path ):
255
+ with open (io_names_path , "rb" ) as f :
256
+ io_names = pickle .load (f )
257
+ input_names = io_names ["input_names" ]
258
+ output_names = io_names ["output_names" ]
259
+
260
+ # load weight_name_map
261
+ weight_name_map = None
262
+ weight_name_map_path = os .path .join (directory , "weight_name_map.pkl" )
263
+ if os .path .exists (weight_name_map_path ):
264
+ with open (weight_name_map_path , "rb" ) as f :
265
+ weight_name_map = pickle .load (f )
266
+ return serialized_engine , input_names , output_names , weight_name_map
211
267
else :
212
- return None , [], []
268
+ return None , [], [], {}
0 commit comments