Skip to content

Commit d133c93

Browse files
committed
Added bespoke wrappers for
graph kernels that allow node/ edge label checking.
1 parent 37d17ce commit d133c93

File tree

2 files changed

+296
-33
lines changed

2 files changed

+296
-33
lines changed

Diff for: gauche/gp.py

+12-33
Original file line numberDiff line numberDiff line change
@@ -2,65 +2,44 @@
22
from functools import lru_cache
33

44
import torch
5-
import gpytorch
65

7-
from gpytorch import Module, settings
6+
from gpytorch import settings
87
from gpytorch.distributions import MultivariateNormal
98
from gpytorch.likelihoods import _GaussianLikelihoodBase
109
from gpytorch.models.exact_prediction_strategies import prediction_strategy
1110
from gpytorch.models import ExactGP
1211

13-
Softplus = torch.nn.Softplus()
1412

15-
16-
class Inputs:
13+
class NonTensorialInputs:
1714
def __init__(self, data):
1815
self.data = data
1916

2017
def append(self, new_data):
2118
self.data.extend(new_data.data)
2219

20+
def __iter__(self):
21+
return iter(self.data)
2322

24-
class GraphKernel(Module):
25-
"""
26-
A class suporting externel kernels.
27-
The external kernel must have a method `fit_transform`, which, when
28-
evaluated on an `Inputs` instance `X`, returns a scaled kernel matrix
29-
v * k(X, X).
30-
31-
As gradients are not propagated through to the external kernel, outputs are
32-
cached to avoid repeated computation.
33-
"""
34-
35-
def __init__(self, graph_kernel, dtype=torch.float):
36-
super().__init__()
37-
self._scale_variance = torch.nn.Parameter(torch.tensor([0.1], dtype=dtype))
38-
self.kernel = graph_kernel
39-
40-
def scale(self, S):
41-
return Softplus(self._scale_variance) * S
42-
43-
def forward(self, X):
44-
return self.scale(self.kern(X))
23+
def __len__(self):
24+
return len(self.data)
4525

46-
@lru_cache(maxsize=5)
47-
def kern(self, X):
48-
return torch.tensor(self.kernel.fit_transform(X.data)).float()
26+
def __getitem__(self, idx):
27+
return self.data[idx]
4928

5029

5130
class SIGP(ExactGP):
5231
"""
53-
A reimplementation of gpytorch(==1.7.0)'s ExactGP that allows for non-tensorial inputs.
54-
The inputs to this class may be a gauche.gp.Inputs instance, with graphs stored within
55-
the object's .data attribute.
32+
A reimplementation of gpytorch's ExactGP that allows for non-tensorial inputs.
33+
The inputs to this class may be a gauche.NonTensorialInputs instance, with graphs
34+
stored within the object's .data attribute.
5635
5736
In the longer term, if ExactGP can be refactored such that the validation checks ensuring
5837
that the inputs are torch.Tensors are optional, this class should subclass ExactGP without
5938
performing those checks.
6039
"""
6140

6241
def __init__(self, train_inputs, train_targets, likelihood):
63-
if train_inputs is not None and type(train_inputs) is Inputs:
42+
if train_inputs is not None and type(train_inputs) is NonTensorialInputs:
6443
train_inputs = (train_inputs,)
6544
if not isinstance(likelihood, _GaussianLikelihoodBase):
6645
raise RuntimeError("SIGP can only handle Gaussian likelihoods")

Diff for: gauche/kernels/graph_kernels/grakel_kernels.py

