From b12accee6aa1b4c99fe19104fa95f1f2f65ebab2 Mon Sep 17 00:00:00 2001 From: "lukyanov_kirya@bk.ru" Date: Mon, 25 Nov 2024 13:53:34 +0300 Subject: [PATCH] add QuantizationDefender --- experiments/attack_defense_test.py | 17 +++++++--- metainfo/evasion_defense_parameters.json | 2 +- src/defense/evasion_defense.py | 43 ++++++++++++++++++------ 3 files changed, 47 insertions(+), 15 deletions(-) diff --git a/experiments/attack_defense_test.py b/experiments/attack_defense_test.py index 6d2b6b6..046f888 100644 --- a/experiments/attack_defense_test.py +++ b/experiments/attack_defense_test.py @@ -131,7 +131,7 @@ def test_attack_defense(): _import_path=POISON_ATTACK_PARAMETERS_PATH, _config_class="PoisonAttackConfig", _config_kwargs={ - "n_edges_percent": 0.5, + "n_edges_percent": 1.0, } ) @@ -210,7 +210,16 @@ def test_attack_defense(): _import_path=EVASION_DEFENSE_PARAMETERS_PATH, _config_class="EvasionDefenseConfig", _config_kwargs={ - "regularization_strength": 0.1 * 10 + "regularization_strength": 0.1 * 1000 + } + ) + + quantization_evasion_defense_config = ConfigPattern( + _class_name="QuantizationDefender", + _import_path=EVASION_DEFENSE_PARAMETERS_PATH, + _config_class="EvasionDefenseConfig", + _config_kwargs={ + "num_levels": 2 } ) @@ -234,8 +243,8 @@ def test_attack_defense(): # gnn_model_manager.set_poison_attacker(poison_attack_config=random_poison_attack_config) # gnn_model_manager.set_poison_defender(poison_defense_config=gnnguard_poison_defense_config) - # gnn_model_manager.set_evasion_attacker(evasion_attack_config=netattackgroup_evasion_attack_config) - # gnn_model_manager.set_evasion_defender(evasion_defense_config=at_evasion_defense_config) + gnn_model_manager.set_evasion_attacker(evasion_attack_config=fgsm_evasion_attack_config) + gnn_model_manager.set_evasion_defender(evasion_defense_config=gradientregularization_evasion_defense_config) warnings.warn("Start training") dataset.train_test_split() diff --git a/metainfo/evasion_defense_parameters.json b/metainfo/evasion_defense_parameters.json index 682a4d2..b9514be 100644 --- a/metainfo/evasion_defense_parameters.json +++ b/metainfo/evasion_defense_parameters.json @@ -5,7 +5,7 @@ "regularization_strength": ["regularization_strength", "float", 0.1, {"min": 0.0001, "step": 0.01}, "?"] }, "QuantizationDefender": { - "qbit": ["qbit", "int", 8, {"min": 1, "step": 1}, "?"] + "num_levels": ["num_levels", "int", 32, {"min": 2, "step": 1}, "?"] }, "AdvTraining": { "attack_name": ["attack_name", "str", "FGSM", {}, "?"] diff --git a/src/defense/evasion_defense.py b/src/defense/evasion_defense.py index 18d2cca..9cd9c5a 100644 --- a/src/defense/evasion_defense.py +++ b/src/defense/evasion_defense.py @@ -54,14 +54,25 @@ def post_batch( pass -class GradientRegularizationDefender(EvasionDefender): +class GradientRegularizationDefender( + EvasionDefender +): name = "GradientRegularizationDefender" - def __init__(self, regularization_strength=0.1): + def __init__( + self, + regularization_strength: float = 0.1 + ): super().__init__() self.regularization_strength = regularization_strength - def post_batch(self, model_manager, batch, loss, **kwargs): + def post_batch( + self, + model_manager, + batch, + loss: torch.Tensor, + **kwargs + ) -> dict: batch.x.requires_grad = True outputs = model_manager.gnn(batch.x, batch.edge_index) loss_loc = model_manager.loss_function(outputs, batch.y) @@ -69,10 +80,10 @@ def post_batch(self, model_manager, batch, loss, **kwargs): grad_outputs=torch.ones_like(loss_loc), create_graph=True, retain_graph=True, only_inputs=True)[0] gradient_penalty = torch.sum(gradients ** 2) + batch.x.requires_grad = False return {"loss": loss + self.regularization_strength * gradient_penalty} -# TODO Kirill, add code in pre_batch class QuantizationDefender( EvasionDefender ): @@ -80,17 +91,31 @@ class QuantizationDefender( def __init__( self, - qbit: int = 8 + num_levels: int = 32 ): super().__init__() - self.regularization_strength = qbit + self.num_levels = num_levels def pre_batch( self, + model_manager, + batch, **kwargs ): - # TODO Kirill - pass + x = batch.x + batch.x = self.quantize(x) + return batch + + def quantize( + self, + x + ): + x_min = x.min() + x_max = x.max() + x_normalized = (x - x_min) / (x_max - x_min) + x_quantized = torch.round(x_normalized * (self.num_levels - 1)) / (self.num_levels - 1) + x_quantized = x_quantized * (x_max - x_min) + x_min + return x_quantized class AdvTraining( @@ -104,10 +129,8 @@ def __init__( attack_name: str = None, attack_config: EvasionAttackConfig = None, attack_type: str = None, - device: str = 'cpu' ): super().__init__() - assert device is not None, "Please specify 'device'!" if not attack_config: # build default config assert attack_name is not None