Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/ispras/GNN-AID into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
LukyanovKirillML committed Nov 27, 2024
2 parents 669c406 + 0a47e92 commit 0137cb8
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 26 deletions.
39 changes: 34 additions & 5 deletions experiments/attack_defense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_attack_defense():
_import_path=POISON_ATTACK_PARAMETERS_PATH,
_config_class="PoisonAttackConfig",
_config_kwargs={
"n_edges_percent": 0.5,
"n_edges_percent": 1.0,
}
)

Expand Down Expand Up @@ -173,7 +173,7 @@ def test_attack_defense():
_import_path=EVASION_ATTACK_PARAMETERS_PATH,
_config_class="EvasionAttackConfig",
_config_kwargs={
"epsilon": 0.01 * 1,
"epsilon": 0.001 * 12,
}
)

Expand Down Expand Up @@ -210,7 +210,36 @@ def test_attack_defense():
_import_path=EVASION_DEFENSE_PARAMETERS_PATH,
_config_class="EvasionDefenseConfig",
_config_kwargs={
"regularization_strength": 0.1 * 10
"regularization_strength": 0.1 * 1000
}
)

quantization_evasion_defense_config = ConfigPattern(
_class_name="QuantizationDefender",
_import_path=EVASION_DEFENSE_PARAMETERS_PATH,
_config_class="EvasionDefenseConfig",
_config_kwargs={
"num_levels": 2
}
)

distillation_evasion_defense_config = ConfigPattern(
_class_name="DistillationDefender",
_import_path=EVASION_DEFENSE_PARAMETERS_PATH,
_config_class="EvasionDefenseConfig",
_config_kwargs={
"temperature": 0.5 * 20
}
)

autoencoder_evasion_defense_config = ConfigPattern(
_class_name="AutoEncoderDefender",
_import_path=EVASION_DEFENSE_PARAMETERS_PATH,
_config_class="EvasionDefenseConfig",
_config_kwargs={
"hidden_dim": 300,
"bottleneck_dim": 100,
"reconstruction_loss_weight": 0.1,
}
)

Expand All @@ -234,8 +263,8 @@ def test_attack_defense():

# gnn_model_manager.set_poison_attacker(poison_attack_config=random_poison_attack_config)
# gnn_model_manager.set_poison_defender(poison_defense_config=gnnguard_poison_defense_config)
# gnn_model_manager.set_evasion_attacker(evasion_attack_config=netattackgroup_evasion_attack_config)
# gnn_model_manager.set_evasion_defender(evasion_defense_config=at_evasion_defense_config)
gnn_model_manager.set_evasion_attacker(evasion_attack_config=fgsm_evasion_attack_config)
gnn_model_manager.set_evasion_defender(evasion_defense_config=autoencoder_evasion_defense_config)