+284
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
from typing import List, Optional
2+
3+
import torch
4+
import networkx as nx
5+
from functools import lru_cache
6+
from gpytorch import Module
7+
8+
from grakel import graph_from_networkx
9+
from grakel.kernels import (
10+
VertexHistogram,
11+
EdgeHistogram,
12+
WeisfeilerLehman,
13+
NeighborhoodHash,
14+
RandomWalk,
15+
RandomWalkLabeled,
16+
ShortestPath,
17+
GraphletSampling,
18+
)
19+
20+
21+
class _GraphKernel(Module):
22+
"""
23+
A base class suporting external graph kernels.
24+
The external kernel must have a method `fit_transform`, which, when
25+
evaluated on an `Inputs` instance `X`, returns a scaled kernel matrix
26+
v * k(X, X).
27+
28+
As gradients are not propagated through to the external kernel, outputs are
29+
cached to avoid repeated computation.
30+
"""
31+
32+
def __init__(
33+
self,
34+
dtype=torch.float,
35+
) -> None:
36+
super().__init__()
37+
self.node_label = None
38+
self.edge_label = None
39+
self._scale_variance = torch.nn.Parameter(torch.tensor([0.1], dtype=dtype))
40+
41+
def scale(self, S: torch.Tensor) -> torch.Tensor:
42+
return torch.nn.functional.softplus(self._scale_variance) * S
43+
44+
def forward(self, X: torch.Tensor) -> torch.Tensor:
45+
return self.scale(self.kernel(X))
46+
47+
def kernel(self, X: torch.Tensor) -> torch.Tensor:
48+
raise NotImplementedError("Subclasses must implement this method.")
49+
50+
51+
class VertexHistogramKernel(_GraphKernel):
52+
"""
53+
A GraKel wrapper for the vertex histogram kernel.
54+
This kernel requires node labels to be specified.
55+
56+
See https://ysig.github.io/GraKeL/0.1a8/kernels/vertex_histogram.html
57+
for more details.
58+
"""
59+
60+
def __init__(
61+
self,
62+
node_label: str,
63+
dtype=torch.float,
64+
):
65+
super().__init__(dtype=dtype)
66+
self.node_label = node_label
67+
68+
@lru_cache(maxsize=5)
69+
def kernel(self, X: List[nx.Graph], **grakel_kwargs) -> torch.Tensor:
70+
# extract required data from the networkx graphs
71+
# constructed with the Graphein utilities
72+
# this is cheap and will be cached
73+
X = graph_from_networkx(
74+
X, node_labels_tag=self.node_label, edge_labels_tag=self.edge_label
75+
)
76+
77+
return torch.tensor(VertexHistogram(**grakel_kwargs).fit_transform(X)).float()
78+
79+
80+
class EdgeHistogramKernel(_GraphKernel):
81+
"""
82+
A GraKel wrapper for the edge histogram kernel.
83+
This kernel requires edge labels to be specified.
84+
85+
See https://ysig.github.io/GraKeL/0.1a8/kernels/edge_histogram.html
86+
for more details.
87+
"""
88+
89+
def __init__(self, edge_label, dtype=torch.float):
90+
super().__init__(dtype=dtype)
91+
self.edge_label = edge_label
92+
93+
@lru_cache(maxsize=5)
94+
def kernel(self, X: torch.Tensor, **grakel_kwargs) -> torch.Tensor:
95+
# extract required data from the networkx graphs
96+
# constructed with the Graphein utilities
97+
# this is cheap and will be cached
98+
X = graph_from_networkx(
99+
X, node_labels_tag=self.node_label, edge_labels_tag=self.edge_label
100+
)
101+
102+
return torch.tensor(EdgeHistogram(**grakel_kwargs).fit_transform(X)).float()
103+
104+
105+
class WeisfeilerLehmanKernel(_GraphKernel):
106+
"""
107+
A GraKel wrapper for the Weisfeiler-Lehman kernel.
108+
This kernel needs node labels to be specified and
109+
can optionally use edge labels for the base kernel.
110+
111+
See https://ysig.github.io/GraKeL/0.1a8/kernels/weisfeiler_lehman.html
112+
for more details.
113+
"""
114+
115+
def __init__(
116+
self, node_label: str, edge_label: Optional[str] = None, dtype=torch.float
117+
):
118+
super().__init__(dtype=dtype)
119+
self.node_label = node_label
120+
self.edge_label = edge_label
121+
122+
@lru_cache(maxsize=5)
123+
def kernel(self, X: torch.Tensor, **grakel_kwargs) -> torch.Tensor:
124+
# extract required data from the networkx graphs
125+
# constructed with the Graphein utilities
126+
# this is cheap and will be cached
127+
X = graph_from_networkx(
128+
X, node_labels_tag=self.node_label, edge_labels_tag=self.edge_label
129+
)
130+
131+
return torch.tensor(WeisfeilerLehman(**grakel_kwargs).fit_transform(X)).float()
132+
133+
134+
class NeighborhoodHashKernel(_GraphKernel):
135+
"""
136+
A GraKel wrapper for the neighborhood hash kernel.
137+
This kernel requires node labels to be specified.
138+
139+
See https://ysig.github.io/GraKeL/0.1a8/kernels/neighborhood_hash.html
140+
for more details.
141+
"""
142+
143+
def __init__(self, node_label: str, dtype=torch.float):
144+
super().__init__(dtype=dtype)
145+
self.node_label = node_label
146+
147+
@lru_cache(maxsize=5)
148+
def kernel(self, X: torch.Tensor, **grakel_kwargs) -> torch.Tensor:
149+
# extract required data from the networkx graphs
150+
# constructed with the Graphein utilities
151+
# this is cheap and will be cached
152+
X = graph_from_networkx(
153+
X, node_labels_tag=self.node_label, edge_labels_tag=self.edge_label
154+
)
155+
156+
return torch.tensor(NeighborhoodHash(**grakel_kwargs).fit_transform(X)).float()
157+
158+
159+
class RandomWalkKernel(_GraphKernel):
160+
"""
161+
A GraKel wrapper for the random walk kernel.
162+
This kernel only works on unlabelled graphs.
163+
See RandomWalkLabeledKernel for labelled graphs.
164+
165+
See https://ysig.github.io/GraKeL/0.1a8/kernels/random_walk.html
166+
for more details.
167+
"""
168+
169+
def __init__(self, dtype=torch.float):
170+
super().__init__(dtype=dtype)
171+
172+
@lru_cache(maxsize=5)
173+
def kernel(self, X: torch.Tensor, **grakel_kwargs) -> torch.Tensor:
174+
# extract required data from the networkx graphs
175+
# constructed with the Graphein utilities
176+
# this is cheap and will be cached
177+
X = graph_from_networkx(
178+
X, node_labels_tag=self.node_label, edge_labels_tag=self.edge_label
179+
)
180+
181+
return torch.tensor(RandomWalk(**grakel_kwargs).fit_transform(X)).float()
182+
183+
184+
class RandomWalkLabeledKernel(_GraphKernel):
185+
"""
186+
A GraKel wrapper for the random walk kernel.
187+
This kernel requires node labels to be specified.
188+
189+
See https://ysig.github.io/GraKeL/0.1a8/kernels/random_walk.html
190+
for more details.
191+
"""
192+
193+
def __init__(self, node_label: str, dtype=torch.float):
194+
super().__init__(dtype=dtype)
195+
self.node_label = node_label
196+
197+
@lru_cache(maxsize=5)
198+
def kernel(self, X: torch.Tensor, **grakel_kwargs) -> torch.Tensor:
199+
# extract required data from the networkx graphs
200+
# constructed with the Graphein utilities
201+
# this is cheap and will be cached
202+
X = graph_from_networkx(
203+
X, node_labels_tag=self.node_label, edge_labels_tag=self.edge_label
204+
)
205+
206+
return torch.tensor(RandomWalkLabeled(**grakel_kwargs).fit_transform(X)).float()
207+
208+
209+
class ShortestPathKernel(_GraphKernel):
210+
"""
211+
A GraKel wrapper for the shortest path kernel.
212+
This kernel only works on unlabelled graphs.
213+
See ShortestPathLabeledKernel for labelled graphs.
214+
215+
See https://ysig.github.io/GraKeL/0.1a8/kernels/shortest_path.html
216+
for more details.
217+
"""
218+
219+
def __init__(self, dtype=torch.float):
220+
super().__init__(dtype=dtype)
221+
222+
@lru_cache(maxsize=5)
223+
def kernel(self, X: torch.Tensor, **grakel_kwargs) -> torch.Tensor:
224+
# extract required data from the networkx graphs
225+
# constructed with the Graphein utilities
226+
# this is cheap and will be cached
227+
X = graph_from_networkx(
228+
X, node_labels_tag=self.node_label, edge_labels_tag=self.edge_label
229+
)
230+
231+
return torch.tensor(
232+
ShortestPath(**grakel_kwargs, with_labels=False).fit_transform(X)
233+
).float()
234+
235+
236+
class ShortestPathLabeledKernel(_GraphKernel):
237+
"""
238+
A GraKel wrapper for the shortest path kernel.
239+
This kernel requires node labels to be specified.
240+
241+
See https://ysig.github.io/GraKeL/0.1a8/kernels/shortest_path.html
242+
for more details.
243+
"""
244+
245+
def __init__(self, node_label: str, dtype=torch.float):
246+
super().__init__(dtype=dtype)
247+
self.node_label = node_label
248+
249+
@lru_cache(maxsize=5)
250+
def kernel(self, X: torch.Tensor, **grakel_kwargs) -> torch.Tensor:
251+
# extract required data from the networkx graphs
252+
# constructed with the Graphein utilities
253+
# this is cheap and will be cached
254+
X = graph_from_networkx(
255+
X, node_labels_tag=self.node_label, edge_labels_tag=self.edge_label
256+
)
257+
258+
return torch.tensor(
259+
ShortestPath(**grakel_kwargs, with_labels=True).fit_transform(X)
260+
).float()
261+
262+
263+
class GraphletSamplingKernel(_GraphKernel):
264+
"""
265+
A GraKel wrapper for the graphlet sampling kernel.
266+
This kernel only works on unlabelled graphs.
267+
268+
See https://ysig.github.io/GraKeL/0.1a8/kernels/graphlet_sampling.html
269+
for more details.
270+
"""
271+
272+
def __init__(self, dtype=torch.float):
273+
super().__init__(dtype=dtype)
274+
275+
@lru_cache(maxsize=5)
276+
def kernel(self, X: torch.Tensor, **grakel_kwargs) -> torch.Tensor:
277+
# extract required data from the networkx graphs
278+
# constructed with the Graphein utilities
279+
# this is cheap and will be cached
280+
X = graph_from_networkx(
281+
X, node_labels_tag=self.node_label, edge_labels_tag=self.edge_label
282+
)
283+
284+
return torch.tensor(GraphletSampling(**grakel_kwargs).fit_transform(X)).float()

0 commit comments

Comments
 (0)