diff --git a/thop/vision/basic_hooks.py b/thop/vision/basic_hooks.py index d99a641..fb864b6 100644 --- a/thop/vision/basic_hooks.py +++ b/thop/vision/basic_hooks.py @@ -64,7 +64,7 @@ def count_normalization(m: nn.modules.batchnorm._BatchNorm, x, y): x = x[0] # bn is by default fused in inference flops = calculate_norm(x.numel()) - if m.affine: + if (getattr(m, 'affine', False) or getattr(m, 'elementwise_affine', False)): flops *= 2 m.total_ops += flops