Skip to content

Commit 2be2e64

Browse files
authored
fix: distingush engines based on compilation settings in addition to … (#3155)
Signed-off-by: Naren Dasan <[email protected]>
1 parent 4d2a04a commit 2be2e64

20 files changed

+519
-131
lines changed

examples/dynamo/engine_caching_bert_example.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def compile_bert(iterations=3):
5252
"truncate_double": True,
5353
"debug": False,
5454
"min_block_size": 1,
55-
"make_refitable": True,
55+
"make_refittable": True,
5656
"cache_built_engines": cache_built_engines,
5757
"reuse_cached_engines": reuse_cached_engines,
5858
"engine_cache_dir": "/tmp/torch_trt_bert_engine_cache",

examples/dynamo/engine_caching_example.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
6363
# in a subsequent compilation, either as part of this session or a new session, the cache will
6464
# pull the built engine and **refit** the weights which can reduce compilation times by orders of magnitude.
6565
# As such, in order to insert a new engine into the cache (i.e. ``cache_built_engines=True``),
66-
# the engine must be refitable (``make_refittable=True``). See :ref:`refit_engine_example` for more details.
66+
# the engine must be refittable (``make_refittable=True``). See :ref:`refit_engine_example` for more details.
6767

6868

6969
def torch_compile(iterations=3):
@@ -97,7 +97,7 @@ def torch_compile(iterations=3):
9797
"enabled_precisions": enabled_precisions,
9898
"debug": debug,
9999
"min_block_size": min_block_size,
100-
"make_refitable": True,
100+
"make_refittable": True,
101101
"cache_built_engines": cache_built_engines,
102102
"reuse_cached_engines": reuse_cached_engines,
103103
},
@@ -157,7 +157,7 @@ def dynamo_compile(iterations=3):
157157
enabled_precisions=enabled_precisions,
158158
debug=debug,
159159
min_block_size=min_block_size,
160-
make_refitable=True,
160+
make_refittable=True,
161161
cache_built_engines=cache_built_engines,
162162
reuse_cached_engines=reuse_cached_engines,
163163
engine_cache_size=1 << 30, # 1GB
@@ -268,7 +268,7 @@ def torch_compile_my_cache(iterations=3):
268268
"enabled_precisions": enabled_precisions,
269269
"debug": debug,
270270
"min_block_size": min_block_size,
271-
"make_refitable": True,
271+
"make_refittable": True,
272272
"cache_built_engines": cache_built_engines,
273273
"reuse_cached_engines": reuse_cached_engines,
274274
"custom_engine_cache": engine_cache,

examples/dynamo/mutable_torchtrt_module_example.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
settings = {
3232
"use_python": False,
3333
"enabled_precisions": {torch.float32},
34-
"make_refitable": True,
34+
"make_refittable": True,
3535
}
3636

3737
model = models.resnet18(pretrained=True).eval().to("cuda")
@@ -80,7 +80,7 @@
8080
"use_python_runtime": True,
8181
"enabled_precisions": {torch.float16},
8282
"debug": True,
83-
"make_refitable": True,
83+
"make_refittable": True,
8484
}
8585

8686
model_id = "runwayml/stable-diffusion-v1-5"

examples/dynamo/refit_engine_example.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@
4343

4444

4545
# %%
46-
# Make a Refitable Compilation Program
46+
# Make a refittable Compilation Program
4747
# ---------------------------------------
4848
#
4949
# The inital step is to compile a module and save it as with a normal. Note that there is an
50-
# additional parameter `make_refitable` that is set to `True`. This parameter is used to
50+
# additional parameter `make_refittable` that is set to `True`. This parameter is used to
5151
# indicate that the engine being built should support weight refitting later. Engines built without
5252
# these setttings will not be able to be refit.
5353
#
@@ -69,7 +69,7 @@
6969
debug=debug,
7070
min_block_size=min_block_size,
7171
torch_executed_ops=torch_executed_ops,
72-
make_refitable=True,
72+
make_refittable=True,
7373
) # Output is a torch.fx.GraphModule
7474

7575
# Save the graph module as an exported program

py/torch_tensorrt/_Input.py

+28
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,34 @@ def __str__(self) -> str:
220220
def __repr__(self) -> str:
221221
return self.__str__()
222222

