diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index e46b6863a3..5c6943bf5b 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -149,6 +149,8 @@ def native_group_norm( eps: float, return_mean_rstd: bool = True, ) -> Union[TRTTensor, Sequence[TRTTensor]]: + # TODO: Ask TRT team about the usage of INormalization Layer usage with num_groups and update the implementation + # with INormalization Layer assert ( len(input.shape) >= 3 ), f"The input dimension should not be less than 3, got {len(input.shape)}!" @@ -187,28 +189,105 @@ def native_group_norm( shape, ) + if weight is None: + weight = to_numpy(1.0) + + if bias is None: + bias = to_numpy(0.0) + weight = get_trt_tensor(ctx, weight, f"{name}_weight") bias = get_trt_tensor(ctx, bias, f"{name}_bias") - if tuple(reshaped_input.shape) != tuple(weight.shape): - weight = impl.slice.expand( - ctx, - target, - source_ir, - f"{name}_expand_weight", - weight, - reshaped_input.shape, - ) - if tuple(reshaped_input.shape) != tuple(bias.shape): - bias = impl.slice.expand( - ctx, target, source_ir, f"{name}_expand_bias", bias, reshaped_input.shape - ) + weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2) + dims = list(range(1, len(input.shape))) - axes = get_axes_for_reduce_op(dims) - group_norm = ctx.net.add_normalization(reshaped_input, weight, bias, axes) - group_norm.epsilon = eps - group_norm.compute_precision = input.dtype - set_layer_name(group_norm, target, f"{name}_group_norm", source_ir) - output = group_norm.get_output(0) + + # E[X] + mean_trt = impl.reduce.mean( + ctx, + target, + source_ir, + f"{name}_mean", + reshaped_input, + dims, + True, + ) + + mean_trt = impl.slice.expand( + ctx, + target, + source_ir, + f"{name}_expand_mean_trt", + mean_trt, + reshaped_input.shape, + ) + + # X - E[X] + sub_trt = impl.elementwise.sub( + ctx, + target, + source_ir, + f"{name}_sub", + reshaped_input, + mean_trt, + ) + + # variance + pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32) + pow_var = impl.elementwise.pow( + ctx, + target, + source_ir, + f"{name}_pow", + sub_trt, + pow_trt, + ) + + var_trt = impl.reduce.mean( + ctx, + target, + source_ir, + f"{name}_mean_var", + pow_var, + dims, + True, + ) + + var_trt = impl.slice.expand( + ctx, + target, + source_ir, + f"{name}_expand_var_trt", + var_trt, + reshaped_input.shape, + ) + + eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32) + add_trt = impl.elementwise.add( + ctx, + target, + source_ir, + f"{name}_add", + var_trt, + eps_trt, + ) + + sqrt_trt = impl.unary.sqrt( + ctx, + target, + source_ir, + f"{name}_sqrt", + add_trt, + ) + + # y = (X - E[X]) / sqrt((var + eps)) + output = impl.elementwise.div( + ctx, + target, + source_ir, + f"{name}_div", + sub_trt, + sqrt_trt, + ) shape = list(output.shape) for i, s in enumerate(shape): @@ -222,6 +301,40 @@ def native_group_norm( reshaped_output = impl.shuffle.reshape( ctx, target, source_ir, f"{name}_reshape_output", output, shape ) + reshaped_gamma = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_gamma", + weight, + weight_bias_shape, + ) + + reshaped_output = impl.elementwise.mul( + ctx, + target, + source_ir, + f"{name}_mul_gamma", + reshaped_output, + reshaped_gamma, + ) + + reshaped_bias = impl.shuffle.reshape( + ctx, + target, + source_ir, + f"{name}_reshape_beta", + bias, + weight_bias_shape, + ) + reshaped_output = impl.elementwise.add( + ctx, + target, + source_ir, + f"{name}_add_beta", + reshaped_output, + reshaped_bias, + ) if return_mean_rstd: # return fake mean and rstd for now return reshaped_output, None, None diff --git a/tests/py/dynamo/conversion/test_group_norm_aten.py b/tests/py/dynamo/conversion/test_group_norm_aten.py index 6ac782cdc7..cf5dd48f0a 100644 --- a/tests/py/dynamo/conversion/test_group_norm_aten.py +++ b/tests/py/dynamo/conversion/test_group_norm_aten.py @@ -31,8 +31,8 @@ def forward(self, x): return torch.ops.aten.group_norm.default( x, 2, - torch.ones((6,)), - torch.zeros((6,)), + torch.randn((6,)), + torch.randn((6,)), 1e-05, True, ) @@ -50,8 +50,8 @@ def forward(self, x): return torch.ops.aten.group_norm.default( x, 2, - torch.ones((6,)), - torch.zeros((6,)), + torch.randn((6,)), + torch.randn((6,)), 1e-05, True, ) @@ -112,6 +112,27 @@ def forward(self, x): inputs, ) + def test_groupnorm_sd(self): + class GroupNorm(torch.nn.Module): + def forward(self, x): + return torch.ops.aten.native_group_norm.default( + x, + torch.randn((320,)).half(), + torch.randn((320,)).half(), + 2, + 320, + 4096, + 32, + 1e-05, + )[0] + + inputs = [torch.randn(2, 320, 64, 64).half()] + with torch.no_grad(): + self.run_test( + GroupNorm(), + inputs, + ) + @parameterized.expand( [ (5, 4, 4, 2, (2, 4, 2), (3, 4, 2), (5, 4, 4)),