warnings.warn("Start training")
dataset.train_test_split()
Expand Down
10 changes: 9 additions & 1 deletion metainfo/evasion_defense_parameters.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
"regularization_strength": ["regularization_strength", "float", 0.1, {"min": 0.0001, "step": 0.01}, "?"]
},
"QuantizationDefender": {
"qbit": ["qbit", "int", 8, {"min": 1, "step": 1}, "?"]
"num_levels": ["num_levels", "int", 32, {"min": 2, "step": 1}, "?"]
},
"DistillationDefender": {
"temperature": ["temperature", "float", 5.0, {"min": 1, "step": 0.01}, "?"]
},
"AutoEncoderDefender": {
"hidden_dim": ["hidden_dim", "int", 5, {"min": 3, "step": 1}, "?"],
"bottleneck_dim": ["bottleneck_dim", "int", 3, {"min": 1, "step": 1}, "?"],
"reconstruction_loss_weight": ["reconstruction_loss_weight", "float", 0.1, {"min": 0.0001, "step": 0.01}, "?"]
},
"AdvTraining": {
"attack_name": ["attack_name", "str", "FGSM", {}, "?"]
Expand Down
180 changes: 170 additions & 10 deletions src/defense/evasion_defense.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,43 +54,106 @@ def post_batch(
pass


class GradientRegularizationDefender(EvasionDefender):
class GradientRegularizationDefender(
EvasionDefender
):
name = "GradientRegularizationDefender"

def __init__(self, regularization_strength=0.1):
def __init__(
self,
regularization_strength: float = 0.1
):
super().__init__()
self.regularization_strength = regularization_strength

def post_batch(self, model_manager, batch, loss, **kwargs):
def post_batch(
self,
model_manager,
batch,
loss: torch.Tensor,
**kwargs
) -> dict:
batch.x.requires_grad = True
outputs = model_manager.gnn(batch.x, batch.edge_index)
loss_loc = model_manager.loss_function(outputs, batch.y)
gradients = torch.autograd.grad(outputs=loss_loc, inputs=batch.x,
grad_outputs=torch.ones_like(loss_loc),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradient_penalty = torch.sum(gradients ** 2)
batch.x.requires_grad = False
return {"loss": loss + self.regularization_strength * gradient_penalty}


# TODO Kirill, add code in pre_batch
class QuantizationDefender(
EvasionDefender
):
name = "QuantizationDefender"

def __init__(
self,
qbit: int = 8
num_levels: int = 32
):
super().__init__()
self.regularization_strength = qbit
assert num_levels > 1
self.num_levels = num_levels

def pre_batch(
self,
model_manager,
batch,
**kwargs
):
# TODO Kirill
pass
x = batch.x
batch.x = self.quantize(x)
return batch

def quantize(
self,
x: torch.Tensor
):
x_min = x.min()
x_max = x.max()
if x_min != x_max:
x_normalized = (x - x_min) / (x_max - x_min)
x_quantized = torch.round(x_normalized * (self.num_levels - 1)) / (self.num_levels - 1)
x_quantized = x_quantized * (x_max - x_min) + x_min
else:
x_quantized = x
return x_quantized


class DistillationDefender(
EvasionDefender
):
name = "DistillationDefender"

def __init__(
self,
temperature: float = 5.0
):
"""
"""
super().__init__()
self.temperature = temperature

def post_batch(
self,
model_manager,
batch,
loss: torch.Tensor
):
"""
"""
model = model_manager.gnn
logits = model(batch)
soft_targets = torch.softmax(logits / self.temperature, dim=1)
distillation_loss = torch.nn.functional.kl_div(
torch.log_softmax(logits / self.temperature, dim=1),
soft_targets,
reduction='batchmean'
) * (self.temperature ** 2)
modified_loss = loss + distillation_loss
return {"loss": modified_loss}


class AdvTraining(
Expand All @@ -104,10 +167,8 @@ def __init__(
attack_name: str = None,
attack_config: EvasionAttackConfig = None,
attack_type: str = None,
device: str = 'cpu'
):
super().__init__()
assert device is not None, "Please specify 'device'!"
if not attack_config:
# build default config
assert attack_name is not None
Expand Down Expand Up @@ -177,3 +238,102 @@ def post_batch(
outputs = model_manager.gnn(self.perturbed_gen_dataset.data.x, self.perturbed_gen_dataset.data.edge_index)
loss_loc = model_manager.loss_function(outputs, batch.y)
return {"loss": loss + loss_loc}


class SimpleAutoEncoder(
torch.nn.Module
):
def __init__(
self,
input_dim: int,
hidden_dim: int,
bottleneck_dim: int,
device: str = 'cpu'
):
"""
"""
super(SimpleAutoEncoder, self).__init__()
self.device = device
self.encoder = torch.nn.Sequential(
torch.nn.Linear(input_dim, hidden_dim),
torch.nn.ReLU(),
torch.nn.Linear(hidden_dim, bottleneck_dim),
torch.nn.ReLU()
).to(self.device)
self.decoder = torch.nn.Sequential(
torch.nn.Linear(bottleneck_dim, hidden_dim),
torch.nn.ReLU(),
torch.nn.Linear(hidden_dim, input_dim)
).to(self.device)

def forward(
self,
x: torch.Tensor
):
x = x.to(self.device)
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded


class AutoEncoderDefender(
EvasionDefender
):
name = "AutoEncoderDefender"

def __init__(
self,
hidden_dim: int,
bottleneck_dim: int,
reconstruction_loss_weight: float = 0.1,
):
"""
"""
super().__init__()
self.autoencoder = None
self.hidden_dim = hidden_dim
self.bottleneck_dim = bottleneck_dim
self.reconstruction_loss_weight = reconstruction_loss_weight

def post_batch(
self,
model_manager,
batch,
loss: torch.Tensor
) -> dict:
"""
"""
model_manager.gnn.eval()
if self.autoencoder is None:
self.init_autoencoder(batch.x)
self.autoencoder.train()
reconstructed_x = self.autoencoder(batch.x)
reconstruction_loss = torch.nn.functional.mse_loss(reconstructed_x, batch.x)
modified_loss = loss + self.reconstruction_loss_weight * reconstruction_loss.detach().clone()
autoencoder_optimizer = torch.optim.Adam(self.autoencoder.parameters(), lr=0.001)
autoencoder_optimizer.zero_grad()
reconstruction_loss.backward()
autoencoder_optimizer.step()
return {"loss": modified_loss}

def denoise_with_autoencoder(
self,
x: torch.Tensor
) -> torch.Tensor:
"""
"""
self.autoencoder.eval()
with torch.no_grad():
x_denoised = self.autoencoder(x)
return x_denoised

def init_autoencoder(
self,
x: torch.Tensor
) -> None:
self.autoencoder = SimpleAutoEncoder(
input_dim=x.shape[1],
bottleneck_dim=self.bottleneck_dim,
hidden_dim=self.hidden_dim,
device=x.device
)
38 changes: 28 additions & 10 deletions src/models_builder/gnn_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,17 +617,35 @@ def arguments_read(
edge_weight = kwargs.get('edge_weight', None)
if batch is None:
batch = torch.zeros(kwargs['x'].shape[0], dtype=torch.int64, device=x.device)
elif len(args) == 2:
x, edge_index = args[0], args[1]
batch = torch.zeros(args[0].shape[0], dtype=torch.int64, device=x.device)
edge_weight = None
elif len(args) == 3:
x, edge_index, batch = args[0], args[1], args[2]
edge_weight = None
elif len(args) == 4:
x, edge_index, batch, edge_weight = args[0], args[1], args[2], args[3]
else:
raise ValueError(f"forward's args should take 2 or 3 arguments but got {len(args)}")
if len(args) == 1:
args = args[0]
if 'x' in args and 'edge_index' in args:
x, edge_index = args.x, args.edge_index
else:
raise ValueError(f"forward's args should contain x and 3"
f" edge_index Tensors but {args.keys} doesn't content this Tensors")
if 'batch' in args:
batch = args.batch
else:
batch = torch.zeros(args.x.shape[0], dtype=torch.int64, device=x.device)
if 'edge_weight' in args:
edge_weight = args.edge_weight
else:
edge_weight = None
else:
if len(args) == 2:
x, edge_index = args[0], args[1]
batch = torch.zeros(args[0].shape[0], dtype=torch.int64, device=x.device)
edge_weight = None
elif len(args) == 3:
x, edge_index, batch = args[0], args[1], args[2]
edge_weight = None
elif len(args) == 4:
x, edge_index, batch, edge_weight = args[0], args[1], args[2], args[3]
else:
raise ValueError(f"forward's args should take 2 or 3 arguments but got {len(args)}")

else:
if hasattr(data, "edge_weight"):
x, edge_index, batch, edge_weight = data.x, data.edge_index, data.batch, data.edge_weight
Expand Down

0 comments on commit 0137cb8

Please sign in to comment.