Skip to content

Commit

Permalink
Added bespoke wrappers for
Browse files Browse the repository at this point in the history
graph kernels that allow node/
edge label checking.
  • Loading branch information
leojklarner committed Nov 7, 2023
1 parent 37d17ce commit d133c93
Show file tree
Hide file tree
Showing 2 changed files with 296 additions and 33 deletions.
45 changes: 12 additions & 33 deletions gauche/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,65 +2,44 @@
from functools import lru_cache

import torch
import gpytorch

from gpytorch import Module, settings
from gpytorch import settings
from gpytorch.distributions import MultivariateNormal
from gpytorch.likelihoods import _GaussianLikelihoodBase
from gpytorch.models.exact_prediction_strategies import prediction_strategy
from gpytorch.models import ExactGP

Softplus = torch.nn.Softplus()


class Inputs:
class NonTensorialInputs:
def __init__(self, data):
self.data = data

def append(self, new_data):
self.data.extend(new_data.data)

def __iter__(self):
return iter(self.data)

class GraphKernel(Module):
"""
A class suporting externel kernels.
The external kernel must have a method `fit_transform`, which, when
evaluated on an `Inputs` instance `X`, returns a scaled kernel matrix
v * k(X, X).
As gradients are not propagated through to the external kernel, outputs are
cached to avoid repeated computation.
"""

def __init__(self, graph_kernel, dtype=torch.float):
super().__init__()
self._scale_variance = torch.nn.Parameter(torch.tensor([0.1], dtype=dtype))
self.kernel = graph_kernel

def scale(self, S):
return Softplus(self._scale_variance) * S

def forward(self, X):
return self.scale(self.kern(X))
def __len__(self):
return len(self.data)

@lru_cache(maxsize=5)
def kern(self, X):
return torch.tensor(self.kernel.fit_transform(X.data)).float()
def __getitem__(self, idx):
return self.data[idx]


class SIGP(ExactGP):
"""
A reimplementation of gpytorch(==1.7.0)'s ExactGP that allows for non-tensorial inputs.
The inputs to this class may be a gauche.gp.Inputs instance, with graphs stored within
the object's .data attribute.
A reimplementation of gpytorch's ExactGP that allows for non-tensorial inputs.
The inputs to this class may be a gauche.NonTensorialInputs instance, with graphs
stored within the object's .data attribute.
In the longer term, if ExactGP can be refactored such that the validation checks ensuring
that the inputs are torch.Tensors are optional, this class should subclass ExactGP without
performing those checks.
"""

def __init__(self, train_inputs, train_targets, likelihood):
if train_inputs is not None and type(train_inputs) is Inputs:
if train_inputs is not None and type(train_inputs) is NonTensorialInputs:
train_inputs = (train_inputs,)
if not isinstance(likelihood, _GaussianLikelihoodBase):
raise RuntimeError("SIGP can only handle Gaussian likelihoods")
Expand Down
284 changes: 284 additions & 0 deletions gauche/kernels/graph_kernels/grakel_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
from typing import List, Optional

import torch
import networkx as nx
from functools import lru_cache
from gpytorch import Module

from grakel import graph_from_networkx
from grakel.kernels import (
VertexHistogram,
EdgeHistogram,
WeisfeilerLehman,
NeighborhoodHash,
RandomWalk,
RandomWalkLabeled,
ShortestPath,
GraphletSampling,
)


class _GraphKernel(Module):
"""
A base class suporting external graph kernels.
The external kernel must have a method `fit_transform`, which, when
evaluated on an `Inputs` instance `X`, returns a scaled kernel matrix
v * k(X, X).
As gradients are not propagated through to the external kernel, outputs are
cached to avoid repeated computation.
"""

def __init__(
self,
dtype=torch.float,
) -> None:
super().__init__()
self.node_label = None
self.edge_label = None
self._scale_variance = torch.nn.Parameter(torch.tensor([0.1], dtype=dtype))