223+
@staticmethod
224+
def equivalent_spec(a: Input, b: Input) -> bool:
225+
if a.shape_mode != b.shape_mode:
226+
return False
227+
228+
if a.shape_mode == Input._ShapeMode.DYNAMIC:
229+
assert isinstance(a.shape, dict)
230+
assert isinstance(b.shape, dict)
231+
checks = [
232+
a.shape["min_shape"] == b.shape["min_shape"],
233+
a.shape["opt_shape"] == b.shape["opt_shape"],
234+
a.shape["max_shape"] == b.shape["max_shape"],
235+
a.dtype == b.dtype,
236+
a.format == b.format,
237+
a.low_tensor_domain_incl == b.low_tensor_domain_incl,
238+
a.high_tensor_domain_excl == b.high_tensor_domain_excl,
239+
]
240+
return all(checks)
241+
else:
242+
checks = [
243+
a.shape == b.shape,
244+
a.dtype == b.dtype,
245+
a.format == b.format,
246+
a.low_tensor_domain_incl == b.low_tensor_domain_incl,
247+
a.high_tensor_domain_excl == b.high_tensor_domain_excl,
248+
]
249+
return all(checks)
250+
223251
@staticmethod
224252
def _supported_input_size_type(input_size: Any) -> bool:
225253
if isinstance(input_size, torch.Size):

