Skip to content

Commit

Permalink
Merge branch 'develop' into metattack
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeratt authored Sep 17, 2024
2 parents 2a88ecf + 0196b9e commit 2784605
Show file tree
Hide file tree
Showing 8 changed files with 1,039 additions and 9 deletions.
108 changes: 103 additions & 5 deletions experiments/attack_defense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,10 @@ def test_attack_defense():
}
)

# gnn_model_manager.set_poison_attacker(poison_attack_config=poison_attack_config)
gnn_model_manager.set_poison_attacker(poison_attack_config=poison_attack_config)
# gnn_model_manager.set_poison_defender(poison_defense_config=poison_defense_config)
gnn_model_manager.set_evasion_attacker(evasion_attack_config=evasion_attack_config)
gnn_model_manager.set_evasion_defender(evasion_defense_config=evasion_defense_config)
# gnn_model_manager.set_evasion_attacker(evasion_attack_config=evasion_attack_config)
# gnn_model_manager.set_evasion_defender(evasion_defense_config=evasion_defense_config)

warnings.warn("Start training")
dataset.train_test_split()
Expand Down Expand Up @@ -244,9 +244,107 @@ def test_meta():
Metric("Accuracy", mask='test')])
print(metric_loc)

def test_nettack_evasion():
my_device = device('cpu')

# Load dataset
full_name = ("single-graph", "Planetoid", 'Cora')
dataset, data, results_dataset_path = DatasetManager.get_by_full_name(
full_name=full_name,
dataset_ver_ind=0
)

# Train model on original dataset and remember the model metric and node predicted probability
gcn_gcn = model_configs_zoo(dataset=dataset, model_name='gcn_gcn')

manager_config = ConfigPattern(
_config_class="ModelManagerConfig",
_config_kwargs={
"mask_features": [],
"optimizer": {
"_class_name": "Adam",
"_config_kwargs": {},
}
}
)

gnn_model_manager = FrameworkGNNModelManager(
gnn=gcn_gcn,
dataset_path=results_dataset_path,
manager_config=manager_config,
modification=ModelModificationConfig(model_ver_ind=0, epochs=0)
)

gnn_model_manager.gnn.to(my_device)

num_steps = 200
gnn_model_manager.train_model(gen_dataset=dataset,
steps=num_steps,
save_model_flag=False)

# Evaluate model
acc_train = gnn_model_manager.evaluate_model(gen_dataset=dataset,
metrics=[Metric("Accuracy", mask='train')])['train']['Accuracy']
acc_test = gnn_model_manager.evaluate_model(gen_dataset=dataset,
metrics=[Metric("Accuracy", mask='test')])['test']['Accuracy']
print(f"Accuracy on train: {acc_train}. Accuracy on test: {acc_test}")

# Node for attack
node_idx = 0

# Model prediction on a node before an evasion attack on it
gnn_model_manager.gnn.eval()
with torch.no_grad():
probabilities = torch.exp(gnn_model_manager.gnn(dataset.data.x, dataset.data.edge_index))

predicted_class = probabilities[node_idx].argmax().item()
predicted_probability = probabilities[node_idx][predicted_class].item()
real_class = dataset.data.y[node_idx].item()

info_before_evasion_attack = {"node_idx": node_idx,
"predicted_class": predicted_class,
"predicted_probability": predicted_probability,
"real_class": real_class}

# Attack config
evasion_attack_config = ConfigPattern(
_class_name="NettackEvasionAttacker",
_import_path=EVASION_ATTACK_PARAMETERS_PATH,
_config_class="EvasionAttackConfig",
_config_kwargs={
"node_idx": node_idx,
"n_perturbations": 20,
"perturb_features": True,
"perturb_structure": True,
"direct": True,
"n_influencers": 0
}
)

gnn_model_manager.set_evasion_attacker(evasion_attack_config=evasion_attack_config)

# Attack
gnn_model_manager.evaluate_model(gen_dataset=dataset, metrics=[Metric("F1", mask='test', average='macro')])

# Model prediction on a node after an evasion attack on it
with torch.no_grad():
probabilities = torch.exp(gnn_model_manager.gnn(gnn_model_manager.evasion_attacker.attack_diff.data.x,
gnn_model_manager.evasion_attacker.attack_diff.data.edge_index))

predicted_class = probabilities[node_idx].argmax().item()
predicted_probability = probabilities[node_idx][predicted_class].item()
real_class = dataset.data.y[node_idx].item()

info_after_evasion_attack = {"node_idx": node_idx,
"predicted_class": predicted_class,
"predicted_probability": predicted_probability,
"real_class": real_class}

print(f"info_before_evasion_attack: {info_before_evasion_attack}")
print(f"info_after_evasion_attack: {info_after_evasion_attack}")


if __name__ == '__main__':
#test_attack_defense()
torch.manual_seed(5000)
test_meta()


