Skip to content

Commit

Permalink
Fixed batchnorm bug
Browse files Browse the repository at this point in the history
  • Loading branch information
cehongwang committed Sep 20, 2024
1 parent e4e4d31 commit 414d972
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
14 changes: 10 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)

import numpy as np
import tensorrt as trt
import torch
import torch.fx
from torch.fx.node import _get_qualified_name
Expand All @@ -43,7 +44,6 @@
from torch_tensorrt.fx.observer import Observer
from torch_tensorrt.logging import TRT_LOGGER

import tensorrt as trt
from packaging import version

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -472,12 +472,18 @@ def _save_weight_mapping(self) -> None:
# Retrieve each weight name(s) in state_dict
if layer_type == "CONSTANT":
if "embedding" in suffix:
sd_weight_name = f"{sd_weight_name}.{torch_attr[0]}"
sd_weight_name = f"{sd_weight_name}.weight"
elif "weight" in suffix or "mm_other" in suffix:
# Linear layer weight
sd_weight_name = f"{sd_weight_name}.{torch_attr[0]}"
sd_weight_name = f"{sd_weight_name}.weight"
elif "running_mean" in suffix:
# Linear layer weight
sd_weight_name = f"{sd_weight_name}.running_mean"
elif "running_var" in suffix:
# Linear layer weight
sd_weight_name = f"{sd_weight_name}.running_var"
else:
sd_weight_name = f"{sd_weight_name}.{torch_attr[1]}"
sd_weight_name = f"{sd_weight_name}.bias"
elif layer_type == "SCALE":
# Batch norm needs all weights to calculate scale and shift
sd_weight_name = [f"{sd_weight_name}.{n}" for n in torch_attr]
Expand Down
29 changes: 21 additions & 8 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,27 @@ def batch_norm(
# Save the original output shape for later use
output_shape = input.shape

if weight is None:
weight = get_trt_tensor(ctx, 1.0, f"{name}_weight")
if bias is None:
bias = get_trt_tensor(ctx, 0.0, f"{name}_bias")
if running_mean is None:
running_mean = get_trt_tensor(ctx, 0.0, f"{name}_running_mean")
if running_var is None:
running_var = get_trt_tensor(ctx, 1.0, f"{name}_running_var")
# We name the weight here according to the state_dict name
weight = (
get_trt_tensor(ctx, 1.0, f"{name}_weight")
if weight is None
else get_trt_tensor(ctx, weight, f"{name}_weight")
)
bias = (
get_trt_tensor(ctx, 1.0, f"{name}_bias")
if bias is None
else get_trt_tensor(ctx, bias, f"{name}_bias")
)
running_mean = (
get_trt_tensor(ctx, 1.0, f"{name}_running_mean")
if running_mean is None
else get_trt_tensor(ctx, running_mean, f"{name}_running_mean")
)
running_var = (
get_trt_tensor(ctx, 1.0, f"{name}_running_var")
if running_var is None
else get_trt_tensor(ctx, running_var, f"{name}_running_var")
)

# eps_tensor for numerical stability
eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps")
Expand Down

0 comments on commit 414d972

Please sign in to comment.