Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/ispras/GNN-AID into gnne…
Browse files Browse the repository at this point in the history
…xplainer_check
  • Loading branch information
mishabounty committed Oct 29, 2024
2 parents a06a195 + f25a047 commit a91ff1c
Show file tree
Hide file tree
Showing 4 changed files with 379 additions and 2 deletions.
119 changes: 119 additions & 0 deletions experiments/interpretation_metrics_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import random
import warnings

import torch

from aux.custom_decorators import timing_decorator
from aux.utils import EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH, EXPLAINERS_INIT_PARAMETERS_PATH
from explainers.explainers_manager import FrameworkExplainersManager
from models_builder.gnn_models import FrameworkGNNModelManager, Metric
from src.aux.configs import ModelModificationConfig, ConfigPattern
from src.base.datasets_processing import DatasetManager
from src.models_builder.models_zoo import model_configs_zoo


@timing_decorator
def run_interpretation_test():
full_name = ("single-graph", "Planetoid", 'Cora')
steps_epochs = 10
save_model_flag = False
my_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset, data, results_dataset_path = DatasetManager.get_by_full_name(
full_name=full_name,
dataset_ver_ind=0
)
gnn = model_configs_zoo(dataset=dataset, model_name='gcn_gcn')
manager_config = ConfigPattern(
_config_class="ModelManagerConfig",
_config_kwargs={
"mask_features": [],
"optimizer": {
# "_config_class": "Config",
"_class_name": "Adam",
# "_import_path": OPTIMIZERS_PARAMETERS_PATH,
# "_class_import_info": ["torch.optim"],
"_config_kwargs": {},
}
}
)
gnn_model_manager = FrameworkGNNModelManager(
gnn=gnn,
dataset_path=results_dataset_path,
manager_config=manager_config,
modification=ModelModificationConfig(model_ver_ind=0, epochs=steps_epochs)
)
gnn_model_manager.gnn.to(my_device)
data.x = data.x.float()
data = data.to(my_device)

warnings.warn("Start training")
try:
raise FileNotFoundError()
except FileNotFoundError:
gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0
train_test_split_path = gnn_model_manager.train_model(gen_dataset=dataset, steps=steps_epochs,
save_model_flag=save_model_flag,
metrics=[Metric("F1", mask='train', average=None)])

if train_test_split_path is not None:
dataset.save_train_test_mask(train_test_split_path)
train_mask, val_mask, test_mask, train_test_sizes = torch.load(train_test_split_path / 'train_test_split')[
:]
dataset.train_mask, dataset.val_mask, dataset.test_mask = train_mask, val_mask, test_mask
data.percent_train_class, data.percent_test_class = train_test_sizes
warnings.warn("Training was successful")

metric_loc = gnn_model_manager.evaluate_model(
gen_dataset=dataset, metrics=[Metric("F1", mask='test', average='macro')])
print(metric_loc)

explainer_init_config = ConfigPattern(
_class_name="GNNExplainer(torch-geom)",
_import_path=EXPLAINERS_INIT_PARAMETERS_PATH,
_config_class="ExplainerInitConfig",
_config_kwargs={
"epochs": 10
}
)
explainer_metrics_run_config = ConfigPattern(
_config_class="ExplainerRunConfig",
_config_kwargs={
"mode": "local",
"kwargs": {
"_class_name": "GNNExplainer(torch-geom)",
"_import_path": EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH,
"_config_class": "Config",
"_config_kwargs": {
"stability_graph_perturbations_nums": 10,
"stability_feature_change_percent": 0.05,
"stability_node_removal_percent": 0.05,
"consistency_num_explanation_runs": 10
},
}
}
)

explainer_GNNExpl = FrameworkExplainersManager(
init_config=explainer_init_config,
dataset=dataset, gnn_manager=gnn_model_manager,
explainer_name='GNNExplainer(torch-geom)',
)

num_explaining_nodes = 10
node_indices = random.sample(range(dataset.data.x.shape[0]), num_explaining_nodes)

# explainer_GNNExpl.explainer.pbar = ProgressBar(socket, "er", desc=f'{explainer_GNNExpl.explainer.name} explaining')
# explanation_metric = NodesExplainerMetric(
# model=explainer_GNNExpl.gnn,
# graph=explainer_GNNExpl.gen_dataset.data,
# explainer=explainer_GNNExpl.explainer
# )
# res = explanation_metric.evaluate(node_indices)
explanation_metrics = explainer_GNNExpl.evaluate_metrics(node_indices, explainer_metrics_run_config)
print(explanation_metrics)


