Skip to content

Commit

Permalink
Refit bug fix (#3097)
Browse files Browse the repository at this point in the history
Co-authored-by: Dheeraj Peri <[email protected]>
  • Loading branch information
cehongwang and peri044 authored Aug 28, 2024
1 parent 6180836 commit 60ec67b
Show file tree
Hide file tree
Showing 14 changed files with 231 additions and 172 deletions.
4 changes: 2 additions & 2 deletions examples/dynamo/mutable_torchtrt_module_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"make_refitable": True,
}

model = models.resnet18(pretrained=False).eval().to("cuda")
model = models.resnet18(pretrained=True).eval().to("cuda")
mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings)
# You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module.
mutable_module(*inputs)
Expand All @@ -45,7 +45,7 @@

# %%
# Making changes to mutable module can trigger refit or re-compilation. For example, loading a different state_dict and setting new weight values will trigger refit, and adding a module to the model will trigger re-compilation.
model2 = models.resnet18(pretrained=True).eval().to("cuda")
model2 = models.resnet18(pretrained=False).eval().to("cuda")
mutable_module.load_state_dict(model2.state_dict())


Expand Down
4 changes: 2 additions & 2 deletions examples/dynamo/refit_engine_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
# Compile the module for the first time and save it.
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

model = models.resnet18(pretrained=False).eval().to("cuda")
model = models.resnet18(pretrained=True).eval().to("cuda")
exp_program = torch.export.export(model, tuple(inputs))
enabled_precisions = {torch.float}
debug = False
Expand Down Expand Up @@ -68,7 +68,7 @@
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Create and compile the updated model
model2 = models.resnet18(pretrained=True).eval().to("cuda")
model2 = models.resnet18(pretrained=False).eval().to("cuda")
exp_program2 = torch.export.export(model2, tuple(inputs))


Expand Down
6 changes: 3 additions & 3 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def compile(
disable_tf32: bool = _defaults.DISABLE_TF32,
assume_dynamic_shape_support: bool = _defaults.ASSUME_DYNAMIC_SHAPE_SUPPORT,
sparse_weights: bool = _defaults.SPARSE_WEIGHTS,
enabled_precisions: (
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
) = _defaults.ENABLED_PRECISIONS,
enabled_precisions: Union[
Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]]
] = _defaults.ENABLED_PRECISIONS,
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
make_refitable: bool = _defaults.MAKE_REFITABLE,
debug: bool = _defaults.DEBUG,
Expand Down
47 changes: 30 additions & 17 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
TorchTensorRTModule,
)
from torch_tensorrt.dynamo.utils import (
check_output,
check_module_output,
get_torch_inputs,
set_log_level,
to_torch_device,
Expand Down Expand Up @@ -115,19 +115,8 @@ def construct_refit_mapping_from_weight_name_map(
for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items():
trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType)
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)
if engine_weight_name.split(" ")[-1] in ["SCALE", "SHIFT"]:
# Batch Norm Layer
params = {}
for w in sd_weight_name:
params[w.split(".")[-1]] = state_dict[w]
scale = params["weight"] / torch.sqrt(params["running_var"] + 1e-7)
shift = params["bias"] - params["running_mean"] * scale
# Set scale to scale or shift to shift
engine_weight_map[engine_weight_name] = eval(
engine_weight_name.split(" ")[-1].lower()
)

