Skip to content

Commit

Permalink
add QuantizationDefender
Browse files Browse the repository at this point in the history
  • Loading branch information
LukyanovKirillML committed Nov 25, 2024
1 parent c3ede99 commit b12acce
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 15 deletions.
17 changes: 13 additions & 4 deletions experiments/attack_defense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_attack_defense():
_import_path=POISON_ATTACK_PARAMETERS_PATH,
_config_class="PoisonAttackConfig",
_config_kwargs={
"n_edges_percent": 0.5,
"n_edges_percent": 1.0,
}
)

Expand Down Expand Up @@ -210,7 +210,16 @@ def test_attack_defense():
_import_path=EVASION_DEFENSE_PARAMETERS_PATH,
_config_class="EvasionDefenseConfig",
_config_kwargs={
"regularization_strength": 0.1 * 10
"regularization_strength": 0.1 * 1000
}
)

quantization_evasion_defense_config = ConfigPattern(
_class_name="QuantizationDefender",
_import_path=EVASION_DEFENSE_PARAMETERS_PATH,
_config_class="EvasionDefenseConfig",
_config_kwargs={
"num_levels": 2
}
)

Expand All @@ -234,8 +243,8 @@ def test_attack_defense():

# gnn_model_manager.set_poison_attacker(poison_attack_config=random_poison_attack_config)
# gnn_model_manager.set_poison_defender(poison_defense_config=gnnguard_poison_defense_config)
# gnn_model_manager.set_evasion_attacker(evasion_attack_config=netattackgroup_evasion_attack_config)
# gnn_model_manager.set_evasion_defender(evasion_defense_config=at_evasion_defense_config)
gnn_model_manager.set_evasion_attacker(evasion_attack_config=fgsm_evasion_attack_config)
gnn_model_manager.set_evasion_defender(evasion_defense_config=gradientregularization_evasion_defense_config)

warnings.warn("Start training")
dataset.train_test_split()
Expand Down
2 changes: 1 addition & 1 deletion metainfo/evasion_defense_parameters.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"regularization_strength": ["regularization_strength", "float", 0.1, {"min": 0.0001, "step": 0.01}, "?"]
},
"QuantizationDefender": {
"qbit": ["qbit", "int", 8, {"min": 1, "step": 1}, "?"]
"num_levels": ["num_levels", "int", 32, {"min": 2, "step": 1}, "?"]
},
"AdvTraining": {
"attack_name": ["attack_name", "str", "FGSM", {}, "?"]
Expand Down
43 changes: 33 additions & 10 deletions src/defense/evasion_defense.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,43 +54,68 @@ def post_batch(
pass


class GradientRegularizationDefender(EvasionDefender):
class GradientRegularizationDefender(
EvasionDefender
):
name = "GradientRegularizationDefender"

def __init__(self, regularization_strength=0.1):
def __init__(
self,
regularization_strength: float = 0.1
):
super().__init__()
self.regularization_strength = regularization_strength

def post_batch(self, model_manager, batch, loss, **kwargs):
def post_batch(
self,
model_manager,
batch,
loss: torch.Tensor,
**kwargs
) -> dict:
batch.x.requires_grad = True
outputs = model_manager.gnn(batch.x, batch.edge_index)
loss_loc = model_manager.loss_function(outputs, batch.y)
gradients = torch.autograd.grad(outputs=loss_loc, inputs=batch.x,
grad_outputs=torch.ones_like(loss_loc),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradient_penalty = torch.sum(gradients ** 2)
batch.x.requires_grad = False
return {"loss": loss + self.regularization_strength * gradient_penalty}


# TODO Kirill, add code in pre_batch
class QuantizationDefender(
EvasionDefender
):
name = "QuantizationDefender"

def __init__(
self,
qbit: int = 8
num_levels: int = 32
):
super().__init__()
self.regularization_strength = qbit
self.num_levels = num_levels

def pre_batch(
self,
model_manager,
batch,
**kwargs
):
# TODO Kirill
pass
x = batch.x
batch.x = self.quantize(x)
return batch

def quantize(
self,
x
):
x_min = x.min()
x_max = x.max()
x_normalized = (x - x_min) / (x_max - x_min)
x_quantized = torch.round(x_normalized * (self.num_levels - 1)) / (self.num_levels - 1)
x_quantized = x_quantized * (x_max - x_min) + x_min
return x_quantized


class AdvTraining(
Expand All @@ -104,10 +129,8 @@ def __init__(
attack_name: str = None,
attack_config: EvasionAttackConfig = None,
attack_type: str = None,
device: str = 'cpu'
):
super().__init__()
assert device is not None, "Please specify 'device'!"
if not attack_config:
# build default config
assert attack_name is not None
Expand Down

0 comments on commit b12acce

Please sign in to comment.