Skip to content

Commit

Permalink
force using slow refit, add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Aug 28, 2024
1 parent 034c2ba commit 2562e2c
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 22 deletions.
32 changes: 20 additions & 12 deletions examples/dynamo/engine_caching_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

np.random.seed(0)
torch.manual_seed(0)
size = (100, 3, 224, 224)

model = models.resnet18(pretrained=True).eval().to("cuda")
enabled_precisions = {torch.float}
Expand All @@ -24,7 +23,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
os.remove(path)


def dynamo_path(iterations=3):
def dynamo_compile(iterations=3):
times = []
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
Expand All @@ -42,7 +41,7 @@ def dynamo_path(iterations=3):
# The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.
for i in range(iterations):
inputs = [torch.rand((100 + i, 3, 224, 224)).to("cuda")]
remove_timing_cache() # remove timing cache for engine caching messurement
remove_timing_cache() # remove timing cache just for engine caching messurement
if i == 0:
cache_built_engines = False
reuse_cached_engines = False
Expand All @@ -63,11 +62,15 @@ def dynamo_path(iterations=3):
reuse_cached_engines=reuse_cached_engines,
engine_cache_size=1 << 30, # 1GB
)
# output = trt_gm(*inputs)
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))

print("-----dynamo_path-----> compilation time:\n", times, "milliseconds")
print("----------------dynamo_compile----------------")
print("disable engine caching, used:", times[0], "ms")
print("enable engine caching to cache engines, used:", times[1], "ms")
print("enable engine caching to reuse engines, used:", times[2], "ms")


# Custom Engine Cache
Expand All @@ -84,11 +87,13 @@ def save(
blob: bytes,
prefix: str = "blob",
):
if not os.path.exists(self.engine_cache_dir):
os.makedirs(self.engine_cache_dir, exist_ok=True)

path = os.path.join(
self.engine_cache_dir,
f"{prefix}_{hash}.bin",
)
os.makedirs(path, exist_ok=True)
with open(path, "wb") as f:
f.write(blob)

Expand All @@ -101,7 +106,7 @@ def load(self, hash: str, prefix: str = "blob") -> Optional[bytes]:
return None


def compile_path(iterations=3):
def torch_compile(iterations=3):
times = []
engine_cache = MyEngineCache("/tmp/your_dir")
start = torch.cuda.Event(enable_timing=True)
Expand All @@ -112,8 +117,8 @@ def compile_path(iterations=3):
# Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
# The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.
for i in range(iterations):
inputs = [torch.rand(size).to("cuda")]
# remove timing cache and reset dynamo for engine caching messurement
inputs = [torch.rand((100, 3, 224, 224)).to("cuda")]
# remove timing cache and reset dynamo just for engine caching messurement
remove_timing_cache()
torch._dynamo.reset()

