From 28896b93f236b30ab999287a2d3db71a41e4ccf3 Mon Sep 17 00:00:00 2001 From: mikhail Date: Tue, 29 Oct 2024 17:26:39 +0300 Subject: [PATCH] add sign and some fixes --- experiments/attack_defense_test.py | 1 - src/attacks/evasion_attacks.py | 9 +++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/experiments/attack_defense_test.py b/experiments/attack_defense_test.py index 9ecf663..190f23d 100644 --- a/experiments/attack_defense_test.py +++ b/experiments/attack_defense_test.py @@ -868,7 +868,6 @@ def test_pgd(): "predicted_class": predicted_class, "predicted_probability": predicted_probability, "real_class": real_class} - # ____________________________________________________________ # ______________________ Attack on graph _____________________ diff --git a/src/attacks/evasion_attacks.py b/src/attacks/evasion_attacks.py index 02a5e10..47a65a5 100644 --- a/src/attacks/evasion_attacks.py +++ b/src/attacks/evasion_attacks.py @@ -111,8 +111,7 @@ def _attack_on_node(self, model_manager, gen_dataset): optimizer.step() with torch.no_grad(): x.copy_(torch.max(torch.min(x, orig_x + self.epsilon), orig_x - self.epsilon)) - x_clamp = torch.clamp(x, -self.epsilon, self.epsilon) - x.copy_(x_clamp) + x.copy_(torch.clamp(x, -self.epsilon, self.epsilon)) # return the modified lines back to the original tensor x gen_dataset.data.x[subset] = x.detach() self.attack_diff = gen_dataset @@ -130,6 +129,7 @@ def _attack_on_graph(self, model_manager, gen_dataset): if self.is_feature_attack: # feature attack x = x.clone() + orig_x = x.clone() x.requires_grad = True optimizer = torch.optim.Adam([x], lr=self.learning_rate, weight_decay=5e-4) @@ -139,10 +139,11 @@ def _attack_on_graph(self, model_manager, gen_dataset): # print(loss) model.zero_grad() loss.backward() + x.grad.sign_() optimizer.step() with torch.no_grad(): - x_clamp = torch.clamp(x, 0, self.epsilon) - x.copy_(x_clamp) + x.copy_(torch.max(torch.min(x, orig_x + self.epsilon), orig_x - self.epsilon)) + x.copy_(torch.clamp(x, -self.epsilon, self.epsilon)) gen_dataset.dataset[graph_idx].x.copy_(x.detach()) self.attack_diff = gen_dataset else: # structure attack