Skip to content

Commit

Permalink
fix: distingush engines based on compilation settings in addition to …
Browse files Browse the repository at this point in the history
…graph structure

Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Sep 11, 2024
1 parent 2518db5 commit 53c42b0
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 16 deletions.
13 changes: 5 additions & 8 deletions py/torch_tensorrt/dynamo/_engine_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def get_hash(
engine_specs_data = pickletools.optimize(engine_specs_data)
engine_specs_hash = sha256_hash(engine_specs_data)

# TODO: Super first idea I had hash combination solution @Evan please iterate on this
hash_val: str = graph_hash_val + input_specs_hash + engine_specs_hash

return hash_val
Expand All @@ -95,6 +94,8 @@ def pack(
serialized_engine (bytes): serialized TRT engine
input_names (List[str]): input names of TRT engine
output_names (List[str]): output names of TRT engine
input_specs (Sequence[Input]): input specs of TRT engine
compilation_settings (CompilationSettings): compilation settings of TRT engine
weight_name_map (Optional[Dict[Any, Any]]): weight name map for refitting
Returns:
Expand All @@ -121,7 +122,7 @@ def unpack(packed_obj: bytes) -> UnpackedCacheHit:
packed_obj (bytes): packed blob
Returns:
Tuple[bytes, List[str], List[str], CompilationSettings, Optional[Dict[str, Any]]]: serialized engine, input names, output names, CompilationSettings, weight name map
Tuple[bytes, List[str], List[str], Sequence[Input], CompilationSettings, Optional[Dict[str, Any]]]: serialized engine, input names, output names, input specs, CompilationSettings, weight name map
"""
unpacked = pickle.loads(packed_obj)
return (
Expand Down Expand Up @@ -283,11 +284,7 @@ def LRU() -> None:
else:
LRU()

def save(
self,
hash: str,
blob: bytes,
) -> None:
def save(self, hash: str, blob: bytes, *args: Any, **kwargs: Any) -> None:
blob_size = len(blob)
if blob_size > self.total_engine_cache_size:
_LOGGER.warning(
Expand Down Expand Up @@ -324,7 +321,7 @@ def save(
f"The size {blob_size} is still larger than the available cache size {self.available_engine_cache_size}."
)

def load(self, hash: str) -> Optional[bytes]:
def load(self, hash: str, *args: Any, **kwargs: Any) -> Optional[bytes]:
directory = os.path.join(self.engine_cache_dir, hash)
if os.path.exists(directory):
blob_path = os.path.join(directory, "blob.bin")
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def run(
)
assert (
setting_compatiblity
), f"Attempted to refit a prebuilt engine with incompatible settings: {incompattible_settings}, (old_settings: {engine_compilation_settings}, new_settings: {self.compilation_settings})"
), f"Attempted to refit a cached engine with incompatible settings: {incompattible_settings}, (old_settings: {engine_compilation_settings}, new_settings: {self.compilation_settings})"

for i, e in enumerate(
[
Expand All @@ -567,7 +567,7 @@ def run(
):
assert (
e
), f"Found that cached engine was built for a different input size (input: {i}, cached size: {cached_engine_input_specs[i]}, new size: {self.input_specs[i]}"
), f"Attempted to refit a cached engine built for a different input size (input: {i}, cached size: {cached_engine_input_specs[i]}, new size: {self.input_specs[i]}"

_LOGGER.info(
"Found the cached engine that corresponds to this graph. It is directly loaded."
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def parse_dynamo_kwargs(

# If cache_built_engines and reuse_cached_engines are True but custom_engine_cache is not provided,
# then create a default disk engine cache
#

engine_cache = None
if kwargs.get("cache_built_engines") or kwargs.get("reuse_cached_engines"):
assert kwargs.get(
Expand Down
22 changes: 17 additions & 5 deletions tests/py/dynamo/models/test_engine_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch_tensorrt as torch_trt
import torchvision.models as models
from torch.testing._internal.common_utils import TestCase
from torch_tensorrt.dynamo._defaults import ENGINE_CACHE_DIR
from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
Expand Down Expand Up @@ -160,9 +160,9 @@ def test_engine_settings_is_not_equal(self):
)
input_specs2 = (
torch_trt.Input(
min_shape=(1, 3, 300, 300),
opt_shape=(100, 3, 300, 300),
max_shape=(200, 3, 300, 300),
min_shape=(1, 3, 224, 224),
opt_shape=(100, 3, 224, 224),
max_shape=(200, 3, 224, 224),
),
)
settings2 = CompilationSettings(
Expand Down Expand Up @@ -192,6 +192,10 @@ def test_dynamo_compile_with_default_disk_engine_cache(self):
if os.path.exists(engine_cache_dir):
shutil.rmtree(engine_cache_dir)

def remove_timing_cache(path=TIMING_CACHE_PATH):
if os.path.exists(path):
os.remove(path)

# The 1st iteration is to measure the compilation time without engine caching
# The 2nd and 3rd iterations are to measure the compilation time with engine caching.
# Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
Expand All @@ -202,6 +206,8 @@ def test_dynamo_compile_with_default_disk_engine_cache(self):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
for i in range(3):
remove_timing_cache()
torch._dynamo.reset()
if i == 0:
cache_built_engines = False
reuse_cached_engines = False
Expand Down Expand Up @@ -351,6 +357,10 @@ def test_torch_compile_with_default_disk_engine_cache(self):
if os.path.exists(engine_cache_dir):
shutil.rmtree(engine_cache_dir)

def remove_timing_cache(path=TIMING_CACHE_PATH):
if os.path.exists(path):
os.remove(path)

# The 1st iteration is to measure the compilation time without engine caching
# The 2nd and 3rd iterations are to measure the compilation time with engine caching.
# Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
Expand All @@ -361,7 +371,9 @@ def test_torch_compile_with_default_disk_engine_cache(self):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
for i in range(3):
# remove timing cache and reset dynamo for engine caching messurement
# remove timing cache and reset dynamo for engine caching measurement
remove_timing_cache()
torch._dynamo.reset()
if i == 0:
cache_built_engines = False
reuse_cached_engines = False
Expand Down

0 comments on commit 53c42b0

Please sign in to comment.