Skip to content

Commit 669c406

Browse files
small adds
1 parent c3ede99 commit 669c406

File tree

2 files changed

+102
-0
lines changed

2 files changed

+102
-0
lines changed

src/models_builder/models_zoo.py

+99
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,105 @@ def model_configs_zoo(
161161
)
162162
)
163163

164+
gat_gcn_sage_gcn_gcn = FrameworkGNNConstructor(
165+
model_config=ModelConfig(
166+
structure=ModelStructureConfig(
167+
[
168+
{
169+
'label': 'n',
170+
'layer': {
171+
'layer_name': 'GATConv',
172+
'layer_kwargs': {
173+
'in_channels': dataset.num_node_features,
174+
'out_channels': 16,
175+
'heads': 3,
176+
},
177+
},
178+
'batchNorm': {
179+
'batchNorm_name': 'BatchNorm1d',
180+
'batchNorm_kwargs': {
181+
'num_features': 48,
182+
'eps': 1e-05,
183+
}
184+
},
185+
'activation': {
186+
'activation_name': 'ReLU',
187+
'activation_kwargs': {
188+
"negative_slope": 0.01
189+
},
190+
},
191+
},
192+
193+
{
194+
'label': 'n',
195+
'layer': {
196+
'layer_name': 'GCNConv',
197+
'layer_kwargs': {
198+
'in_channels': 48,
199+
'out_channels': dataset.num_classes,
200+
},
201+
},
202+
'activation': {
203+
'activation_name': 'LeakyReLU',
204+
'activation_kwargs': None,
205+
},
206+
},
207+
208+
{
209+
'label': 'n',
210+
'layer': {
211+
'layer_name': 'SAGEConv',
212+
'layer_kwargs': {
213+
'in_channels': dataset.num_node_features,
214+
'out_channels': 16,
215+
},
216+
},
217+
'activation': {
218+
'activation_name': 'Tanh',
219+
'activation_kwargs': None,
220+
},
221+
'dropout': {
222+
'dropout_name': 'Dropout',
223+
'dropout_kwargs': {
224+
'p': 0.5,
225+
}
226+
}
227+
},
228+
229+
{
230+
'label': 'n',
231+
'layer': {
232+
'layer_name': 'GCNConv',
233+
'layer_kwargs': {
234+
'in_channels': dataset.num_node_features,
235+
'out_channels': 16,
236+
},
237+
},
238+
'activation': {
239+
'activation_name': 'Sigmoid',
240+
'activation_kwargs': None,
241+
},
242+
},
243+
244+
{
245+
'label': 'n',
246+
'layer': {
247+
'layer_name': 'GCNConv',
248+
'layer_kwargs': {
249+
'in_channels': dataset.num_node_features,
250+
'out_channels': 16,
251+
},
252+
},
253+
'activation': {
254+
'activation_name': 'LogSoftmax',
255+
'activation_kwargs': None,
256+
},
257+
},
258+
]
259+
)
260+
)
261+
)
262+
164263
test_gnn = FrameworkGNNConstructor(
165264
model_config=ModelConfig(
166265
structure=ModelStructureConfig(

tests/models_test.py

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import collections
2+
import collections.abc
3+
collections.Callable = collections.abc.Callable
14
import unittest
25
import shutil
36
import signal

0 commit comments

Comments
 (0)