Skip to content

Commit

Permalink
Group norm bug fix (#3014)
Browse files Browse the repository at this point in the history
  • Loading branch information
cehongwang authored Jul 26, 2024
1 parent abf3370 commit 6ac2ec8
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 23 deletions.
151 changes: 132 additions & 19 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ def native_group_norm(
eps: float,
return_mean_rstd: bool = True,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
# TODO: Ask TRT team about the usage of INormalization Layer usage with num_groups and update the implementation
# with INormalization Layer
assert (
len(input.shape) >= 3
), f"The input dimension should not be less than 3, got {len(input.shape)}!"
Expand Down Expand Up @@ -187,28 +189,105 @@ def native_group_norm(
shape,
)

if weight is None:
weight = to_numpy(1.0)

if bias is None:
bias = to_numpy(0.0)

weight = get_trt_tensor(ctx, weight, f"{name}_weight")
bias = get_trt_tensor(ctx, bias, f"{name}_bias")
if tuple(reshaped_input.shape) != tuple(weight.shape):
weight = impl.slice.expand(
ctx,
target,
source_ir,
f"{name}_expand_weight",
weight,
reshaped_input.shape,
)
if tuple(reshaped_input.shape) != tuple(bias.shape):
bias = impl.slice.expand(
ctx, target, source_ir, f"{name}_expand_bias", bias, reshaped_input.shape
)
weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2)

dims = list(range(1, len(input.shape)))
axes = get_axes_for_reduce_op(dims)
group_norm = ctx.net.add_normalization(reshaped_input, weight, bias, axes)
group_norm.epsilon = eps
group_norm.compute_precision = input.dtype
set_layer_name(group_norm, target, f"{name}_group_norm", source_ir)
output = group_norm.get_output(0)

# E[X]
mean_trt = impl.reduce.mean(
ctx,
target,
source_ir,
f"{name}_mean",
reshaped_input,
dims,
True,
)

mean_trt = impl.slice.expand(
ctx,
target,
source_ir,
f"{name}_expand_mean_trt",
mean_trt,
reshaped_input.shape,
)

# X - E[X]
sub_trt = impl.elementwise.sub(
ctx,
target,
source_ir,
f"{name}_sub",
reshaped_input,
mean_trt,
)

# variance
pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32)
pow_var = impl.elementwise.pow(
ctx,
target,
source_ir,
f"{name}_pow",
sub_trt,
pow_trt,
)

var_trt = impl.reduce.mean(
ctx,
target,
source_ir,
f"{name}_mean_var",
pow_var,
dims,
True,
)

var_trt = impl.slice.expand(
ctx,
target,
source_ir,
f"{name}_expand_var_trt",
var_trt,
reshaped_input.shape,
)

eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32)
add_trt = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_add",
var_trt,
eps_trt,
)

sqrt_trt = impl.unary.sqrt(
ctx,
target,
source_ir,
f"{name}_sqrt",
add_trt,
)

# y = (X - E[X]) / sqrt((var + eps))
output = impl.elementwise.div(
ctx,
target,
source_ir,
f"{name}_div",
sub_trt,
sqrt_trt,
)

shape = list(output.shape)
for i, s in enumerate(shape):
Expand All @@ -222,6 +301,40 @@ def native_group_norm(
reshaped_output = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_output", output, shape
)
reshaped_gamma = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_gamma",
weight,
weight_bias_shape,
)

reshaped_output = impl.elementwise.mul(
ctx,
target,
source_ir,
f"{name}_mul_gamma",
reshaped_output,
reshaped_gamma,
)

reshaped_bias = impl.shuffle.reshape(
ctx,
target,
source_ir,
f"{name}_reshape_beta",
bias,
weight_bias_shape,
)
reshaped_output = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_add_beta",
reshaped_output,
reshaped_bias,
)
if return_mean_rstd:
# return fake mean and rstd for now
return reshaped_output, None, None
Expand Down
29 changes: 25 additions & 4 deletions tests/py/dynamo/conversion/test_group_norm_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def forward(self, x):
return torch.ops.aten.group_norm.default(
x,
2,
torch.ones((6,)),
torch.zeros((6,)),
torch.randn((6,)),
torch.randn((6,)),
1e-05,
True,
)
Expand All @@ -50,8 +50,8 @@ def forward(self, x):
return torch.ops.aten.group_norm.default(
x,
2,
torch.ones((6,)),
torch.zeros((6,)),
torch.randn((6,)),
torch.randn((6,)),
1e-05,
True,
)
Expand Down Expand Up @@ -112,6 +112,27 @@ def forward(self, x):
inputs,
)

def test_groupnorm_sd(self):
class GroupNorm(torch.nn.Module):
def forward(self, x):
return torch.ops.aten.native_group_norm.default(
x,
torch.randn((320,)).half(),
torch.randn((320,)).half(),
2,
320,
4096,
32,
1e-05,
)[0]

inputs = [torch.randn(2, 320, 64, 64).half()]
with torch.no_grad():
self.run_test(
GroupNorm(),
inputs,
)

@parameterized.expand(
[
(5, 4, 4, 2, (2, 4, 2), (3, 4, 2), (5, 4, 4)),
Expand Down

0 comments on commit 6ac2ec8

Please sign in to comment.