Skip to content

Commit

Permalink
make better files in explainers
Browse files Browse the repository at this point in the history
  • Loading branch information
LukyanovKirillML committed Nov 21, 2024
1 parent ba86e9d commit 45220f9
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 50 deletions.
55 changes: 43 additions & 12 deletions src/explainers/explainer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from pathlib import Path
from time import sleep
from abc import ABC, abstractmethod
from typing import Union, Callable, Any, Type

from flask_socketio import SocketIO
from tqdm import tqdm

from base.datasets_processing import GeneralDataset, DatasetManager
from base.datasets_processing import GeneralDataset


class ProgressBar(tqdm):
Expand Down Expand Up @@ -84,19 +85,21 @@ class Explainer(

@staticmethod
def check_availability(
gen_dataset: DatasetManager,
gen_dataset: GeneralDataset,
model_manager: Type
) -> bool:
""" Availability check for the given dataset and model manager. """
return False

def __init__(
self,
gen_dataset: GeneralDataset, model):
gen_dataset: GeneralDataset,
model: Type,
**kwargs
):
"""
:param gen_dataset: dataset
:param model: GNN model
:param kwargs: init args
"""
self.gen_dataset = gen_dataset
self.model = model
Expand All @@ -110,7 +113,12 @@ def __init__(

@finalize_decorator
@abstractmethod
def run(self, mode, kwargs, finalize=True):
def run(
self,
mode: str,
kwargs: dict,
finalize: bool = True
):
"""
Run explanation on a given element (node or graph).
finalize_decorator handles finalize() call when run() is finished.
Expand All @@ -123,15 +131,20 @@ def run(self, mode, kwargs, finalize=True):
pass

@abstractmethod
def _finalize(self):
def _finalize(
self
):
"""
Convert current explanation into inner framework json-able format.
:return:
"""
pass

def save(self, path):
def save(
self,
path: Union[str, Path]
) -> None:
"""
Dump explanation in json format at a given path.
Expand All @@ -141,24 +154,40 @@ def save(self, path):
self.explanation.save(path)


class DummyExplainer(Explainer):
class DummyExplainer(
Explainer
):
""" Dummy explainer for debugging
"""
name = '_Dummy'

@staticmethod
def check_availability(gen_dataset, model_manager):
def check_availability(
gen_dataset: GeneralDataset,
model_manager: Type
) -> bool:
""" Fits for all """
return True

def __init__(self, gen_dataset, model, init_arg=None, **kwargs):
def __init__(
self,
gen_dataset: GeneralDataset,
model: Type,
init_arg=None,
**kwargs
):
Explainer.__init__(self, gen_dataset, model)
self.init_arg = init_arg
self._local_explanation = None
self._global_explanation = None

@finalize_decorator
def run(self, mode, kwargs, finalize=True):
def run(
self,
mode: str,
kwargs: dict,
finalize: bool = True
) -> None:
self.pbar.reset(total=10, mode=mode)
if mode == "local":
assert self._global_explanation is not None
Expand Down Expand Up @@ -187,7 +216,9 @@ def run(self, mode, kwargs, finalize=True):
# Remove unpickable attributes
self.pbar = None

def _finalize(self):
def _finalize(
self
) -> None:
mode = self._run_mode
if mode == "local":
assert self._global_explanation is not None
Expand Down
78 changes: 62 additions & 16 deletions src/explainers/explainer_metrics.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from typing import Type

import numpy as np
import torch
from torch_geometric.utils import subgraph


class NodesExplainerMetric:
def __init__(self, model, graph, explainer, kwargs_dict):
def __init__(
self,
model: Type,
graph,
explainer,
kwargs_dict: dict
):
self.model = model
self.explainer = explainer
self.graph = graph
Expand All @@ -21,7 +29,10 @@ def __init__(self, model, graph, explainer, kwargs_dict):
self.dictionary = {
}

def evaluate(self, target_nodes_indices):
def evaluate(
self,
target_nodes_indices: list
) -> dict:
num_targets = len(target_nodes_indices)
sparsity = 0
stability = 0
Expand All @@ -46,7 +57,10 @@ def evaluate(self, target_nodes_indices):
self.dictionary["fidelity"] = fidelity
return self.dictionary

def calculate_fidelity(self, target_nodes_indices):
def calculate_fidelity(
self,
target_nodes_indices: list
) -> float:
original_answer = self.model.get_answer(self.x, self.edge_index)
same_answers_count = 0
for node_ind in target_nodes_indices:
Expand All @@ -62,19 +76,22 @@ def calculate_fidelity(self, target_nodes_indices):
fidelity = same_answers_count / len(target_nodes_indices)
return fidelity

def calculate_sparsity(self, node_ind):
def calculate_sparsity(
self,
node_ind: int
) -> float:
explanation = self.get_explanation(node_ind)
sparsity = 1 - (len(explanation["data"]["nodes"]) + len(explanation["data"]["edges"])) / (
len(self.x) + len(self.edge_index))
return sparsity

def calculate_stability(
self,
node_ind,
graph_perturbations_nums=10,
feature_change_percent=0.05,
node_removal_percent=0.05
):
node_ind: int,
graph_perturbations_nums: int = 10,
feature_change_percent: float = 0.05,
node_removal_percent: float = 0.05
) -> float:
base_explanation = self.get_explanation(node_ind)
stability = 0
for _ in range(graph_perturbations_nums):
Expand All @@ -90,7 +107,11 @@ def calculate_stability(
stability = stability / graph_perturbations_nums
return stability

def calculate_consistency(self, node_ind, num_explanation_runs=10):
def calculate_consistency(
self,
node_ind: int,
num_explanation_runs: int = 10
) -> float:
explanation = self.get_explanation(node_ind)
consistency = 0
for _ in range(num_explanation_runs):
Expand All @@ -103,13 +124,22 @@ def calculate_consistency(self, node_ind, num_explanation_runs=10):
consistency = consistency / num_explanation_runs
return consistency

def calculate_explanation(self, x, edge_index, node_idx, **kwargs):
def calculate_explanation(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
node_idx: int,
**kwargs
):
print(f"Processing explanation calculation for node id {node_idx}.")
self.explainer.evaluate_tensor_graph(x, edge_index, node_idx, **kwargs)
print(f"Explanation calculation for node id {node_idx} completed.")
return self.explainer.explanation.dictionary

def get_explanation(self, node_ind):
def get_explanation(
self,
node_ind: int
):
if node_ind in self.nodes_explanations:
node_explanation = self.nodes_explanations[node_ind]
else:
Expand All @@ -118,7 +148,9 @@ def get_explanation(self, node_ind):
return node_explanation

@staticmethod
def parse_explanation(explanation):
def parse_explanation(
explanation: dict
) -> [dict, dict]:
important_nodes = {
int(node): float(weight) for node, weight in explanation["data"]["nodes"].items()
}
Expand All @@ -129,7 +161,12 @@ def parse_explanation(explanation):
return important_nodes, important_edges

@staticmethod
def filter_graph_by_explanation(x, edge_index, explanation, target_node):
def filter_graph_by_explanation(
x: torch.Tensor,
edge_index: torch.Tensor,
explanation: dict,
target_node: int
) -> [torch.Tensor, torch.Tensor, int]:
important_nodes, important_edges = NodesExplainerMetric.parse_explanation(explanation)
all_important_nodes = set(important_nodes.keys())
all_important_nodes.add(target_node)
Expand All @@ -147,7 +184,10 @@ def filter_graph_by_explanation(x, edge_index, explanation, target_node):
return new_x, new_edge_index, new_target_node

@staticmethod
def calculate_explanation_vectors(base_explanation, perturbed_explanation):
def calculate_explanation_vectors(
base_explanation,
perturbed_explanation
):
base_important_nodes, base_important_edges = NodesExplainerMetric.parse_explanation(
base_explanation
)
Expand All @@ -171,7 +211,13 @@ def calculate_explanation_vectors(base_explanation, perturbed_explanation):
return base_explanation_vector, perturbed_explanation_vector

@staticmethod
def perturb_graph(x, edge_index, node_ind, feature_change_percent, node_removal_percent):
def perturb_graph(
x: torch.Tensor,
edge_index: torch.Tensor,
node_ind: int,
feature_change_percent: float,
node_removal_percent: float
) -> [torch.Tensor, torch.Tensor]:
new_x = x.clone()
num_nodes = x.shape[0]
num_features = x.shape[1]
Expand Down
Loading

0 comments on commit 45220f9

Please sign in to comment.