Skip to content

Commit 04535db

Browse files
Merge pull request #27 from abhhfcgjk/develop
GNNGuard and Nettack on a few nodes
2 parents c7e42d2 + 92148c8 commit 04535db

11 files changed

+607
-67
lines changed

experiments/attack_defense_test.py

+100-46
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from src.base.datasets_processing import DatasetManager
1313
from src.models_builder.models_zoo import model_configs_zoo
1414
from attacks.QAttack import qattack
15+
from defense.JaccardDefense import jaccard_def
16+
from attacks.metattack import meta_gradient_attack
17+
from defense.GNNGuard import gnnguard
1518

1619

1720
def test_attack_defense():
@@ -23,6 +26,7 @@ def test_attack_defense():
2326
# full_name = ("multiple-graphs", "TUDataset", 'MUTAG')
2427
# full_name = ("single-graph", "custom", 'karate')
2528
full_name = ("single-graph", "Planetoid", 'Cora')
29+
# full_name = ("single-graph", "Planetoid", 'CiteSeer')
2630
# full_name = ("multiple-graphs", "TUDataset", 'PROTEINS')
2731

2832
dataset, data, results_dataset_path = DatasetManager.get_by_full_name(
@@ -113,7 +117,7 @@ def test_attack_defense():
113117
# }
114118
# )
115119

116-
poison_attack_config = ConfigPattern(
120+
metafull_poison_attack_config = ConfigPattern(
117121
_class_name="MetaAttackFull",
118122
_import_path=POISON_ATTACK_PARAMETERS_PATH,
119123
_config_class="PoisonAttackConfig",
@@ -122,68 +126,117 @@ def test_attack_defense():
122126
}
123127
)
124128

125-
# poison_attack_config = ConfigPattern(
126-
# _class_name="RandomPoisonAttack",
127-
# _import_path=POISON_ATTACK_PARAMETERS_PATH,
128-
# _config_class="PoisonAttackConfig",
129-
# _config_kwargs={
130-
# "n_edges_percent": 0.1,
131-
# }
132-
# )
129+
random_poison_attack_config = ConfigPattern(
130+
_class_name="RandomPoisonAttack",
131+
_import_path=POISON_ATTACK_PARAMETERS_PATH,
132+
_config_class="PoisonAttackConfig",
133+
_config_kwargs={
134+
"n_edges_percent": 0.5,
135+
}
136+
)
133137

134-
poison_defense_config = ConfigPattern(
138+
gnnguard_poison_defense_config = ConfigPattern(
135139
_class_name="GNNGuard",
136140
_import_path=POISON_DEFENSE_PARAMETERS_PATH,
137141
_config_class="PoisonDefenseConfig",
138142
_config_kwargs={
139-
"n_edges_percent": 0.1,
143+
"lr": 0.01,
144+
"train_iters": 100,
145+
# "model": gnn_model_manager.gnn
140146
}
141147
)
142148

149+
jaccard_poison_defense_config = ConfigPattern(
150+
_class_name="JaccardDefender",
151+
_import_path=POISON_DEFENSE_PARAMETERS_PATH,
152+
_config_class="PoisonDefenseConfig",
153+
_config_kwargs={
154+
"threshold": 0.05,
155+
}
156+
)
143157

144-
evasion_attack_config = ConfigPattern(
158+
qattack_evasion_attack_config = ConfigPattern(
145159
_class_name="QAttack",
146160
_import_path=EVASION_ATTACK_PARAMETERS_PATH,
147161
_config_class="EvasionAttackConfig",
148162
_config_kwargs={
149-
"population_size": 50,
150-
"individual_size": 30,
151-
"generations": 50,
163+
"population_size": 500,
164+
"individual_size": 100,
165+
"generations": 100,
152166
"prob_cross": 0.5,
153167
"prob_mutate": 0.02
154168
}
155169
)
156-
# evasion_attack_config = ConfigPattern(
157-
# _class_name="FGSM",
158-
# _import_path=EVASION_ATTACK_PARAMETERS_PATH,
159-
# _config_class="EvasionAttackConfig",
160-
# _config_kwargs={
161-
# "epsilon": 0.01 * 1,
162-
# }
163-
# )
164170

165-
# evasion_defense_config = ConfigPattern(
166-
# _class_name="GradientRegularizationDefender",
167-
# _import_path=EVASION_DEFENSE_PARAMETERS_PATH,
168-
# _config_class="EvasionDefenseConfig",
169-
# _config_kwargs={
170-
# "regularization_strength": 0.1 * 10
171-
# }
172-
# )
173-
evasion_defense_config = ConfigPattern(
171+
fgsm_evasion_attack_config = ConfigPattern(
172+
_class_name="FGSM",
173+
_import_path=EVASION_ATTACK_PARAMETERS_PATH,
174+
_config_class="EvasionAttackConfig",
175+
_config_kwargs={
176+
"epsilon": 0.01 * 1,
177+
}
178+
)
179+
180+
netattack_evasion_attack_config = ConfigPattern(
181+
_class_name="NettackEvasionAttacker",
182+
_import_path=EVASION_ATTACK_PARAMETERS_PATH,
183+
_config_class="EvasionAttackConfig",
184+
_config_kwargs={
185+
"node_idx": 0, # Node for attack
186+
"n_perturbations": 20,
187+
"perturb_features": True,
188+
"perturb_structure": True,
189+
"direct": True,
190+
"n_influencers": 3
191+
}
192+
)
193+
194+
netattackgroup_evasion_attack_config = ConfigPattern(
195+
_class_name="NettackGroupEvasionAttacker",
196+
_import_path=EVASION_ATTACK_PARAMETERS_PATH,
197+
_config_class="EvasionAttackConfig",
198+
_config_kwargs={
199+
"node_idxs": [random.randint(0, 500) for _ in range(20)], # Nodes for attack
200+
"n_perturbations": 50,
201+
"perturb_features": True,
202+
"perturb_structure": True,
203+
"direct": True,
204+
"n_influencers": 10
205+
}
206+
)
207+
208+
gradientregularization_evasion_defense_config = ConfigPattern(
209+
_class_name="GradientRegularizationDefender",
210+
_import_path=EVASION_DEFENSE_PARAMETERS_PATH,
211+
_config_class="EvasionDefenseConfig",
212+
_config_kwargs={
213+
"regularization_strength": 0.1 * 10
214+
}
215+
)
216+
217+
218+
fgsm_evasion_attack_config0 = ConfigPattern(
219+
_class_name="FGSM",
220+
_import_path=EVASION_ATTACK_PARAMETERS_PATH,
221+
_config_class="EvasionAttackConfig",
222+
_config_kwargs={
223+
"epsilon": 0.1 * 1,
224+
}
225+
)
226+
at_evasion_defense_config = ConfigPattern(
174227
_class_name="AdvTraining",
175228
_import_path=EVASION_DEFENSE_PARAMETERS_PATH,
176229
_config_class="EvasionDefenseConfig",
177230
_config_kwargs={
178231
"attack_name": None,
179-
"attack_config": evasion_attack_config # evasion_attack_config
232+
"attack_config": fgsm_evasion_attack_config0 # evasion_attack_config
180233
}
181234
)
182235

183-
# gnn_model_manager.set_poison_attacker(poison_attack_config=poison_attack_config)
184-
# gnn_model_manager.set_poison_defender(poison_defense_config=poison_defense_config)
185-
gnn_model_manager.set_evasion_attacker(evasion_attack_config=evasion_attack_config)
186-
# gnn_model_manager.set_evasion_defender(evasion_defense_config=evasion_defense_config)
236+
# gnn_model_manager.set_poison_attacker(poison_attack_config=random_poison_attack_config)
237+
# gnn_model_manager.set_poison_defender(poison_defense_config=gnnguard_poison_defense_config)
238+
gnn_model_manager.set_evasion_attacker(evasion_attack_config=netattackgroup_evasion_attack_config)
239+
# gnn_model_manager.set_evasion_defender(evasion_defense_config=at_evasion_defense_config)
187240

188241
warnings.warn("Start training")
189242
dataset.train_test_split()
@@ -207,7 +260,8 @@ def test_attack_defense():
207260
warnings.warn("Training was successful")
208261

209262
metric_loc = gnn_model_manager.evaluate_model(
210-
gen_dataset=dataset, metrics=[Metric("F1", mask='test', average='macro')])
263+
gen_dataset=dataset, metrics=[Metric("F1", mask='test', average='macro'),
264+
Metric("Accuracy", mask='test')])
211265
print(metric_loc)
212266

213267
def test_meta():
@@ -326,12 +380,12 @@ def test_nettack_evasion():
326380
acc_test_loc = gnn_model_manager.evaluate_model(gen_dataset=dataset,
327381
metrics=[Metric("Accuracy", mask=mask_loc)])[mask_loc]['Accuracy']
328382

329-
acc_train = gnn_model_manager.evaluate_model(gen_dataset=dataset,
330-
metrics=[Metric("Accuracy", mask='train')])['train']['Accuracy']
331-
acc_test = gnn_model_manager.evaluate_model(gen_dataset=dataset,
332-
metrics=[Metric("Accuracy", mask='test')])['test']['Accuracy']
383+
# acc_train = gnn_model_manager.evaluate_model(gen_dataset=dataset,
384+
# metrics=[Metric("Accuracy", mask='train')])['train']['Accuracy']
385+
# acc_test = gnn_model_manager.evaluate_model(gen_dataset=dataset,
386+
# metrics=[Metric("Accuracy", mask='test')])['test']['Accuracy']
333387

334-
print(f"Accuracy on train: {acc_train}. Accuracy on test: {acc_test}")
388+
# print(f"Accuracy on train: {acc_train}. Accuracy on test: {acc_test}")
335389
print(f"Accuracy on test loc: {acc_test_loc}")
336390

337391
# Model prediction on a node before an evasion attack on it
@@ -724,8 +778,8 @@ def test_adv_training():
724778
if __name__ == '__main__':
725779
import random
726780
random.seed(10)
727-
test_attack_defense()
728-
torch.manual_seed(5000)
729-
# test_adv_training()
781+
#test_attack_defense()
782+
# torch.manual_seed(5000)
730783
# test_gnnguard()
731784
# test_jaccard()
785+
test_attack_defense()

metainfo/evasion_attack_parameters.json

+8
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@
1212
"direct": ["direct", "bool", true, {}, "Indicates whether to directly modify edges/features of the node attacked or only those of influencers"],
1313
"n_influencers": ["n_influencers", "int", 0, {"min": 0, "step": 1}, "Number of influencing nodes. Will be ignored if direct is True"]
1414
},
15+
"NettackGroupEvasionAttacker": {
16+
17+
"n_perturbations": ["n_perturbations", "int", null, {"min": 0, "step": 1}, "Number of perturbations. If None, then n_perturbations = degree(node_idx)"],
18+
"perturb_features": ["perturb_features", "bool", true, {}, "Indicates whether the features can be changed"],
19+
"perturb_structure": ["perturb_structure", "bool", true, {}, "Indicates whether the structure can be changed"],
20+
"direct": ["direct", "bool", true, {}, "Indicates whether to directly modify edges/features of the node attacked or only those of influencers"],
21+
"n_influencers": ["n_influencers", "int", 0, {"min": 0, "step": 1}, "Number of influencing nodes. Will be ignored if direct is True"]
22+
},
1523
"QAttack": {
1624
"population_size": ["Population size", "int", 50, {"min": 1, "step": 1}, "Number of genes in population"],
1725
"individual_size": ["Individual size", "int", 30, {"min": 1, "step": 1}, "Number of rewiring operations within one gene"],

metainfo/modules_parameters.json

+7-7
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
{
1111
"import_info": ["GCNConv", ["torch_geometric.nn"]],
1212
"need_full_gnn_flag": false,
13-
"forward_parameters": "x=x, edge_index=edge_index"
13+
"forward_parameters": "x=x, edge_index=edge_index, edge_weight=edge_weight"
1414
}
1515
},
1616
"SAGEConv": {
@@ -53,7 +53,7 @@
5353
{
5454
"import_info": ["SGConv", ["torch_geometric.nn"]],
5555
"need_full_gnn_flag": false,
56-
"forward_parameters": "x=x, edge_index=edge_index"
56+
"forward_parameters": "x=x, edge_index=edge_index, edge_weight=edge_weight"
5757
}
5858
},
5959
"GINConv": {
@@ -74,7 +74,7 @@
7474
{
7575
"import_info": ["TAGConv", ["torch_geometric.nn"]],
7676
"need_full_gnn_flag": false,
77-
"forward_parameters": "x=x, edge_index=edge_index"
77+
"forward_parameters": "x=x, edge_index=edge_index, edge_weight=edge_weight"
7878
}
7979
},
8080
"ARMAConv": {
@@ -88,7 +88,7 @@
8888
{
8989
"import_info": ["ARMAConv", ["torch_geometric.nn"]],
9090
"need_full_gnn_flag": false,
91-
"forward_parameters": "x=x, edge_index=edge_index"
91+
"forward_parameters": "x=x, edge_index=edge_index, edge_weight=edge_weight"
9292
}
9393
},
9494
"SSGConv": {
@@ -102,7 +102,7 @@
102102
{
103103
"import_info": ["SSGConv", ["torch_geometric.nn"]],
104104
"need_full_gnn_flag": false,
105-
"forward_parameters": "x=x, edge_index=edge_index"
105+
"forward_parameters": "x=x, edge_index=edge_index, edge_weight=edge_weight"
106106
}
107107
},
108108
"GMM": {
@@ -130,7 +130,7 @@
130130
{
131131
"import_info": ["CGConv", ["torch_geometric.nn"]],
132132
"need_full_gnn_flag": false,
133-
"forward_parameters": "x=x, edge_index=edge_index"
133+
"forward_parameters": "x=x, edge_index=edge_index, edge_weight=edge_weight"
134134
}
135135
},
136136
"APPNP": {
@@ -144,7 +144,7 @@
144144
{
145145
"import_info": ["APPNP", ["torch_geometric.nn"]],
146146
"need_full_gnn_flag": false,
147-
"forward_parameters": "x=x, edge_index=edge_index"
147+
"forward_parameters": "x=x, edge_index=edge_index, edge_weight=edge_weight"
148148
}
149149
},
150150
"Linear": {

metainfo/poison_defense_parameters.json

+5-4
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
"BadRandomPoisonDefender": {
55
"n_edges_percent": ["n_edges_percent", "float", 0.1, {"min": 0.0001, "step": 0.01}, "?"]
66
},
7+
"JaccardDefender": {
8+
"threshold": ["Edge Threshold", "float", 0.35, {"min": 0, "max": 1, "step": 0.01}, "Jaccard index threshold for dropping edges"]
9+
},
710
"GNNGuard": {
811
"lr": ["lr", "float", 0.01, {"min": 0.0001, "step": 0.005}, "?"],
912
"attention": ["attention", "bool", true, {}, "?"],
10-
"drop": ["drop", "bool", true, {}, "?"]
11-
},
12-
"JaccardDefender": {
13-
"threshold": ["Edge Threshold", "float", 0.35, {"min": 0, "max": 1, "step": 0.01}, "Jaccard index threshold for dropping edges"]
13+
"drop": ["drop", "bool", true, {}, "?"],
14+
"train_iters": ["train_iters", "int", 50, {}, "?"]
1415
}
1516
}
1617

src/attacks/QAttack/qattack.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def fitness_individual(self, model, gen_dataset, gene):
7878
# Get labels from black-box
7979
labels = model.gnn.get_answer(dataset.x, dataset.edge_index)
8080
labeled_nodes = dict(enumerate(labels.tolist()))
81+
# labeled_nodes = {n: labels.tolist()[n-1] for n in adj_list.keys()} # FIXME check order for labels and node id consistency
8182

8283
# Calculate modularity
8384
Q = self.modularity(adj_list, labeled_nodes)
@@ -204,6 +205,7 @@ def mutation(self, gen_dataset):
204205
self.population[i][n]['del'] = np.random.choice(list(adj_list[n]), 1)
205206
else:
206207
selected_nodes = set(self.population[i].keys())
208+
#non_selected_nodes = non_isolated_nodes.difference(selected_nodes)
207209
non_drain_nodes = non_drain_nodes.difference(selected_nodes)
208210
new_node = np.random.choice(list(non_drain_nodes), size=1, replace=False)[0]
209211
self.population[i].pop(n)
@@ -238,4 +240,4 @@ def attack(self, model_manager, gen_dataset, mask_tensor):
238240
set(adj_list[n]).union(set([int(rewiring[n]['add'])])).difference(set([int(rewiring[n]['del'])])))
239241

240242
gen_dataset.dataset.data.edge_index = from_adj_list(adj_list)
241-
return gen_dataset
243+
return gen_dataset

src/attacks/evasion_attacks.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -137,4 +137,22 @@ def _evasion(gen_dataset, feature_perturbations, structure_perturbations):
137137
torch.tensor((edge[1], edge[0]), dtype=torch.int32).to(torch.int64).unsqueeze(1)), dim=1)
138138

139139
gen_dataset.data.edge_index = edge_index
140-
140+
141+
class NettackGroupEvasionAttacker(EvasionAttacker):
142+
name = "NettackGroupEvasionAttacker"
143+
def __init__(self,node_idxs, **kwargs):
144+
super().__init__()
145+
self.node_idxs = node_idxs # kwargs.get("node_idxs")
146+
assert isinstance(self.node_idxs, list)
147+
self.n_perturbations = kwargs.get("n_perturbations")
148+
self.perturb_features = kwargs.get("perturb_features")
149+
self.perturb_structure = kwargs.get("perturb_structure")
150+
self.direct = kwargs.get("direct")
151+
self.n_influencers = kwargs.get("n_influencers")
152+
self.attacker = NettackEvasionAttacker(0, **kwargs)
153+
154+
def attack(self, model_manager, gen_dataset, mask_tensor):
155+
for node_idx in self.node_idxs:
156+
self.attacker.node_idx = node_idx
157+
gen_dataset = self.attacker.attack(model_manager, gen_dataset, mask_tensor)
158+
return gen_dataset

0 commit comments

Comments
 (0)