if __name__ == '__main__':
random.seed(11)
run_interpretation_test()
11 changes: 11 additions & 0 deletions src/explainers/GNNExplainer/torch_geom_our/out.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,17 @@ def run(self, mode, kwargs, finalize=True):
self.raw_explanation = self.explainer(self.x, self.edge_index, index=self.node_idx)
self.pbar.close()

@finalize_decorator
def evaluate_tensor_graph(self, x, edge_index, node_idx, **kwargs):
self._run_mode = "local"
self.node_idx = node_idx
self.x = x
self.edge_index = edge_index
self.pbar.reset(total=self.epochs, mode=self._run_mode)
self.explainer.algorithm.pbar = self.pbar
self.raw_explanation = self.explainer(self.x, self.edge_index, index=self.node_idx, **kwargs)
self.pbar.close()

def _finalize(self):
mode = self._run_mode
assert mode == "local"
Expand Down
203 changes: 203 additions & 0 deletions src/explainers/explainer_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
import numpy as np
import torch
from torch_geometric.utils import subgraph


class NodesExplainerMetric:
def __init__(self, model, graph, explainer, kwargs_dict):
self.model = model
self.explainer = explainer
self.graph = graph
self.x = self.graph.x
self.edge_index = self.graph.edge_index
self.kwargs_dict = {
"stability_graph_perturbations_nums": 10,
"stability_feature_change_percent": 0.05,
"stability_node_removal_percent": 0.05,
"consistency_num_explanation_runs": 10
}
self.kwargs_dict.update(kwargs_dict)
self.nodes_explanations = {} # explanations cache. node_ind -> explanation
self.dictionary = {
}

def evaluate(self, target_nodes_indices):
num_targets = len(target_nodes_indices)
sparsity = 0
stability = 0
consistency = 0
for node_ind in target_nodes_indices:
self.get_explanation(node_ind)
sparsity += self.calculate_sparsity(node_ind)
stability += self.calculate_stability(
node_ind,
graph_perturbations_nums=self.kwargs_dict["stability_graph_perturbations_nums"],
feature_change_percent=self.kwargs_dict["stability_feature_change_percent"],
node_removal_percent=self.kwargs_dict["stability_node_removal_percent"]
)
consistency += self.calculate_consistency(
node_ind,
num_explanation_runs=self.kwargs_dict["consistency_num_explanation_runs"]
)
fidelity = self.calculate_fidelity(target_nodes_indices)
self.dictionary["sparsity"] = sparsity / num_targets
self.dictionary["stability"] = stability / num_targets
self.dictionary["consistency"] = consistency / num_targets
self.dictionary["fidelity"] = fidelity
return self.dictionary

def calculate_fidelity(self, target_nodes_indices):
original_answer = self.model.get_answer(self.x, self.edge_index)
same_answers_count = 0
for node_ind in target_nodes_indices:
node_explanation = self.get_explanation(node_ind)
new_x, new_edge_index, new_target_node = self.filter_graph_by_explanation(
self.x, self.edge_index, node_explanation, node_ind
)
filtered_answer = self.model.get_answer(new_x, new_edge_index)
matched = filtered_answer[new_target_node] == original_answer[node_ind]
print(f"Processed fidelity calculation for node id {node_ind}. Matched: {matched}")
if matched:
same_answers_count += 1
fidelity = same_answers_count / len(target_nodes_indices)
return fidelity

def calculate_sparsity(self, node_ind):
explanation = self.get_explanation(node_ind)
sparsity = 1 - (len(explanation["data"]["nodes"]) + len(explanation["data"]["edges"])) / (
len(self.x) + len(self.edge_index))
return sparsity

def calculate_stability(
self,
node_ind,
graph_perturbations_nums=10,
feature_change_percent=0.05,
node_removal_percent=0.05
):
base_explanation = self.get_explanation(node_ind)
stability = 0
for _ in range(graph_perturbations_nums):
new_x, new_edge_index = self.perturb_graph(
self.x, self.edge_index, node_ind, feature_change_percent, node_removal_percent
)
perturbed_explanation = self.calculate_explanation(new_x, new_edge_index, node_ind)
base_explanation_vector, perturbed_explanation_vector = \
NodesExplainerMetric.calculate_explanation_vectors(base_explanation, perturbed_explanation)

stability += euclidean_distance(base_explanation_vector, perturbed_explanation_vector)

stability = stability / graph_perturbations_nums
return stability

