Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gpjax/kernels/non_euclidean/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
def __call__( # TODO not consistent with general kernel interface
self,
x: Int[Array, "N 1"],
y: Int[Array, "N 1"],
y: Int[Array, "M 1"],
*,
S,
**kwargs,
Expand Down
68 changes: 63 additions & 5 deletions gpjax/variational_families.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
from flax import nnx
import jax.numpy as jnp
import jax.scipy as jsp
from jaxtyping import Float
from jaxtyping import (
Float,
Int,
)

from gpjax.dataset import Dataset
from gpjax.distributions import GaussianDistribution
Expand Down Expand Up @@ -108,6 +111,7 @@ def __init__(
self,
posterior: AbstractPosterior[P, L],
inducing_inputs: tp.Union[
Int[Array, "N D"],
Float[Array, "N D"],
Real,
],
Expand Down Expand Up @@ -140,7 +144,7 @@ class VariationalGaussian(AbstractVariationalGaussian[L]):
def __init__(
self,
posterior: AbstractPosterior[P, L],
inducing_inputs: Float[Array, "N D"],
inducing_inputs: tp.Union[Int[Array, "N D"], Float[Array, "N D"]],
variational_mean: tp.Union[Float[Array, "N 1"], None] = None,
variational_root_covariance: tp.Union[Float[Array, "N N"], None] = None,
jitter: ScalarFloat = 1e-6,
Expand All @@ -156,6 +160,12 @@ def __init__(
self.variational_mean = Real(variational_mean)
self.variational_root_covariance = LowerTriangular(variational_root_covariance)

def _fmt_Kzt_Ktt(self, Kzt, Ktt):
return Kzt, Ktt

def _fmt_inducing_inputs(self):
return self.inducing_inputs.value

def prior_kl(self) -> ScalarFloat:
r"""Compute the prior KL divergence.

Expand All @@ -178,7 +188,7 @@ def prior_kl(self) -> ScalarFloat:
# Unpack variational parameters
variational_mean = self.variational_mean.value
variational_sqrt = self.variational_root_covariance.value
inducing_inputs = self.inducing_inputs.value
inducing_inputs = self._fmt_inducing_inputs()

# Unpack mean function and kernel
mean_function = self.posterior.prior.mean_function
Expand All @@ -202,7 +212,9 @@ def prior_kl(self) -> ScalarFloat:

return q_inducing.kl_divergence(p_inducing)

def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution:
def predict(
self, test_inputs: tp.Union[Int[Array, "N D"], Float[Array, "N D"]]
) -> GaussianDistribution:
r"""Compute the predictive distribution of the GP at the test inputs t.

This is the integral $q(f(t)) = \int p(f(t)\mid u) q(u) \mathrm{d}u$, which
Expand All @@ -222,7 +234,7 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution:
# Unpack variational parameters
variational_mean = self.variational_mean.value
variational_sqrt = self.variational_root_covariance.value
inducing_inputs = self.inducing_inputs.value
inducing_inputs = self._fmt_inducing_inputs()

# Unpack mean function and kernel
mean_function = self.posterior.prior.mean_function
Expand All @@ -241,6 +253,8 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution:
Kzt = kernel.cross_covariance(inducing_inputs, test_points)
test_mean = mean_function(test_points)

Kzt, Ktt = self._fmt_Kzt_Ktt(Kzt, Ktt)

# Lz⁻¹ Kzt
Lz_inv_Kzt = solve(Lz, Kzt)

Expand All @@ -259,8 +273,10 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution:
- jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt)
+ jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T)
)

if hasattr(covariance, "to_dense"):
covariance = covariance.to_dense()

covariance = add_jitter(covariance, self.jitter)
covariance = Dense(covariance)

Expand All @@ -269,6 +285,48 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution:
)


class GraphVariationalGaussian(VariationalGaussian[L]):
def __init__(
self,
posterior: AbstractPosterior[P, L],
inducing_inputs: Int[Array, "N D"],
variational_mean: tp.Union[Float[Array, "N 1"], None] = None,
variational_root_covariance: tp.Union[Float[Array, "N N"], None] = None,
jitter: ScalarFloat = 1e-6,
):
super().__init__(
posterior,
inducing_inputs,
variational_mean,
variational_root_covariance,
jitter,
)
self.inducing_inputs = self.inducing_inputs.value.astype(jnp.int64)

