Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Python function docstrings #9

Merged
merged 2 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ jobs:
v_local = tuple(map(int, pyproject_version.split('.')))

# Compare with version on PyPI
v_pypi = (0, 0, 0) # tuple(map(int, check_latest_pypi_version('ultralytics-thop').split('.')))
v_pypi = tuple(map(int, check_latest_pypi_version('ultralytics-thop').split('.')))
print(f'Local version is {v_local}')
print(f'PyPI version is {v_pypi}')
d = [a - b for a, b in zip(v_local, v_pypi)] # diff
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "ultralytics-thop"
version = "0.0.1" # Placeholder version, needs to be dynamically set
version = "0.0.2" # Placeholder version, needs to be dynamically set
description = "A tool to count the FLOPs of PyTorch model."
readme = "README.md"
requires-python = ">=3.8"
Expand Down
5 changes: 5 additions & 0 deletions tests/test_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

class TestUtils:
def test_conv2d_no_bias(self):
"""Tests a 2D convolutional layer without bias using THOP profiling with predefined input dimensions and
convolution parameters.
"""
n, in_c, ih, iw = 1, 3, 32, 32 # torch.randint(1, 10, (4,)).tolist()
out_c, kh, kw = 12, 5, 5
s, p, d, g = 1, 1, 1, 1
Expand All @@ -22,6 +25,7 @@ def test_conv2d_no_bias(self):
assert flops == 810000, f"{flops} v.s. {810000}"

def test_conv2d(self):
"""Tests a Conv2D layer with specific input dimensions, kernel size, stride, padding, dilation, and groups."""
n, in_c, ih, iw = 1, 3, 32, 32 # torch.randint(1, 10, (4,)).tolist()
out_c, kh, kw = 12, 5, 5
s, p, d, g = 1, 1, 1, 1
Expand All @@ -36,6 +40,7 @@ def test_conv2d(self):
assert flops == 810000, f"{flops} v.s. {810000}"

def test_conv2d_random(self):
"""Test Conv2D layer with random parameters and validate the computed FLOPs and parameters using 'profile'."""
for i in range(10):
out_c, kh, kw = torch.randint(1, 20, (3,)).tolist()
n, in_c, ih, iw = torch.randint(1, 20, (4,)).tolist() # torch.randint(1, 10, (4,)).tolist()
Expand Down
3 changes: 3 additions & 0 deletions tests/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@

class TestUtils:
def test_matmul_case2(self):
"""Test matrix multiplication case asserting the FLOPs and parameters of a nn.Linear layer."""
n, in_c, out_c = 1, 100, 200
net = nn.Linear(in_c, out_c)
flops, params = profile(net, inputs=(torch.randn(n, in_c),))
print(flops, params)
assert flops == n * in_c * out_c

def test_matmul_case2(self):
"""Tests matrix multiplication to assert FLOPs and parameters of nn.Linear layer using random dimensions."""
for i in range(10):
n, in_c, out_c = torch.randint(1, 500, (3,)).tolist()
net = nn.Linear(in_c, out_c)
Expand All @@ -22,6 +24,7 @@ def test_matmul_case2(self):
assert flops == n * in_c * out_c

def test_conv2d(self):
"""Tests the number of FLOPs and parameters for a randomly initialized nn.Linear layer using torch.profiler."""
n, in_c, out_c = torch.randint(1, 500, (3,)).tolist()
net = nn.Linear(in_c, out_c)
flops, params = profile(net, inputs=(torch.randn(n, in_c),))
Expand Down
3 changes: 3 additions & 0 deletions tests/test_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

class TestUtils:
def test_relu(self):
"""Tests the ReLU activation function to ensure it has zero FLOPs and checks parameter count using THOP
profiling.
"""
n, in_c, out_c = 1, 100, 200
data = torch.randn(n, in_c)
net = nn.ReLU()
Expand Down
4 changes: 4 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@

class TestUtils:
def test_clever_format_returns_formatted_number(self):
"""Tests that the clever_format function returns a formatted number string with a '1.00B' pattern."""
nums = 1
format = "%.2f"
clever_nums = utils.clever_format(nums, format)
assert clever_nums == "1.00B"

def test_clever_format_returns_formatted_numbers(self):
"""Tests that the clever_format function correctly formats a list of numbers as strings with a '1.00B'
pattern.
"""
nums = [1, 2]
format = "%.2f"
clever_nums = utils.clever_format(nums, format)
Expand Down
19 changes: 18 additions & 1 deletion thop/fx_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@


def count_clamp(input_shapes, output_shapes):
"""Ensures proper array sizes for tensors by clamping input and output shapes."""
return 0


def count_mul(input_shapes, output_shapes):
# element-wise
"""Returns the number of elements in the first output shape."""
return output_shapes[0].numel()


