Skip to content

Commit

Permalink
Code Refactor for Speed and Readability (#21)
Browse files Browse the repository at this point in the history
Co-authored-by: Glenn Jocher <[email protected]>
Co-authored-by: UltralyticsAssistant <[email protected]>
  • Loading branch information
3 people committed Jun 9, 2024
1 parent a879079 commit 54128bd
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 14 deletions.
7 changes: 2 additions & 5 deletions benchmark/evaluate_famous_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,10 @@
and callable(models.__dict__[name])
)

print("%s | %s | %s" % ("Model", "Params(M)", "FLOPs(G)"))
print("Model | Params(M) | FLOPs(G)")
print("---|---|---")

device = "cpu"
if torch.cuda.is_available():
device = "cuda"

device = "cuda" if torch.cuda.is_available() else "cpu"
for name in model_names:
try:
model = models.__dict__[name]().to(device)
Expand Down
2 changes: 1 addition & 1 deletion benchmark/evaluate_rnn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"stacked-BiLSTM": nn.Sequential(nn.LSTM(input_size, hidden_size, bidirectional=True, num_layers=4)),
}

print("{} | {} | {}".format("Model", "Params(M)", "FLOPs(G)"))
print("Model | Params(M) | FLOPs(G)")
print("---|---|---")

for name, model in models.items():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_matmul_case2(self):

def test_matmul_case2(self):
"""Tests matrix multiplication to assert FLOPs and parameters of nn.Linear layer using random dimensions."""
for i in range(10):
for _ in range(10):
n, in_c, out_c = torch.randint(1, 500, (3,)).tolist()
net = nn.Linear(in_c, out_c)
flops, params = profile(net, inputs=(torch.randn(n, in_c),))
Expand Down
6 changes: 2 additions & 4 deletions thop/fx_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False):
node_flops = None

input_shapes = []
output_shapes = []
fprint("input_shape:", end="\t")
for arg in node.args:
if str(arg) not in v_maps:
Expand All @@ -157,8 +156,7 @@ def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False):
input_shapes.append(v_maps[str(arg)])
fprint()
fprint(f"output_shape:\t{node.meta['tensor_meta'].shape}")
output_shapes.append(node.meta["tensor_meta"].shape)

output_shapes = [node.meta["tensor_meta"].shape]
if node.op in ["output", "placeholder"]:
node_flops = 0
elif node.op == "call_function":
Expand Down Expand Up @@ -194,7 +192,7 @@ def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False):
print("weight_shape: None")
else:
print(type(m))
print(f"weight_shape: {mod.state_dict()[node.target + '.weight'].shape}")
print(f"weight_shape: {mod.state_dict()[f'{node.target}.weight'].shape}")

v_maps[str(node.name)] = node.meta["tensor_meta"].shape
if node_flops is not None:
Expand Down
4 changes: 1 addition & 3 deletions thop/vision/basic_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@

def count_parameters(m, x, y):
"""Calculate and update the total number of parameters in a given PyTorch model."""
total_params = 0
for p in m.parameters():
total_params += torch.DoubleTensor([p.numel()])
total_params = sum(torch.DoubleTensor([p.numel()]) for p in m.parameters())
m.total_params[0] = calculate_parameters(m.parameters())


Expand Down

0 comments on commit 54128bd

Please sign in to comment.