Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
UltralyticsAssistant committed May 31, 2024
1 parent f1a42d7 commit 451f0f7
Show file tree
Hide file tree
Showing 11 changed files with 49 additions and 16 deletions.
4 changes: 3 additions & 1 deletion tests/test_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

class TestUtils:
def test_conv2d_no_bias(self):
"""Tests a 2D convolutional layer without bias using THOP profiling with predefined input dimensions and convolution parameters."""
"""Tests a 2D convolutional layer without bias using THOP profiling with predefined input dimensions and
convolution parameters.
"""
n, in_c, ih, iw = 1, 3, 32, 32 # torch.randint(1, 10, (4,)).tolist()
out_c, kh, kw = 12, 5, 5
s, p, d, g = 1, 1, 1, 1
Expand Down
4 changes: 3 additions & 1 deletion tests/test_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

class TestUtils:
def test_relu(self):
"""Tests the ReLU activation function to ensure it has zero FLOPs and checks parameter count using THOP profiling."""
"""Tests the ReLU activation function to ensure it has zero FLOPs and checks parameter count using THOP
profiling.
"""
n, in_c, out_c = 1, 100, 200
data = torch.randn(n, in_c)
net = nn.ReLU()
Expand Down
4 changes: 3 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ def test_clever_format_returns_formatted_number(self):
assert clever_nums == "1.00B"

def test_clever_format_returns_formatted_numbers(self):
"""Tests that the clever_format function correctly formats a list of numbers as strings with a '1.00B' pattern."""
"""Tests that the clever_format function correctly formats a list of numbers as strings with a '1.00B'
pattern.
"""
nums = [1, 2]
format = "%.2f"
clever_nums = utils.clever_format(nums, format)
Expand Down
8 changes: 6 additions & 2 deletions thop/fx_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ def null_print(*args, **kwargs):


def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False):
"""Profiles the given torch.nn Module to calculate total FLOPs for each operation and prints detailed node information if verbose."""
"""Profiles the given torch.nn Module to calculate total FLOPs for each operation and prints detailed node
information if verbose.
"""
gm: torch.fx.GraphModule = symbolic_trace(mod)
g = gm.graph
ShapeProp(gm).propagate(input)
Expand Down Expand Up @@ -226,7 +228,9 @@ def __init__(self):
self.myop = MyOP()

def forward(self, x):
"""Applies two linear transformations to the input tensor, clamps the second, then combines and processes with MyOP operator."""
"""Applies two linear transformations to the input tensor, clamps the second, then combines and processes
with MyOP operator.
"""
out1 = self.linear1(x)
out2 = self.linear2(x).clamp(min=0.0, max=1.0)
return self.myop(out1 + out2)
Expand Down
4 changes: 3 additions & 1 deletion thop/onnx_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def create_dict(self, weight, input, output):
return diction

def nodes_counter(self, diction, node):
"""Count nodes of a specific type in an ONNX graph, returning the count and associated node operation details."""
"""Count nodes of a specific type in an ONNX graph, returning the count and associated node operation
details.
"""
if node.op_type not in onnx_operators:
print("Sorry, we haven't add ", node.op_type, "into dictionary.")
return 0, None, None
Expand Down
4 changes: 3 additions & 1 deletion thop/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@


def profile_origin(model, inputs, custom_ops=None, verbose=True, report_missing=False):
"""Profiles a PyTorch model's operations and parameters by applying custom or default hooks and returns total operations and parameters."""
"""Profiles a PyTorch model's operations and parameters by applying custom or default hooks and returns total
operations and parameters.
"""
handler_collection = []
types_collection = set()
if custom_ops is None:
Expand Down
12 changes: 9 additions & 3 deletions thop/rnn_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ def count_gru_cell(m: nn.GRUCell, x: torch.Tensor, y: torch.Tensor):


def _count_lstm_cell(input_size, hidden_size, bias=True):
"""Calculates the total operations for an LSTM cell during inference given input size, hidden size, and optional bias."""
"""Calculates the total operations for an LSTM cell during inference given input size, hidden size, and optional
bias.
"""
total_ops = 0

# i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\
Expand All @@ -80,7 +82,9 @@ def _count_lstm_cell(input_size, hidden_size, bias=True):


