-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #15 from ispras/attack-defense_pipeline
Attack defense pipeline
- Loading branch information
Showing
54 changed files
with
1,088 additions
and
445 deletions.
There are no files selected for viewing
Binary file removed
BIN
-4.99 KB
...tr":{"a":"as_is"}}/labeling=binary/dataset_attack_type=original/dataset_ver_ind=0/data.pt
Binary file not shown.
Binary file removed
BIN
-437 Bytes
...a":"as_is"}}/labeling=binary/dataset_attack_type=original/dataset_ver_ind=0/pre_filter.pt
Binary file not shown.
Binary file removed
BIN
-443 Bytes
..."as_is"}}/labeling=binary/dataset_attack_type=original/dataset_ver_ind=0/pre_transform.pt
Binary file not shown.
Binary file removed
BIN
-1.71 KB
...tr":{"a":"as_is"}}/labeling=binary/dataset_attack_type=original/dataset_ver_ind=0/data.pt
Binary file not shown.
Binary file removed
BIN
-437 Bytes
...a":"as_is"}}/labeling=binary/dataset_attack_type=original/dataset_ver_ind=0/pre_filter.pt
Binary file not shown.
Binary file removed
BIN
-443 Bytes
..."as_is"}}/labeling=binary/dataset_attack_type=original/dataset_ver_ind=0/pre_transform.pt
Binary file not shown.
Binary file removed
BIN
-1.77 KB
...b":"categorical"}}/labeling=binary/dataset_attack_type=original/dataset_ver_ind=0/data.pt
Binary file not shown.
Binary file removed
BIN
-437 Bytes
...tegorical"}}/labeling=binary/dataset_attack_type=original/dataset_ver_ind=0/pre_filter.pt
Binary file not shown.
Binary file removed
BIN
-443 Bytes
...orical"}}/labeling=binary/dataset_attack_type=original/dataset_ver_ind=0/pre_transform.pt
Binary file not shown.
Binary file removed
BIN
-1.71 KB
..."a":"continuous"}}/labeling=binary/dataset_attack_type=original/dataset_ver_ind=0/data.pt
Binary file not shown.
Binary file removed
BIN
-437 Bytes
...ontinuous"}}/labeling=binary/dataset_attack_type=original/dataset_ver_ind=0/pre_filter.pt
Binary file not shown.
Binary file removed
BIN
-443 Bytes
...inuous"}}/labeling=binary/dataset_attack_type=original/dataset_ver_ind=0/pre_transform.pt
Binary file not shown.
Binary file removed
BIN
-1.77 KB
...b":"categorical"}}/labeling=binary/dataset_attack_type=original/dataset_ver_ind=0/data.pt
Binary file not shown.
Binary file removed
BIN
-437 Bytes
...tegorical"}}/labeling=binary/dataset_attack_type=original/dataset_ver_ind=0/pre_filter.pt
Binary file not shown.
Binary file removed
BIN
-443 Bytes
...orical"}}/labeling=binary/dataset_attack_type=original/dataset_ver_ind=0/pre_transform.pt
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
import torch | ||
|
||
import warnings | ||
|
||
from torch import device | ||
|
||
from src.aux.utils import POISON_ATTACK_PARAMETERS_PATH, POISON_DEFENSE_PARAMETERS_PATH | ||
from src.models_builder.gnn_models import FrameworkGNNModelManager, Metric | ||
from src.aux.configs import ModelModificationConfig, ConfigPattern | ||
from src.base.datasets_processing import DatasetManager | ||
from src.models_builder.models_zoo import model_configs_zoo | ||
|
||
|
||
def test_attack_defense(): | ||
# my_device = device('cuda' if is_available() else 'cpu') | ||
my_device = device('cpu') | ||
|
||
full_name = None | ||
|
||
# full_name = ("multiple-graphs", "TUDataset", 'MUTAG') | ||
# full_name = ("single-graph", "custom", 'karate') | ||
full_name = ("single-graph", "Planetoid", 'Cora') | ||
# full_name = ("multiple-graphs", "TUDataset", 'PROTEINS') | ||
|
||
dataset, data, results_dataset_path = DatasetManager.get_by_full_name( | ||
full_name=full_name, | ||
dataset_ver_ind=0 | ||
) | ||
|
||
# dataset, data, results_dataset_path = DatasetManager.get_by_full_name( | ||
# full_name=("single-graph", "custom", "example",), | ||
# features={'attr': {'a': 'as_is', 'b': 'as_is'}}, | ||
# labeling='threeClasses', | ||
# dataset_ver_ind=0 | ||
# ) | ||
|
||
# dataset, data, results_dataset_path = DatasetManager.get_by_full_name( | ||
# # full_name=("single-graph", "vk_samples", "vk2-ff40-N100000-A.1612175945",), | ||
# full_name=("single-graph", "vk_samples", "vk2-ff20-N10000-A.1611943634",), | ||
# # full_name=("single-graph", "vk_samples", "vk2-ff20-N1000-U.1612273925",), | ||
# # features=('sex',), | ||
# features={'str_f': tuple(), 'str_g': None, 'attr': { | ||
# # "('personal', 'political')": 'one_hot', | ||
# # "('occupation', 'type')": 'one_hot', # Don't work now | ||
# # "('relation',)": 'one_hot', | ||
# # "('age',)": 'one_hot', | ||
# "('sex',)": 'one_hot', | ||
# }}, | ||
# # features={'str_f': tuple(), 'str_g': None, 'attr': {'sex': 'one_hot', }}, | ||
# labeling='sex1', | ||
# dataset_ver_ind=0 | ||
# ) | ||
|
||
# print(data.train_mask) | ||
|
||
gnn = model_configs_zoo(dataset=dataset, model_name='gcn_gcn') | ||
# gnn = model_configs_zoo(dataset=dataset, model_name='gcn_gcn_lin') | ||
# gnn = model_configs_zoo(dataset=dataset, model_name='test_gnn') | ||
# gnn = model_configs_zoo(dataset=dataset, model_name='gin_gin_gin_lin_lin') | ||
# gnn = model_configs_zoo(dataset=dataset, model_name='gin_gin_gin_lin_lin_prot') | ||
|
||
manager_config = ConfigPattern( | ||
_config_class="ModelManagerConfig", | ||
_config_kwargs={ | ||
"mask_features": [], | ||
"optimizer": { | ||
# "_config_class": "Config", | ||
"_class_name": "Adam", | ||
# "_import_path": OPTIMIZERS_PARAMETERS_PATH, | ||
# "_class_import_info": ["torch.optim"], | ||
"_config_kwargs": {}, | ||
} | ||
} | ||
) | ||
# manager_config = ModelManagerConfig(**{ | ||
# "mask_features": [], | ||
# "optimizer": { | ||
# # "_config_class": "Config", | ||
# "_class_name": "Adam", | ||
# # "_import_path": OPTIMIZERS_PARAMETERS_PATH, | ||
# # "_class_import_info": ["torch.optim"], | ||
# "_config_kwargs": {}, | ||
# } | ||
# } | ||
# ) | ||
|
||
# train_test_split = [0.8, 0.2] | ||
# train_test_split = [0.6, 0.4] | ||
steps_epochs = 200 | ||
gnn_model_manager = FrameworkGNNModelManager( | ||
gnn=gnn, | ||
dataset_path=results_dataset_path, | ||
manager_config=manager_config, | ||
modification=ModelModificationConfig(model_ver_ind=0, epochs=steps_epochs) | ||
) | ||
|
||
# save_model_flag = False | ||
save_model_flag = True | ||
|
||
# data.x = data.x.float() | ||
gnn_model_manager.gnn.to(my_device) | ||
data = data.to(my_device) | ||
|
||
poison_attack_config = ConfigPattern( | ||
_class_name="RandomPoisonAttack", | ||
_import_path=POISON_ATTACK_PARAMETERS_PATH, | ||
_config_class="PoisonAttackConfig", | ||
_config_kwargs={ | ||
"n_edges_percent": 0.1, | ||
} | ||
) | ||
|
||
# poison_defense_config = ConfigPattern( | ||
# _class_name="BadRandomPoisonDefender", | ||
# _import_path=POISON_DEFENSE_PARAMETERS_PATH, | ||
# _config_class="PoisonDefenseConfig", | ||
# _config_kwargs={ | ||
# "n_edges_percent": 0.1, | ||
# } | ||
# ) | ||
poison_defense_config = ConfigPattern( | ||
_class_name="EmptyPoisonDefender", | ||
_import_path=POISON_DEFENSE_PARAMETERS_PATH, | ||
_config_class="PoisonDefenseConfig", | ||
_config_kwargs={ | ||
} | ||
) | ||
|
||
gnn_model_manager.set_poison_attacker(poison_attack_config=poison_attack_config) | ||
gnn_model_manager.set_poison_defender(poison_defense_config=poison_defense_config) | ||
|
||
warnings.warn("Start training") | ||
dataset.train_test_split() | ||
|
||
try: | ||
raise FileNotFoundError() | ||
# gnn_model_manager.load_model_executor() | ||
except FileNotFoundError: | ||
gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0 | ||
train_test_split_path = gnn_model_manager.train_model(gen_dataset=dataset, steps=steps_epochs, | ||
save_model_flag=save_model_flag, | ||
metrics=[Metric("F1", mask='train', average=None)]) | ||
|
||
if train_test_split_path is not None: | ||
dataset.save_train_test_mask(train_test_split_path) | ||
train_mask, val_mask, test_mask, train_test_sizes = torch.load(train_test_split_path / 'train_test_split')[ | ||
:] | ||
dataset.train_mask, dataset.val_mask, dataset.test_mask = train_mask, val_mask, test_mask | ||
data.percent_train_class, data.percent_test_class = train_test_sizes | ||
|
||
warnings.warn("Training was successful") | ||
|
||
metric_loc = gnn_model_manager.evaluate_model( | ||
gen_dataset=dataset, metrics=[Metric("F1", mask='test', average='macro')]) | ||
print(metric_loc) | ||
|
||
|
||
if __name__ == '__main__': | ||
test_attack_defense() | ||
|
||
|
Oops, something went wrong.