From 92148c88f6460488a50871932117db9cdb4e2c06 Mon Sep 17 00:00:00 2001 From: abhhfcgjk Date: Fri, 18 Oct 2024 13:43:42 +0300 Subject: [PATCH] edge_weight parameter fix --- metainfo/modules_parameters.json | 8 ++++---- src/models_builder/gnn_constructor.py | 5 ++++- src/models_builder/gnn_models.py | 12 ++++++++---- tests/explainers_test.py | 3 +++ 4 files changed, 19 insertions(+), 9 deletions(-) diff --git a/metainfo/modules_parameters.json b/metainfo/modules_parameters.json index ae76138..5058167 100644 --- a/metainfo/modules_parameters.json +++ b/metainfo/modules_parameters.json @@ -23,7 +23,7 @@ { "import_info": ["SAGEConv", ["torch_geometric.nn"]], "need_full_gnn_flag": false, - "forward_parameters": "x=x, edge_index=edge_index, edge_weight=edge_weight" + "forward_parameters": "x=x, edge_index=edge_index" } }, "GATConv": { @@ -40,7 +40,7 @@ { "import_info": ["GATConv", ["torch_geometric.nn"]], "need_full_gnn_flag": false, - "forward_parameters": "x=x, edge_index=edge_index, edge_weight=edge_weight" + "forward_parameters": "x=x, edge_index=edge_index" } }, "SGConv": { @@ -62,7 +62,7 @@ { "import_info": ["GINConv", ["torch_geometric.nn"]], "need_full_gnn_flag": false, - "forward_parameters": "x=x, edge_index=edge_index, edge_weight=edge_weight" + "forward_parameters": "x=x, edge_index=edge_index" } }, "TAGConv": { @@ -118,7 +118,7 @@ { "import_info": ["GMM", ["models_builder.custom_layers"]], "need_full_gnn_flag": false, - "forward_parameters": "x=x, edge_index=edge_index, edge_weight=edge_weight" + "forward_parameters": "x=x, edge_index=edge_index" } }, "CGConv": { diff --git a/src/models_builder/gnn_constructor.py b/src/models_builder/gnn_constructor.py index 7c3abd4..02ada4d 100644 --- a/src/models_builder/gnn_constructor.py +++ b/src/models_builder/gnn_constructor.py @@ -554,6 +554,9 @@ def arguments_read(*args, **kwargs): else: raise ValueError(f"forward's args should take 2 or 3 arguments but got {len(args)}") else: - x, edge_index, batch, edge_weight = data.x, data.edge_index, data.batch, data.edge_weight + if hasattr(data, "edge_weight"): + x, edge_index, batch, edge_weight = data.x, data.edge_index, data.batch, data.edge_weight + else: + x, edge_index, batch, edge_weight = data.x, data.edge_index, data.batch, None return x, edge_index, batch, edge_weight diff --git a/src/models_builder/gnn_models.py b/src/models_builder/gnn_models.py index a9aec12..acfe7b0 100644 --- a/src/models_builder/gnn_models.py +++ b/src/models_builder/gnn_models.py @@ -821,9 +821,13 @@ def train_on_batch(self, batch, task_type=None): if self.evasion_defender: self.evasion_defender.pre_batch(model_manager=self, batch=batch) loss = None + if hasattr(batch, "edge_weight"): + weight = batch.edge_weight + else: + weight = None if task_type == "single-graph": self.optimizer.zero_grad() - logits = self.gnn(batch.x, batch.edge_index) + logits = self.gnn(batch.x, batch.edge_index, weight) loss = self.loss_function(logits, batch.y) if self.clip is not None: clip_grad_norm(self.gnn.parameters(), self.clip) @@ -832,7 +836,7 @@ def train_on_batch(self, batch, task_type=None): # self.optimizer.step() elif task_type == "multiple-graphs": self.optimizer.zero_grad() - logits = self.gnn(batch.x, batch.edge_index, batch.batch) + logits = self.gnn(batch.x, batch.edge_index, batch.batch, weight) loss = self.loss_function(logits, batch.y) # loss.backward() # self.optimizer.step() @@ -843,8 +847,8 @@ def train_on_batch(self, batch, task_type=None): pos_edge_index = edge_index[:, batch.y == 1] neg_edge_index = edge_index[:, batch.y == 0] - pos_out = self.gnn(batch.x, pos_edge_index) - neg_out = self.gnn(batch.x, neg_edge_index) + pos_out = self.gnn(batch.x, pos_edge_index, weight) + neg_out = self.gnn(batch.x, neg_edge_index, weight) pos_loss = self.loss_function(pos_out, torch.ones_like(pos_out)) neg_loss = self.loss_function(neg_out, torch.zeros_like(neg_out)) diff --git a/tests/explainers_test.py b/tests/explainers_test.py index 15ddd78..514d994 100644 --- a/tests/explainers_test.py +++ b/tests/explainers_test.py @@ -1,6 +1,9 @@ import collections import collections.abc collections.Callable = collections.abc.Callable +import sys +import os +sys.path.append(f"{os.getcwd()}/src") import unittest import warnings