-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathBaseGNN.py
138 lines (114 loc) · 4.82 KB
/
BaseGNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
Base classes for Graph Neural Networks
"""
import torch
import torch.nn as nn
from torch_geometric.data.batch import Batch
from torch import Tensor
from Pooling import GlobalMeanPool, GlobalMaxPool, IdenticalPool
from torch.nn import Identity
class GNNBasic(torch.nn.Module):
r"""
Base class for graph neural networks
Args:
*args (list): argument list for the use of :func:`~arguments_read`
**kwargs (dict): key word arguments for the use of :func:`~arguments_read`
"""
def __init__(self, config, *args, **kwargs):
super(GNNBasic, self).__init__()
self.config = config
def arguments_read(self, *args, **kwargs):
r"""
It is an argument reading function for diverse model input formats.
Support formats are:
``model(x, edge_index)``
``model(x, edge_index, batch)``
``model(data=data)``.
Notes:
edge_weight is optional for node prediction tasks.
Args:
*args: [x, edge_index, [batch]]
**kwargs: data, [edge_weight]
Returns:
Unpacked node features, sparse adjacency matrices, batch indicators, and optional edge weights.
"""
data: Batch = kwargs.get('data') or None
if not data:
if not args:
assert 'x' in kwargs
assert 'edge_index' in kwargs
x, edge_index = kwargs['x'], kwargs['edge_index'],
batch = kwargs.get('batch')
if batch is None:
batch = torch.zeros(kwargs['x'].shape[0], dtype=torch.int64, device=torch.device('cuda'))
elif len(args) == 2:
x, edge_index, batch = args[0], args[1], \
torch.zeros(args[0].shape[0], dtype=torch.int64, device=torch.device('cuda'))
elif len(args) == 3:
x, edge_index, batch = args[0], args[1], args[2]
else:
raise ValueError(f"forward's args should take 2 or 3 arguments but got {len(args)}")
else:
x, edge_index, batch = data.x, data.edge_index, data.batch
if self.config.model.model_level != 'node':
# --- Maybe batch size --- Reason: some method may filter graphs leading inconsistent of batch size
batch_size: int = kwargs.get('batch_size') or (batch[-1].item() + 1)
if self.config.model.model_level == 'node':
edge_weight = kwargs.get('edge_weight')
return x, edge_index, edge_weight, batch
elif self.config.dataset.dim_edge:
edge_attr = data.edge_attr
return x, edge_index, edge_attr, batch, batch_size
return x, edge_index, batch, batch_size
def probs(self, *args, **kwargs):
# nodes x classes
return self(*args, **kwargs).softmax(dim=1)
class BasicEncoder(torch.nn.Module):
r"""
Base GNN feature encoder.
Args:
config (Union[CommonArgs, Munch]): munchified dictionary of args (:obj:`config.model.dim_hidden`, :obj:`config.model.model_layer`, :obj:`config.model.model_level`, :obj:`config.model.global_pool`, :obj:`config.model.dropout_rate`)
.. code-block:: python
config = munchify({model: {dim_hidden: int(300),
model_layer: int(5),
model_level: str('node'),
global_pool: str('mean'),
dropout_rate: float(0.5),}
})
"""
def __init__(self, config, **kwargs):
if type(self).mro()[type(self).mro().index(__class__) + 1] is torch.nn.Module:
super(BasicEncoder, self).__init__()
else:
super(BasicEncoder, self).__init__(config)
num_layer = config.model.model_layer
self.relu1 = nn.ReLU()
self.relus = nn.ModuleList(
[
nn.ReLU()
for _ in range(num_layer - 1)
]
)
if kwargs.get('no_bn'):
self.batch_norm1 = Identity()
self.batch_norms = [
Identity()
for _ in range(num_layer - 1)
]
else:
self.batch_norm1 = nn.BatchNorm1d(config.model.dim_hidden)
self.batch_norms = nn.ModuleList([
nn.BatchNorm1d(config.model.dim_hidden)
for _ in range(num_layer - 1)
])
self.dropout1 = nn.Dropout(config.model.dropout_rate)
self.dropouts = nn.ModuleList([
nn.Dropout(config.model.dropout_rate)
for _ in range(num_layer - 1)
])
if config.model.model_level == 'node':
self.readout = IdenticalPool()
elif config.model.global_pool == 'mean':
self.readout = GlobalMeanPool()
else:
self.readout = GlobalMaxPool()