Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metrics for subdataset #22

Merged
merged 3 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ models
models_exp
results
data_info
explanations
#user_models_obj
#user_models_managers
user_datasets
Expand Down
22 changes: 16 additions & 6 deletions experiments/attack_defense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,15 +282,21 @@ def test_nettack_evasion():
steps=num_steps,
save_model_flag=False)

# Node for attack
node_idx = 1

# Evaluate model
mask_loc = Metric.create_mask_by_target_list(y_true=dataset.labels, target_list=[node_idx])
acc_test_loc = gnn_model_manager.evaluate_model(gen_dataset=dataset,
metrics=[Metric("Accuracy", mask=mask_loc)])[mask_loc]['Accuracy']

acc_train = gnn_model_manager.evaluate_model(gen_dataset=dataset,
metrics=[Metric("Accuracy", mask='train')])['train']['Accuracy']
acc_test = gnn_model_manager.evaluate_model(gen_dataset=dataset,
metrics=[Metric("Accuracy", mask='test')])['test']['Accuracy']
print(f"Accuracy on train: {acc_train}. Accuracy on test: {acc_test}")

# Node for attack
node_idx = 0
print(f"Accuracy on train: {acc_train}. Accuracy on test: {acc_test}")
print(f"Accuracy on test loc: {acc_test_loc}")

# Model prediction on a node before an evasion attack on it
gnn_model_manager.gnn.eval()
Expand Down Expand Up @@ -342,6 +348,9 @@ def test_nettack_evasion():

print(f"info_before_evasion_attack: {info_before_evasion_attack}")
print(f"info_after_evasion_attack: {info_after_evasion_attack}")
acc_test_loc = gnn_model_manager.evaluate_model(gen_dataset=dataset,
metrics=[Metric("Accuracy", mask=mask_loc)])[mask_loc]['Accuracy']
print(f"Accuracy on test loc: {acc_test_loc}")

def test_qattack():
from attacks.QAttack import qattack
Expand Down Expand Up @@ -453,6 +462,7 @@ def test_qattack():

if __name__ == '__main__':
#test_attack_defense()
torch.manual_seed(5000)
#test_meta()
test_qattack()
# torch.manual_seed(5000)
# test_meta()
# test_qattack()
test_nettack_evasion()
14 changes: 13 additions & 1 deletion src/models_builder/gnn_models.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe create mask by target_list can be done with one simple function (e.g. torch.scatter) but it's ok like this

Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,18 @@ def compute(self, y_true, y_pred):

raise NotImplementedError()

@staticmethod
def create_mask_by_target_list(y_true, target_list=None):
if target_list is None:
mask = [True] * len(y_true)
else:
mask = [False] * len(y_true)
for i in target_list:
if 0 <= i < len(mask):
mask[i] = True
return tensor(mask)
# return mask


class GNNModelManager:
""" class of basic functions over models:
Expand Down Expand Up @@ -1057,7 +1069,7 @@ def evaluate_model(self, gen_dataset, metrics):
'all': [True] * len(gen_dataset.labels),
}[mask]
except KeyError:
assert isinstance(mask, list)
assert isinstance(mask, torch.Tensor)
mask_tensor = mask
if self.evasion_attacker:
self.evasion_attacker.attack(model_manager=self, gen_dataset=gen_dataset, mask_tensor=mask_tensor)
Expand Down
Loading