def _ensure_2d(self, mat):
mat = jnp.asarray(mat)
if mat.ndim == 0:
return mat[None, None]
if mat.ndim == 1:
return mat[:, None]
return mat

def _fmt_Kzt_Ktt(self, Kzt, Ktt):
Ktt = Ktt.to_dense() if hasattr(Ktt, "to_dense") else Ktt
Kzt = Kzt.to_dense() if hasattr(Kzt, "to_dense") else Kzt
Ktt = self._ensure_2d(Ktt)
Kzt = self._ensure_2d(Kzt)
return Kzt, Ktt

def _fmt_inducing_inputs(self):
return self.inducing_inputs

@property
def num_inducing(self) -> int:
"""The number of inducing inputs."""
return self.inducing_inputs.shape[0]


class WhitenedVariationalGaussian(VariationalGaussian[L]):
r"""The whitened variational Gaussian family of probability distributions.

Expand Down
59 changes: 59 additions & 0 deletions tests/test_variational_families.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
Array,
Float,
)
import networkx as nx
import numpy as np
import numpyro.distributions as npd
from numpyro.distributions import Distribution as NumpyroDistribution
import pytest
Expand All @@ -35,6 +37,7 @@
AbstractVariationalFamily,
CollapsedVariationalGaussian,
ExpectationVariationalGaussian,
GraphVariationalGaussian,
NaturalVariationalGaussian,
VariationalGaussian,
WhitenedVariationalGaussian,
Expand Down Expand Up @@ -118,6 +121,7 @@ def test_variational_gaussians(
)
likelihood = gpx.likelihoods.Gaussian(123)
inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing).reshape(-1, 1)

test_inputs = jnp.linspace(-5.0, 5.0, n_test).reshape(-1, 1)

posterior = prior * likelihood
Expand Down Expand Up @@ -174,6 +178,61 @@ def test_variational_gaussians(
assert sigma.shape == (n_test, n_test)


@pytest.mark.parametrize("n_test", [10, 20])
@pytest.mark.parametrize("n_inducing", [10, 20])
@pytest.mark.parametrize(
"variational_family",
[
GraphVariationalGaussian,
],
)
def test_graph_variational_gaussian(
n_test: int,
n_inducing: int,
variational_family: AbstractVariationalFamily,
) -> None:
G = nx.barbell_graph(100, 0)
L = nx.laplacian_matrix(G).toarray()

kernel = gpx.kernels.GraphKernel(
laplacian=L,
lengthscale=2.3,
variance=3.2,
smoothness=6.1,
)
meanf = gpx.mean_functions.Constant()
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
likelihood = gpx.likelihoods.Bernoulli(num_datapoints=G.number_of_nodes())

inducing_inputs = jnp.array(
np.random.randint(low=1, high=100, size=(n_inducing, 1))
).astype(jnp.int64)

test_inputs = jnp.array(np.random.randint(low=0, high=1, size=(n_test, 1))).astype(
jnp.int64
)

posterior = prior * likelihood
q = variational_family(posterior=posterior, inducing_inputs=inducing_inputs)
# Test KL
kl = q.prior_kl()
assert isinstance(kl, jnp.ndarray)
assert kl.shape == ()
assert kl >= 0.0

# Test predictions
predictive_dist = q(test_inputs)
assert isinstance(predictive_dist, NumpyroDistribution)

mu = predictive_dist.mean
sigma = predictive_dist.covariance()

assert isinstance(mu, jnp.ndarray)
assert isinstance(sigma, jnp.ndarray)
assert mu.shape == (n_test,)
assert sigma.shape == (n_test, n_test)


@pytest.mark.parametrize("n_test", [1, 10])
@pytest.mark.parametrize("n_datapoints", [1, 10])
@pytest.mark.parametrize("n_inducing", [1, 10, 20])
Expand Down
Loading