Skip to content

Commit

Permalink
edge_weight parameter fix
Browse files Browse the repository at this point in the history
  • Loading branch information
abhhfcgjk committed Oct 18, 2024
1 parent f9cb1ea commit 92148c8
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 9 deletions.
8 changes: 4 additions & 4 deletions metainfo/modules_parameters.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
{
"import_info": ["SAGEConv", ["torch_geometric.nn"]],
"need_full_gnn_flag": false,
"forward_parameters": "x=x, edge_index=edge_index, edge_weight=edge_weight"
"forward_parameters": "x=x, edge_index=edge_index"
}
},
"GATConv": {
Expand All @@ -40,7 +40,7 @@
{
"import_info": ["GATConv", ["torch_geometric.nn"]],
"need_full_gnn_flag": false,
"forward_parameters": "x=x, edge_index=edge_index, edge_weight=edge_weight"
"forward_parameters": "x=x, edge_index=edge_index"
}
},
"SGConv": {
Expand All @@ -62,7 +62,7 @@
{
"import_info": ["GINConv", ["torch_geometric.nn"]],
"need_full_gnn_flag": false,
"forward_parameters": "x=x, edge_index=edge_index, edge_weight=edge_weight"
"forward_parameters": "x=x, edge_index=edge_index"
}
},
"TAGConv": {
Expand Down Expand Up @@ -118,7 +118,7 @@
{
"import_info": ["GMM", ["models_builder.custom_layers"]],
"need_full_gnn_flag": false,
"forward_parameters": "x=x, edge_index=edge_index, edge_weight=edge_weight"
"forward_parameters": "x=x, edge_index=edge_index"
}
},
"CGConv": {
Expand Down
5 changes: 4 additions & 1 deletion src/models_builder/gnn_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,9 @@ def arguments_read(*args, **kwargs):
else:
raise ValueError(f"forward's args should take 2 or 3 arguments but got {len(args)}")
else:
x, edge_index, batch, edge_weight = data.x, data.edge_index, data.batch, data.edge_weight
if hasattr(data, "edge_weight"):
x, edge_index, batch, edge_weight = data.x, data.edge_index, data.batch, data.edge_weight
else:
x, edge_index, batch, edge_weight = data.x, data.edge_index, data.batch, None

return x, edge_index, batch, edge_weight
12 changes: 8 additions & 4 deletions src/models_builder/gnn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,9 +821,13 @@ def train_on_batch(self, batch, task_type=None):
if self.evasion_defender:
self.evasion_defender.pre_batch(model_manager=self, batch=batch)
loss = None
if hasattr(batch, "edge_weight"):
weight = batch.edge_weight
else:
weight = None
if task_type == "single-graph":
self.optimizer.zero_grad()
logits = self.gnn(batch.x, batch.edge_index)
logits = self.gnn(batch.x, batch.edge_index, weight)
loss = self.loss_function(logits, batch.y)
if self.clip is not None:
clip_grad_norm(self.gnn.parameters(), self.clip)
Expand All @@ -832,7 +836,7 @@ def train_on_batch(self, batch, task_type=None):
# self.optimizer.step()
elif task_type == "multiple-graphs":
self.optimizer.zero_grad()
logits = self.gnn(batch.x, batch.edge_index, batch.batch)
logits = self.gnn(batch.x, batch.edge_index, batch.batch, weight)
loss = self.loss_function(logits, batch.y)
# loss.backward()
# self.optimizer.step()
Expand All @@ -843,8 +847,8 @@ def train_on_batch(self, batch, task_type=None):
pos_edge_index = edge_index[:, batch.y == 1]
neg_edge_index = edge_index[:, batch.y == 0]

pos_out = self.gnn(batch.x, pos_edge_index)
neg_out = self.gnn(batch.x, neg_edge_index)
pos_out = self.gnn(batch.x, pos_edge_index, weight)
neg_out = self.gnn(batch.x, neg_edge_index, weight)

pos_loss = self.loss_function(pos_out, torch.ones_like(pos_out))
neg_loss = self.loss_function(neg_out, torch.zeros_like(neg_out))
Expand Down
3 changes: 3 additions & 0 deletions tests/explainers_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import collections
import collections.abc
collections.Callable = collections.abc.Callable
import sys
import os
sys.path.append(f"{os.getcwd()}/src")

import unittest
import warnings
Expand Down

0 comments on commit 92148c8

Please sign in to comment.