def count_matmul(input_shapes, output_shapes):
"""Calculates the total number of operations for a matrix multiplication given input and output shapes."""
in_shape = input_shapes[0]
out_shape = output_shapes[0]
in_features = in_shape[-1]
Expand All @@ -30,6 +32,7 @@ def count_matmul(input_shapes, output_shapes):


def count_fn_linear(input_shapes, output_shapes, *args, **kwargs):
"""Calculates total operations (FLOPs) for a linear layer given input and output shapes."""
mul_flops = count_matmul(input_shapes, output_shapes)
if "bias" in kwargs:
add_flops = output_shapes[0].numel()
Expand All @@ -40,6 +43,7 @@ def count_fn_linear(input_shapes, output_shapes, *args, **kwargs):


def count_fn_conv2d(input_shapes, output_shapes, *args, **kwargs):
"""Calculates total operations (FLOPs) for a 2D convolutional layer given input and output shapes."""
inputs, weight, bias, stride, padding, dilation, groups = args
if len(input_shapes) == 2:
x_shape, k_shape = input_shapes
Expand All @@ -56,14 +60,17 @@ def count_fn_conv2d(input_shapes, output_shapes, *args, **kwargs):


def count_nn_linear(module: nn.Module, input_shapes, output_shapes):
"""Counts the FLOPs for a fully connected (linear) layer in a neural network module."""
return count_matmul(input_shapes, output_shapes)


def count_zero_ops(module: nn.Module, input_shapes, output_shapes, *args, **kwargs):
"""Returns 0 for the given neural network module, input shapes, and output shapes."""
return 0


def count_nn_conv2d(module: nn.Conv2d, input_shapes, output_shapes):
"""Calculates total operations for a 2D convolutional neural network layer in a given neural network module."""
bias_op = 1 if module.bias is not None else 0
out_shape = output_shapes[0]

Expand All @@ -75,6 +82,7 @@ def count_nn_conv2d(module: nn.Conv2d, input_shapes, output_shapes):


def count_nn_bn2d(module: nn.BatchNorm2d, input_shapes, output_shapes):
"""Calculate the total operations for a given nn.BatchNorm2d module based on its output shape."""
assert len(output_shapes) == 1, "nn.BatchNorm2d should only have one output"
y = output_shapes[0]
# y = (x - mean) / \sqrt{var + e} * weight + bias
Expand Down Expand Up @@ -116,10 +124,14 @@ def count_nn_bn2d(module: nn.BatchNorm2d, input_shapes, output_shapes):


def null_print(*args, **kwargs):
"""A no-op print function that takes any arguments without performing any actions."""
return


def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False):
"""Profiles the given torch.nn Module to calculate total FLOPs for each operation and prints detailed node
information if verbose.
"""
gm: torch.fx.GraphModule = symbolic_trace(mod)
g = gm.graph
ShapeProp(gm).propagate(input)
Expand Down Expand Up @@ -204,16 +216,21 @@ def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False):

class MyOP(nn.Module):
def forward(self, input):
"""Performs forward pass on given input data."""
return input / 1

class MyModule(torch.nn.Module):
def __init__(self):
"""Initializes MyModule with two linear layers and a custom MyOP operator."""
super().__init__()
self.linear1 = torch.nn.Linear(5, 3)
self.linear2 = torch.nn.Linear(5, 3)
self.myop = MyOP()

def forward(self, x):
"""Applies two linear transformations to the input tensor, clamps the second, then combines and processes
with MyOP operator.
"""
out1 = self.linear1(x)
out2 = self.linear2(x).clamp(min=0.0, max=1.0)
return self.myop(out1 + out2)
Expand Down
6 changes: 6 additions & 0 deletions thop/onnx_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@

class OnnxProfile:
def __init__(self):
"""Initialize the OnnxProfile class with necessary imports for ONNX profiling."""
pass

def calculate_params(self, model: onnx.ModelProto):
"""Calculate the total number of parameters in an ONNX model."""
onnx_weights = model.graph.initializer
params = 0

Expand All @@ -25,6 +27,7 @@ def calculate_params(self, model: onnx.ModelProto):
return params

def create_dict(self, weight, input, output):
"""Create and return a dictionary mapping weight, input, and output names to their respective dimensions."""
diction = {}
for w in weight:
dim = np.array(w.dims)
Expand Down Expand Up @@ -52,6 +55,9 @@ def create_dict(self, weight, input, output):
return diction

def nodes_counter(self, diction, node):
"""Count nodes of a specific type in an ONNX graph, returning the count and associated node operation
details.
"""
if node.op_type not in onnx_operators:
print("Sorry, we haven't add ", node.op_type, "into dictionary.")
return 0, None, None
Expand Down
5 changes: 5 additions & 0 deletions thop/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@


