-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
9 changed files
with
979 additions
and
0 deletions.
There are no files selected for viewing
100 changes: 100 additions & 0 deletions
100
tests/test_kernels/test_graph_kernels/test_edge_histogram_kernel.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
""" | ||
Unit tests for the edge histogram graph kernel. | ||
""" | ||
|
||
import pytest | ||
import torch | ||
import gpytorch | ||
from gauche import SIGP, NonTensorialInputs | ||
from gauche.kernels.graph_kernels import EdgeHistogramKernel | ||
from gauche.dataloader import MolPropLoader | ||
import graphein.molecule as gm | ||
|
||
graphein_config = gm.MoleculeGraphConfig( | ||
node_metadata_functions=[gm.total_degree], | ||
edge_metadata_functions=[gm.add_bond_type], | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"edge_label", | ||
["bond_type", "bond_type_and_total_degree", None, 1], | ||
) | ||
def test_edge_histogram_kernel_edge_label(edge_label): | ||
""" | ||
Test if edge histogram kernel works as intended | ||
when using edge labels. | ||
""" | ||
|
||
class GraphGP(SIGP): | ||
def __init__(self, train_x, train_y, likelihood): | ||
super().__init__(train_x, train_y, likelihood) | ||
self.mean = gpytorch.means.ConstantMean() | ||
self.covariance = EdgeHistogramKernel(edge_label=edge_label) | ||
|
||
def forward(self, x): | ||
mean = self.mean(torch.zeros(len(x), 1)).float() | ||
covariance = self.covariance(x) | ||
|
||
# for numerical stability | ||
jitter = max(covariance.diag().mean().detach().item() * 1e-4, 1e-4) | ||
covariance += torch.eye(len(x)) * jitter | ||
return gpytorch.distributions.MultivariateNormal(mean, covariance) | ||
|
||
loader = MolPropLoader() | ||
loader.load_benchmark("Photoswitch") | ||
loader.featurize("molecular_graphs", graphein_config=graphein_config) | ||
|
||
X = NonTensorialInputs(loader.features) | ||
likelihood = gpytorch.likelihoods.GaussianLikelihood() | ||
model = GraphGP(X, loader.labels, likelihood) | ||
|
||
model.train() | ||
likelihood.train() | ||
|
||
if edge_label == "bond_type": | ||
output = model(X) | ||
else: | ||
with pytest.raises(Exception): | ||
output = model(X) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"node_label", | ||
["element", "total_degree", "XYZ", None, 1], | ||
) | ||
def test_edge_histogram_kernel_node_label(node_label): | ||
""" | ||
Test if edge histogram kernel fails consistently | ||
when also using node labels. | ||
""" | ||
|
||
class GraphGP(SIGP): | ||
def __init__(self, train_x, train_y, likelihood): | ||
super().__init__(train_x, train_y, likelihood) | ||
self.mean = gpytorch.means.ConstantMean() | ||
self.covariance = EdgeHistogramKernel( | ||
edge_label="bond_type", node_label=node_label | ||
) | ||
|
||
def forward(self, x): | ||
mean = self.mean(torch.zeros(len(x), 1)).float() | ||
covariance = self.covariance(x) | ||
|
||
# for numerical stability | ||
jitter = max(covariance.diag().mean().detach().item() * 1e-4, 1e-4) | ||
covariance += torch.eye(len(x)) * jitter | ||
return gpytorch.distributions.MultivariateNormal(mean, covariance) | ||
|
||
loader = MolPropLoader() | ||
loader.load_benchmark("Photoswitch") | ||
loader.featurize("molecular_graphs", graphein_config=graphein_config) | ||
|
||
with pytest.raises(Exception): | ||
X = NonTensorialInputs(loader.features) | ||
likelihood = gpytorch.likelihoods.GaussianLikelihood() | ||
model = GraphGP(X, loader.labels, likelihood) | ||
|
||
model.train() | ||
likelihood.train() | ||
output = model(X) |
125 changes: 125 additions & 0 deletions
125
tests/test_kernels/test_graph_kernels/test_graphlet_sampling_kernel.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
""" | ||
Unit tests for the graphlet sampling graph kernel. | ||
""" | ||
|
||
import pytest | ||
import torch | ||
import gpytorch | ||
from gauche import SIGP, NonTensorialInputs | ||
from gauche.kernels.graph_kernels import GraphletSamplingKernel | ||
from gauche.dataloader import MolPropLoader | ||
import graphein.molecule as gm | ||
|
||
graphein_config = gm.MoleculeGraphConfig( | ||
node_metadata_functions=[gm.total_degree], | ||
edge_metadata_functions=[gm.add_bond_type], | ||
) | ||
|
||
|
||
def test_graphlet_sampling_kernel(): | ||
""" | ||
Test if graphlet sampling kernel works as intended | ||
when not providing any labels. | ||
""" | ||
|
||
class GraphGP(SIGP): | ||
def __init__(self, train_x, train_y, likelihood): | ||
super().__init__(train_x, train_y, likelihood) | ||
self.mean = gpytorch.means.ConstantMean() | ||
self.covariance = GraphletSamplingKernel() | ||
|
||
def forward(self, x): | ||
mean = self.mean(torch.zeros(len(x), 1)).float() | ||
covariance = self.covariance(x) | ||
|
||
# for numerical stability | ||
jitter = max(covariance.diag().mean().detach().item() * 1e-4, 1e-4) | ||
covariance += torch.eye(len(x)) * jitter | ||
return gpytorch.distributions.MultivariateNormal(mean, covariance) | ||
|
||
loader = MolPropLoader() | ||
loader.load_benchmark("Photoswitch") | ||
loader.featurize("molecular_graphs", graphein_config=graphein_config) | ||
|
||
X = NonTensorialInputs(loader.features) | ||
likelihood = gpytorch.likelihoods.GaussianLikelihood() | ||
model = GraphGP(X, loader.labels, likelihood) | ||
model.train() | ||
likelihood.train() | ||
output = model(X) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"node_label", | ||
["element", "total_degree", "XYZ", None, 1], | ||
) | ||
def test_graphlet_sampling_kernel_node_label(node_label): | ||
""" | ||
Test if graphlet sampling kernel fails consistently | ||
when using node labels. | ||
""" | ||
|
||
class GraphGP(SIGP): | ||
def __init__(self, train_x, train_y, likelihood): | ||
super().__init__(train_x, train_y, likelihood) | ||
self.mean = gpytorch.means.ConstantMean() | ||
self.covariance = GraphletSamplingKernel(node_label=node_label) | ||
|
||
def forward(self, x): | ||
mean = self.mean(torch.zeros(len(x), 1)).float() | ||
covariance = self.covariance(x) | ||
|
||
# for numerical stability | ||
jitter = max(covariance.diag().mean().detach().item() * 1e-4, 1e-4) | ||
covariance += torch.eye(len(x)) * jitter | ||
return gpytorch.distributions.MultivariateNormal(mean, covariance) | ||
|
||
loader = MolPropLoader() | ||
loader.load_benchmark("Photoswitch") | ||
loader.featurize("molecular_graphs", graphein_config=graphein_config) | ||
|
||
with pytest.raises(Exception): | ||
X = NonTensorialInputs(loader.features) | ||
likelihood = gpytorch.likelihoods.GaussianLikelihood() | ||
model = GraphGP(X, loader.labels, likelihood) | ||
model.train() | ||
likelihood.train() | ||
output = model(X) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"edge_label", | ||
["bond_type", "bond_type_and_total_degree", None, 1], | ||
) | ||
def test_graphlet_sampling_kernel_edge_label(edge_label): | ||
""" | ||
Test if graphlet sampling kernel fails consistently | ||
when using edge labels. | ||
""" | ||
|
||
class GraphGP(SIGP): | ||
def __init__(self, train_x, train_y, likelihood): | ||
super().__init__(train_x, train_y, likelihood) | ||
self.mean = gpytorch.means.ConstantMean() | ||
self.covariance = GraphletSamplingKernel(edge_label=edge_label) | ||
|
||
def forward(self, x): | ||
mean = self.mean(torch.zeros(len(x), 1)).float() | ||
covariance = self.covariance(x) | ||
|
||
# for numerical stability | ||
jitter = max(covariance.diag().mean().detach().item() * 1e-4, 1e-4) | ||
covariance += torch.eye(len(x)) * jitter | ||
return gpytorch.distributions.MultivariateNormal(mean, covariance) | ||
|
||
loader = MolPropLoader() | ||
loader.load_benchmark("Photoswitch") | ||
loader.featurize("molecular_graphs", graphein_config=graphein_config) | ||
|
||
with pytest.raises(Exception): | ||
X = NonTensorialInputs(loader.features) | ||
likelihood = gpytorch.likelihoods.GaussianLikelihood() | ||
model = GraphGP(X, loader.labels, likelihood) | ||
model.train() | ||
likelihood.train() | ||
output = model(X) |
100 changes: 100 additions & 0 deletions
100
tests/test_kernels/test_graph_kernels/test_neighborhood_hash_kernel.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
""" | ||
Unit tests for the neighborhood hash graph kernel. | ||
""" | ||
|
||
import pytest | ||
import torch | ||
import gpytorch | ||
from gauche import SIGP, NonTensorialInputs | ||
from gauche.kernels.graph_kernels import NeighborhoodHashKernel | ||
from gauche.dataloader import MolPropLoader | ||
import graphein.molecule as gm | ||
|
||
graphein_config = gm.MoleculeGraphConfig( | ||
node_metadata_functions=[gm.total_degree], | ||
edge_metadata_functions=[gm.add_bond_type], | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"node_label", | ||
["element", "total_degree", "XYZ", None, 1], | ||
) | ||
def test_neighborhood_hash_kernel_node_label(node_label): | ||
""" | ||
Test if neighborhood hash kernel works as intended | ||
when using node labels. | ||
""" | ||
|
||
class GraphGP(SIGP): | ||
def __init__(self, train_x, train_y, likelihood): | ||
super().__init__(train_x, train_y, likelihood) | ||
self.mean = gpytorch.means.ConstantMean() | ||
self.covariance = NeighborhoodHashKernel(node_label=node_label) | ||
|
||
def forward(self, x): | ||
mean = self.mean(torch.zeros(len(x), 1)).float() | ||
covariance = self.covariance(x) | ||
|
||
# for numerical stability | ||
jitter = max(covariance.diag().mean().detach().item() * 1e-4, 1e-4) | ||
covariance += torch.eye(len(x)) * jitter | ||
return gpytorch.distributions.MultivariateNormal(mean, covariance) | ||
|
||
loader = MolPropLoader() | ||
loader.load_benchmark("Photoswitch") | ||
loader.featurize("molecular_graphs", graphein_config=graphein_config) | ||
|
||
X = NonTensorialInputs(loader.features) | ||
likelihood = gpytorch.likelihoods.GaussianLikelihood() | ||
model = GraphGP(X, loader.labels, likelihood) | ||
|
||
model.train() | ||
likelihood.train() | ||
|
||
if node_label in ["element", "total_degree"]: | ||
output = model(X) | ||
else: | ||
with pytest.raises(Exception): | ||
output = model(X) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"edge_label", | ||
["bond_type", "bond_type_and_total_degree", None, 1], | ||
) | ||
def test_neighborhood_hash_kernel_edge_label(edge_label): | ||
""" | ||
Test if neighborhood hash kernel fails consistently | ||
when also using edge labels. | ||
""" | ||
|
||
class GraphGP(SIGP): | ||
def __init__(self, train_x, train_y, likelihood): | ||
super().__init__(train_x, train_y, likelihood) | ||
self.mean = gpytorch.means.ConstantMean() | ||
self.covariance = NeighborhoodHashKernel( | ||
node_label="element", edge_label=edge_label | ||
) | ||
|
||
def forward(self, x): | ||
mean = self.mean(torch.zeros(len(x), 1)).float() | ||
covariance = self.covariance(x) | ||
|
||
# for numerical stability | ||
jitter = max(covariance.diag().mean().detach().item() * 1e-4, 1e-4) | ||
covariance += torch.eye(len(x)) * jitter | ||
return gpytorch.distributions.MultivariateNormal(mean, covariance) | ||
|
||
loader = MolPropLoader() | ||
loader.load_benchmark("Photoswitch") | ||
loader.featurize("molecular_graphs", graphein_config=graphein_config) | ||
|
||
with pytest.raises(Exception): | ||
X = NonTensorialInputs(loader.features) | ||
likelihood = gpytorch.likelihoods.GaussianLikelihood() | ||
model = GraphGP(X, loader.labels, likelihood) | ||
|
||
model.train() | ||
likelihood.train() | ||
output = model(X) |
Oops, something went wrong.