Skip to content

Commit

Permalink
Add Python function docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed May 31, 2024
1 parent 19a158a commit f1a42d7
Show file tree
Hide file tree
Showing 14 changed files with 94 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
v_local = tuple(map(int, pyproject_version.split('.')))
# Compare with version on PyPI
v_pypi = (0, 0, 0) # tuple(map(int, check_latest_pypi_version('ultralytics-thop').split('.')))
v_pypi = tuple(map(int, check_latest_pypi_version('ultralytics-thop').split('.')))
print(f'Local version is {v_local}')
print(f'PyPI version is {v_pypi}')
d = [a - b for a, b in zip(v_local, v_pypi)] # diff
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "ultralytics-thop"
version = "0.0.1" # Placeholder version, needs to be dynamically set
version = "0.0.2" # Placeholder version, needs to be dynamically set
description = "A tool to count the FLOPs of PyTorch model."
readme = "README.md"
requires-python = ">=3.8"
Expand Down
3 changes: 3 additions & 0 deletions tests/test_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

class TestUtils:
def test_conv2d_no_bias(self):
"""Tests a 2D convolutional layer without bias using THOP profiling with predefined input dimensions and convolution parameters."""
n, in_c, ih, iw = 1, 3, 32, 32 # torch.randint(1, 10, (4,)).tolist()
out_c, kh, kw = 12, 5, 5
s, p, d, g = 1, 1, 1, 1
Expand All @@ -22,6 +23,7 @@ def test_conv2d_no_bias(self):
assert flops == 810000, f"{flops} v.s. {810000}"

def test_conv2d(self):
"""Tests a Conv2D layer with specific input dimensions, kernel size, stride, padding, dilation, and groups."""
n, in_c, ih, iw = 1, 3, 32, 32 # torch.randint(1, 10, (4,)).tolist()
out_c, kh, kw = 12, 5, 5
s, p, d, g = 1, 1, 1, 1
Expand All @@ -36,6 +38,7 @@ def test_conv2d(self):
assert flops == 810000, f"{flops} v.s. {810000}"

def test_conv2d_random(self):
"""Test Conv2D layer with random parameters and validate the computed FLOPs and parameters using 'profile'."""
for i in range(10):
out_c, kh, kw = torch.randint(1, 20, (3,)).tolist()
n, in_c, ih, iw = torch.randint(1, 20, (4,)).tolist() # torch.randint(1, 10, (4,)).tolist()
Expand Down
3 changes: 3 additions & 0 deletions tests/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@

class TestUtils:
def test_matmul_case2(self):
"""Test matrix multiplication case asserting the FLOPs and parameters of a nn.Linear layer."""
n, in_c, out_c = 1, 100, 200
net = nn.Linear(in_c, out_c)
flops, params = profile(net, inputs=(torch.randn(n, in_c),))
print(flops, params)
assert flops == n * in_c * out_c

def test_matmul_case2(self):
"""Tests matrix multiplication to assert FLOPs and parameters of nn.Linear layer using random dimensions."""
for i in range(10):
n, in_c, out_c = torch.randint(1, 500, (3,)).tolist()
net = nn.Linear(in_c, out_c)
Expand All @@ -22,6 +24,7 @@ def test_matmul_case2(self):
assert flops == n * in_c * out_c

def test_conv2d(self):
"""Tests the number of FLOPs and parameters for a randomly initialized nn.Linear layer using torch.profiler."""
n, in_c, out_c = torch.randint(1, 500, (3,)).tolist()
net = nn.Linear(in_c, out_c)
flops, params = profile(net, inputs=(torch.randn(n, in_c),))
Expand Down
1 change: 1 addition & 0 deletions tests/test_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

class TestUtils:
def test_relu(self):
"""Tests the ReLU activation function to ensure it has zero FLOPs and checks parameter count using THOP profiling."""
n, in_c, out_c = 1, 100, 200
data = torch.randn(n, in_c)
net = nn.ReLU()
Expand Down
2 changes: 2 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@

class TestUtils:
def test_clever_format_returns_formatted_number(self):
"""Tests that the clever_format function returns a formatted number string with a '1.00B' pattern."""
nums = 1
format = "%.2f"
clever_nums = utils.clever_format(nums, format)
assert clever_nums == "1.00B"

def test_clever_format_returns_formatted_numbers(self):
"""Tests that the clever_format function correctly formats a list of numbers as strings with a '1.00B' pattern."""
nums = [1, 2]
format = "%.2f"
clever_nums = utils.clever_format(nums, format)
Expand Down
15 changes: 14 additions & 1 deletion thop/fx_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@


def count_clamp(input_shapes, output_shapes):
"""Ensures proper array sizes for tensors by clamping input and output shapes."""
return 0


def count_mul(input_shapes, output_shapes):
# element-wise
"""Returns the number of elements in the first output shape."""
return output_shapes[0].numel()


