Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cosmetic changes #35

Merged
merged 8 commits into from
Nov 25, 2024
23 changes: 19 additions & 4 deletions src/attacks/attack_base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
from typing import Type

from base.datasets_processing import GeneralDataset


class Attacker:
name = "Attacker"

def __init__(self):
def __init__(
self
):
pass

def attack(self, **kwargs):
def attack(
self,
**kwargs
):
pass

def attack_diff(self):
def attack_diff(
self
):
pass

@staticmethod
def check_availability(gen_dataset, model_manager):
def check_availability(
gen_dataset: GeneralDataset,
model_manager: Type
):
return False


150 changes: 110 additions & 40 deletions src/attacks/evasion_attacks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import Type, Union

import torch
import torch.nn.functional as F
import numpy as np

from attacks.attack_base import Attacker
from base.datasets_processing import GeneralDataset

# Nettack imports
from src.attacks.nettack.nettack import Nettack
Expand All @@ -11,31 +14,50 @@
# PGD imports
from attacks.evasion_attacks_collection.pgd.utils import Projection, RandomSampling
import torch.nn.functional as F
from torch_geometric.utils import to_dense_adj, dense_to_sparse, k_hop_subgraph
from torch_geometric.utils import k_hop_subgraph
from tqdm import tqdm
from torch_geometric.nn import SGConv


class EvasionAttacker(Attacker):
def __init__(self, **kwargs):
class EvasionAttacker(
Attacker
):
def __init__(
self,
**kwargs
):
super().__init__()


class EmptyEvasionAttacker(EvasionAttacker):
class EmptyEvasionAttacker(
EvasionAttacker
):
name = "EmptyEvasionAttacker"

def attack(self, **kwargs):
def attack(
self,
**kwargs
):
pass


class FGSMAttacker(EvasionAttacker):
class FGSMAttacker(
EvasionAttacker
):
name = "FGSM"

def __init__(self, epsilon=0.1):
def __init__(
self,
epsilon: float = 0.1
):
super().__init__()
self.epsilon = epsilon