elif sd_weight_name not in state_dict:
if sd_weight_name not in state_dict:
# If weights is not in sd, we can leave it unchanged
continue
else:
Expand Down Expand Up @@ -157,16 +146,25 @@ def _refit_single_trt_engine_with_gm(
"""

refitted = set()

torch_device = list(new_gm.state_dict().values())[0].device.type
refitter = trt.Refitter(old_engine, TRT_LOGGER)
weight_list = refitter.get_all_weights()

if weight_name_map:
# Get the refitting mapping
trt_wt_location = trt.TensorLocation.DEVICE
trt_wt_location = (
trt.TensorLocation.DEVICE
if torch_device == "cuda"
else trt.TensorLocation.HOST
)
mapping = construct_refit_mapping_from_weight_name_map(
weight_name_map, new_gm.state_dict()
)

# Debug Use
# correct = construct_refit_mapping(new_gm, input_list, settings)
# comparison = {k: (np.allclose(correct[k][0], mapping[k][0].cpu().numpy(), 1e-2, 1e-2), correct[k][0], mapping[k][0]) for k in mapping if k in correct}

for layer_name in weight_list:
if layer_name not in mapping:
logger.warning(f"{layer_name} is not found in weight mapping.")
Expand Down Expand Up @@ -235,7 +233,7 @@ def refit_module_weights(
compiled_module = copy.deepcopy(compiled_module)
elif inline_module:
raise AssertionError(
"Exported program does not support modifying in place. Please set inplace to false and use the returned graph module."
"Exported program does not support modifying in place. Please set in_place to false and use the returned graph module."
)

# Get the settings and check the setting to be uniform
Expand Down Expand Up @@ -283,6 +281,7 @@ def refit_module_weights(
arg_inputs = [arg_inputs]
torch_inputs = get_torch_inputs(arg_inputs, device)

torch_kwarg_inputs: Any = {}
if kwarg_inputs:
torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device)
runtime = trt.Runtime(TRT_LOGGER)
Expand Down Expand Up @@ -436,6 +435,7 @@ def refit_module_weights(
settings=settings,
weight_name_map=weight_name_map,
)

except AssertionError as e:
# If fast_refit is used and failed, we fall back to regular refit
logger.warning(e)
Expand Down Expand Up @@ -463,14 +463,27 @@ def refit_module_weights(
setattr(compiled_module, f"{name}_engine", refitted_engine)

if verify_output and arg_inputs is not None:
if check_output(
if check_module_output(
new_module=new_gm,
refitted_module=compiled_module,
arg_inputs=torch_inputs,
kwarg_inputs=torch_kwarg_inputs,
):
logger.info("Refitting Succeed!")
else:
if weight_name_map:
logger.warning(
"Refitting with weight_name_map yielded incorrect result! The outputs do not match."
)
return refit_module_weights(
compiled_module,
new_weight_module,
arg_inputs,
kwarg_inputs,
verify_output,
use_weight_map_cache=False,
in_place=in_place,
)
logger.error("Refitting Failed! The outputs do not match.")
else:
logger.info("Refitting Completed! Output verification skipped.")
Expand Down
85 changes: 64 additions & 21 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
import gc
import io
import logging
import os
import warnings
from datetime import datetime
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple
from typing import (
Any,
Callable,
Dict,
List,
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
Union,
)

import numpy as np
import torch
Expand All @@ -26,7 +38,7 @@
get_node_name,
get_trt_tensor,
)
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, to_torch_device
from torch_tensorrt.fx.observer import Observer
from torch_tensorrt.logging import TRT_LOGGER

Expand Down Expand Up @@ -327,6 +339,39 @@ def _construct_trt_network_def(self) -> None:
f"TRT INetwork construction elapsed time: {datetime.now() - run_module_start_time}"
)

@staticmethod
def find_weight(
weight_name: str, np_map: dict[str, Any], state_dict: dict[str, Any]
) -> str:
"""
We need to build map from engine weight name to state_dict weight name.
The purpose of this function is to find the corresponding weight name in module state_dict.
weight_name: the target weight name we want to search for
np_map: the map from weight name to np values in INetworkDefinition
state_dict: state of the graph module
"""
network_weight = np_map[weight_name]
network_weight = torch.from_numpy(np_map[weight_name]).cuda()
for sd_w_name, sd_weight in state_dict.items():
if TRTInterpreter.check_weight_equal(sd_weight, network_weight):
del state_dict[sd_w_name]
return sd_w_name
return ""

@staticmethod
def check_weight_equal(
sd_weight: torch.tensor, network_weight: Union[torch.Tensor, np.ndarray]
) -> Any:
if not isinstance(network_weight, torch.Tensor):
network_weight = torch.from_numpy(network_weight).cuda()
try:
return sd_weight.shape == network_weight.shape and torch.all(
torch.abs(sd_weight - network_weight) < 0.01
)
except Exception:
return torch.all(sd_weight == network_weight)

def _save_weight_mapping(self) -> None:
"""
Construct the weight name mapping from engine weight name to state_dict weight name.
Expand All @@ -336,23 +381,6 @@ def _save_weight_mapping(self) -> None:
2. Value mapping that, for each weight in INetworkDefinition search for identical weight in state_dict
"""

def find_weight(
weight_name: str, np_map: dict[str, Any], sd: dict[str, Any]
) -> str:
network_weight = np_map[weight_name]
for sd_w_name, sd_weight in sd.items():
if check_weight_equal(sd_weight, network_weight):
return sd_w_name
return ""

def check_weight_equal(
sd_weight: torch.tensor, network_weight: np.ndarray
) -> Any:
sd_weight = sd_weight.reshape(-1).cpu().numpy()
return sd_weight.size == network_weight.size and np.allclose(
sd_weight, network_weight, 1e-1, 1e-1
)

MODULE_MAP = {
"SCALE": (
trt.IScaleLayer,
Expand Down Expand Up @@ -398,8 +426,19 @@ def check_weight_equal(
)
}
"""
_LOGGER.info("Building weight name mapping...")
# Stage 1: Name mapping
sd = self.module.state_dict()
torch_device = to_torch_device(self.compilation_settings.device)
gm_is_on_cuda = list(sd.values())[0].device.type == "cuda"
if not gm_is_on_cuda:
# If the model original position is on CPU, move it GPU
sd = {
k: v.reshape(-1).to(torch_device)
for k, v in self.module.state_dict().items()
}
else:
sd = {k: v.reshape(-1) for k, v in self.module.state_dict().items()}
weight_name_map: dict[str, Any] = {}
np_map = {}
net = self.ctx.net
Expand Down Expand Up @@ -448,10 +487,10 @@ def check_weight_equal(
if "SCALE" in engine_weight_name:
# There is no direct connection in batch_norm layer. So skip it
pass
elif sd_weight_name not in sd or not check_weight_equal(
elif sd_weight_name not in sd or not TRTInterpreter.check_weight_equal(
sd[sd_weight_name], np_map[engine_weight_name]
):
weight_name_map[engine_weight_name] = find_weight(
weight_name_map[engine_weight_name] = TRTInterpreter.find_weight(
engine_weight_name, np_map, sd
)

Expand All @@ -462,6 +501,10 @@ def check_weight_equal(

self.weight_name_map = weight_name_map

del np_map, sd
gc.collect()
torch.cuda.empty_cache()

def run(
self,
strict_type_constraints: bool = False,
Expand Down
13 changes: 8 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,13 @@ def convert_module(
from torch_tensorrt.dynamo._refit import _refit_single_trt_engine_with_gm
from torch_tensorrt.logging import TRT_LOGGER

runtime = trt.Runtime(TRT_LOGGER)
refit_test_engine = runtime.deserialize_cuda_engine(
interpreter_result.serialized_engine
)
weight_name_map: Any = None
# Do the test refit with cached map if make_refitable is enabled
if settings.make_refitable:
weight_name_map = interpreter_result.weight_name_map
runtime = trt.Runtime(TRT_LOGGER)
refit_test_engine = runtime.deserialize_cuda_engine(
interpreter_result.serialized_engine
)
try:
_refit_single_trt_engine_with_gm(
new_gm=module,
Expand All @@ -146,9 +145,13 @@ def convert_module(
settings=settings,
weight_name_map=interpreter_result.weight_name_map,
)
weight_name_map = interpreter_result.weight_name_map
except AssertionError:
logger.warning("Fast refit test failed. Removing the weight map caching.")

del refit_test_engine
torch.cuda.empty_cache()

rt_cls = PythonTorchTensorRTModule

if ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime:
Expand Down
Loading

0 comments on commit 60ec67b

Please sign in to comment.