Skip to content

Commit

Permalink
Code Refactor ruff check --fix --extend-select I (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Jun 16, 2024
1 parent 8a0dd6b commit 014798b
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 14 deletions.
3 changes: 1 addition & 2 deletions tests/test_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ def test_relu(self):
"""Tests the ReLU activation function to ensure it has zero FLOPs and checks parameter count using THOP
profiling.
"""
n, in_c, out_c = 1, 100, 200
data = torch.randn(n, in_c)
n, in_c, _out_c = 1, 100, 200
net = nn.ReLU()
flops, params = profile(net, inputs=(torch.randn(n, in_c),))
print(flops, params)
Expand Down
7 changes: 3 additions & 4 deletions thop/fx_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def count_matmul(input_shapes, output_shapes):

def count_fn_linear(input_shapes, output_shapes, *args, **kwargs):
"""Calculates total operations (FLOPs) for a linear layer given input and output shapes."""
mul_flops = count_matmul(input_shapes, output_shapes)
flops = count_matmul(input_shapes, output_shapes)
if "bias" in kwargs:
add_flops = output_shapes[0].numel()
return mul_flops
flops += output_shapes[0].numel()
return flops


from .vision.calc_func import calculate_conv
Expand Down Expand Up @@ -131,7 +131,6 @@ def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False):
information if verbose.
"""
gm: torch.fx.GraphModule = symbolic_trace(mod)
g = gm.graph
ShapeProp(gm).propagate(input)

fprint = null_print
Expand Down
9 changes: 1 addition & 8 deletions thop/vision/basic_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

def count_parameters(m, x, y):
"""Calculate and update the total number of parameters in a given PyTorch model."""
total_params = sum(torch.DoubleTensor([p.numel()]) for p in m.parameters())
m.total_params[0] = calculate_parameters(m.parameters())


Expand All @@ -22,10 +21,7 @@ def zero_ops(m, x, y):
def count_convNd(m: _ConvNd, x, y: torch.Tensor):
"""Calculate and add the number of convolutional operations (FLOPs) to the model's total operations count."""
x = x[0]

kernel_ops = torch.zeros(m.weight.size()[2:]).numel() # Kw x Kh
bias_ops = 1 if m.bias is not None else 0


m.total_ops += calculate_conv2d_flops(
input_size=list(x.shape),
output_size=list(y.shape),
Expand Down Expand Up @@ -95,9 +91,6 @@ def count_prelu(m, x, y):
def count_relu(m, x, y):
"""Calculate and update the total operation counts for a ReLU layer."""
x = x[0]

nelements = x.numel()

m.total_ops += calculate_relu_flops(list(x.shape))


Expand Down

0 comments on commit 014798b

Please sign in to comment.