Skip to content

Commit

Permalink
fix: distingush engines based on compilation settings in addition to …
Browse files Browse the repository at this point in the history
…graph structure

Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Sep 11, 2024
1 parent 8154408 commit e2ca04c
Show file tree
Hide file tree
Showing 18 changed files with 465 additions and 120 deletions.
2 changes: 1 addition & 1 deletion examples/dynamo/engine_caching_bert_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def compile_bert(iterations=3):
"truncate_double": True,
"debug": False,
"min_block_size": 1,
"make_refitable": True,
"make_refittable": True,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"engine_cache_dir": "/tmp/torch_trt_bert_engine_cache",
Expand Down
8 changes: 4 additions & 4 deletions examples/dynamo/engine_caching_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
# in a subsequent compilation, either as part of this session or a new session, the cache will
# pull the built engine and **refit** the weights which can reduce compilation times by orders of magnitude.
# As such, in order to insert a new engine into the cache (i.e. ``cache_built_engines=True``),
# the engine must be refitable (``make_refittable=True``). See :ref:`refit_engine_example` for more details.
# the engine must be refittable (``make_refittable=True``). See :ref:`refit_engine_example` for more details.


def torch_compile(iterations=3):
Expand Down Expand Up @@ -97,7 +97,7 @@ def torch_compile(iterations=3):
"enabled_precisions": enabled_precisions,
"debug": debug,
"min_block_size": min_block_size,
"make_refitable": True,
"make_refittable": True,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
},
Expand Down Expand Up @@ -157,7 +157,7 @@ def dynamo_compile(iterations=3):
enabled_precisions=enabled_precisions,
debug=debug,
min_block_size=min_block_size,
make_refitable=True,
make_refittable=True,
cache_built_engines=cache_built_engines,
reuse_cached_engines=reuse_cached_engines,
engine_cache_size=1 << 30, # 1GB
Expand Down Expand Up @@ -268,7 +268,7 @@ def torch_compile_my_cache(iterations=3):
"enabled_precisions": enabled_precisions,
"debug": debug,
"min_block_size": min_block_size,
"make_refitable": True,
"make_refittable": True,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"custom_engine_cache": engine_cache,
Expand Down
4 changes: 2 additions & 2 deletions examples/dynamo/mutable_torchtrt_module_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
settings = {
"use_python": False,
"enabled_precisions": {torch.float32},
"make_refitable": True,
"make_refittable": True,
}

model = models.resnet18(pretrained=True).eval().to("cuda")
Expand Down Expand Up @@ -80,7 +80,7 @@
"use_python_runtime": True,
"enabled_precisions": {torch.float16},
"debug": True,
"make_refitable": True,
"make_refittable": True,
}

model_id = "runwayml/stable-diffusion-v1-5"
Expand Down
6 changes: 3 additions & 3 deletions examples/dynamo/refit_engine_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@


# %%
# Make a Refitable Compilation Program
# Make a refittable Compilation Program
# ---------------------------------------
#
# The inital step is to compile a module and save it as with a normal. Note that there is an
# additional parameter `make_refitable` that is set to `True`. This parameter is used to
# additional parameter `make_refittable` that is set to `True`. This parameter is used to
# indicate that the engine being built should support weight refitting later. Engines built without
# these setttings will not be able to be refit.
#
Expand All @@ -69,7 +69,7 @@
debug=debug,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
make_refitable=True,
make_refittable=True,
) # Output is a torch.fx.GraphModule

