Skip to content

Commit

Permalink
Merge pull request #32 from ispras/gnnexplainer_check
Browse files Browse the repository at this point in the history
fix gnn_explainer work with features
  • Loading branch information
LukyanovKirillML authored Oct 29, 2024
2 parents f25a047 + 231326f commit 781c01a
Showing 1 changed file with 38 additions and 19 deletions.
57 changes: 38 additions & 19 deletions src/explainers/GNNExplainer/torch_geom_our/out.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,46 +122,65 @@ def _finalize(self):

self.explanation = AttributionExplanation(
local=mode,
edges="continuous" if edge_mask is not None else False,
features="continuous" if node_mask is not None else False)
edges="continuous" if self.edge_mask_type=="object" else False,
nodes="continuous" if self.node_mask_type=="object" else False,
features="continuous" if self.node_mask_type=="common_attributes" else False)

important_edges = {}
important_nodes = {}
important_features = {}

# TODO What if edge_mask_type or node_mask_type is None, common_attributes, attributes?
if self.edge_mask_type is not None and self.node_mask_type is not None:

# Multi graphs check is not needed: the explanation format for
# graph classification and node classification is the same
eps = 0.001

# Edges
num_edges = edge_mask.size(0)
assert num_edges == self.edge_index.size(1)
edges = self.edge_index

for i in range(num_edges):
imp = float(edge_mask[i])
if not imp < eps:
edge = edges[0][i], edges[1][i]
important_edges[f"{edge[0]},{edge[1]}"] = format(imp, '.4f')
if self.edge_mask_type=="object":
num_edges = edge_mask.size(0)
assert num_edges == self.edge_index.size(1)
edges = self.edge_index

for i in range(num_edges):
imp = float(edge_mask[i])
if not imp < eps:
edge = edges[0][i], edges[1][i]
important_edges[f"{edge[0]},{edge[1]}"] = format(imp, '.4f')
else: # if "common_attributes" or "attributes"
raise NotImplementedError(f"Edge mask type '{self.edge_mask_type}' is not yet implemented.")

# Nodes
num_nodes = node_mask.size(0)
assert num_nodes == self.x.size(0)

for i in range(num_nodes):
imp = float(node_mask[i][0])
if not imp < eps:
important_nodes[i] = format(imp, '.4f')
if self.node_mask_type=="object":
num_nodes = node_mask.size(0)
assert num_nodes == self.x.size(0)

for i in range(num_nodes):
imp = float(node_mask[i][0])
if not imp < eps:
important_nodes[i] = format(imp, '.4f')
# Features
elif self.node_mask_type=="common_attributes":
num_features = node_mask.size(1)
assert num_features == self.x.size(1)

for i in range(num_features):
imp = float(node_mask[0][i])
if not imp < eps:
important_features[i] = format(imp, '.4f')
else: # if "attributes"
# TODO add functional if node_mask_type=="attributes"
raise NotImplementedError(f"Node mask type '{self.node_mask_type}' is not yet implemented.")

if self.gen_dataset.is_multi():
important_edges = {self.graph_idx: important_edges}
important_nodes = {self.graph_idx: important_nodes}
important_features = {self.graph_idx: important_features}

# TODO Write functions with output threshold
self.explanation.add_edges(important_edges)
self.explanation.add_nodes(important_nodes)
self.explanation.add_features(important_features)

# print(important_edges)
# print(important_nodes)
Expand Down

0 comments on commit 781c01a

Please sign in to comment.