Skip to content

Commit

Permalink
Merge pull request Lyken17#188 from Lyken17/revert-187-master
Browse files Browse the repository at this point in the history
Revert "FLOPs typo"
  • Loading branch information
Lyken17 authored Aug 17, 2022
2 parents 9e74475 + bcc777a commit 54e60ed
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 42 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

OR

`pip install --upgrade git+https://github.com/lvmingzhe/pytorch-OpCounter.git`
`pip install --upgrade git+https://github.com/Lyken17/pytorch-OpCounter.git`

## How to use
* Basic usage
Expand Down
2 changes: 1 addition & 1 deletion benchmark/evaluate_famous_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
and callable(models.__dict__[name])
)

print("%s | %s | %s" % ("Model", "Params(M)", "MACs(G)"))
print("%s | %s | %s" % ("Model", "Params(M)", "FLOPs(G)"))
print("---|---|---")

device = "cpu"
Expand Down
2 changes: 1 addition & 1 deletion benchmark/evaluate_rnn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
),
}

print("{} | {} | {}".format("Model", "Params(M)", "MACs(G)"))
print("{} | {} | {}".format("Model", "Params(M)", "FLOPs(G)"))
print("---|---|---")

for name, model in models.items():
Expand Down
42 changes: 3 additions & 39 deletions thop/vision/basic_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,11 @@ def count_parameters(m, x, y):
total_params = 0
for p in m.parameters():
total_params += torch.DoubleTensor([p.numel()])
# m.total_params[0] = calculate_parameters(m.parameters())
try:
if m.total_params:
m.total_params[0] = calculate_parameters(m.parameters())
except:
logging.warning('no m.total_params[0]')
m.total_params[0] = calculate_parameters(m.parameters())


def zero_ops(m, x, y):
try:
if m.total_ops:
m.total_ops += calculate_zero_ops()
except:
logging.warning('no m.total_ops zero_ops')
# m.total_ops += calculate_zero_ops()
m.total_ops += calculate_zero_ops()


def count_convNd(m: _ConvNd, x, y: torch.Tensor):
Expand Down Expand Up @@ -67,17 +57,6 @@ def count_convNd_ver2(m: _ConvNd, x, y: torch.Tensor):
m.total_ops += calculate_conv(m.bias.nelement(), m.weight.nelement(), output_size)


# def count_normalization(m: nn.modules.batchnorm._BatchNorm, x, y):
# # TODO: add test cases
# # https://github.com/Lyken17/pytorch-OpCounter/issues/124
# # y = (x - mean) / sqrt(eps + var) * weight + bias
# x = x[0]
# # bn is by default fused in inference
# flops = calculate_norm(x.numel())
# if m.affine:
# flops *= 2
# m.total_ops += flops

def count_normalization(m: nn.modules.batchnorm._BatchNorm, x, y):
# TODO: add test cases
# https://github.com/Lyken17/pytorch-OpCounter/issues/124
Expand All @@ -87,22 +66,7 @@ def count_normalization(m: nn.modules.batchnorm._BatchNorm, x, y):
flops = calculate_norm(x.numel())
if m.affine:
flops *= 2
try:
if m.total_ops:
m.total_ops += flops
except:
logging.warning('no m.total_ops')

# def count_normalization(m: nn.modules.instancenorm._InstanceNorm, x, y):
# # TODO: add test cases
# # https://github.com/Lyken17/pytorch-OpCounter/issues/124
# # y = (x - mean) / sqrt(eps + var) * weight + bias
# x = x[0]
# # bn is by default fused in inference
# flops = calculate_norm(x.numel())
# if m.affine:
# flops *= 2
# m.total_ops += flops
m.total_ops += flops


# def count_layer_norm(m, x, y):
Expand Down

0 comments on commit 54e60ed

Please sign in to comment.