Skip to content

Commit 59ba4a2

Browse files
committed
support saving weight name map
1 parent ae5ae51 commit 59ba4a2

File tree

4 files changed

+107
-55
lines changed

4 files changed

+107
-55
lines changed

examples/dynamo/engine_caching_bert_example.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def compile_bert(iterations=3):
4040
"use_python_runtime": False,
4141
"enabled_precisions": {torch.float},
4242
"truncate_double": True,
43-
"debug": True,
43+
"debug": False,
4444
"min_block_size": 1,
4545
"make_refitable": True,
4646
"save_engine_cache": save_engine_cache,
@@ -57,7 +57,7 @@ def compile_bert(iterations=3):
5757
torch.cuda.synchronize()
5858
times.append(start.elapsed_time(end))
5959

60-
print("-----compile bert-----> compilation time:", times, "milliseconds")
60+
print("-----compile bert-----> compilation time:\n", times, "milliseconds")
6161

6262

6363
if __name__ == "__main__":

examples/dynamo/engine_caching_example.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,11 @@ def dynamo_path(iterations=3):
7272
torch.cuda.synchronize()
7373
times.append(start.elapsed_time(end))
7474

75-
print("-----dynamo_path-----> compilation time:", times, "milliseconds")
75+
print("-----dynamo_path-----> compilation time:\n", times, "milliseconds")
7676

7777