def attack(self, model_manager, gen_dataset, mask_tensor):
def attack(
self,
model_manager: Type,
gen_dataset: GeneralDataset,
mask_tensor: torch.Tensor
):
gen_dataset.data.x.requires_grad = True
output = model_manager.gnn(gen_dataset.data.x, gen_dataset.data.edge_index, gen_dataset.data.batch)
loss = model_manager.loss_function(output[mask_tensor],
Expand All @@ -49,16 +71,20 @@ def attack(self, model_manager, gen_dataset, mask_tensor):
return gen_dataset


class PGDAttacker(EvasionAttacker):
class PGDAttacker(
EvasionAttacker
):
name = "PGD"

def __init__(self,
is_feature_attack=False,
element_idx=0,
epsilon=0.5,
learning_rate=0.001,
num_iterations=100,
num_rand_trials=100):
def __init__(
self,
is_feature_attack: bool = False,
element_idx: int = 0,
epsilon: float = 0.5,
learning_rate: float = 0.001,
num_iterations: int = 100,
num_rand_trials: int = 100
):

super().__init__()
self.attack_diff = None
Expand All @@ -69,13 +95,22 @@ def __init__(self,
self.num_iterations = num_iterations
self.num_rand_trials = num_rand_trials

def attack(self, model_manager, gen_dataset, mask_tensor):
def attack(
self,
model_manager: Type,
gen_dataset: GeneralDataset,
mask_tensor: torch.Tensor
) -> None:
if gen_dataset.is_multi():
self._attack_on_graph(model_manager, gen_dataset)
else:
self._attack_on_node(model_manager, gen_dataset)

def _attack_on_node(self, model_manager, gen_dataset):
def _attack_on_node(
self,
model_manager: Type,
gen_dataset: GeneralDataset
) -> None:
node_idx = self.element_idx

edge_index = gen_dataset.data.edge_index
Expand Down Expand Up @@ -118,7 +153,11 @@ def _attack_on_node(self, model_manager, gen_dataset):
else: # structure attack
pass

def _attack_on_graph(self, model_manager, gen_dataset):
def _attack_on_graph(
self,
model_manager: Type,
gen_dataset: GeneralDataset
):
graph_idx = self.element_idx

edge_index = gen_dataset.dataset[graph_idx].edge_index
Expand Down Expand Up @@ -149,21 +188,26 @@ def _attack_on_graph(self, model_manager, gen_dataset):
else: # structure attack
pass

def attack_diff(self):
def attack_diff(
self
):
return self.attack_diff


class NettackEvasionAttacker(EvasionAttacker):
class NettackEvasionAttacker(
EvasionAttacker
):
name = "NettackEvasionAttacker"

def __init__(self,
node_idx=0,
n_perturbations=None,
perturb_features=True,
perturb_structure=True,
direct=True,
n_influencers=0
):
def __init__(
self,
node_idx: int = 0,
n_perturbations: Union[int, None] = None,
perturb_features: bool = True,
perturb_structure: bool = True,
direct: bool = True,
n_influencers: int = 0
):

super().__init__()
self.attack_diff = None
Expand All @@ -174,7 +218,12 @@ def __init__(self,
self.direct = direct
self.n_influencers = n_influencers

def attack(self, model_manager, gen_dataset, mask_tensor):
def attack(
self,
model_manager: Type,
gen_dataset: GeneralDataset,
mask_tensor: torch.Tensor
) -> GeneralDataset:
# Prepare
data = gen_dataset.data
_A_obs, _X_obs, _z_obs = data_to_csr_matrix(data)
Expand Down Expand Up @@ -222,11 +271,17 @@ def attack(self, model_manager, gen_dataset, mask_tensor):

return gen_dataset

def attack_diff(self):
def attack_diff(
self
):
return self.attack_diff

@staticmethod
def _evasion(gen_dataset, feature_perturbations, structure_perturbations):
def _evasion(
gen_dataset: GeneralDataset,
feature_perturbations,
structure_perturbations
):
cleaned_feat_pert = list(filter(None, feature_perturbations))
if cleaned_feat_pert: # list is not empty
x = gen_dataset.data.x.clone()
Expand All @@ -243,17 +298,27 @@ def _evasion(gen_dataset, feature_perturbations, structure_perturbations):
# add edges
for edge in cleaned_struct_pert:
edge_index = torch.cat((edge_index,
torch.tensor((edge[0], edge[1]), dtype=torch.int32).to(torch.int64).unsqueeze(1)), dim=1)
torch.tensor((edge[0], edge[1]), dtype=torch.int32).to(torch.int64).unsqueeze(
1)), dim=1)
edge_index = torch.cat((edge_index,
torch.tensor((edge[1], edge[0]), dtype=torch.int32).to(torch.int64).unsqueeze(1)), dim=1)
torch.tensor((edge[1], edge[0]), dtype=torch.int32).to(torch.int64).unsqueeze(
1)), dim=1)

gen_dataset.data.edge_index = edge_index

class NettackGroupEvasionAttacker(EvasionAttacker):

class NettackGroupEvasionAttacker(
EvasionAttacker
):
name = "NettackGroupEvasionAttacker"
def __init__(self,node_idxs, **kwargs):

def __init__(
self,
node_idxs: list,
**kwargs
):
super().__init__()
self.node_idxs = node_idxs # kwargs.get("node_idxs")
self.node_idxs = node_idxs # kwargs.get("node_idxs")
assert isinstance(self.node_idxs, list)
self.n_perturbations = kwargs.get("n_perturbations")
self.perturb_features = kwargs.get("perturb_features")
Expand All @@ -262,8 +327,13 @@ def __init__(self,node_idxs, **kwargs):
self.n_influencers = kwargs.get("n_influencers")
self.attacker = NettackEvasionAttacker(0, **kwargs)

def attack(self, model_manager, gen_dataset, mask_tensor):
def attack(
self,
model_manager: Type,
gen_dataset: GeneralDataset,
mask_tensor: torch.Tensor
) -> GeneralDataset:
for node_idx in self.node_idxs:
self.attacker.node_idx = node_idx
gen_dataset = self.attacker.attack(model_manager, gen_dataset, mask_tensor)
return gen_dataset
return gen_dataset
18 changes: 14 additions & 4 deletions src/attacks/mi_attacks.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
from attacks.attack_base import Attacker


class MIAttacker(Attacker):
def __init__(self, **kwargs):
class MIAttacker(
Attacker
):
def __init__(
self,
**kwargs
):
super().__init__()


class EmptyMIAttacker(MIAttacker):
class EmptyMIAttacker(
MIAttacker
):
name = "EmptyMIAttacker"

def attack(self, **kwargs):
def attack(
self,
**kwargs
):
pass
Loading
Loading