Skip to content

Commit

Permalink
Refactor code for speed and clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Jun 30, 2024
1 parent eafdca3 commit c8ac820
Show file tree
Hide file tree
Showing 12 changed files with 56 additions and 83 deletions.
3 changes: 1 addition & 2 deletions benchmark/evaluate_famous_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from torchvision import models

from thop.profile import profile
from torchvision import models

model_names = sorted(
name
Expand Down
1 change: 0 additions & 1 deletion benchmark/evaluate_rnn_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
import torch.nn as nn

from thop.profile import profile

input_size = 160
Expand Down
7 changes: 2 additions & 5 deletions tests/test_conv2d.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import torch
import torch.nn as nn

from thop import profile


class TestUtils:
def test_conv2d_no_bias(self):
"""Tests a 2D convolutional layer without bias using THOP profiling with predefined input dimensions and
convolution parameters.
"""
"""Tests a 2D Conv layer without bias using THOP profiling on predefined input dimensions and parameters."""
n, in_c, ih, iw = 1, 3, 32, 32 # torch.randint(1, 10, (4,)).tolist()
out_c, kh, kw = 12, 5, 5
s, p, d, g = 1, 1, 1, 1
Expand All @@ -23,7 +20,7 @@ def test_conv2d_no_bias(self):
assert flops == 810000, f"{flops} v.s. 810000"

def test_conv2d(self):
"""Tests a Conv2D layer with specific input dimensions, kernel size, stride, padding, dilation, and groups."""
"""Tests Conv2D layer with bias, profiling FLOPs and params for specific input dimensions and layer configs."""
n, in_c, ih, iw = 1, 3, 32, 32 # torch.randint(1, 10, (4,)).tolist()
out_c, kh, kw = 12, 5, 5
s, p, d, g = 1, 1, 1, 1
Expand Down
7 changes: 3 additions & 4 deletions tests/test_matmul.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import torch
import torch.nn as nn

from thop import profile


class TestUtils:
def test_matmul_case2(self):
"""Test matrix multiplication case asserting the FLOPs and parameters of a nn.Linear layer."""
"""Test matrix multiplication case by profiling FLOPs and parameters of a PyTorch nn.Linear layer."""
n, in_c, out_c = 1, 100, 200
net = nn.Linear(in_c, out_c)
flops, params = profile(net, inputs=(torch.randn(n, in_c),))
print(flops, params)
assert flops == n * in_c * out_c

def test_matmul_case2(self):
"""Tests matrix multiplication to assert FLOPs and parameters of nn.Linear layer using random dimensions."""
"""Tests matrix multiplication to profile FLOPs and parameters of nn.Linear layer using random dimensions."""
for _ in range(10):
n, in_c, out_c = torch.randint(1, 500, (3,)).tolist()
net = nn.Linear(in_c, out_c)
Expand All @@ -23,7 +22,7 @@ def test_matmul_case2(self):
assert flops == n * in_c * out_c

def test_conv2d(self):
"""Tests the number of FLOPs and parameters for a randomly initialized nn.Linear layer using torch.profiler."""
"""Tests FLOPs and parameters for a nn.Linear layer with random dimensions using torch.profiler."""
n, in_c, out_c = torch.randint(1, 500, (3,)).tolist()
net = nn.Linear(in_c, out_c)
flops, params = profile(net, inputs=(torch.randn(n, in_c),))
Expand Down
5 changes: 1 addition & 4 deletions tests/test_relu.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import torch
import torch.nn as nn

from thop import profile


class TestUtils:
def test_relu(self):
"""Tests the ReLU activation function to ensure it has zero FLOPs and checks parameter count using THOP
profiling.
"""
"""Tests ReLU activation ensuring zero FLOPs and displays parameter count using THOP profiling."""
n, in_c, _out_c = 1, 100, 200
net = nn.ReLU()
flops, params = profile(net, inputs=(torch.randn(n, in_c),))
Expand Down
6 changes: 2 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@

class TestUtils:
def test_clever_format_returns_formatted_number(self):
"""Tests that the clever_format function returns a formatted number string with a '1.00B' pattern."""
"""Tests that clever_format returns a string like '1.00B' for the given number and format pattern."""
nums = 1
format = "%.2f"
clever_nums = utils.clever_format(nums, format)
assert clever_nums == "1.00B"

def test_clever_format_returns_formatted_numbers(self):
"""Tests that the clever_format function correctly formats a list of numbers as strings with a '1.00B'
pattern.
"""
"""Verifies clever_format formats a list of numbers to strings, e.g., '[1, 2]' to '[1.00B, 2.00B]'."""
nums = [1, 2]
format = "%.2f"
clever_nums = utils.clever_format(nums, format)
Expand Down
18 changes: 8 additions & 10 deletions thop/fx_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


def count_clamp(input_shapes, output_shapes):
"""Ensures proper array sizes for tensors by clamping input and output shapes."""
"""Ensures tensor array sizes are appropriate by clamping specified input and output shapes."""
return 0


Expand All @@ -23,7 +23,7 @@ def count_mul(input_shapes, output_shapes):


def count_matmul(input_shapes, output_shapes):
"""Calculates the total number of operations for a matrix multiplication given input and output shapes."""
"""Calculates matrix multiplication ops based on input and output tensor shapes for performance profiling."""
in_shape = input_shapes[0]
out_shape = output_shapes[0]
in_features = in_shape[-1]
Expand All @@ -32,7 +32,7 @@ def count_matmul(input_shapes, output_shapes):


def count_fn_linear(input_shapes, output_shapes, *args, **kwargs):
"""Calculates total operations (FLOPs) for a linear layer given input and output shapes."""
"""Calculates the total FLOPs for a linear layer, including bias operations if specified."""
flops = count_matmul(input_shapes, output_shapes)
if "bias" in kwargs:
flops += output_shapes[0].numel()
Expand All @@ -43,7 +43,7 @@ def count_fn_linear(input_shapes, output_shapes, *args, **kwargs):


def count_fn_conv2d(input_shapes, output_shapes, *args, **kwargs):
"""Calculates total operations (FLOPs) for a 2D convolutional layer given input and output shapes."""
"""Calculates total operations (FLOPs) for a 2D conv layer based on input and output shapes using `calculate_conv`."""
inputs, weight, bias, stride, padding, dilation, groups = args
if len(input_shapes) == 2:
x_shape, k_shape = input_shapes
Expand All @@ -65,12 +65,12 @@ def count_nn_linear(module: nn.Module, input_shapes, output_shapes):


def count_zero_ops(module: nn.Module, input_shapes, output_shapes, *args, **kwargs):
"""Returns 0 for the given neural network module, input shapes, and output shapes."""
"""Returns 0 for a neural network module, input shapes, and output shapes in PyTorch."""
return 0


def count_nn_conv2d(module: nn.Conv2d, input_shapes, output_shapes):
"""Calculates total operations for a 2D convolutional neural network layer in a given neural network module."""
"""Calculates FLOPs for a 2D Conv2D layer in an nn.Module using input and output shapes."""
bias_op = 1 if module.bias is not None else 0
out_shape = output_shapes[0]

Expand All @@ -82,7 +82,7 @@ def count_nn_conv2d(module: nn.Conv2d, input_shapes, output_shapes):


def count_nn_bn2d(module: nn.BatchNorm2d, input_shapes, output_shapes):
"""Calculate the total operations for a given nn.BatchNorm2d module based on its output shape."""
"""Calculate FLOPs for an nn.BatchNorm2d layer based on the given output shape."""
assert len(output_shapes) == 1, "nn.BatchNorm2d should only have one output"
y = output_shapes[0]
return 2 * y.numel()
Expand Down Expand Up @@ -127,9 +127,7 @@ def null_print(*args, **kwargs):


def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False):
"""Profiles the given torch.nn Module to calculate total FLOPs for each operation and prints detailed node
information if verbose.
"""
"""Profiles nn.Module for total FLOPs per operation and prints detailed nodes if verbose."""
gm: torch.fx.GraphModule = symbolic_trace(mod)
ShapeProp(gm).propagate(input)

