Skip to content

Commit

Permalink
DistillationDefender
Browse files Browse the repository at this point in the history
  • Loading branch information
LukyanovKirillML committed Nov 25, 2024
1 parent b12acce commit 9e51ade
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 13 deletions.
13 changes: 11 additions & 2 deletions experiments/attack_defense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def test_attack_defense():
_import_path=EVASION_ATTACK_PARAMETERS_PATH,
_config_class="EvasionAttackConfig",
_config_kwargs={
"epsilon": 0.01 * 1,
"epsilon": 0.001 * 12,
}
)

Expand Down Expand Up @@ -223,6 +223,15 @@ def test_attack_defense():
}
)

distillation_evasion_defense_config = ConfigPattern(
_class_name="DistillationDefender",
_import_path=EVASION_DEFENSE_PARAMETERS_PATH,
_config_class="EvasionDefenseConfig",
_config_kwargs={
"temperature": 0.5 * 20
}
)

fgsm_evasion_attack_config0 = ConfigPattern(
_class_name="FGSM",
_import_path=EVASION_ATTACK_PARAMETERS_PATH,
Expand All @@ -244,7 +253,7 @@ def test_attack_defense():
# gnn_model_manager.set_poison_attacker(poison_attack_config=random_poison_attack_config)
# gnn_model_manager.set_poison_defender(poison_defense_config=gnnguard_poison_defense_config)
gnn_model_manager.set_evasion_attacker(evasion_attack_config=fgsm_evasion_attack_config)
gnn_model_manager.set_evasion_defender(evasion_defense_config=gradientregularization_evasion_defense_config)
gnn_model_manager.set_evasion_defender(evasion_defense_config=distillation_evasion_defense_config)

warnings.warn("Start training")
dataset.train_test_split()
Expand Down
3 changes: 3 additions & 0 deletions metainfo/evasion_defense_parameters.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
"QuantizationDefender": {
"num_levels": ["num_levels", "int", 32, {"min": 2, "step": 1}, "?"]
},
"DistillationDefender": {
"temperature": ["temperature", "float", 5.0, {"min": 1, "step": 0.01}, "?"]
},
"AdvTraining": {
"attack_name": ["attack_name", "str", "FGSM", {}, "?"]
}
Expand Down
36 changes: 35 additions & 1 deletion src/defense/evasion_defense.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def pre_batch(

def quantize(
self,
x
x: torch.Tensor
):
x_min = x.min()
x_max = x.max()
Expand All @@ -118,6 +118,40 @@ def quantize(
return x_quantized


class DistillationDefender(
EvasionDefender
):
name = "DistillationDefender"

def __init__(
self,
temperature: float = 5.0
):
"""
"""
super().__init__()
self.temperature = temperature

def post_batch(
self,
model_manager,
batch,
loss: torch.Tensor
):
"""
"""
model = model_manager.gnn
logits = model(batch)
soft_targets = torch.softmax(logits / self.temperature, dim=1)
distillation_loss = torch.nn.functional.kl_div(
torch.log_softmax(logits / self.temperature, dim=1),
soft_targets,
reduction='batchmean'
) * (self.temperature ** 2)
modified_loss = loss + distillation_loss
return {"loss": modified_loss}


class AdvTraining(
EvasionDefender
):
Expand Down
38 changes: 28 additions & 10 deletions src/models_builder/gnn_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,17 +617,35 @@ def arguments_read(
edge_weight = kwargs.get('edge_weight', None)
if batch is None:
batch = torch.zeros(kwargs['x'].shape[0], dtype=torch.int64, device=x.device)
elif len(args) == 2:
x, edge_index = args[0], args[1]
batch = torch.zeros(args[0].shape[0], dtype=torch.int64, device=x.device)
edge_weight = None
elif len(args) == 3:
x, edge_index, batch = args[0], args[1], args[2]
edge_weight = None
elif len(args) == 4:
x, edge_index, batch, edge_weight = args[0], args[1], args[2], args[3]
else:
raise ValueError(f"forward's args should take 2 or 3 arguments but got {len(args)}")
if len(args) == 1:
args = args[0]
if 'x' in args and 'edge_index' in args:
x, edge_index = args.x, args.edge_index
else:
raise ValueError(f"forward's args should contain x and 3"
f" edge_index Tensors but {args.keys} doesn't content this Tensors")
if 'batch' in args:
batch = args.batch
else:
batch = torch.zeros(args.x.shape[0], dtype=torch.int64, device=x.device)
if 'edge_weight' in args:
edge_weight = args.edge_weight
else:
edge_weight = None
else:
if len(args) == 2:
x, edge_index = args[0], args[1]
batch = torch.zeros(args[0].shape[0], dtype=torch.int64, device=x.device)
edge_weight = None
elif len(args) == 3:
x, edge_index, batch = args[0], args[1], args[2]
edge_weight = None
elif len(args) == 4:
x, edge_index, batch, edge_weight = args[0], args[1], args[2], args[3]
else:
raise ValueError(f"forward's args should take 2 or 3 arguments but got {len(args)}")

else:
if hasattr(data, "edge_weight"):
x, edge_index, batch, edge_weight = data.x, data.edge_index, data.batch, data.edge_weight
Expand Down

0 comments on commit 9e51ade

Please sign in to comment.