1
1
import copy
2
+ import io
2
3
import logging
3
4
import os
4
5
import pickle
6
+ import pickletools
5
7
import shutil
6
8
from abc import ABC , abstractmethod
7
- from typing import Any , Dict , List , Optional , Tuple , cast
9
+ from typing import Any , Dict , List , Optional , Sequence , Tuple , cast
8
10
9
11
import torch
10
- from torch ._inductor .codecache import FxGraphCachePickler
12
+ from torch ._inductor .codecache import FxGraphCachePickler , sha256_hash
11
13
from torch .fx .experimental .proxy_tensor import unset_fake_temporarily
14
+ from torch_tensorrt ._Input import Input
15
+ from torch_tensorrt .dynamo ._settings import (
16
+ _SETTINGS_TO_BE_ENGINE_INVARIANT ,
17
+ CompilationSettings ,
18
+ )
12
19
13
20
_LOGGER : logging .Logger = logging .getLogger (__name__ )
14
21
22
+ UnpackedCacheHit = Tuple [
23
+ bytes ,
24
+ List [str ],
25
+ List [str ],
26
+ Sequence [Input ],
27
+ CompilationSettings ,
28
+ Optional [Dict [str , Any ]],
29
+ ]
30
+
15
31
16
32
class BaseEngineCache (ABC ):
17
33
@@ -24,7 +40,11 @@ def __init__(
24
40
pass
25
41
26
42
@staticmethod
27
- def get_hash (gm : torch .fx .GraphModule ) -> str :
43
+ def get_hash (
44
+ gm : torch .fx .GraphModule ,
45
+ input_specs : Sequence [Input ],
46
+ settings : CompilationSettings ,
47
+ ) -> str :
28
48
"""Get the hash value of the GraphModule
29
49
30
50
Args:
@@ -39,7 +59,23 @@ def get_hash(gm: torch.fx.GraphModule) -> str:
39
59
for name , param in new_gm .named_parameters ():
40
60
param .data .zero_ ()
41
61
42
- hash_val = cast (str , FxGraphCachePickler .get_hash (new_gm ))
62
+ graph_hash_val = cast (str , FxGraphCachePickler .get_hash (new_gm ))
63
+
64
+ input_spec_strs = [str (i ) for i in input_specs ]
65
+ with io .BytesIO () as stream :
66
+ input_specs_data = pickle .dumps (input_spec_strs )
67
+ input_specs_data = pickletools .optimize (input_specs_data )
68
+ input_specs_hash = sha256_hash (input_specs_data )
69
+
70
+ invariant_engine_specs = [
71
+ str (getattr (settings , field )) for field in _SETTINGS_TO_BE_ENGINE_INVARIANT
72
+ ]
73
+ with io .BytesIO () as stream :
74
+ engine_specs_data = pickle .dumps (invariant_engine_specs )
75
+ engine_specs_data = pickletools .optimize (engine_specs_data )
76
+ engine_specs_hash = sha256_hash (engine_specs_data )
77
+
78
+ hash_val : str = graph_hash_val + input_specs_hash + engine_specs_hash
43
79
44
80
return hash_val
45
81
@@ -48,6 +84,8 @@ def pack(
48
84
serialized_engine : bytes ,
49
85
input_names : List [str ],
50
86
output_names : List [str ],
87
+ input_specs : Sequence [Input ],
88
+ compilation_settings : CompilationSettings ,
51
89
weight_name_map : Optional [Dict [Any , Any ]],
52
90
) -> bytes :
53
91
"""Pack serialized engine, input names, output names, and weight map into a single blob
@@ -56,40 +94,83 @@ def pack(
56
94
serialized_engine (bytes): serialized TRT engine
57
95
input_names (List[str]): input names of TRT engine
58
96
output_names (List[str]): output names of TRT engine
97
+ input_specs (Sequence[Input]): input specs of TRT engine
98
+ compilation_settings (CompilationSettings): compilation settings of TRT engine
59
99
weight_name_map (Optional[Dict[Any, Any]]): weight name map for refitting
60
100
61
101
Returns:
62
102
bytes: packed blob
63
103
"""
104
+
105
+ settings = copy .deepcopy (compilation_settings )
64
106
return pickle .dumps (
65
107
{
66
108
"serialized_engine" : bytes (serialized_engine ),
67
109
"input_names" : input_names ,
68
110
"output_names" : output_names ,
111
+ "input_specs" : input_specs ,
112
+ "compilation_settings" : settings ,
69
113
"weight_name_map" : weight_name_map ,
70
114
}
71
115
)
72
116
73
117
@staticmethod
74
- def unpack (
75
- packed_obj : bytes ,
76
- ) -> Tuple [bytes , List [str ], List [str ], Optional [Dict [Any , Any ]]]:
118
+ def unpack (packed_obj : bytes ) -> UnpackedCacheHit :
77
119
"""Unpack packed blob into serialized engine, input names, output names, and weight map
78
120
79
121
Args:
80
122
packed_obj (bytes): packed blob
81
123
82
124
Returns:
83
- Tuple[bytes, List[str], List[str], Optional[Dict[str, Any]]]: serialized engine, input names, output names, weight name map
125
+ Tuple[bytes, List[str], List[str], Sequence[Input], CompilationSettings, Optional[Dict[str, Any]]]: serialized engine, input names, output names, input specs, CompilationSettings , weight name map
84
126
"""
85
127
unpacked = pickle .loads (packed_obj )
86
128
return (
87
129
unpacked ["serialized_engine" ],
88
130
unpacked ["input_names" ],
89
131
unpacked ["output_names" ],
132
+ unpacked ["input_specs" ],
133
+ unpacked ["compilation_settings" ],
90
134
unpacked ["weight_name_map" ],
91
135
)
92
136
137
+ def insert (
138
+ self , hash : str , entry : UnpackedCacheHit , * args : Any , ** kwargs : Any
139
+ ) -> None :
140
+ """
141
+ Insert a cache entry into the engine cache.
142
+
143
+ Args:
144
+ hash (str): The hash value of the GraphModule.
145
+ entry (Tuple[bytes, List[str], List[str], CompilationSettings, Optional[Dict[Any, Any]]]): The cache entry to be inserted.
146
+ *args: Variable length argument list passed to ``save``.
147
+ **kwargs: Arbitrary keyword arguments passed to ``save``.
148
+
149
+ Returns:
150
+ None
151
+ """
152
+ packed_cache_info = BaseEngineCache .pack (* entry )
153
+ return self .save (hash , packed_cache_info , * args , ** kwargs )
154
+
155
+ def check (self , hash : str , * args : Any , ** kwargs : Any ) -> Optional [UnpackedCacheHit ]:
156
+ """
157
+ Check if a cache entry exists for the given hash.
158
+
159
+ Args:
160
+ hash (str): The hash value of the GraphModule.
161
+ *args: Variable length argument list passed to ``load``.
162
+ **kwargs: Arbitrary keyword arguments passed to ``load``.
163
+
164
+ Returns:
165
+ Optional[Tuple[bytes, List[str], List[str], CompilationSettings, Optional[Dict[Any, Any]]]]: The unpacked cache entry if found, None otherwise.
166
+ """
167
+ packed_cache_info = self .load (hash , * args , ** kwargs )
168
+
169
+ if packed_cache_info :
170
+ return BaseEngineCache .unpack (packed_cache_info )
171
+ else :
172
+ return None
173
+
93
174
@abstractmethod
94
175
def save (self , hash : str , blob : bytes , * args : Any , ** kwargs : Any ) -> None :
95
176
"""Store blob in cache
@@ -203,11 +284,7 @@ def LRU() -> None:
203
284
else :
204
285
LRU ()
205
286
206
- def save (
207
- self ,
208
- hash : str ,
209
- blob : bytes ,
210
- ) -> None :
287
+ def save (self , hash : str , blob : bytes , * args : Any , ** kwargs : Any ) -> None :
211
288
blob_size = len (blob )
212
289
if blob_size > self .total_engine_cache_size :
213
290
_LOGGER .warning (
@@ -244,7 +321,7 @@ def save(
244
321
f"The size { blob_size } is still larger than the available cache size { self .available_engine_cache_size } ."
245
322
)
246
323
247
- def load (self , hash : str ) -> Optional [bytes ]:
324
+ def load (self , hash : str , * args : Any , ** kwargs : Any ) -> Optional [bytes ]:
248
325
directory = os .path .join (self .engine_cache_dir , hash )
249
326
if os .path .exists (directory ):
250
327
blob_path = os .path .join (directory , "blob.bin" )
0 commit comments