Expand All @@ -129,7 +134,7 @@ def compile_path(iterations=3):
model,
backend="tensorrt",
options={
"use_python_runtime": use_python_runtime,
"use_python_runtime": True,
"enabled_precisions": enabled_precisions,
"debug": debug,
"min_block_size": min_block_size,
Expand All @@ -144,9 +149,12 @@ def compile_path(iterations=3):
torch.cuda.synchronize()
times.append(start.elapsed_time(end))

print("-----compile_path-----> compilation time:\n", times, "milliseconds")
print("----------------torch_compile----------------")
print("disable engine caching, used:", times[0], "ms")
print("enable engine caching to cache engines, used:", times[1], "ms")
print("enable engine caching to reuse engines, used:", times[2], "ms")


if __name__ == "__main__":
dynamo_path()
# compile_path()
dynamo_compile()
torch_compile()
50 changes: 40 additions & 10 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,20 +502,50 @@ def run(
"Found the cached engine that corresponds to this graph. It is directly loaded."
)

from torch_tensorrt.dynamo._refit import (
_refit_single_trt_engine_with_gm,
)

runtime = trt.Runtime(TRT_LOGGER)
engine = runtime.deserialize_cuda_engine(serialized_engine)

_refit_single_trt_engine_with_gm(
new_gm=self.module,
old_engine=engine,
input_list=self.input_specs,
settings=self.compilation_settings,
weight_name_map=weight_name_map,
# TODO: Fast refit is problematic for now. It will fail if the engine has batch_norm layers.
# We use slow refit anyway. After it gets fixed, we can uncomment below and delete the slow refit.

# from torch_tensorrt.dynamo._refit import (
# _refit_single_trt_engine_with_gm,
# )
#
# _refit_single_trt_engine_with_gm(
# new_gm=self.module,
# old_engine=engine,
# input_list=self.input_specs,
# settings=self.compilation_settings,
# weight_name_map=weight_name_map,
# )

from torch_tensorrt.dynamo._refit import construct_refit_mapping

refitted = set()
refitter = trt.Refitter(engine, TRT_LOGGER)
weight_list = refitter.get_all_weights()
mapping = construct_refit_mapping(
self.module, self.input_specs, self.compilation_settings
)
trt_wt_location = trt.TensorLocation.HOST
for layer_name in weight_list:
if layer_name not in mapping:
raise AssertionError(
f"{layer_name} is not found in weight mapping"
)
# Use Numpy to create weights
weight, datatype = mapping[layer_name]
trt_wt_tensor = trt.Weights(
datatype, weight.ctypes.data, weight.size
)
refitter.set_named_weights(
layer_name, trt_wt_tensor, trt_wt_location
)
refitted.add(layer_name)

if len(refitted) != len(weight_list):
_LOGGER.warning("Not all weights have been refitted!!!")

serialized_engine = bytes(engine.serialize())

Expand Down
153 changes: 153 additions & 0 deletions tests/py/dynamo/models/test_engine_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# type: ignore
import os
import shutil
import unittest
from typing import Optional

import torch
import torch_tensorrt as torch_trt
import torchvision.models as models
from torch.testing._internal.common_utils import TestCase
from torch_tensorrt.dynamo._defaults import ENGINE_CACHE_DIR
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity

assertions = unittest.TestCase()


class MyEngineCache(BaseEngineCache):
def __init__(
self,
engine_cache_dir: str,
) -> None:
self.engine_cache_dir = engine_cache_dir

def save(
self,
hash: str,
blob: bytes,
prefix: str = "blob",
):
if not os.path.exists(self.engine_cache_dir):
os.makedirs(self.engine_cache_dir, exist_ok=True)

path = os.path.join(
self.engine_cache_dir,
f"{prefix}_{hash}.bin",
)
with open(path, "wb") as f:
f.write(blob)

def load(self, hash: str, prefix: str = "blob") -> Optional[bytes]:
path = os.path.join(self.engine_cache_dir, f"{prefix}_{hash}.bin")
if os.path.exists(path):
with open(path, "rb") as f:
blob = f.read()
return blob
return None


class TestEngineCache(TestCase):

def test_dynamo_compile(self):
model = models.resnet18(pretrained=True).eval().to("cuda")
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
# Mark the dim0 of inputs as dynamic
batch = torch.export.Dim("batch", min=1, max=200)
exp_program = torch.export.export(
model, args=example_inputs, dynamic_shapes={"x": {0: batch}}
)
engine_cache_dir = ENGINE_CACHE_DIR
if os.path.exists(engine_cache_dir):
shutil.rmtree(engine_cache_dir)
# The 1st iteration is to measure the compilation time without engine caching
# The 2nd and 3rd iterations are to measure the compilation time with engine caching.
# Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
# The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.
inputs = [torch.rand((128, 3, 224, 224)).to("cuda")]
results = []
for i in range(3):
if i == 0:
cache_built_engines = False
reuse_cached_engines = False
else:
cache_built_engines = True
reuse_cached_engines = True

trt_gm = torch_trt.dynamo.compile(
exp_program,
tuple(inputs),
use_python_runtime=False,
enabled_precisions={torch.float},
debug=False,
min_block_size=1,
make_refitable=True,
cache_built_engines=cache_built_engines,
reuse_cached_engines=reuse_cached_engines,
engine_cache_size=1 << 30, # 1GB
)
results.append(trt_gm(*inputs))

cos_sim = cosine_similarity(results[0], results[1])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_dynamo_compile TRT without engine caching doesn't match with that with engine caching. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

cos_sim = cosine_similarity(results[1], results[2])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_dynamo_compile TRT with engine caching doesn't match with that cached engine. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

def test_torch_compile(self):
# Custom Engine Cache
model = models.resnet18(pretrained=True).eval().to("cuda")

engine_cache_dir = "/tmp/your_dir"
if os.path.exists(engine_cache_dir):
shutil.rmtree(engine_cache_dir)

engine_cache = MyEngineCache(engine_cache_dir)
# The 1st iteration is to measure the compilation time without engine caching
# The 2nd and 3rd iterations are to measure the compilation time with engine caching.
# Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
# The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.
inputs = [torch.rand((100, 3, 224, 224)).to("cuda")]
results = []
for i in range(3):
# remove timing cache and reset dynamo for engine caching messurement
if i == 0:
cache_built_engines = False
reuse_cached_engines = False
else:
cache_built_engines = True
reuse_cached_engines = True

compiled_model = torch.compile(
model,
backend="tensorrt",
options={
"use_python_runtime": True,
"enabled_precisions": {torch.float},
"debug": False,
"min_block_size": 1,
"make_refitable": True,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"custom_engine_cache": engine_cache, # use custom engine cache
},
)
results.append(compiled_model(*inputs)) # trigger the compilation

cos_sim = cosine_similarity(results[0], results[1])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_torch_compile TRT without engine caching doesn't match with that with engine caching. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

cos_sim = cosine_similarity(results[1], results[2])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_torch_compile TRT with engine caching doesn't match with that cached engine. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

0 comments on commit 2562e2c

Please sign in to comment.