From 54128bd5460a95e3af0ff08ee9dc0e7827ef7dea Mon Sep 17 00:00:00 2001 From: Paula Derrenger <107626595+pderrenger@users.noreply.github.com> Date: Sun, 9 Jun 2024 05:09:53 +0200 Subject: [PATCH] Code Refactor for Speed and Readability (#21) Co-authored-by: Glenn Jocher Co-authored-by: UltralyticsAssistant --- benchmark/evaluate_famous_models.py | 7 ++----- benchmark/evaluate_rnn_models.py | 2 +- tests/test_matmul.py | 2 +- thop/fx_profile.py | 6 ++---- thop/vision/basic_hooks.py | 4 +--- 5 files changed, 7 insertions(+), 14 deletions(-) diff --git a/benchmark/evaluate_famous_models.py b/benchmark/evaluate_famous_models.py index 451f97e..f641bda 100644 --- a/benchmark/evaluate_famous_models.py +++ b/benchmark/evaluate_famous_models.py @@ -11,13 +11,10 @@ and callable(models.__dict__[name]) ) -print("%s | %s | %s" % ("Model", "Params(M)", "FLOPs(G)")) +print("Model | Params(M) | FLOPs(G)") print("---|---|---") -device = "cpu" -if torch.cuda.is_available(): - device = "cuda" - +device = "cuda" if torch.cuda.is_available() else "cpu" for name in model_names: try: model = models.__dict__[name]().to(device) diff --git a/benchmark/evaluate_rnn_models.py b/benchmark/evaluate_rnn_models.py index 44cd3fb..d1fad55 100644 --- a/benchmark/evaluate_rnn_models.py +++ b/benchmark/evaluate_rnn_models.py @@ -24,7 +24,7 @@ "stacked-BiLSTM": nn.Sequential(nn.LSTM(input_size, hidden_size, bidirectional=True, num_layers=4)), } -print("{} | {} | {}".format("Model", "Params(M)", "FLOPs(G)")) +print("Model | Params(M) | FLOPs(G)") print("---|---|---") for name, model in models.items(): diff --git a/tests/test_matmul.py b/tests/test_matmul.py index 48391c9..83cc361 100644 --- a/tests/test_matmul.py +++ b/tests/test_matmul.py @@ -15,7 +15,7 @@ def test_matmul_case2(self): def test_matmul_case2(self): """Tests matrix multiplication to assert FLOPs and parameters of nn.Linear layer using random dimensions.""" - for i in range(10): + for _ in range(10): n, in_c, out_c = torch.randint(1, 500, (3,)).tolist() net = nn.Linear(in_c, out_c) flops, params = profile(net, inputs=(torch.randn(n, in_c),)) diff --git a/thop/fx_profile.py b/thop/fx_profile.py index 52a6c29..2d40b66 100644 --- a/thop/fx_profile.py +++ b/thop/fx_profile.py @@ -148,7 +148,6 @@ def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False): node_flops = None input_shapes = [] - output_shapes = [] fprint("input_shape:", end="\t") for arg in node.args: if str(arg) not in v_maps: @@ -157,8 +156,7 @@ def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False): input_shapes.append(v_maps[str(arg)]) fprint() fprint(f"output_shape:\t{node.meta['tensor_meta'].shape}") - output_shapes.append(node.meta["tensor_meta"].shape) - + output_shapes = [node.meta["tensor_meta"].shape] if node.op in ["output", "placeholder"]: node_flops = 0 elif node.op == "call_function": @@ -194,7 +192,7 @@ def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False): print("weight_shape: None") else: print(type(m)) - print(f"weight_shape: {mod.state_dict()[node.target + '.weight'].shape}") + print(f"weight_shape: {mod.state_dict()[f'{node.target}.weight'].shape}") v_maps[str(node.name)] = node.meta["tensor_meta"].shape if node_flops is not None: diff --git a/thop/vision/basic_hooks.py b/thop/vision/basic_hooks.py index 5a7897a..731b359 100644 --- a/thop/vision/basic_hooks.py +++ b/thop/vision/basic_hooks.py @@ -10,9 +10,7 @@ def count_parameters(m, x, y): """Calculate and update the total number of parameters in a given PyTorch model.""" - total_params = 0 - for p in m.parameters(): - total_params += torch.DoubleTensor([p.numel()]) + total_params = sum(torch.DoubleTensor([p.numel()]) for p in m.parameters()) m.total_params[0] = calculate_parameters(m.parameters())