def count_matmul(input_shapes, output_shapes):
"""Calculates the total number of operations for a matrix multiplication given input and output shapes."""
in_shape = input_shapes[0]
out_shape = output_shapes[0]
in_features = in_shape[-1]
Expand All @@ -30,6 +32,7 @@ def count_matmul(input_shapes, output_shapes):


def count_fn_linear(input_shapes, output_shapes, *args, **kwargs):
"""Calculates total operations (FLOPs) for a linear layer given input and output shapes."""
mul_flops = count_matmul(input_shapes, output_shapes)
if "bias" in kwargs:
add_flops = output_shapes[0].numel()
Expand All @@ -40,6 +43,7 @@ def count_fn_linear(input_shapes, output_shapes, *args, **kwargs):


def count_fn_conv2d(input_shapes, output_shapes, *args, **kwargs):
"""Calculates total operations (FLOPs) for a 2D convolutional layer given input and output shapes."""
inputs, weight, bias, stride, padding, dilation, groups = args
if len(input_shapes) == 2:
x_shape, k_shape = input_shapes
Expand All @@ -56,14 +60,17 @@ def count_fn_conv2d(input_shapes, output_shapes, *args, **kwargs):


def count_nn_linear(module: nn.Module, input_shapes, output_shapes):
"""Counts the FLOPs for a fully connected (linear) layer in a neural network module."""
return count_matmul(input_shapes, output_shapes)


def count_zero_ops(module: nn.Module, input_shapes, output_shapes, *args, **kwargs):
"""Returns 0 for the given neural network module, input shapes, and output shapes."""
return 0


def count_nn_conv2d(module: nn.Conv2d, input_shapes, output_shapes):
"""Calculates total operations for a 2D convolutional neural network layer in a given neural network module."""
bias_op = 1 if module.bias is not None else 0
out_shape = output_shapes[0]

Expand All @@ -75,6 +82,7 @@ def count_nn_conv2d(module: nn.Conv2d, input_shapes, output_shapes):


def count_nn_bn2d(module: nn.BatchNorm2d, input_shapes, output_shapes):
"""Calculate the total operations for a given nn.BatchNorm2d module based on its output shape."""
assert len(output_shapes) == 1, "nn.BatchNorm2d should only have one output"
y = output_shapes[0]
# y = (x - mean) / \sqrt{var + e} * weight + bias
Expand Down Expand Up @@ -116,10 +124,12 @@ def count_nn_bn2d(module: nn.BatchNorm2d, input_shapes, output_shapes):


def null_print(*args, **kwargs):
"""A no-op print function that takes any arguments without performing any actions."""
return


def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False):
"""Profiles the given torch.nn Module to calculate total FLOPs for each operation and prints detailed node information if verbose."""
gm: torch.fx.GraphModule = symbolic_trace(mod)
g = gm.graph
ShapeProp(gm).propagate(input)
Expand Down Expand Up @@ -204,16 +214,19 @@ def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False):

class MyOP(nn.Module):
def forward(self, input):
"""Performs forward pass on given input data."""
return input / 1

class MyModule(torch.nn.Module):
def __init__(self):
"""Initializes MyModule with two linear layers and a custom MyOP operator."""
super().__init__()
self.linear1 = torch.nn.Linear(5, 3)
self.linear2 = torch.nn.Linear(5, 3)
self.myop = MyOP()

def forward(self, x):
"""Applies two linear transformations to the input tensor, clamps the second, then combines and processes with MyOP operator."""
out1 = self.linear1(x)
out2 = self.linear2(x).clamp(min=0.0, max=1.0)
return self.myop(out1 + out2)
Expand Down
4 changes: 4 additions & 0 deletions thop/onnx_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@

class OnnxProfile:
def __init__(self):
"""Initialize the OnnxProfile class with necessary imports for ONNX profiling."""
pass

def calculate_params(self, model: onnx.ModelProto):
"""Calculate the total number of parameters in an ONNX model."""
onnx_weights = model.graph.initializer
params = 0

Expand All @@ -25,6 +27,7 @@ def calculate_params(self, model: onnx.ModelProto):
return params

def create_dict(self, weight, input, output):
"""Create and return a dictionary mapping weight, input, and output names to their respective dimensions."""
diction = {}
for w in weight:
dim = np.array(w.dims)
Expand Down Expand Up @@ -52,6 +55,7 @@ def create_dict(self, weight, input, output):
return diction

def nodes_counter(self, diction, node):
"""Count nodes of a specific type in an ONNX graph, returning the count and associated node operation details."""
if node.op_type not in onnx_operators:
print("Sorry, we haven't add ", node.op_type, "into dictionary.")
return 0, None, None
Expand Down
3 changes: 3 additions & 0 deletions thop/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@