Expand Down
6 changes: 2 additions & 4 deletions thop/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@


def profile_origin(model, inputs, custom_ops=None, verbose=True, report_missing=False):
"""Profiles a PyTorch model's operations and parameters by applying custom or default hooks and returns total
operations and parameters.
"""
"""Profiles a PyTorch model's operations and parameters, applying either custom or default hooks."""
handler_collection = []
types_collection = set()
if custom_ops is None:
Expand Down Expand Up @@ -145,7 +143,7 @@ def profile(
ret_layer_info=False,
report_missing=False,
):
"""Profiles a PyTorch model, returning total operations and parameters, with optional layer-wise details."""
"""Profiles a PyTorch model, returning total operations, parameters, and optionally layer-wise details."""
handler_collection = {}
types_collection = set()
if custom_ops is None:
Expand Down
22 changes: 8 additions & 14 deletions thop/rnn_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def _count_rnn_cell(input_size, hidden_size, bias=True):
"""Calculate the total operations for an RNN cell based on input size, hidden size, and bias configuration."""
"""Calculate the total operations for an RNN cell given input size, hidden size, and optional bias."""
total_ops = hidden_size * (input_size + hidden_size) + hidden_size
if bias:
total_ops += hidden_size * 2
Expand All @@ -13,7 +13,7 @@ def _count_rnn_cell(input_size, hidden_size, bias=True):


