Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
UltralyticsAssistant committed Jun 30, 2024
1 parent c8ac820 commit 04ca8ce
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 2 deletions.
3 changes: 2 additions & 1 deletion benchmark/evaluate_famous_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
from thop.profile import profile
from torchvision import models

from thop.profile import profile

model_names = sorted(
name
for name in models.__dict__
Expand Down
1 change: 1 addition & 0 deletions benchmark/evaluate_rnn_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn

from thop.profile import profile

input_size = 160
Expand Down
1 change: 1 addition & 0 deletions tests/test_conv2d.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn

from thop import profile


Expand Down
1 change: 1 addition & 0 deletions tests/test_matmul.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn

from thop import profile


Expand Down
1 change: 1 addition & 0 deletions tests/test_relu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn

from thop import profile


Expand Down
4 changes: 3 additions & 1 deletion thop/fx_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def count_fn_linear(input_shapes, output_shapes, *args, **kwargs):


def count_fn_conv2d(input_shapes, output_shapes, *args, **kwargs):
"""Calculates total operations (FLOPs) for a 2D conv layer based on input and output shapes using `calculate_conv`."""
"""Calculates total operations (FLOPs) for a 2D conv layer based on input and output shapes using
`calculate_conv`.
"""
inputs, weight, bias, stride, padding, dilation, groups = args
if len(input_shapes) == 2:
x_shape, k_shape = input_shapes
Expand Down

0 comments on commit 04ca8ce

Please sign in to comment.