7878
# Custom Engine Cache
7979
class MyEngineCache(BaseEngineCache):
80-
8180
def __init__(
8281
self,
8382
engine_cache_size: int,
@@ -174,7 +173,7 @@ def compile_path(iterations=3):
174173
torch.cuda.synchronize()
175174
times.append(start.elapsed_time(end))
176175

177-
print("-----compile_path-----> compilation time:", times, "milliseconds")
176+
print("-----compile_path-----> compilation time:\n", times, "milliseconds")
178177

179178

180179
if __name__ == "__main__":

py/torch_tensorrt/dynamo/_engine_caching.py

+90-34
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
import ast
21
import copy
32
import logging
43
import os
4+
import pickle
55
import shutil
6+
import sys
67
from abc import ABC, abstractmethod
78
from typing import Any, Dict, List, Optional, Tuple, cast
89

@@ -50,6 +51,7 @@ def save(
5051
serialized_engine: bytes,
5152
input_names: List[str],
5253
output_names: List[str],
54+
weight_name_map: Optional[Dict[str, Any]] = None,
5355
) -> bool:
5456
"""Save the serialized engine to hard disk
5557
@@ -58,21 +60,24 @@ def save(
5860
serialized_engine (bytes): serialized TRT engine
5961
input_names (List[str]): input names of TRT engine
6062
output_names (List[str]): output names of TRT engine
63+
weight_name_map (Optional[Dict[str, Any]]): weight name map for refitting
6164
6265
Returns:
6366
bool: whether the serialized engine is saved successfully
6467
"""
6568
pass
6669

6770
@abstractmethod
68-
def load(self, hash: str) -> Tuple[Optional[bytes], List[str], List[str]]:
71+
def load(
72+
self, hash: str
73+
) -> Tuple[Optional[bytes], List[str], List[str], Optional[Dict[str, Any]]]:
6974
"""Load the serialized engine from hard disk
7075
7176
Args:
7277
hash (str): hash value of the GraphModule
7378
7479
Returns:
75-
Sequence[Optional[bytes], List[str], List[str]]: serialized TRT engine, input names of TRT Engine, output names of TRT Engine
80+
Sequence[Optional[bytes], List[str], List[str], Optional[Dict[str, Any]]]: serialized engine, input names, output names, weight name map
7681
"""
7782
pass
7883

@@ -89,16 +94,16 @@ def __init__(
8994
self.engine_cache_dir = engine_cache_dir
9095
self.hash2size_map: Dict[str, int] = {}
9196

92-
def has_available_cache_size(self, serialized_engine: bytes) -> bool:
97+
def has_available_cache_size(self, needed_size: int) -> bool:
9398
"""Check if the cache has available space for saving the serialized engine
9499
95100
Args:
96-
serialized_engine (bytes): serialized TRT engine
101+
needed_size (int): needed size for erialized TRT engine and/or weight_name_map
97102
98103
Returns:
99104
bool: whether the cache has available size for the serialized engine
100105
"""
101-
return int(serialized_engine.nbytes) <= self.available_engine_cache_size
106+
return needed_size <= self.available_engine_cache_size
102107

103108
def clear_cache(self, needed_min_size: int) -> bool:
104109
"""Clear the cache to make sure at least `needed_min_size` bytes are available, if possible
@@ -154,36 +159,75 @@ def save(
154159
serialized_engine: bytes,
155160
input_names: List[str],
156161
output_names: List[str],
162+
weight_name_map: Optional[Dict[str, Any]] = None,
157163
) -> bool:
158164
serialized_engine_size = int(serialized_engine.nbytes)
165+
if weight_name_map is not None:
166+
serialized_engine_size += sum(
167+
sys.getsizeof(v) for v in weight_name_map.values()
168+
)
159169
if serialized_engine_size > self.total_engine_cache_size:
160170
_LOGGER.warning(
161171
f"The serialized engine cannot be saved because the size of the engine {serialized_engine_size} is larger than the total cache size {self.total_engine_cache_size}."
162172
)
163173
return False
164174

165-
# Check if there is enough available cache size for the serialized engine
166-
if not self.has_available_cache_size(serialized_engine):
175+
# Check if there is enough available cache size for the serialized engine and/or weight_name_map
176+
if not self.has_available_cache_size(serialized_engine_size):
167177
self.clear_cache(serialized_engine_size)
168178

169179
# Save the serialized engine to the cache directory
170-
if self.has_available_cache_size(serialized_engine):
171-
path = os.path.join(
172-
self.engine_cache_dir,
173-
f"{hash}/engine--{input_names}--{output_names}.trt",
180+
if self.has_available_cache_size(serialized_engine_size):
181+
self.hash2size_map[hash] = serialized_engine_size
182+
self.available_engine_cache_size -= serialized_engine_size
183+
directory = os.path.join(self.engine_cache_dir, hash)
184+
185+
engine_path = os.path.join(
186+
directory,
187+
"engine.trt",
188+
)
189+
io_names_path = os.path.join(
190+
directory,
191+
"io_names.pkl",
174192
)
175193
try:
176-
os.makedirs(os.path.dirname(path), exist_ok=True)
177-
with open(path, "wb") as f:
194+
os.makedirs(os.path.dirname(engine_path), exist_ok=True)
195+
with open(engine_path, "wb") as f:
178196
f.write(serialized_engine)
179-
self.hash2size_map[hash] = serialized_engine_size
180-
self.available_engine_cache_size -= serialized_engine_size
181-
_LOGGER.info(f"A TRT engine was cached to {path}")
182-
197+
os.makedirs(os.path.dirname(io_names_path), exist_ok=True)
198+
with open(io_names_path, "wb") as f:
199+
pickle.dump(
200+
{"input_names": input_names, "output_names": output_names}, f
201+
)
202+
_LOGGER.info(f"The TRT engine was saved to {engine_path}")
183203
except Exception as e:
184-
_LOGGER.warning(f"Failed to save the TRT engine to {path}: {e}")
204+
del self.hash2size_map[hash]
205+
self.available_engine_cache_size += serialized_engine_size
206+
shutil.rmtree(directory)
207+
_LOGGER.warning(f"Failed to save the TRT engine to {engine_path}: {e}")
185208
return False
186209

210+
if weight_name_map is not None:
211+
weight_name_map_path = os.path.join(
212+
directory,
213+
"weight_name_map.pkl",
214+
)
215+
try:
216+
os.makedirs(os.path.dirname(weight_name_map_path), exist_ok=True)
217+
with open(weight_name_map_path, "wb") as f:
218+
pickle.dump(weight_name_map, f)
219+
_LOGGER.info(
220+
f"The weight_name_map was saved to {weight_name_map_path}"
221+
)
222+
except Exception as e:
223+
del self.hash2size_map[hash]
224+
self.available_engine_cache_size += serialized_engine_size
225+
shutil.rmtree(directory)
226+
_LOGGER.warning(
227+
f"Failed to save the weight_name_map to {weight_name_map_path}: {e}"
228+
)
229+
return False
230+
187231
return True
188232

189233
else:
@@ -192,21 +236,33 @@ def save(
192236
)
193237
return False
194238

195-
def load(self, hash: str) -> Tuple[Optional[bytes], List[str], List[str]]:
239+
def load(
240+
self, hash: str
241+
) -> Tuple[Optional[bytes], List[str], List[str], Optional[Dict[str, Any]]]:
196242
directory = os.path.join(self.engine_cache_dir, hash)
197243
if os.path.exists(directory):
198-
engine_list = os.listdir(directory)
199-
assert (
200-
len(engine_list) == 1
201-
), f"There are more than one engine {engine_list} under {directory}."
202-
path = os.path.join(directory, engine_list[0])
203-
input_names_str, output_names_str = (
204-
engine_list[0].split(".trt")[0].split("--")[1:]
205-
)
206-
input_names = ast.literal_eval(input_names_str)
207-
output_names = ast.literal_eval(output_names_str)
208-
with open(path, "rb") as f:
209-
serialized_engine = f.read()
210-
return serialized_engine, input_names, output_names
244+
# load engine
245+
serialized_engine = None
246+
engine_path = os.path.join(directory, "engine.trt")
247+
if os.path.exists(engine_path):
248+
with open(engine_path, "rb") as f:
249+
serialized_engine = f.read()
250+
251+
input_names = []
252+
output_names = []
253+
io_names_path = os.path.join(directory, "io_names.pkl")
254+
if os.path.exists(io_names_path):
255+
with open(io_names_path, "rb") as f:
256+
io_names = pickle.load(f)
257+
input_names = io_names["input_names"]
258+
output_names = io_names["output_names"]
259+
260+
# load weight_name_map
261+
weight_name_map = None
262+
weight_name_map_path = os.path.join(directory, "weight_name_map.pkl")
263+
if os.path.exists(weight_name_map_path):
264+
with open(weight_name_map_path, "rb") as f:
265+
weight_name_map = pickle.load(f)
266+
return serialized_engine, input_names, output_names, weight_name_map
211267
else:
212-
return None, [], []
268+
return None, [], [], {}

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+13-16
Original file line numberDiff line numberDiff line change
@@ -482,28 +482,21 @@ def run(
482482

483483
if self.compilation_settings.load_engine_cache:
484484
# query the cached TRT engine
485-
serialized_engine, input_names, output_names = engine_cache.load(hash_val)
485+
serialized_engine, input_names, output_names, weight_name_map = (
486+
engine_cache.load(hash_val)
487+
)
486488
if serialized_engine is not None:
487489
self._input_names = input_names
488490
self._output_names = output_names
491+
self.weight_name_map = weight_name_map
489492
_LOGGER.info(
490493
"Hit the cached TRT engine. It is loaded for skipping recompilation."
491494
)
492-
493-
# refit the engine
494-
from torch_tensorrt.dynamo._refit import (
495-
_refit_single_trt_engine_with_gm,
496-
)
497-
498-
runtime = trt.Runtime(TRT_LOGGER)
499-
engine = runtime.deserialize_cuda_engine(serialized_engine)
500-
_refit_single_trt_engine_with_gm(
501-
self.module, engine, self.input_specs, self.compilation_settings
502-
)
503-
_LOGGER.info("Refitting Succeed!")
504-
505495
return TRTInterpreterResult(
506-
serialized_engine, self._input_names, self._output_names
496+
serialized_engine,
497+
self._input_names,
498+
self._output_names,
499+
self.weight_name_map,
507500
)
508501

509502
self._construct_trt_network_def()
@@ -537,7 +530,11 @@ def run(
537530
)
538531
if self.compilation_settings.save_engine_cache:
539532
engine_cache.save(
540-
hash_val, serialized_engine, self._input_names, self._output_names
533+
hash_val,
534+
serialized_engine,
535+
self._input_names,
536+
self._output_names,
537+
self.weight_name_map,
541538
)
542539

543540
with io.BytesIO() as engine_bytes:

0 commit comments

Comments
 (0)