Skip to content

Commit

Permalink
Refactor code (#18)
Browse files Browse the repository at this point in the history
Co-authored-by: UltralyticsAssistant <[email protected]>
  • Loading branch information
glenn-jocher and UltralyticsAssistant committed Jun 1, 2024
1 parent f194f5f commit 296d04b
Show file tree
Hide file tree
Showing 15 changed files with 58 additions and 553 deletions.
9 changes: 5 additions & 4 deletions benchmark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,17 @@ However, the application in real world is far more complex. Let's consider a mat
```python
for i in range(m):
for j in range(n):
C[i][j] += A[i][j] * B[j] # one mul-add
C[i][j] += A[i][j] * B[j] # one mul-add
```

It would be `mn` `MACs` and `2mn` `FLOPs`. But such implementation is slow and parallelization is necessary to run faster

```python
for i in range(m):
parallelfor j in range(n):
d[j] = A[i][j] * B[j] # one mul
C[i][j] = sum(d) # n adds
parallelfor
j in range(n):
d[j] = A[i][j] * B[j] # one mul
C[i][j] = sum(d) # n adds
```

Then the number of `MACs` is no longer `mn` .
Expand Down
8 changes: 3 additions & 5 deletions tests/test_conv2d.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import pytest
import torch
import torch.nn as nn
from jinja2 import StrictUndefined

from thop import profile

Expand All @@ -22,7 +20,7 @@ def test_conv2d_no_bias(self):
_, _, oh, ow = out.shape

flops, params = profile(net, inputs=(data,))
assert flops == 810000, f"{flops} v.s. {810000}"
assert flops == 810000, f"{flops} v.s. 810000"

def test_conv2d(self):
"""Tests a Conv2D layer with specific input dimensions, kernel size, stride, padding, dilation, and groups."""
Expand All @@ -37,11 +35,11 @@ def test_conv2d(self):
_, _, oh, ow = out.shape

flops, params = profile(net, inputs=(data,))
assert flops == 810000, f"{flops} v.s. {810000}"
assert flops == 810000, f"{flops} v.s. 810000"

def test_conv2d_random(self):
"""Test Conv2D layer with random parameters and validate the computed FLOPs and parameters using 'profile'."""
for i in range(10):
for _ in range(10):
out_c, kh, kw = torch.randint(1, 20, (3,)).tolist()
n, in_c, ih, iw = torch.randint(1, 20, (4,)).tolist() # torch.randint(1, 10, (4,)).tolist()
ih += kh
Expand Down
1 change: 0 additions & 1 deletion tests/test_matmul.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
import torch
import torch.nn as nn

Expand Down
1 change: 0 additions & 1 deletion tests/test_relu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
import torch
import torch.nn as nn

Expand Down
2 changes: 0 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import pytest

from thop import utils


Expand Down
2 changes: 1 addition & 1 deletion thop/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.2.4"
__version__ = "0.2.5"

import torch

Expand Down
8 changes: 3 additions & 5 deletions thop/fx_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ def count_nn_bn2d(module: nn.BatchNorm2d, input_shapes, output_shapes):
"""Calculate the total operations for a given nn.BatchNorm2d module based on its output shape."""
assert len(output_shapes) == 1, "nn.BatchNorm2d should only have one output"
y = output_shapes[0]
# y = (x - mean) / \sqrt{var + e} * weight + bias
total_ops = 2 * y.numel()
return total_ops
return 2 * y.numel()


zero_ops = (
Expand Down Expand Up @@ -120,7 +118,7 @@ def count_nn_bn2d(module: nn.BatchNorm2d, input_shapes, output_shapes):
from torch.fx import symbolic_trace
from torch.fx.passes.shape_prop import ShapeProp

from .utils import prGreen, prRed, prYellow
from .utils import prRed, prYellow


def null_print(*args, **kwargs):
Expand Down Expand Up @@ -193,7 +191,7 @@ def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False):
prRed(f"{key} is missing")
print("module type:", type(m))
if isinstance(m, zero_ops):
print(f"weight_shape: None")
print("weight_shape: None")
else:
print(type(m))
print(f"weight_shape: {mod.state_dict()[node.target + '.weight'].shape}")
Expand Down
82 changes: 0 additions & 82 deletions thop/onnx_profile.py

This file was deleted.

10 changes: 4 additions & 6 deletions thop/profile.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from thop.rnn_hooks import *
from thop.vision.basic_hooks import *

# logger = logging.getLogger(__name__)
# logger.setLevel(logging.INFO)
from .utils import prGreen, prRed, prYellow
from .utils import prRed

default_dtype = torch.float64

Expand Down Expand Up @@ -68,7 +66,7 @@ def profile_origin(model, inputs, custom_ops=None, verbose=True, report_missing=
verbose = True

def add_hooks(m):
if len(list(m.children())) > 0:
if list(m.children()):
return

if hasattr(m, "total_ops") or hasattr(m, "total_params"):
Expand Down Expand Up @@ -114,7 +112,7 @@ def add_hooks(m):
total_ops = 0
total_params = 0
for m in model.modules():
if len(list(m.children())) > 0: # skip for non-leaf module
if list(m.children()): # skip for non-leaf module
continue
total_ops += m.total_ops
total_params += m.total_params
Expand All @@ -129,7 +127,7 @@ def add_hooks(m):

# remove temporal buffers
for n, m in model.named_modules():
if len(list(m.children())) > 0:
if list(m.children()):
continue
if "total_ops" in m._buffers:
m._buffers.pop("total_ops")
Expand Down
69 changes: 33 additions & 36 deletions thop/rnn_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,26 +103,25 @@ def count_rnn(m: nn.RNN, x, y):
if isinstance(x[0], PackedSequence):
batch_size = torch.max(x[0].batch_sizes)
num_steps = x[0].batch_sizes.size(0)
elif m.batch_first:
batch_size = x[0].size(0)
num_steps = x[0].size(1)
else:
if m.batch_first:
batch_size = x[0].size(0)
num_steps = x[0].size(1)
else:
batch_size = x[0].size(1)
num_steps = x[0].size(0)
batch_size = x[0].size(1)
num_steps = x[0].size(0)

total_ops = 0
if m.bidirectional:
total_ops += _count_rnn_cell(input_size, hidden_size, bias) * 2
else:
total_ops += _count_rnn_cell(input_size, hidden_size, bias)

for i in range(num_layers - 1):
if m.bidirectional:
total_ops += _count_rnn_cell(hidden_size * 2, hidden_size, bias) * 2
else:
total_ops += _count_rnn_cell(hidden_size, hidden_size, bias)

for _ in range(num_layers - 1):
total_ops += (
_count_rnn_cell(hidden_size * 2, hidden_size, bias) * 2
if m.bidirectional
else _count_rnn_cell(hidden_size, hidden_size, bias)
)
# time unroll
total_ops *= num_steps
# batch_size
Expand All @@ -141,26 +140,25 @@ def count_gru(m: nn.GRU, x, y):
if isinstance(x[0], PackedSequence):
batch_size = torch.max(x[0].batch_sizes)
num_steps = x[0].batch_sizes.size(0)
elif m.batch_first:
batch_size = x[0].size(0)
num_steps = x[0].size(1)
else:
if m.batch_first:
batch_size = x[0].size(0)
num_steps = x[0].size(1)
else:
batch_size = x[0].size(1)
num_steps = x[0].size(0)
batch_size = x[0].size(1)
num_steps = x[0].size(0)

total_ops = 0
if m.bidirectional:
total_ops += _count_gru_cell(input_size, hidden_size, bias) * 2
else:
total_ops += _count_gru_cell(input_size, hidden_size, bias)

for i in range(num_layers - 1):
if m.bidirectional:
total_ops += _count_gru_cell(hidden_size * 2, hidden_size, bias) * 2
else:
total_ops += _count_gru_cell(hidden_size, hidden_size, bias)

for _ in range(num_layers - 1):
total_ops += (
_count_gru_cell(hidden_size * 2, hidden_size, bias) * 2
if m.bidirectional
else _count_gru_cell(hidden_size, hidden_size, bias)
)
# time unroll
total_ops *= num_steps
# batch_size
Expand All @@ -181,26 +179,25 @@ def count_lstm(m: nn.LSTM, x, y):
if isinstance(x[0], PackedSequence):
batch_size = torch.max(x[0].batch_sizes)
num_steps = x[0].batch_sizes.size(0)
elif m.batch_first:
batch_size = x[0].size(0)
num_steps = x[0].size(1)
else:
if m.batch_first:
batch_size = x[0].size(0)
num_steps = x[0].size(1)
else:
batch_size = x[0].size(1)
num_steps = x[0].size(0)
batch_size = x[0].size(1)
num_steps = x[0].size(0)

total_ops = 0
if m.bidirectional:
total_ops += _count_lstm_cell(input_size, hidden_size, bias) * 2
else:
total_ops += _count_lstm_cell(input_size, hidden_size, bias)

for i in range(num_layers - 1):
if m.bidirectional:
total_ops += _count_lstm_cell(hidden_size * 2, hidden_size, bias) * 2
else:
total_ops += _count_lstm_cell(hidden_size, hidden_size, bias)

for _ in range(num_layers - 1):
total_ops += (
_count_lstm_cell(hidden_size * 2, hidden_size, bias) * 2
if m.bidirectional
else _count_lstm_cell(hidden_size, hidden_size, bias)
)
# time unroll
total_ops *= num_steps
# batch_size
Expand Down
13 changes: 1 addition & 12 deletions thop/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,6 @@ def actual_call(*args, **kwargs):
prGreen = colorful_print(print, color=COLOR_GREEN)
prYellow = colorful_print(print, color=COLOR_YELLOW)

# def prRed(skk):
# print("\033[91m{}\033[00m".format(skk))

# def prGreen(skk):
# print("\033[92m{}\033[00m".format(skk))

# def prYellow(skk):
# print("\033[93m{}\033[00m".format(skk))


def clever_format(nums, format="%.2f"):
"""Formats numerical values into a more readable string with units (K, M, G, T) based on their magnitude."""
Expand All @@ -48,9 +39,7 @@ def clever_format(nums, format="%.2f"):
else:
clever_nums.append(format % num + "B")

clever_nums = clever_nums[0] if len(clever_nums) == 1 else (*clever_nums,)

return clever_nums
return clever_nums[0] if len(clever_nums) == 1 else (*clever_nums,)


if __name__ == "__main__":
Expand Down
4 changes: 1 addition & 3 deletions thop/vision/basic_hooks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import argparse
import logging

import torch
import torch.nn as nn
from torch.nn.modules.conv import _ConvNd

Expand Down Expand Up @@ -139,7 +137,7 @@ def count_upsample(m, x, y):
"bilinear",
"bicubic",
): # "trilinear"
logging.warning("mode %s is not implemented yet, take it a zero op" % m.mode)
logging.warning(f"mode {m.mode} is not implemented yet, take it a zero op")
m.total_ops += 0
else:
x = x[0]
Expand Down
Loading

0 comments on commit 296d04b

Please sign in to comment.