def calculate_consistency(self, node_ind, num_explanation_runs=10):
explanation = self.get_explanation(node_ind)
consistency = 0
for _ in range(num_explanation_runs):
perturbed_explanation = self.calculate_explanation(self.x, self.edge_index, node_ind)
base_explanation_vector, perturbed_explanation_vector = \
NodesExplainerMetric.calculate_explanation_vectors(explanation, perturbed_explanation)
consistency += cosine_similarity(base_explanation_vector, perturbed_explanation_vector)
explanation = perturbed_explanation

consistency = consistency / num_explanation_runs
return consistency

def calculate_explanation(self, x, edge_index, node_idx, **kwargs):
print(f"Processing explanation calculation for node id {node_idx}.")
self.explainer.evaluate_tensor_graph(x, edge_index, node_idx, **kwargs)
print(f"Explanation calculation for node id {node_idx} completed.")
return self.explainer.explanation.dictionary

def get_explanation(self, node_ind):
if node_ind in self.nodes_explanations:
node_explanation = self.nodes_explanations[node_ind]
else:
node_explanation = self.calculate_explanation(self.x, self.edge_index, node_ind)
self.nodes_explanations[node_ind] = node_explanation
return node_explanation

@staticmethod
def parse_explanation(explanation):
important_nodes = {
int(node): float(weight) for node, weight in explanation["data"]["nodes"].items()
}
important_edges = {
tuple(map(int, edge.split(','))): float(weight)
for edge, weight in explanation["data"]["edges"].items()
}
return important_nodes, important_edges

@staticmethod
def filter_graph_by_explanation(x, edge_index, explanation, target_node):
important_nodes, important_edges = NodesExplainerMetric.parse_explanation(explanation)
all_important_nodes = set(important_nodes.keys())
all_important_nodes.add(target_node)
for u, v in important_edges.keys():
all_important_nodes.add(u)
all_important_nodes.add(v)

important_node_indices = list(all_important_nodes)
node_mask = torch.zeros(x.size(0), dtype=torch.bool)
node_mask[important_node_indices] = True

new_edge_index, new_edge_weight = subgraph(node_mask, edge_index, relabel_nodes=True)
new_x = x[node_mask]
new_target_node = important_node_indices.index(target_node)
return new_x, new_edge_index, new_target_node

@staticmethod
def calculate_explanation_vectors(base_explanation, perturbed_explanation):
base_important_nodes, base_important_edges = NodesExplainerMetric.parse_explanation(
base_explanation
)
perturbed_important_nodes, perturbed_important_edges = NodesExplainerMetric.parse_explanation(
perturbed_explanation
)
union_nodes = set(base_important_nodes.keys()) | set(perturbed_important_nodes.keys())
union_edges = set(base_important_edges.keys()) | set(perturbed_important_edges.keys())
explain_vector_len = len(union_nodes) + len(union_edges)
base_explanation_vector = np.zeros(explain_vector_len)
perturbed_explanation_vector = np.zeros(explain_vector_len)
i = 0
for expl_node_ind in union_nodes:
base_explanation_vector[i] = base_important_nodes.get(expl_node_ind, 0)
perturbed_explanation_vector[i] = perturbed_important_nodes.get(expl_node_ind, 0)
i += 1
for expl_edge in union_edges:
base_explanation_vector[i] = base_important_edges.get(expl_edge, 0)
perturbed_explanation_vector[i] = perturbed_important_edges.get(expl_edge, 0)
i += 1
return base_explanation_vector, perturbed_explanation_vector

@staticmethod
def perturb_graph(x, edge_index, node_ind, feature_change_percent, node_removal_percent):
new_x = x.clone()
num_nodes = x.shape[0]
num_features = x.shape[1]
num_features_to_change = int(feature_change_percent * num_nodes * num_features)
indices = torch.randint(0, num_nodes * num_features, (num_features_to_change,), device=x.device)
new_x.view(-1)[indices] = 1.0 - new_x.view(-1)[indices]

neighbors = edge_index[1][edge_index[0] == node_ind].unique()
num_nodes_to_remove = int(node_removal_percent * neighbors.shape[0])

if num_nodes_to_remove > 0:
nodes_to_remove = neighbors[
torch.randperm(neighbors.size(0), device=edge_index.device)[:num_nodes_to_remove]
]
mask = ~((edge_index[0] == node_ind).unsqueeze(1) & (edge_index[1].unsqueeze(0) == nodes_to_remove).any(
dim=0))
new_edge_index = edge_index[:, mask]
else:
new_edge_index = edge_index

return new_x, new_edge_index


def cosine_similarity(a, b):
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))


def euclidean_distance(a, b):
return np.linalg.norm(a - b)
Loading

0 comments on commit a91ff1c

Please sign in to comment.