8 changes: 8 additions & 0 deletions metainfo/evasion_attack_parameters.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@
},
"FGSM": {
"epsilon": ["epsilon", "float", 0.1, {"min": 0.0001, "step": 0.01}, "?"]
},
"NettackEvasionAttacker": {
"node_idx": ["node_idx", "int", 0, {"min": 0, "step": 1}, "Index of the node to attack"],
"n_perturbations": ["n_perturbations", "int", null, {"min": 0, "step": 1}, "Number of perturbations. If None, then n_perturbations = degree(node_idx)"],
"perturb_features": ["perturb_features", "bool", true, {}, "Indicates whether the features can be changed"],
"perturb_structure": ["perturb_structure", "bool", true, {}, "Indicates whether the structure can be changed"],
"direct": ["direct", "bool", true, {}, "Indicates whether to directly modify edges/features of the node attacked or only those of influencers"],
"n_influencers": ["n_influencers", "int", 0, {"min": 0, "step": 1}, "Number of influencing nodes. Will be ignored if direct is True"]
}
}

1 change: 1 addition & 0 deletions requirements3.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ torch-spline-conv==1.2.2

# For explainers
dive-into-graphs==1.1.0

# PGMExplainer
pgmpy==0.1.24

Expand Down
103 changes: 103 additions & 0 deletions src/attacks/evasion_attacks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import torch
import torch.nn.functional as F
import numpy as np

from attacks.attack_base import Attacker

# Nettack imports
from src.attacks.nettack.nettack import Nettack
from src.attacks.nettack.utils import preprocess_graph, largest_connected_components, data_to_csr_matrix, train_w1_w2


class EvasionAttacker(Attacker):
def __init__(self, **kwargs):
Expand Down Expand Up @@ -35,3 +40,101 @@ def attack(self, model_manager, gen_dataset, mask_tensor):
perturbed_data_x = torch.clamp(perturbed_data_x, 0, 1)
gen_dataset.data.x = perturbed_data_x.detach()
return gen_dataset


class NettackEvasionAttacker(EvasionAttacker):
name = "NettackEvasionAttacker"

def __init__(self,
node_idx=0,
n_perturbations=None,
perturb_features=True,
perturb_structure=True,
direct=True,
n_influencers=0
):

super().__init__()
self.attack_diff = None
self.node_idx = node_idx
self.n_perturbations = n_perturbations
self.perturb_features = perturb_features
self.perturb_structure = perturb_structure
self.direct = direct
self.n_influencers = n_influencers

def attack(self, model_manager, gen_dataset, mask_tensor):
# Prepare
data = gen_dataset.data
_A_obs, _X_obs, _z_obs = data_to_csr_matrix(data)
_A_obs = _A_obs + _A_obs.T
_A_obs[_A_obs > 1] = 1
lcc = largest_connected_components(_A_obs)

_A_obs = _A_obs[lcc][:, lcc]

assert np.abs(_A_obs - _A_obs.T).sum() == 0, "Input graph is not symmetric"
assert _A_obs.max() == 1 and len(np.unique(_A_obs[_A_obs.nonzero()].A1)) == 1, "Graph must be unweighted"
assert _A_obs.sum(0).A1.min() > 0, "Graph contains singleton nodes"

_X_obs = _X_obs[lcc].astype('float32')
_z_obs = _z_obs[lcc]
_N = _A_obs.shape[0]
_K = _z_obs.max() + 1
_Z_obs = np.eye(_K)[_z_obs]
_An = preprocess_graph(_A_obs)
degrees = _A_obs.sum(0).A1

if self.n_perturbations is None:
self.n_perturbations = int(degrees[self.node_idx])
hidden = model_manager.gnn.GCNConv_0.out_channels
# End prepare

# Learn matrix W1 and W2
W1, W2 = train_w1_w2(dataset=gen_dataset, hidden=hidden)

# Attack
nettack = Nettack(_A_obs, _X_obs, _z_obs, W1, W2, self.node_idx, verbose=True)

nettack.reset()
nettack.attack_surrogate(n_perturbations=self.n_perturbations,
perturb_structure=self.perturb_structure,
perturb_features=self.perturb_features,
direct=self.direct,
n_influencers=self.n_influencers)

print(f'edges: {nettack.structure_perturbations}')
print(f'features: {nettack.feature_perturbations}')

self._evasion(gen_dataset, nettack.feature_perturbations, nettack.structure_perturbations)
self.attack_diff = gen_dataset

return gen_dataset

def attack_diff(self):
return self.attack_diff

@staticmethod
def _evasion(gen_dataset, feature_perturbations, structure_perturbations):
cleaned_feat_pert = list(filter(None, feature_perturbations))
if cleaned_feat_pert: # list is not empty
x = gen_dataset.data.x.clone()
for vertex, feature in cleaned_feat_pert:
if x[vertex, feature] == 0.0:
x[vertex, feature] = 1.0
elif x[vertex, feature] == 1.0:
x[vertex, feature] = 0.0
gen_dataset.data.x = x

cleaned_struct_pert = list(filter(None, structure_perturbations))
if cleaned_struct_pert: # list is not empty
edge_index = gen_dataset.data.edge_index.clone()
# add edges
for edge in cleaned_struct_pert:
edge_index = torch.cat((edge_index,
torch.tensor((edge[0], edge[1]), dtype=torch.int32).to(torch.int64).unsqueeze(1)), dim=1)
edge_index = torch.cat((edge_index,
torch.tensor((edge[1], edge[0]), dtype=torch.int32).to(torch.int64).unsqueeze(1)), dim=1)

gen_dataset.data.edge_index = edge_index

Loading

0 comments on commit 2784605

Please sign in to comment.