def profile_origin(model, inputs, custom_ops=None, verbose=True, report_missing=False):
"""Profiles a PyTorch model's operations and parameters by applying custom or default hooks and returns total
operations and parameters.
"""
handler_collection = []
types_collection = set()
if custom_ops is None:
Expand Down Expand Up @@ -162,6 +165,7 @@ def profile(
verbose = True

def add_hooks(m: nn.Module):
"""Registers hooks to a neural network module to track total operations and parameters."""
m.register_buffer("total_ops", torch.zeros(1, dtype=torch.float64))
m.register_buffer("total_params", torch.zeros(1, dtype=torch.float64))

Expand Down Expand Up @@ -200,6 +204,7 @@ def add_hooks(m: nn.Module):
model(*inputs)

def dfs_count(module: nn.Module, prefix="\t") -> (int, int):
"""Recursively counts the total operations and parameters of the given PyTorch module and its submodules."""
total_ops, total_params = module.total_ops.item(), 0
ret_dict = {}
for n, m in module.named_children():
Expand Down
16 changes: 15 additions & 1 deletion thop/rnn_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def _count_rnn_cell(input_size, hidden_size, bias=True):
# h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh})
"""Calculate the total operations for an RNN cell based on input size, hidden size, and bias configuration."""
total_ops = hidden_size * (input_size + hidden_size) + hidden_size
if bias:
total_ops += hidden_size * 2
Expand All @@ -13,6 +13,7 @@ def _count_rnn_cell(input_size, hidden_size, bias=True):


def count_rnn_cell(m: nn.RNNCell, x: torch.Tensor, y: torch.Tensor):
"""Counts RNN cell operations based on input, hidden size, bias, and batch size."""
total_ops = _count_rnn_cell(m.input_size, m.hidden_size, m.bias)

batch_size = x[0].size(0)
Expand All @@ -22,6 +23,7 @@ def count_rnn_cell(m: nn.RNNCell, x: torch.Tensor, y: torch.Tensor):


def _count_gru_cell(input_size, hidden_size, bias=True):
"""Counts the total operations for a GRU cell based on input size, hidden size, and bias."""
total_ops = 0
# r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
# z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
Expand All @@ -45,6 +47,7 @@ def _count_gru_cell(input_size, hidden_size, bias=True):


def count_gru_cell(m: nn.GRUCell, x: torch.Tensor, y: torch.Tensor):
"""Calculates and updates the total operations for a GRU cell in a mini-batch during inference."""
total_ops = _count_gru_cell(m.input_size, m.hidden_size, m.bias)

batch_size = x[0].size(0)
Expand All @@ -54,6 +57,9 @@ def count_gru_cell(m: nn.GRUCell, x: torch.Tensor, y: torch.Tensor):


def _count_lstm_cell(input_size, hidden_size, bias=True):
"""Calculates the total operations for an LSTM cell during inference given input size, hidden size, and optional
bias.
"""
total_ops = 0

# i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
Expand All @@ -76,6 +82,9 @@ def _count_lstm_cell(input_size, hidden_size, bias=True):


def count_lstm_cell(m: nn.LSTMCell, x: torch.Tensor, y: torch.Tensor):
"""Count the number of operations for a single LSTM cell in a given batch, updating the model's total operations
count.
"""
total_ops = _count_lstm_cell(m.input_size, m.hidden_size, m.bias)

batch_size = x[0].size(0)
Expand All @@ -85,6 +94,7 @@ def count_lstm_cell(m: nn.LSTMCell, x: torch.Tensor, y: torch.Tensor):


def count_rnn(m: nn.RNN, x, y):
"""Calculate and update the total number of operations for a single RNN cell in a given batch."""
bias = m.bias
input_size = m.input_size
hidden_size = m.hidden_size
Expand Down Expand Up @@ -122,6 +132,7 @@ def count_rnn(m: nn.RNN, x, y):


def count_gru(m: nn.GRU, x, y):
"""Calculate the total number of operations for a GRU layer in a neural network model."""
bias = m.bias
input_size = m.input_size
hidden_size = m.hidden_size
Expand Down Expand Up @@ -159,6 +170,9 @@ def count_gru(m: nn.GRU, x, y):


def count_lstm(m: nn.LSTM, x, y):
"""Calculate the total operations for LSTM layers in a network, accounting for input size, hidden size, bias, and
bidirectionality.
"""
bias = m.bias
input_size = m.input_size
hidden_size = m.hidden_size
Expand Down
3 changes: 3 additions & 0 deletions thop/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@


def colorful_print(fn_print, color=COLOR_RED):
"""A decorator to print text in the specified terminal color by wrapping the given print function."""

def actual_call(*args, **kwargs):
print(f"\033[{color}", end="")
fn_print(*args, **kwargs)
Expand All @@ -29,6 +31,7 @@ def actual_call(*args, **kwargs):


def clever_format(nums, format="%.2f"):
"""Formats numerical values into a more readable string with units (K, M, G, T) based on their magnitude."""
if not isinstance(nums, Iterable):
nums = [nums]
clever_nums = []
Expand Down
Loading