Skip to content

Commit

Permalink
avoid adding nn.SyncBatchNorm when torch version < 1.1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
yaoyaoding committed Jun 3, 2020
1 parent 0fece23 commit 958f7a4
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion thop/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def prYellow(skk): print("\033[93m{}\033[00m".format(skk))
nn.BatchNorm1d: count_bn,
nn.BatchNorm2d: count_bn,
nn.BatchNorm3d: count_bn,
nn.SyncBatchNorm: count_bn,

nn.ReLU: zero_ops,
nn.ReLU6: zero_ops,
Expand Down Expand Up @@ -70,6 +69,10 @@ def prYellow(skk): print("\033[93m{}\033[00m".format(skk))
nn.LSTM: count_lstm,
}

if LooseVersion(torch.__version__) >= LooseVersion("1.1.0"):
register_hooks.update({
nn.SyncBatchNorm: count_bn
})

def profile_origin(model, inputs, custom_ops=None, verbose=True):
handler_collection = []
Expand Down

0 comments on commit 958f7a4

Please sign in to comment.