diff --git a/gpjax/kernels/non_euclidean/graph.py b/gpjax/kernels/non_euclidean/graph.py index e3302d22..534894c2 100644 --- a/gpjax/kernels/non_euclidean/graph.py +++ b/gpjax/kernels/non_euclidean/graph.py @@ -101,7 +101,7 @@ def __init__( def __call__( # TODO not consistent with general kernel interface self, x: Int[Array, "N 1"], - y: Int[Array, "N 1"], + y: Int[Array, "M 1"], *, S, **kwargs, diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index 19cb2d30..cf4535aa 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -19,7 +19,10 @@ from flax import nnx import jax.numpy as jnp import jax.scipy as jsp -from jaxtyping import Float +from jaxtyping import ( + Float, + Int, +) from gpjax.dataset import Dataset from gpjax.distributions import GaussianDistribution @@ -108,6 +111,7 @@ def __init__( self, posterior: AbstractPosterior[P, L], inducing_inputs: tp.Union[ + Int[Array, "N D"], Float[Array, "N D"], Real, ], @@ -140,7 +144,7 @@ class VariationalGaussian(AbstractVariationalGaussian[L]): def __init__( self, posterior: AbstractPosterior[P, L], - inducing_inputs: Float[Array, "N D"], + inducing_inputs: tp.Union[Int[Array, "N D"], Float[Array, "N D"]], variational_mean: tp.Union[Float[Array, "N 1"], None] = None, variational_root_covariance: tp.Union[Float[Array, "N N"], None] = None, jitter: ScalarFloat = 1e-6, @@ -156,6 +160,12 @@ def __init__( self.variational_mean = Real(variational_mean) self.variational_root_covariance = LowerTriangular(variational_root_covariance) + def _fmt_Kzt_Ktt(self, Kzt, Ktt): + return Kzt, Ktt + + def _fmt_inducing_inputs(self): + return self.inducing_inputs.value + def prior_kl(self) -> ScalarFloat: r"""Compute the prior KL divergence. @@ -178,7 +188,7 @@ def prior_kl(self) -> ScalarFloat: # Unpack variational parameters variational_mean = self.variational_mean.value variational_sqrt = self.variational_root_covariance.value - inducing_inputs = self.inducing_inputs.value + inducing_inputs = self._fmt_inducing_inputs() # Unpack mean function and kernel mean_function = self.posterior.prior.mean_function @@ -202,7 +212,9 @@ def prior_kl(self) -> ScalarFloat: return q_inducing.kl_divergence(p_inducing) - def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: + def predict( + self, test_inputs: tp.Union[Int[Array, "N D"], Float[Array, "N D"]] + ) -> GaussianDistribution: r"""Compute the predictive distribution of the GP at the test inputs t. This is the integral $q(f(t)) = \int p(f(t)\mid u) q(u) \mathrm{d}u$, which @@ -222,7 +234,7 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: # Unpack variational parameters variational_mean = self.variational_mean.value variational_sqrt = self.variational_root_covariance.value - inducing_inputs = self.inducing_inputs.value + inducing_inputs = self._fmt_inducing_inputs() # Unpack mean function and kernel mean_function = self.posterior.prior.mean_function @@ -241,6 +253,8 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: Kzt = kernel.cross_covariance(inducing_inputs, test_points) test_mean = mean_function(test_points) + Kzt, Ktt = self._fmt_Kzt_Ktt(Kzt, Ktt) + # Lz⁻¹ Kzt Lz_inv_Kzt = solve(Lz, Kzt) @@ -259,8 +273,10 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) + jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T) ) + if hasattr(covariance, "to_dense"): covariance = covariance.to_dense() + covariance = add_jitter(covariance, self.jitter) covariance = Dense(covariance) @@ -269,6 +285,59 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: ) +class GraphVariationalGaussian(VariationalGaussian[L]): + r"""A variational Gaussian defined over graph-structured inducing inputs. + + This subclass adapts the :class:`VariationalGaussian` family to the + case where the inducing inputs are discrete graph node indices rather + than continuous spatial coordinates. + + The main differences are: + * Inducing inputs are integer node IDs. + * Kernel matrices are ensured to be dense and 2D. + """ + + def __init__( + self, + posterior: AbstractPosterior[P, L], + inducing_inputs: Int[Array, "N D"], + variational_mean: tp.Union[Float[Array, "N 1"], None] = None, + variational_root_covariance: tp.Union[Float[Array, "N N"], None] = None, + jitter: ScalarFloat = 1e-6, + ): + super().__init__( + posterior, + inducing_inputs, + variational_mean, + variational_root_covariance, + jitter, + ) + self.inducing_inputs = self.inducing_inputs.value.astype(jnp.int64) + + def _ensure_2d(self, mat): + mat = jnp.asarray(mat) + if mat.ndim == 0: + return mat[None, None] + if mat.ndim == 1: + return mat[:, None] + return mat + + def _fmt_Kzt_Ktt(self, Kzt, Ktt): + Ktt = Ktt.to_dense() if hasattr(Ktt, "to_dense") else Ktt + Kzt = Kzt.to_dense() if hasattr(Kzt, "to_dense") else Kzt + Ktt = self._ensure_2d(Ktt) + Kzt = self._ensure_2d(Kzt) + return Kzt, Ktt + + def _fmt_inducing_inputs(self): + return self.inducing_inputs + + @property + def num_inducing(self) -> int: + """The number of inducing inputs.""" + return self.inducing_inputs.shape[0] + + class WhitenedVariationalGaussian(VariationalGaussian[L]): r"""The whitened variational Gaussian family of probability distributions. @@ -811,6 +880,7 @@ def predict( "AbstractVariationalFamily", "AbstractVariationalGaussian", "VariationalGaussian", + "GraphVariationalGaussian", "WhitenedVariationalGaussian", "NaturalVariationalGaussian", "ExpectationVariationalGaussian", diff --git a/tests/test_variational_families.py b/tests/test_variational_families.py index a25cd053..9463c7b8 100644 --- a/tests/test_variational_families.py +++ b/tests/test_variational_families.py @@ -25,6 +25,8 @@ Array, Float, ) +import networkx as nx +import numpy as np import numpyro.distributions as npd from numpyro.distributions import Distribution as NumpyroDistribution import pytest @@ -35,6 +37,7 @@ AbstractVariationalFamily, CollapsedVariationalGaussian, ExpectationVariationalGaussian, + GraphVariationalGaussian, NaturalVariationalGaussian, VariationalGaussian, WhitenedVariationalGaussian, @@ -118,6 +121,7 @@ def test_variational_gaussians( ) likelihood = gpx.likelihoods.Gaussian(123) inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing).reshape(-1, 1) + test_inputs = jnp.linspace(-5.0, 5.0, n_test).reshape(-1, 1) posterior = prior * likelihood @@ -174,6 +178,61 @@ def test_variational_gaussians( assert sigma.shape == (n_test, n_test) +@pytest.mark.parametrize("n_test", [10, 20]) +@pytest.mark.parametrize("n_inducing", [10, 20]) +@pytest.mark.parametrize( + "variational_family", + [ + GraphVariationalGaussian, + ], +) +def test_graph_variational_gaussian( + n_test: int, + n_inducing: int, + variational_family: AbstractVariationalFamily, +) -> None: + G = nx.barbell_graph(100, 0) + L = nx.laplacian_matrix(G).toarray() + + kernel = gpx.kernels.GraphKernel( + laplacian=L, + lengthscale=2.3, + variance=3.2, + smoothness=6.1, + ) + meanf = gpx.mean_functions.Constant() + prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) + likelihood = gpx.likelihoods.Bernoulli(num_datapoints=G.number_of_nodes()) + + inducing_inputs = jnp.array( + np.random.randint(low=1, high=100, size=(n_inducing, 1)) + ).astype(jnp.int64) + + test_inputs = jnp.array(np.random.randint(low=0, high=1, size=(n_test, 1))).astype( + jnp.int64 + ) + + posterior = prior * likelihood + q = variational_family(posterior=posterior, inducing_inputs=inducing_inputs) + # Test KL + kl = q.prior_kl() + assert isinstance(kl, jnp.ndarray) + assert kl.shape == () + assert kl >= 0.0 + + # Test predictions + predictive_dist = q(test_inputs) + assert isinstance(predictive_dist, NumpyroDistribution) + + mu = predictive_dist.mean + sigma = predictive_dist.covariance() + + assert isinstance(mu, jnp.ndarray) + assert isinstance(sigma, jnp.ndarray) + assert mu.shape == (n_test,) + assert sigma.shape == (n_test, n_test) + + @pytest.mark.parametrize("n_test", [1, 10]) @pytest.mark.parametrize("n_datapoints", [1, 10]) @pytest.mark.parametrize("n_inducing", [1, 10, 20])