def count_rnn_cell(m: nn.RNNCell, x: torch.Tensor, y: torch.Tensor):
"""Counts RNN cell operations based on input, hidden size, bias, and batch size."""
"""Counts the total RNN cell operations based on input tensor, hidden size, bias, and batch size."""
total_ops = _count_rnn_cell(m.input_size, m.hidden_size, m.bias)

batch_size = x[0].size(0)
Expand All @@ -23,7 +23,7 @@ def count_rnn_cell(m: nn.RNNCell, x: torch.Tensor, y: torch.Tensor):


def _count_gru_cell(input_size, hidden_size, bias=True):
"""Counts the total operations for a GRU cell based on input size, hidden size, and bias."""
"""Counts the total operations for a GRU cell based on input size, hidden size, and bias configuration."""
total_ops = 0
# r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\
# z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\
Expand Down Expand Up @@ -57,9 +57,7 @@ def count_gru_cell(m: nn.GRUCell, x: torch.Tensor, y: torch.Tensor):


def _count_lstm_cell(input_size, hidden_size, bias=True):
"""Calculates the total operations for an LSTM cell during inference given input size, hidden size, and optional
bias.
"""
"""Counts LSTM cell operations during inference based on input size, hidden size, and bias configuration."""
total_ops = 0

# i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
Expand All @@ -82,9 +80,7 @@ def _count_lstm_cell(input_size, hidden_size, bias=True):


def count_lstm_cell(m: nn.LSTMCell, x: torch.Tensor, y: torch.Tensor):
"""Count the number of operations for a single LSTM cell in a given batch, updating the model's total operations
count.
"""
"""Counts and updates the total operations for an LSTM cell in a mini-batch during inference."""
total_ops = _count_lstm_cell(m.input_size, m.hidden_size, m.bias)

batch_size = x[0].size(0)
Expand All @@ -94,7 +90,7 @@ def count_lstm_cell(m: nn.LSTMCell, x: torch.Tensor, y: torch.Tensor):


def count_rnn(m: nn.RNN, x, y):
"""Calculate and update the total number of operations for a single RNN cell in a given batch."""
"""Calculate and update the total number of operations for each RNN cell in a given batch."""
bias = m.bias
input_size = m.input_size
hidden_size = m.hidden_size
Expand Down Expand Up @@ -131,7 +127,7 @@ def count_rnn(m: nn.RNN, x, y):


def count_gru(m: nn.GRU, x, y):
"""Calculate the total number of operations for a GRU layer in a neural network model."""
"""Calculates total operations for a GRU layer, updating the model's operation count based on batch size."""
bias = m.bias
input_size = m.input_size
hidden_size = m.hidden_size
Expand Down Expand Up @@ -168,9 +164,7 @@ def count_gru(m: nn.GRU, x, y):


def count_lstm(m: nn.LSTM, x, y):
"""Calculate the total operations for LSTM layers in a network, accounting for input size, hidden size, bias, and
bidirectionality.
"""
"""Calculate total operations for LSTM layers, including bidirectional, updating model's total operations."""
bias = m.bias
input_size = m.input_size
hidden_size = m.hidden_size
Expand Down
2 changes: 1 addition & 1 deletion thop/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def actual_call(*args, **kwargs):


def clever_format(nums, format="%.2f"):
"""Formats numerical values into a more readable string with units (K, M, G, T) based on their magnitude."""
"""Formats numbers into human-readable strings with units (K for thousand, M for million, etc.)."""
if not isinstance(nums, Iterable):
nums = [nums]
clever_nums = []
Expand Down
26 changes: 11 additions & 15 deletions thop/vision/basic_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@


