Skip to content

Commit

Permalink
Merge pull request #22 from ispras/metrics_for_subdataset
Browse files Browse the repository at this point in the history
Metrics for subdataset
  • Loading branch information
LukyanovKirillML authored Oct 15, 2024
2 parents afa3b24 + 3f082fb commit 211da4a
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 8 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ models
models_exp
results
data_info
explanations
#user_models_obj
#user_models_managers
user_datasets
Expand Down
22 changes: 16 additions & 6 deletions experiments/attack_defense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,15 +282,21 @@ def test_nettack_evasion():
steps=num_steps,
save_model_flag=False)

# Node for attack
node_idx = 1

# Evaluate model
mask_loc = Metric.create_mask_by_target_list(y_true=dataset.labels, target_list=[node_idx])
acc_test_loc = gnn_model_manager.evaluate_model(gen_dataset=dataset,
metrics=[Metric("Accuracy", mask=mask_loc)])[mask_loc]['Accuracy']

acc_train = gnn_model_manager.evaluate_model(gen_dataset=dataset,
metrics=[Metric("Accuracy", mask='train')])['train']['Accuracy']
acc_test = gnn_model_manager.evaluate_model(gen_dataset=dataset,
metrics=[Metric("Accuracy", mask='test')])['test']['Accuracy']
print(f"Accuracy on train: {acc_train}. Accuracy on test: {acc_test}")

# Node for attack
node_idx = 0
print(f"Accuracy on train: {acc_train}. Accuracy on test: {acc_test}")
print(f"Accuracy on test loc: {acc_test_loc}")

# Model prediction on a node before an evasion attack on it
gnn_model_manager.gnn.eval()
Expand Down Expand Up @@ -342,6 +348,9 @@ def test_nettack_evasion():

print(f"info_before_evasion_attack: {info_before_evasion_attack}")
print(f"info_after_evasion_attack: {info_after_evasion_attack}")
acc_test_loc = gnn_model_manager.evaluate_model(gen_dataset=dataset,
metrics=[Metric("Accuracy", mask=mask_loc)])[mask_loc]['Accuracy']
print(f"Accuracy on test loc: {acc_test_loc}")

def test_qattack():
from attacks.QAttack import qattack
Expand Down Expand Up @@ -453,6 +462,7 @@ def test_qattack():

if __name__ == '__main__':
#test_attack_defense()
torch.manual_seed(5000)
#test_meta()
test_qattack()
# torch.manual_seed(5000)
# test_meta()
# test_qattack()
test_nettack_evasion()
14 changes: 13 additions & 1 deletion src/models_builder/gnn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,18 @@ def compute(self, y_true, y_pred):

raise NotImplementedError()

@staticmethod
def create_mask_by_target_list(y_true, target_list=None):
if target_list is None:
mask = [True] * len(y_true)
else:
mask = [False] * len(y_true)
for i in target_list:
if 0 <= i < len(mask):
mask[i] = True
return tensor(mask)
# return mask


class GNNModelManager:
""" class of basic functions over models:
Expand Down Expand Up @@ -1057,7 +1069,7 @@ def evaluate_model(self, gen_dataset, metrics):
'all': [True] * len(gen_dataset.labels),
}[mask]
except KeyError:
assert isinstance(mask, list)
assert isinstance(mask, torch.Tensor)
mask_tensor = mask
if self.evasion_attacker:
self.evasion_attacker.attack(model_manager=self, gen_dataset=gen_dataset, mask_tensor=mask_tensor)
Expand Down

0 comments on commit 211da4a

Please sign in to comment.