From 958f7a4f5140a6e05ee184672b89e3c501b12140 Mon Sep 17 00:00:00 2001 From: Yaoyao Ding Date: Wed, 3 Jun 2020 17:40:08 +0800 Subject: [PATCH] avoid adding nn.SyncBatchNorm when torch version < 1.1.0 --- thop/profile.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/thop/profile.py b/thop/profile.py index 8247d05..4b98364 100644 --- a/thop/profile.py +++ b/thop/profile.py @@ -35,7 +35,6 @@ def prYellow(skk): print("\033[93m{}\033[00m".format(skk)) nn.BatchNorm1d: count_bn, nn.BatchNorm2d: count_bn, nn.BatchNorm3d: count_bn, - nn.SyncBatchNorm: count_bn, nn.ReLU: zero_ops, nn.ReLU6: zero_ops, @@ -70,6 +69,10 @@ def prYellow(skk): print("\033[93m{}\033[00m".format(skk)) nn.LSTM: count_lstm, } +if LooseVersion(torch.__version__) >= LooseVersion("1.1.0"): + register_hooks.update({ + nn.SyncBatchNorm: count_bn + }) def profile_origin(model, inputs, custom_ops=None, verbose=True): handler_collection = []