def count_parameters(m, x, y):
"""Calculate and update the total number of parameters in a given PyTorch model."""
"""Calculate and return the total number of learnable parameters in a given PyTorch model."""
m.total_params[0] = calculate_parameters(m.parameters())


def zero_ops(m, x, y):
"""Incrementally add the number of zero operations to the model's total operations count."""
"""Incrementally add zero operations to the model's total operations count."""
m.total_ops += calculate_zero_ops()


def count_convNd(m: _ConvNd, x, y: torch.Tensor):
"""Calculate and add the number of convolutional operations (FLOPs) to the model's total operations count."""
"""Calculate and add the number of convolutional operations (FLOPs) for a ConvNd layer to the model's total ops."""
x = x[0]

m.total_ops += calculate_conv2d_flops(
Expand All @@ -40,7 +40,7 @@ def count_convNd(m: _ConvNd, x, y: torch.Tensor):


def count_convNd_ver2(m: _ConvNd, x, y: torch.Tensor):
"""Calculates the total operations for a convolutional layer and updates the layer's total_ops attribute."""
"""Calculates and updates total operations (FLOPs) for a convolutional layer in a PyTorch model."""
x = x[0]

# N x H x W (exclude Cout)
Expand All @@ -56,9 +56,7 @@ def count_convNd_ver2(m: _ConvNd, x, y: torch.Tensor):


def count_normalization(m: nn.modules.batchnorm._BatchNorm, x, y):
"""Calculate and add the FLOPs for a batch normalization layer, considering elementwise operations and possible
affine parameters.
"""
"""Calculate and add the FLOPs for a batch normalization layer, including elementwise and affine operations."""
# https://github.com/Lyken17/pytorch-OpCounter/issues/124
# y = (x - mean) / sqrt(eps + var) * weight + bias
x = x[0]
Expand All @@ -80,7 +78,7 @@ def count_normalization(m: nn.modules.batchnorm._BatchNorm, x, y):


def count_prelu(m, x, y):
"""Calculate and update the total operation counts for a PReLU layer."""
"""Calculate and update the total operation counts for a PReLU layer using input element number."""
x = x[0]

nelements = x.numel()
Expand All @@ -95,7 +93,7 @@ def count_relu(m, x, y):


def count_softmax(m, x, y):
"""Calculate and update the total operation counts for a Softmax layer."""
"""Calculate and update the total operation counts for a Softmax layer in a PyTorch model."""
x = x[0]
nfeatures = x.size()[m.dim]
batch_size = x.numel() // nfeatures
Expand All @@ -104,15 +102,15 @@ def count_softmax(m, x, y):


def count_avgpool(m, x, y):
"""Calculate and update the total operation counts for an AvgPool layer."""
"""Calculate and update the total number of operations (FLOPs) for an AvgPool layer based on the output elements."""
# total_div = 1
# kernel_ops = total_add + total_div
num_elements = y.numel()
m.total_ops += calculate_avgpool(num_elements)


def count_adap_avgpool(m, x, y):
"""Calculate and update the total operation counts for an AdaptiveAvgPool layer."""
"""Calculate and update the total operation counts for an AdaptiveAvgPool layer using kernel and element counts."""
kernel = torch.div(torch.DoubleTensor([*(x[0].shape[2:])]), torch.DoubleTensor([*(y.shape[2:])]))
total_add = torch.prod(kernel)
num_elements = y.numel()
Expand All @@ -121,7 +119,7 @@ def count_adap_avgpool(m, x, y):

# TODO: verify the accuracy
def count_upsample(m, x, y):
"""Update the total operations counter in the given module for supported upsampling modes."""
"""Update total operations counter for upsampling layers based on the mode used."""
if m.mode not in (
"nearest",
"linear",
Expand All @@ -137,9 +135,7 @@ def count_upsample(m, x, y):

# nn.Linear
def count_linear(m, x, y):
"""Counts total operations for nn.Linear layers by calculating multiplications and additions based on input and
output elements.
"""
"""Counts total operations for nn.Linear layers using input and output element dimensions."""
total_mul = m.in_features
# total_add = m.in_features - 1
# total_add += 1 if m.bias is not None else 0
Expand Down
Loading

0 comments on commit c8ac820

Please sign in to comment.