def scale(self, S: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.softplus(self._scale_variance) * S

def forward(self, X: torch.Tensor) -> torch.Tensor:
return self.scale(self.kernel(X))

def kernel(self, X: torch.Tensor) -> torch.Tensor:
raise NotImplementedError("Subclasses must implement this method.")


class VertexHistogramKernel(_GraphKernel):
"""
A GraKel wrapper for the vertex histogram kernel.
This kernel requires node labels to be specified.
See https://ysig.github.io/GraKeL/0.1a8/kernels/vertex_histogram.html
for more details.
"""

def __init__(
self,
node_label: str,
dtype=torch.float,
):
super().__init__(dtype=dtype)
self.node_label = node_label

@lru_cache(maxsize=5)
def kernel(self, X: List[nx.Graph], **grakel_kwargs) -> torch.Tensor:
# extract required data from the networkx graphs
# constructed with the Graphein utilities
# this is cheap and will be cached
X = graph_from_networkx(
X, node_labels_tag=self.node_label, edge_labels_tag=self.edge_label
)

return torch.tensor(VertexHistogram(**grakel_kwargs).fit_transform(X)).float()


class EdgeHistogramKernel(_GraphKernel):
"""
A GraKel wrapper for the edge histogram kernel.
This kernel requires edge labels to be specified.
See https://ysig.github.io/GraKeL/0.1a8/kernels/edge_histogram.html
for more details.
"""

def __init__(self, edge_label, dtype=torch.float):
super().__init__(dtype=dtype)
self.edge_label = edge_label

@lru_cache(maxsize=5)
def kernel(self, X: torch.Tensor, **grakel_kwargs) -> torch.Tensor:
# extract required data from the networkx graphs
# constructed with the Graphein utilities
# this is cheap and will be cached
X = graph_from_networkx(
X, node_labels_tag=self.node_label, edge_labels_tag=self.edge_label
)

return torch.tensor(EdgeHistogram(**grakel_kwargs).fit_transform(X)).float()


class WeisfeilerLehmanKernel(_GraphKernel):
"""
A GraKel wrapper for the Weisfeiler-Lehman kernel.
This kernel needs node labels to be specified and
can optionally use edge labels for the base kernel.
See https://ysig.github.io/GraKeL/0.1a8/kernels/weisfeiler_lehman.html
for more details.
"""

def __init__(
self, node_label: str, edge_label: Optional[str] = None, dtype=torch.float
):
super().__init__(dtype=dtype)
self.node_label = node_label
self.edge_label = edge_label

@lru_cache(maxsize=5)
def kernel(self, X: torch.Tensor, **grakel_kwargs) -> torch.Tensor:
# extract required data from the networkx graphs
# constructed with the Graphein utilities
# this is cheap and will be cached
X = graph_from_networkx(
X, node_labels_tag=self.node_label, edge_labels_tag=self.edge_label
)

return torch.tensor(WeisfeilerLehman(**grakel_kwargs).fit_transform(X)).float()


class NeighborhoodHashKernel(_GraphKernel):
"""
A GraKel wrapper for the neighborhood hash kernel.
This kernel requires node labels to be specified.
See https://ysig.github.io/GraKeL/0.1a8/kernels/neighborhood_hash.html
for more details.
"""

def __init__(self, node_label: str, dtype=torch.float):
super().__init__(dtype=dtype)
self.node_label = node_label

@lru_cache(maxsize=5)
def kernel(self, X: torch.Tensor, **grakel_kwargs) -> torch.Tensor:
# extract required data from the networkx graphs
# constructed with the Graphein utilities
# this is cheap and will be cached
X = graph_from_networkx(
X, node_labels_tag=self.node_label, edge_labels_tag=self.edge_label
)

return torch.tensor(NeighborhoodHash(**grakel_kwargs).fit_transform(X)).float()


class RandomWalkKernel(_GraphKernel):
"""
A GraKel wrapper for the random walk kernel.
This kernel only works on unlabelled graphs.
See RandomWalkLabeledKernel for labelled graphs.
See https://ysig.github.io/GraKeL/0.1a8/kernels/random_walk.html
for more details.
"""

def __init__(self, dtype=torch.float):
super().__init__(dtype=dtype)

@lru_cache(maxsize=5)
def kernel(self, X: torch.Tensor, **grakel_kwargs) -> torch.Tensor:
# extract required data from the networkx graphs
# constructed with the Graphein utilities
# this is cheap and will be cached
X = graph_from_networkx(
X, node_labels_tag=self.node_label, edge_labels_tag=self.edge_label
)

return torch.tensor(RandomWalk(**grakel_kwargs).fit_transform(X)).float()


class RandomWalkLabeledKernel(_GraphKernel):
"""
A GraKel wrapper for the random walk kernel.
This kernel requires node labels to be specified.
See https://ysig.github.io/GraKeL/0.1a8/kernels/random_walk.html
for more details.
"""

def __init__(self, node_label: str, dtype=torch.float):
super().__init__(dtype=dtype)
self.node_label = node_label

@lru_cache(maxsize=5)
def kernel(self, X: torch.Tensor, **grakel_kwargs) -> torch.Tensor:
# extract required data from the networkx graphs
# constructed with the Graphein utilities
# this is cheap and will be cached
X = graph_from_networkx(
X, node_labels_tag=self.node_label, edge_labels_tag=self.edge_label
)

return torch.tensor(RandomWalkLabeled(**grakel_kwargs).fit_transform(X)).float()


class ShortestPathKernel(_GraphKernel):
"""
A GraKel wrapper for the shortest path kernel.
This kernel only works on unlabelled graphs.
See ShortestPathLabeledKernel for labelled graphs.
See https://ysig.github.io/GraKeL/0.1a8/kernels/shortest_path.html
for more details.
"""

def __init__(self, dtype=torch.float):
super().__init__(dtype=dtype)

@lru_cache(maxsize=5)
def kernel(self, X: torch.Tensor, **grakel_kwargs) -> torch.Tensor:
# extract required data from the networkx graphs
# constructed with the Graphein utilities
# this is cheap and will be cached
X = graph_from_networkx(
X, node_labels_tag=self.node_label, edge_labels_tag=self.edge_label
)

return torch.tensor(
ShortestPath(**grakel_kwargs, with_labels=False).fit_transform(X)
).float()


class ShortestPathLabeledKernel(_GraphKernel):
"""
A GraKel wrapper for the shortest path kernel.
This kernel requires node labels to be specified.
See https://ysig.github.io/GraKeL/0.1a8/kernels/shortest_path.html
for more details.
"""

def __init__(self, node_label: str, dtype=torch.float):
super().__init__(dtype=dtype)
self.node_label = node_label

@lru_cache(maxsize=5)
def kernel(self, X: torch.Tensor, **grakel_kwargs) -> torch.Tensor:
# extract required data from the networkx graphs
# constructed with the Graphein utilities
# this is cheap and will be cached
X = graph_from_networkx(
X, node_labels_tag=self.node_label, edge_labels_tag=self.edge_label
)

return torch.tensor(
ShortestPath(**grakel_kwargs, with_labels=True).fit_transform(X)
).float()


class GraphletSamplingKernel(_GraphKernel):
"""
A GraKel wrapper for the graphlet sampling kernel.
This kernel only works on unlabelled graphs.
See https://ysig.github.io/GraKeL/0.1a8/kernels/graphlet_sampling.html
for more details.
"""

def __init__(self, dtype=torch.float):
super().__init__(dtype=dtype)

@lru_cache(maxsize=5)
def kernel(self, X: torch.Tensor, **grakel_kwargs) -> torch.Tensor:
# extract required data from the networkx graphs
# constructed with the Graphein utilities
# this is cheap and will be cached
X = graph_from_networkx(
X, node_labels_tag=self.node_label, edge_labels_tag=self.edge_label
)

return torch.tensor(GraphletSampling(**grakel_kwargs).fit_transform(X)).float()

0 comments on commit d133c93

Please sign in to comment.