py/torch_tensorrt/dynamo/_compiler.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def compile(
6060
Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]]
6161
] = _defaults.ENABLED_PRECISIONS,
6262
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
63-
make_refitable: bool = _defaults.MAKE_REFITABLE,
63+
make_refittable: bool = _defaults.MAKE_REFITTABLE,
6464
debug: bool = _defaults.DEBUG,
6565
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
6666
workspace_size: int = _defaults.WORKSPACE_SIZE,
@@ -180,14 +180,14 @@ def compile(
180180

181181
if "refit" in kwargs.keys():
182182
warnings.warn(
183-
"Refit is deprecated. Please use make_refitable=True if you want to enable refitting of the engine.",
183+
"Refit is deprecated. Please use make_refittable=True if you want to enable refitting of the engine.",
184184
DeprecationWarning,
185185
stacklevel=2,
186186
)
187-
if make_refitable:
188-
raise ValueError("Use flag make_refitable only. Flag refit is deprecated.")
187+
if make_refittable:
188+
raise ValueError("Use flag make_refittable only. Flag refit is deprecated.")
189189
else:
190-
make_refitable = kwargs["refit"]
190+
make_refittable = kwargs["refit"]
191191

192192
engine_capability = EngineCapability._from(engine_capability)
193193

@@ -238,8 +238,8 @@ def compile(
238238
engine_cache = None
239239
if cache_built_engines or reuse_cached_engines:
240240
assert (
241-
make_refitable
242-
), "Engine caching requires make_refitable to be set to True"
241+
make_refittable
242+
), "Engine caching requires make_refittable to be set to True"
243243
engine_cache = (
244244
custom_engine_cache
245245
if custom_engine_cache is not None
@@ -270,7 +270,7 @@ def compile(
270270
"require_full_compilation": require_full_compilation,
271271
"disable_tf32": disable_tf32,
272272
"sparse_weights": sparse_weights,
273-
"make_refitable": make_refitable,
273+
"make_refittable": make_refittable,
274274
"engine_capability": engine_capability,
275275
"dla_sram_size": dla_sram_size,
276276
"dla_local_dram_size": dla_local_dram_size,
@@ -513,7 +513,7 @@ def convert_exported_program_to_serialized_trt_engine(
513513
require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION,
514514
disable_tf32: bool = _defaults.DISABLE_TF32,
515515
sparse_weights: bool = _defaults.SPARSE_WEIGHTS,
516-
make_refitable: bool = _defaults.MAKE_REFITABLE,
516+
make_refittable: bool = _defaults.MAKE_REFITTABLE,
517517
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
518518
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
519519
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
@@ -600,7 +600,7 @@ def convert_exported_program_to_serialized_trt_engine(
600600
)
601601
if "refit" in kwargs.keys():
602602
warnings.warn(
603-
"Refit is deprecated. Please use make_refitable=True if you want to enable refitting of the engine.",
603+
"Refit is deprecated. Please use make_refittable=True if you want to enable refitting of the engine.",
604604
DeprecationWarning,
605605
stacklevel=2,
606606
)
@@ -646,7 +646,7 @@ def convert_exported_program_to_serialized_trt_engine(
646646
"require_full_compilation": require_full_compilation,
647647
"disable_tf32": disable_tf32,
648648
"sparse_weights": sparse_weights,
649-
"make_refitable": make_refitable,
649+
"make_refittable": make_refittable,
650650
"engine_capability": engine_capability,
651651
"num_avg_timing_iters": num_avg_timing_iters,
652652
"dla_sram_size": dla_sram_size,

py/torch_tensorrt/dynamo/_defaults.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
USE_PYTHON_RUNTIME = False
2727
USE_FAST_PARTITIONER = True
2828
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
29-
MAKE_REFITABLE = False
29+
MAKE_REFITTABLE = False
3030
REQUIRE_FULL_COMPILATION = False
3131
DRYRUN = False
3232
HARDWARE_COMPATIBLE = False

py/torch_tensorrt/dynamo/_engine_cache.py

+91-14
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,33 @@
11
import copy
2+
import io
23
import logging
34
import os
45
import pickle
6+
import pickletools
57
import shutil
68
from abc import ABC, abstractmethod
7-
from typing import Any, Dict, List, Optional, Tuple, cast
9+
from typing import Any, Dict, List, Optional, Sequence, Tuple, cast
810

911
import torch
10-
from torch._inductor.codecache import FxGraphCachePickler
12+
from torch._inductor.codecache import FxGraphCachePickler, sha256_hash
1113
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
14+
from torch_tensorrt._Input import Input
15+
from torch_tensorrt.dynamo._settings import (
16+
_SETTINGS_TO_BE_ENGINE_INVARIANT,
17+
CompilationSettings,
18+
)
1219

1320
_LOGGER: logging.Logger = logging.getLogger(__name__)
1421

22+
UnpackedCacheHit = Tuple[
23+
bytes,
24+
List[str],
25+
List[str],
26+
Sequence[Input],
27+
CompilationSettings,
28+
Optional[Dict[str, Any]],
29+
]
30+
1531

1632
class BaseEngineCache(ABC):
1733

@@ -24,7 +40,11 @@ def __init__(
2440
pass
2541

2642
@staticmethod
27-
def get_hash(gm: torch.fx.GraphModule) -> str:
43+
def get_hash(
44+
gm: torch.fx.GraphModule,
45+
input_specs: Sequence[Input],
46+
settings: CompilationSettings,
47+
) -> str:
2848
"""Get the hash value of the GraphModule
2949
3050
Args:
@@ -39,7 +59,23 @@ def get_hash(gm: torch.fx.GraphModule) -> str:
3959
for name, param in new_gm.named_parameters():
4060
param.data.zero_()
4161

42-
hash_val = cast(str, FxGraphCachePickler.get_hash(new_gm))
62+
graph_hash_val = cast(str, FxGraphCachePickler.get_hash(new_gm))
63+
64+
input_spec_strs = [str(i) for i in input_specs]
65+
with io.BytesIO() as stream:
66+
input_specs_data = pickle.dumps(input_spec_strs)
67+
input_specs_data = pickletools.optimize(input_specs_data)
68+
input_specs_hash = sha256_hash(input_specs_data)
69+
70+
invariant_engine_specs = [
71+
str(getattr(settings, field)) for field in _SETTINGS_TO_BE_ENGINE_INVARIANT
72+
]
73+
with io.BytesIO() as stream:
74+
engine_specs_data = pickle.dumps(invariant_engine_specs)
75+
engine_specs_data = pickletools.optimize(engine_specs_data)
76+
engine_specs_hash = sha256_hash(engine_specs_data)
77+
78+
hash_val: str = graph_hash_val + input_specs_hash + engine_specs_hash
4379

4480
return hash_val
4581

@@ -48,6 +84,8 @@ def pack(
4884
serialized_engine: bytes,
4985
input_names: List[str],
5086
output_names: List[str],
87+
input_specs: Sequence[Input],
88+
compilation_settings: CompilationSettings,
5189
weight_name_map: Optional[Dict[Any, Any]],
5290
) -> bytes:
5391
"""Pack serialized engine, input names, output names, and weight map into a single blob
@@ -56,40 +94,83 @@ def pack(
5694
serialized_engine (bytes): serialized TRT engine
5795
input_names (List[str]): input names of TRT engine
5896
output_names (List[str]): output names of TRT engine
97+
input_specs (Sequence[Input]): input specs of TRT engine
98+
compilation_settings (CompilationSettings): compilation settings of TRT engine
5999
weight_name_map (Optional[Dict[Any, Any]]): weight name map for refitting
60100
61101
Returns:
62102
bytes: packed blob
63103
"""
104+
105+
settings = copy.deepcopy(compilation_settings)
64106
return pickle.dumps(
65107
{
66108
"serialized_engine": bytes(serialized_engine),
67109
"input_names": input_names,
68110
"output_names": output_names,
111+
"input_specs": input_specs,
112+
"compilation_settings": settings,
69113
"weight_name_map": weight_name_map,
70114
}
71115
)
72116

73117
@staticmethod
74-
def unpack(
75-
packed_obj: bytes,
76-
) -> Tuple[bytes, List[str], List[str], Optional[Dict[Any, Any]]]:
118+
def unpack(packed_obj: bytes) -> UnpackedCacheHit:
77119
"""Unpack packed blob into serialized engine, input names, output names, and weight map
78120
79121
Args:
80122
packed_obj (bytes): packed blob
81123
82124
Returns:
83-
Tuple[bytes, List[str], List[str], Optional[Dict[str, Any]]]: serialized engine, input names, output names, weight name map
125+
Tuple[bytes, List[str], List[str], Sequence[Input], CompilationSettings, Optional[Dict[str, Any]]]: serialized engine, input names, output names, input specs, CompilationSettings, weight name map
84126
"""
85127
unpacked = pickle.loads(packed_obj)
86128
return (
87129
unpacked["serialized_engine"],
88130
unpacked["input_names"],
89131
unpacked["output_names"],
132+
unpacked["input_specs"],
133+
unpacked["compilation_settings"],
90134
unpacked["weight_name_map"],
91135
)
92136

137+
def insert(
138+
self, hash: str, entry: UnpackedCacheHit, *args: Any, **kwargs: Any
139+
) -> None:
140+
"""
141+
Insert a cache entry into the engine cache.
142+
143+
Args:
144+
hash (str): The hash value of the GraphModule.
145+
entry (Tuple[bytes, List[str], List[str], CompilationSettings, Optional[Dict[Any, Any]]]): The cache entry to be inserted.
146+
*args: Variable length argument list passed to ``save``.
147+
**kwargs: Arbitrary keyword arguments passed to ``save``.
148+
149+
Returns:
150+
None
151+
"""
152+
packed_cache_info = BaseEngineCache.pack(*entry)
153+
return self.save(hash, packed_cache_info, *args, **kwargs)
154+
155+
def check(self, hash: str, *args: Any, **kwargs: Any) -> Optional[UnpackedCacheHit]:
156+
"""
157+
Check if a cache entry exists for the given hash.
158+
159+
Args:
160+
hash (str): The hash value of the GraphModule.
161+
*args: Variable length argument list passed to ``load``.
162+
**kwargs: Arbitrary keyword arguments passed to ``load``.
163+
164+
Returns:
165+
Optional[Tuple[bytes, List[str], List[str], CompilationSettings, Optional[Dict[Any, Any]]]]: The unpacked cache entry if found, None otherwise.
166+
"""
167+
packed_cache_info = self.load(hash, *args, **kwargs)
168+
169+
if packed_cache_info:
170+
return BaseEngineCache.unpack(packed_cache_info)
171+
else:
172+
return None
173+
93174
@abstractmethod
94175
def save(self, hash: str, blob: bytes, *args: Any, **kwargs: Any) -> None:
95176
"""Store blob in cache
@@ -203,11 +284,7 @@ def LRU() -> None:
203284
else:
204285
LRU()
205286

206-
def save(
207-
self,
208-
hash: str,
209-
blob: bytes,
210-
) -> None:
287+
def save(self, hash: str, blob: bytes, *args: Any, **kwargs: Any) -> None:
211288
blob_size = len(blob)
212289
if blob_size > self.total_engine_cache_size:
213290
_LOGGER.warning(
@@ -244,7 +321,7 @@ def save(
244321
f"The size {blob_size} is still larger than the available cache size {self.available_engine_cache_size}."
245322
)
246323

247-
def load(self, hash: str) -> Optional[bytes]:
324+
def load(self, hash: str, *args: Any, **kwargs: Any) -> Optional[bytes]:
248325
directory = os.path.join(self.engine_cache_dir, hash)
249326
if os.path.exists(directory):
250327
blob_path = os.path.join(directory, "blob.bin")

0 commit comments

Comments
 (0)