12
12
from src .base .datasets_processing import DatasetManager
13
13
from src .models_builder .models_zoo import model_configs_zoo
14
14
from attacks .QAttack import qattack
15
+ from defense .JaccardDefense import jaccard_def
16
+ from attacks .metattack import meta_gradient_attack
17
+ from defense .GNNGuard import gnnguard
15
18
16
19
17
20
def test_attack_defense ():
@@ -23,6 +26,7 @@ def test_attack_defense():
23
26
# full_name = ("multiple-graphs", "TUDataset", 'MUTAG')
24
27
# full_name = ("single-graph", "custom", 'karate')
25
28
full_name = ("single-graph" , "Planetoid" , 'Cora' )
29
+ # full_name = ("single-graph", "Planetoid", 'CiteSeer')
26
30
# full_name = ("multiple-graphs", "TUDataset", 'PROTEINS')
27
31
28
32
dataset , data , results_dataset_path = DatasetManager .get_by_full_name (
@@ -113,7 +117,7 @@ def test_attack_defense():
113
117
# }
114
118
# )
115
119
116
- poison_attack_config = ConfigPattern (
120
+ metafull_poison_attack_config = ConfigPattern (
117
121
_class_name = "MetaAttackFull" ,
118
122
_import_path = POISON_ATTACK_PARAMETERS_PATH ,
119
123
_config_class = "PoisonAttackConfig" ,
@@ -122,68 +126,117 @@ def test_attack_defense():
122
126
}
123
127
)
124
128
125
- # poison_attack_config = ConfigPattern(
126
- # _class_name="RandomPoisonAttack",
127
- # _import_path=POISON_ATTACK_PARAMETERS_PATH,
128
- # _config_class="PoisonAttackConfig",
129
- # _config_kwargs={
130
- # "n_edges_percent": 0.1 ,
131
- # }
132
- # )
129
+ random_poison_attack_config = ConfigPattern (
130
+ _class_name = "RandomPoisonAttack" ,
131
+ _import_path = POISON_ATTACK_PARAMETERS_PATH ,
132
+ _config_class = "PoisonAttackConfig" ,
133
+ _config_kwargs = {
134
+ "n_edges_percent" : 0.5 ,
135
+ }
136
+ )
133
137
134
- poison_defense_config = ConfigPattern (
138
+ gnnguard_poison_defense_config = ConfigPattern (
135
139
_class_name = "GNNGuard" ,
136
140
_import_path = POISON_DEFENSE_PARAMETERS_PATH ,
137
141
_config_class = "PoisonDefenseConfig" ,
138
142
_config_kwargs = {
139
- "n_edges_percent" : 0.1 ,
143
+ "lr" : 0.01 ,
144
+ "train_iters" : 100 ,
145
+ # "model": gnn_model_manager.gnn
140
146
}
141
147
)
142
148
149
+ jaccard_poison_defense_config = ConfigPattern (
150
+ _class_name = "JaccardDefender" ,
151
+ _import_path = POISON_DEFENSE_PARAMETERS_PATH ,
152
+ _config_class = "PoisonDefenseConfig" ,
153
+ _config_kwargs = {
154
+ "threshold" : 0.05 ,
155
+ }
156
+ )
143
157
144
- evasion_attack_config = ConfigPattern (
158
+ qattack_evasion_attack_config = ConfigPattern (
145
159
_class_name = "QAttack" ,
146
160
_import_path = EVASION_ATTACK_PARAMETERS_PATH ,
147
161
_config_class = "EvasionAttackConfig" ,
148
162
_config_kwargs = {
149
- "population_size" : 50 ,
150
- "individual_size" : 30 ,
151
- "generations" : 50 ,
163
+ "population_size" : 500 ,
164
+ "individual_size" : 100 ,
165
+ "generations" : 100 ,
152
166
"prob_cross" : 0.5 ,
153
167
"prob_mutate" : 0.02
154
168
}
155
169
)
156
- # evasion_attack_config = ConfigPattern(
157
- # _class_name="FGSM",
158
- # _import_path=EVASION_ATTACK_PARAMETERS_PATH,
159
- # _config_class="EvasionAttackConfig",
160
- # _config_kwargs={
161
- # "epsilon": 0.01 * 1,
162
- # }
163
- # )
164
170
165
- # evasion_defense_config = ConfigPattern(
166
- # _class_name="GradientRegularizationDefender",
167
- # _import_path=EVASION_DEFENSE_PARAMETERS_PATH,
168
- # _config_class="EvasionDefenseConfig",
169
- # _config_kwargs={
170
- # "regularization_strength": 0.1 * 10
171
- # }
172
- # )
173
- evasion_defense_config = ConfigPattern (
171
+ fgsm_evasion_attack_config = ConfigPattern (
172
+ _class_name = "FGSM" ,
173
+ _import_path = EVASION_ATTACK_PARAMETERS_PATH ,
174
+ _config_class = "EvasionAttackConfig" ,
175
+ _config_kwargs = {
176
+ "epsilon" : 0.01 * 1 ,
177
+ }
178
+ )
179
+
180
+ netattack_evasion_attack_config = ConfigPattern (
181
+ _class_name = "NettackEvasionAttacker" ,
182
+ _import_path = EVASION_ATTACK_PARAMETERS_PATH ,
183
+ _config_class = "EvasionAttackConfig" ,
184
+ _config_kwargs = {
185
+ "node_idx" : 0 , # Node for attack
186
+ "n_perturbations" : 20 ,
187
+ "perturb_features" : True ,
188
+ "perturb_structure" : True ,
189
+ "direct" : True ,
190
+ "n_influencers" : 3
191
+ }
192
+ )
193
+
194
+ netattackgroup_evasion_attack_config = ConfigPattern (
195
+ _class_name = "NettackGroupEvasionAttacker" ,
196
+ _import_path = EVASION_ATTACK_PARAMETERS_PATH ,
197
+ _config_class = "EvasionAttackConfig" ,
198
+ _config_kwargs = {
199
+ "node_idxs" : [random .randint (0 , 500 ) for _ in range (20 )], # Nodes for attack
200
+ "n_perturbations" : 50 ,
201
+ "perturb_features" : True ,
202
+ "perturb_structure" : True ,
203
+ "direct" : True ,
204
+ "n_influencers" : 10
205
+ }
206
+ )
207
+
208
+ gradientregularization_evasion_defense_config = ConfigPattern (
209
+ _class_name = "GradientRegularizationDefender" ,
210
+ _import_path = EVASION_DEFENSE_PARAMETERS_PATH ,
211
+ _config_class = "EvasionDefenseConfig" ,
212
+ _config_kwargs = {
213
+ "regularization_strength" : 0.1 * 10
214
+ }
215
+ )
216
+
217
+
218
+ fgsm_evasion_attack_config0 = ConfigPattern (
219
+ _class_name = "FGSM" ,
220
+ _import_path = EVASION_ATTACK_PARAMETERS_PATH ,
221
+ _config_class = "EvasionAttackConfig" ,
222
+ _config_kwargs = {
223
+ "epsilon" : 0.1 * 1 ,
224
+ }
225
+ )
226
+ at_evasion_defense_config = ConfigPattern (
174
227
_class_name = "AdvTraining" ,
175
228
_import_path = EVASION_DEFENSE_PARAMETERS_PATH ,
176
229
_config_class = "EvasionDefenseConfig" ,
177
230
_config_kwargs = {
178
231
"attack_name" : None ,
179
- "attack_config" : evasion_attack_config # evasion_attack_config
232
+ "attack_config" : fgsm_evasion_attack_config0 # evasion_attack_config
180
233
}
181
234
)
182
235
183
- # gnn_model_manager.set_poison_attacker(poison_attack_config=poison_attack_config )
184
- # gnn_model_manager.set_poison_defender(poison_defense_config=poison_defense_config )
185
- gnn_model_manager .set_evasion_attacker (evasion_attack_config = evasion_attack_config )
186
- # gnn_model_manager.set_evasion_defender(evasion_defense_config=evasion_defense_config )
236
+ # gnn_model_manager.set_poison_attacker(poison_attack_config=random_poison_attack_config )
237
+ # gnn_model_manager.set_poison_defender(poison_defense_config=gnnguard_poison_defense_config )
238
+ gnn_model_manager .set_evasion_attacker (evasion_attack_config = netattackgroup_evasion_attack_config )
239
+ # gnn_model_manager.set_evasion_defender(evasion_defense_config=at_evasion_defense_config )
187
240
188
241
warnings .warn ("Start training" )
189
242
dataset .train_test_split ()
@@ -207,7 +260,8 @@ def test_attack_defense():
207
260
warnings .warn ("Training was successful" )
208
261
209
262
metric_loc = gnn_model_manager .evaluate_model (
210
- gen_dataset = dataset , metrics = [Metric ("F1" , mask = 'test' , average = 'macro' )])
263
+ gen_dataset = dataset , metrics = [Metric ("F1" , mask = 'test' , average = 'macro' ),
264
+ Metric ("Accuracy" , mask = 'test' )])
211
265
print (metric_loc )
212
266
213
267
def test_meta ():
@@ -326,12 +380,12 @@ def test_nettack_evasion():
326
380
acc_test_loc = gnn_model_manager .evaluate_model (gen_dataset = dataset ,
327
381
metrics = [Metric ("Accuracy" , mask = mask_loc )])[mask_loc ]['Accuracy' ]
328
382
329
- acc_train = gnn_model_manager .evaluate_model (gen_dataset = dataset ,
330
- metrics = [Metric ("Accuracy" , mask = 'train' )])['train' ]['Accuracy' ]
331
- acc_test = gnn_model_manager .evaluate_model (gen_dataset = dataset ,
332
- metrics = [Metric ("Accuracy" , mask = 'test' )])['test' ]['Accuracy' ]
383
+ # acc_train = gnn_model_manager.evaluate_model(gen_dataset=dataset,
384
+ # metrics=[Metric("Accuracy", mask='train')])['train']['Accuracy']
385
+ # acc_test = gnn_model_manager.evaluate_model(gen_dataset=dataset,
386
+ # metrics=[Metric("Accuracy", mask='test')])['test']['Accuracy']
333
387
334
- print (f"Accuracy on train: { acc_train } . Accuracy on test: { acc_test } " )
388
+ # print(f"Accuracy on train: {acc_train}. Accuracy on test: {acc_test}")
335
389
print (f"Accuracy on test loc: { acc_test_loc } " )
336
390
337
391
# Model prediction on a node before an evasion attack on it
@@ -724,8 +778,8 @@ def test_adv_training():
724
778
if __name__ == '__main__' :
725
779
import random
726
780
random .seed (10 )
727
- test_attack_defense ()
728
- torch .manual_seed (5000 )
729
- # test_adv_training()
781
+ #test_attack_defense()
782
+ # torch.manual_seed(5000)
730
783
# test_gnnguard()
731
784
# test_jaccard()
785
+ test_attack_defense ()
0 commit comments