def count_lstm_cell(m: nn.LSTMCell, x: torch.Tensor, y: torch.Tensor):
"""Count the number of operations for a single LSTM cell in a given batch, updating the model's total operations count."""
"""Count the number of operations for a single LSTM cell in a given batch, updating the model's total operations
count.
"""
total_ops = _count_lstm_cell(m.input_size, m.hidden_size, m.bias)

batch_size = x[0].size(0)
Expand Down Expand Up @@ -166,7 +170,9 @@ def count_gru(m: nn.GRU, x, y):


def count_lstm(m: nn.LSTM, x, y):
"""Calculate the total operations for LSTM layers in a network, accounting for input size, hidden size, bias, and bidirectionality."""
"""Calculate the total operations for LSTM layers in a network, accounting for input size, hidden size, bias, and
bidirectionality.
"""
bias = m.bias
input_size = m.input_size
hidden_size = m.hidden_size
Expand Down
1 change: 1 addition & 0 deletions thop/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

def colorful_print(fn_print, color=COLOR_RED):
"""A decorator to print text in the specified terminal color by wrapping the given print function."""

def actual_call(*args, **kwargs):
print(f"\033[{color}", end="")
fn_print(*args, **kwargs)
Expand Down
8 changes: 6 additions & 2 deletions thop/vision/basic_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def count_convNd_ver2(m: _ConvNd, x, y: torch.Tensor):


def count_normalization(m: nn.modules.batchnorm._BatchNorm, x, y):
"""Calculate and add the FLOPs for a batch normalization layer, considering elementwise operations and possible affine parameters."""
"""Calculate and add the FLOPs for a batch normalization layer, considering elementwise operations and possible
affine parameters.
"""
# https://github.com/Lyken17/pytorch-OpCounter/issues/124
# y = (x - mean) / sqrt(eps + var) * weight + bias
x = x[0]
Expand Down Expand Up @@ -146,7 +148,9 @@ def count_upsample(m, x, y):

# nn.Linear
def count_linear(m, x, y):
"""Counts total operations for nn.Linear layers by calculating multiplications and additions based on input and output elements."""
"""Counts total operations for nn.Linear layers by calculating multiplications and additions based on input and
output elements.
"""
total_mul = m.in_features
# total_add = m.in_features - 1
# total_add += 1 if m.bias is not None else 0
Expand Down
4 changes: 3 additions & 1 deletion thop/vision/calc_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def calculate_relu(input_size: torch.Tensor):


def calculate_softmax(batch_size, nfeatures):
"""Calculate the number of FLOPs required for a softmax activation function based on batch size and number of features."""
"""Calculate the number of FLOPs required for a softmax activation function based on batch size and number of
features.
"""
total_exp = nfeatures
total_add = nfeatures - 1
total_div = nfeatures
Expand Down
12 changes: 9 additions & 3 deletions thop/vision/onnx_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def onnx_counter_add(diction, node):


def onnx_counter_conv(diction, node):
"""Calculates MACs, output size, and name for an ONNX convolution node based on input tensor dimensions and node attributes."""
"""Calculates MACs, output size, and name for an ONNX convolution node based on input tensor dimensions and node
attributes.
"""
# bias,kernelsize,outputsize
dim_bias = 0
input_count = 0
Expand Down Expand Up @@ -125,7 +127,9 @@ def onnx_counter_relu(diction, node):


def onnx_counter_reducemean(diction, node):
"""Compute MACs, output size, and name for the ReduceMean ONNX node, adjusting dimensions based on the 'axes' and 'keepdims' attributes."""
"""Compute MACs, output size, and name for the ReduceMean ONNX node, adjusting dimensions based on the 'axes' and
'keepdims' attributes.
"""
keep_dim = 0
for attr in node.attribute:
if "axes" in attr.name:
Expand Down Expand Up @@ -330,7 +334,9 @@ def onnx_counter_concat(diction, node):


def onnx_counter_clip(diction, node):
"""Calculate MACs, output size, and output name for an ONNX node clip operation using provided dimensions and input size."""
"""Calculate MACs, output size, and output name for an ONNX node clip operation using provided dimensions and input
size.
"""
macs = calculate_zero_ops()
output_name = node.output[0]
input_size = diction[node.input[0]]
Expand Down

0 comments on commit 451f0f7

Please sign in to comment.