def profile_origin(model, inputs, custom_ops=None, verbose=True, report_missing=False):
"""Profiles a PyTorch model's operations and parameters by applying custom or default hooks and returns total operations and parameters."""
handler_collection = []
types_collection = set()
if custom_ops is None:
Expand Down Expand Up @@ -162,6 +163,7 @@ def profile(
verbose = True

def add_hooks(m: nn.Module):
"""Registers hooks to a neural network module to track total operations and parameters."""
m.register_buffer("total_ops", torch.zeros(1, dtype=torch.float64))
m.register_buffer("total_params", torch.zeros(1, dtype=torch.float64))

Expand Down Expand Up @@ -200,6 +202,7 @@ def add_hooks(m: nn.Module):
model(*inputs)

def dfs_count(module: nn.Module, prefix="\t") -> (int, int):
"""Recursively counts the total operations and parameters of the given PyTorch module and its submodules."""
total_ops, total_params = module.total_ops.item(), 0
ret_dict = {}
for n, m in module.named_children():
Expand Down
10 changes: 9 additions & 1 deletion thop/rnn_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def _count_rnn_cell(input_size, hidden_size, bias=True):
# h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh})
"""Calculate the total operations for an RNN cell based on input size, hidden size, and bias configuration."""
total_ops = hidden_size * (input_size + hidden_size) + hidden_size
if bias:
total_ops += hidden_size * 2
Expand All @@ -13,6 +13,7 @@ def _count_rnn_cell(input_size, hidden_size, bias=True):


def count_rnn_cell(m: nn.RNNCell, x: torch.Tensor, y: torch.Tensor):
"""Counts RNN cell operations based on input, hidden size, bias, and batch size."""
total_ops = _count_rnn_cell(m.input_size, m.hidden_size, m.bias)

batch_size = x[0].size(0)
Expand All @@ -22,6 +23,7 @@ def count_rnn_cell(m: nn.RNNCell, x: torch.Tensor, y: torch.Tensor):


def _count_gru_cell(input_size, hidden_size, bias=True):
"""Counts the total operations for a GRU cell based on input size, hidden size, and bias."""
total_ops = 0
# r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
# z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
Expand All @@ -45,6 +47,7 @@ def _count_gru_cell(input_size, hidden_size, bias=True):


def count_gru_cell(m: nn.GRUCell, x: torch.Tensor, y: torch.Tensor):
"""Calculates and updates the total operations for a GRU cell in a mini-batch during inference."""
total_ops = _count_gru_cell(m.input_size, m.hidden_size, m.bias)

batch_size = x[0].size(0)
Expand All @@ -54,6 +57,7 @@ def count_gru_cell(m: nn.GRUCell, x: torch.Tensor, y: torch.Tensor):


def _count_lstm_cell(input_size, hidden_size, bias=True):
"""Calculates the total operations for an LSTM cell during inference given input size, hidden size, and optional bias."""
total_ops = 0

# i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
Expand All @@ -76,6 +80,7 @@ def _count_lstm_cell(input_size, hidden_size, bias=True):


def count_lstm_cell(m: nn.LSTMCell, x: torch.Tensor, y: torch.Tensor):
"""Count the number of operations for a single LSTM cell in a given batch, updating the model's total operations count."""
total_ops = _count_lstm_cell(m.input_size, m.hidden_size, m.bias)

batch_size = x[0].size(0)
Expand All @@ -85,6 +90,7 @@ def count_lstm_cell(m: nn.LSTMCell, x: torch.Tensor, y: torch.Tensor):


def count_rnn(m: nn.RNN, x, y):
"""Calculate and update the total number of operations for a single RNN cell in a given batch."""
bias = m.bias
input_size = m.input_size
hidden_size = m.hidden_size
Expand Down Expand Up @@ -122,6 +128,7 @@ def count_rnn(m: nn.RNN, x, y):


def count_gru(m: nn.GRU, x, y):
"""Calculate the total number of operations for a GRU layer in a neural network model."""
bias = m.bias
input_size = m.input_size
hidden_size = m.hidden_size
Expand Down Expand Up @@ -159,6 +166,7 @@ def count_gru(m: nn.GRU, x, y):


def count_lstm(m: nn.LSTM, x, y):
"""Calculate the total operations for LSTM layers in a network, accounting for input size, hidden size, bias, and bidirectionality."""
bias = m.bias
input_size = m.input_size
hidden_size = m.hidden_size
Expand Down
2 changes: 2 additions & 0 deletions thop/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


def colorful_print(fn_print, color=COLOR_RED):
"""A decorator to print text in the specified terminal color by wrapping the given print function."""
def actual_call(*args, **kwargs):
print(f"\033[{color}", end="")
fn_print(*args, **kwargs)
Expand All @@ -29,6 +30,7 @@ def actual_call(*args, **kwargs):


def clever_format(nums, format="%.2f"):
"""Formats numerical values into a more readable string with units (K, M, G, T) based on their magnitude."""
if not isinstance(nums, Iterable):
nums = [nums]
clever_nums = []
Expand Down
Loading

0 comments on commit f1a42d7

Please sign in to comment.