# Save the graph module as an exported program
Expand Down
22 changes: 11 additions & 11 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def compile(
Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]]
] = _defaults.ENABLED_PRECISIONS,
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
make_refitable: bool = _defaults.MAKE_REFITABLE,
make_refittable: bool = _defaults.MAKE_REFITTABLE,
debug: bool = _defaults.DEBUG,
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
workspace_size: int = _defaults.WORKSPACE_SIZE,
Expand Down Expand Up @@ -180,14 +180,14 @@ def compile(

if "refit" in kwargs.keys():
warnings.warn(
"Refit is deprecated. Please use make_refitable=True if you want to enable refitting of the engine.",
"Refit is deprecated. Please use make_refittable=True if you want to enable refitting of the engine.",
DeprecationWarning,
stacklevel=2,
)
if make_refitable:
raise ValueError("Use flag make_refitable only. Flag refit is deprecated.")
if make_refittable:
raise ValueError("Use flag make_refittable only. Flag refit is deprecated.")
else:
make_refitable = kwargs["refit"]
make_refittable = kwargs["refit"]

engine_capability = EngineCapability._from(engine_capability)

Expand Down Expand Up @@ -238,8 +238,8 @@ def compile(
engine_cache = None
if cache_built_engines or reuse_cached_engines:
assert (
make_refitable
), "Engine caching requires make_refitable to be set to True"
make_refittable
), "Engine caching requires make_refittable to be set to True"
engine_cache = (
custom_engine_cache
if custom_engine_cache is not None
Expand Down Expand Up @@ -270,7 +270,7 @@ def compile(
"require_full_compilation": require_full_compilation,
"disable_tf32": disable_tf32,
"sparse_weights": sparse_weights,
"make_refitable": make_refitable,
"make_refittable": make_refittable,
"engine_capability": engine_capability,
"dla_sram_size": dla_sram_size,
"dla_local_dram_size": dla_local_dram_size,
Expand Down Expand Up @@ -513,7 +513,7 @@ def convert_exported_program_to_serialized_trt_engine(
require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION,
disable_tf32: bool = _defaults.DISABLE_TF32,
sparse_weights: bool = _defaults.SPARSE_WEIGHTS,
make_refitable: bool = _defaults.MAKE_REFITABLE,
make_refittable: bool = _defaults.MAKE_refittable,
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
Expand Down Expand Up @@ -600,7 +600,7 @@ def convert_exported_program_to_serialized_trt_engine(
)
if "refit" in kwargs.keys():
warnings.warn(
"Refit is deprecated. Please use make_refitable=True if you want to enable refitting of the engine.",
"Refit is deprecated. Please use make_refittable=True if you want to enable refitting of the engine.",
DeprecationWarning,
stacklevel=2,
)
Expand Down Expand Up @@ -646,7 +646,7 @@ def convert_exported_program_to_serialized_trt_engine(
"require_full_compilation": require_full_compilation,
"disable_tf32": disable_tf32,
"sparse_weights": sparse_weights,
"make_refitable": make_refitable,
"make_refittable": make_refittable,
"engine_capability": engine_capability,
"num_avg_timing_iters": num_avg_timing_iters,
"dla_sram_size": dla_sram_size,
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
USE_PYTHON_RUNTIME = False
USE_FAST_PARTITIONER = True
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
MAKE_REFITABLE = False
MAKE_REFITTABLE = False
REQUIRE_FULL_COMPILATION = False
DRYRUN = False
HARDWARE_COMPATIBLE = False
Expand Down
99 changes: 92 additions & 7 deletions py/torch_tensorrt/dynamo/_engine_cache.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,34 @@
import copy
import io
import logging
import os
import pickle
import pickletools
import shutil
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, cast

import torch
from torch._inductor.codecache import FxGraphCachePickler
from sympy.polys.matrices.dense import Sequence
from torch._inductor.codecache import FxGraphCachePickler, sha256_hash
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo._settings import (
_SETTINGS_TO_BE_ENGINE_INVARIANT,
CompilationSettings,
)

_LOGGER: logging.Logger = logging.getLogger(__name__)

UnpackedCacheHit = Tuple[
bytes,
List[str],
List[str],
Sequence[Input],
CompilationSettings,
Optional[Dict[str, Any]],
]


class BaseEngineCache(ABC):

Expand All @@ -24,7 +41,11 @@ def __init__(
pass

@staticmethod
def get_hash(gm: torch.fx.GraphModule) -> str:
def get_hash(
gm: torch.fx.GraphModule,
input_specs: Sequence[Input],
settings: CompilationSettings,
) -> str:
"""Get the hash value of the GraphModule
Args:
Expand All @@ -39,7 +60,24 @@ def get_hash(gm: torch.fx.GraphModule) -> str:
for name, param in new_gm.named_parameters():
param.data.zero_()

hash_val = cast(str, FxGraphCachePickler.get_hash(new_gm))
graph_hash_val = cast(str, FxGraphCachePickler.get_hash(new_gm))

input_spec_strs = [str(i) for i in input_specs]
with io.BytesIO() as stream:
input_specs_data = pickle.dumps(input_spec_strs)
input_specs_data = pickletools.optimize(input_specs_data)
input_specs_hash = sha256_hash(input_specs_data)

invariant_engine_specs = [
str(getattr(settings, field)) for field in _SETTINGS_TO_BE_ENGINE_INVARIANT
]
with io.BytesIO() as stream:
engine_specs_data = pickle.dumps(invariant_engine_specs)
engine_specs_data = pickletools.optimize(engine_specs_data)
engine_specs_hash = sha256_hash(engine_specs_data)

# TODO: Super first idea I had hash combination solution @Evan please iterate on this
hash_val: str = graph_hash_val + input_specs_hash + engine_specs_hash

return hash_val

Expand All @@ -48,6 +86,8 @@ def pack(
serialized_engine: bytes,
input_names: List[str],
output_names: List[str],
input_specs: Tuple[Input],
compilation_settings: CompilationSettings,
weight_name_map: Optional[Dict[Any, Any]],
) -> bytes:
"""Pack serialized engine, input names, output names, and weight map into a single blob
Expand All @@ -61,35 +101,80 @@ def pack(
Returns:
bytes: packed blob
"""

settings = copy.deepcopy(compilation_settings)
settings.torch_executed_ops = {
f"torch.ops.{op.__str__()}" for op in settings.torch_executed_ops
}

return pickle.dumps(
{
"serialized_engine": bytes(serialized_engine),
"input_names": input_names,
"output_names": output_names,
"input_specs": input_specs,
"compilation_settings": settings,
"weight_name_map": weight_name_map,
}
)

@staticmethod
def unpack(
packed_obj: bytes,
) -> Tuple[bytes, List[str], List[str], Optional[Dict[Any, Any]]]:
def unpack(packed_obj: bytes) -> UnpackedCacheHit:
"""Unpack packed blob into serialized engine, input names, output names, and weight map
Args:
packed_obj (bytes): packed blob
Returns:
Tuple[bytes, List[str], List[str], Optional[Dict[str, Any]]]: serialized engine, input names, output names, weight name map
Tuple[bytes, List[str], List[str], CompilationSettings, Optional[Dict[str, Any]]]: serialized engine, input names, output names, CompilationSettings, weight name map
"""
unpacked = pickle.loads(packed_obj)
return (
unpacked["serialized_engine"],
unpacked["input_names"],
unpacked["output_names"],
unpacked["input_specs"],
unpacked["compilation_settings"],
unpacked["weight_name_map"],
)

def insert(
self, hash: str, entry: UnpackedCacheHit, *args: Any, **kwargs: Any
) -> None:
"""
Insert a cache entry into the engine cache.
Args:
hash (str): The hash value of the GraphModule.
entry (Tuple[bytes, List[str], List[str], CompilationSettings, Optional[Dict[Any, Any]]]): The cache entry to be inserted.
*args: Variable length argument list passed to ``save``.
**kwargs: Arbitrary keyword arguments passed to ``save``.
Returns:
None
"""
packed_cache_info = BaseEngineCache.pack(*entry)
return self.save(hash, packed_cache_info, *args, **kwargs)

def check(self, hash: str, *args: Any, **kwargs: Any) -> Optional[UnpackedCacheHit]:
"""
Check if a cache entry exists for the given hash.
Args:
hash (str): The hash value of the GraphModule.
*args: Variable length argument list passed to ``load``.
**kwargs: Arbitrary keyword arguments passed to ``load``.
Returns:
Optional[Tuple[bytes, List[str], List[str], CompilationSettings, Optional[Dict[Any, Any]]]]: The unpacked cache entry if found, None otherwise.
"""
packed_cache_info = self.load(hash, *args, **kwargs)

if packed_cache_info:
return BaseEngineCache.unpack(packed_cache_info)
else:
return None

@abstractmethod
def save(self, hash: str, blob: bytes, *args: Any, **kwargs: Any) -> None:
"""Store blob in cache
Expand Down
10 changes: 6 additions & 4 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def refit_module_weights(
)

# Get the settings and check the setting to be uniform
settings: CompilationSettings = None
settings: Optional[CompilationSettings] = None
if inline_module:

# Obtain the settings
Expand All @@ -254,7 +254,7 @@ def refit_module_weights(
]
assert (
encoded_metadata != ""
), "The engine provided is either not refittable or was built with a version of Torch-TensorRT that is too old, please recompile using the latest version with make_refitable=True"
), "The engine provided is either not refittable or was built with a version of Torch-TensorRT that is too old, please recompile using the latest version with make_refittable=True"
settings = TorchTensorRTModule.decode_metadata(encoded_metadata)["settings"]
# Handle torch modules
compiled_submodules_map = dict(compiled_submodules)
Expand All @@ -269,8 +269,10 @@ def refit_module_weights(
continue
settings = submodule.settings

assert settings is not None

assert (
settings.make_refitable
settings.make_refittable
), "Refitting is not enabled. Please recompile the engine with refit=True."

if settings.debug:
Expand Down Expand Up @@ -396,7 +398,7 @@ def refit_module_weights(
if isinstance(compiled_submodule, PythonTorchTensorRTModule):
engine = compiled_submodule.engine
elif isinstance(compiled_submodule, TorchTensorRTModule):
engine_info = compiled_submodule.engine.__getstate__()[0]
engine_info = compiled_submodule.engine.__getstate__()[0] # type: ignore[index]
engine = get_engine_from_encoded_engine(
engine_info[ENGINE_IDX], runtime
)
Expand Down
Loading

0 comments on commit e2ca04c

Please sign in to comment.