From 61bd98b5fbd34d9ef0dd449a35f7ccf41cae06c7 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Tue, 17 Jan 2023 21:03:06 +0000 Subject: [PATCH 1/8] Restructure modules --- README.md | 4 + jaxkern/__init__.py | 29 +- jaxkern/_version.py | 157 ++++-- jaxkern/base.py | 220 ++++++++ jaxkern/computations.py | 282 ++++++++++ jaxkern/kernels.py | 1058 ----------------------------------- jaxkern/non_euclidean.py | 66 +++ jaxkern/nonstationary.py | 208 +++++++ jaxkern/stationary.py | 304 ++++++++++ jaxkern/utils.py | 39 ++ tests/test_base.py | 201 +++++++ tests/test_kernels.py | 596 -------------------- tests/test_non_euclidean.py | 51 ++ tests/test_nonstationary.py | 231 ++++++++ tests/test_stationary.py | 272 +++++++++ tests/test_utils.py | 26 + 16 files changed, 2021 insertions(+), 1723 deletions(-) create mode 100644 jaxkern/base.py create mode 100644 jaxkern/computations.py delete mode 100644 jaxkern/kernels.py create mode 100644 jaxkern/non_euclidean.py create mode 100644 jaxkern/nonstationary.py create mode 100644 jaxkern/stationary.py create mode 100644 jaxkern/utils.py create mode 100644 tests/test_base.py delete mode 100644 tests/test_kernels.py create mode 100644 tests/test_non_euclidean.py create mode 100644 tests/test_nonstationary.py create mode 100644 tests/test_stationary.py create mode 100644 tests/test_utils.py diff --git a/README.md b/README.md index bee39ef..5c01d7d 100644 --- a/README.md +++ b/README.md @@ -3,3 +3,7 @@ [![codecov](https://codecov.io/gh/JaxGaussianProcesses/JaxKern/branch/main/graph/badge.svg?token=8WD7YYMPFS)](https://codecov.io/gh/JaxGaussianProcesses/JaxKern) [![CircleCI](https://dl.circleci.com/status-badge/img/gh/JaxGaussianProcesses/JaxKern/tree/main.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/gh/JaxGaussianProcesses/JaxKern/tree/main) +[![Documentation Status](https://readthedocs.org/projects/gpjax/badge/?version=latest)](https://gpjax.readthedocs.io/en/latest/?badge=latest) +[![PyPI version](https://badge.fury.io/py/jaxkern.svg)](https://badge.fury.io/py/jaxkern) +[![Downloads](https://pepy.tech/badge/jaxkern)](https://pepy.tech/project/jaxkern) +[![Slack Invite](https://img.shields.io/badge/Slack_Invite--blue?style=social&logo=slack)](https://join.slack.com/t/gpjax/shared_invite/zt-1da57pmjn-rdBCVg9kApirEEn2E5Q2Zw) \ No newline at end of file diff --git a/jaxkern/__init__.py b/jaxkern/__init__.py index 8672f5e..b1decac 100644 --- a/jaxkern/__init__.py +++ b/jaxkern/__init__.py @@ -1,17 +1,25 @@ """JaxKern.""" -from .kernels import ( +from .base import ProductKernel, SumKernel +from .computations import ( + ConstantDiagonalKernelComputation, + DenseKernelComputation, + DiagonalKernelComputation, + EigenKernelComputation, +) +from .nonstationary import ( + Linear, + Periodic, + Polynomial, +) +from .stationary import ( RBF, - GraphKernel, Matern12, Matern32, Matern52, - Polynomial, - ProductKernel, - SumKernel, - DenseKernelComputation, - DiagonalKernelComputation, - ConstantDiagonalKernelComputation, + RationalQuadratic, + PoweredExponential, ) +from .non_euclidean import GraphKernel __all__ = [ "RBF", @@ -19,12 +27,17 @@ "Matern12", "Matern32", "Matern52", + "Linear", "Polynomial", "ProductKernel", "SumKernel", "DenseKernelComputation", "DiagonalKernelComputation", "ConstantDiagonalKernelComputation", + "EigenKernelComputation", + "PoweredExponential", + "Periodic", + "RationalQuadratic", ] from . import _version diff --git a/jaxkern/_version.py b/jaxkern/_version.py index d4df35e..63c904a 100644 --- a/jaxkern/_version.py +++ b/jaxkern/_version.py @@ -1,4 +1,3 @@ - # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build @@ -12,12 +11,12 @@ """Git implementation of _version.py.""" import errno +import functools import os import re import subprocess import sys from typing import Callable, Dict -import functools def get_keywords(): @@ -61,17 +60,18 @@ class NotThisMethod(Exception): def register_vcs_handler(vcs, method): # decorator """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f): """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f + return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): """Call the given command(s).""" assert isinstance(commands, list) process = None @@ -87,10 +87,14 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, try: dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen([command] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None), **popen_kwargs) + process = subprocess.Popen( + [command] + args, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr else None), + **popen_kwargs, + ) break except OSError: e = sys.exc_info()[1] @@ -125,15 +129,21 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @@ -192,7 +202,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} + tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -201,7 +211,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r'\d', r)} + tags = {r for r in refs if re.search(r"\d", r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -209,24 +219,31 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] + r = ref[len(tag_prefix) :] # Filter out refs that exactly match prefix or that don't start # with a number once the prefix is stripped (mostly a concern # when prefix is '') - if not re.match(r'\d', r): + if not re.match(r"\d", r): continue if verbose: print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } @register_vcs_handler("git", "pieces_from_vcs") @@ -248,8 +265,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): env.pop("GIT_DIR", None) runner = functools.partial(runner, env=env) - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=not verbose) + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -257,10 +273,19 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner(GITS, [ - "describe", "--tags", "--dirty", "--always", "--long", - "--match", f"{tag_prefix}[[:digit:]]*" - ], cwd=root) + describe_out, rc = runner( + GITS, + [ + "describe", + "--tags", + "--dirty", + "--always", + "--long", + "--match", + f"{tag_prefix}[[:digit:]]*", + ], + cwd=root, + ) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") @@ -275,8 +300,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], - cwd=root) + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root) # --abbrev-ref was added in git-1.6.3 if rc != 0 or branch_name is None: raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") @@ -316,17 +340,16 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] + git_describe = git_describe[: git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) + pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces # tag @@ -335,10 +358,12 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( + full_tag, + tag_prefix, + ) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] + pieces["closest-tag"] = full_tag[len(tag_prefix) :] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -387,8 +412,7 @@ def render_pep440(pieces): rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -417,8 +441,7 @@ def render_pep440_branch(pieces): rendered = "0" if pieces["branch"] != "master": rendered += ".dev0" - rendered += "+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -579,11 +602,13 @@ def render_git_describe_long(pieces): def render(pieces, style): """Render the given version pieces into the requested style.""" if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } if not style or style == "default": style = "pep440" # the default @@ -607,9 +632,13 @@ def render(pieces, style): else: raise ValueError("unknown style '%s'" % style) - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } def get_versions(): @@ -623,8 +652,7 @@ def get_versions(): verbose = cfg.verbose try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) except NotThisMethod: pass @@ -633,13 +661,16 @@ def get_versions(): # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for _ in cfg.versionfile_source.split('/'): + for _ in cfg.versionfile_source.split("/"): root = os.path.dirname(root) except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None, + } try: pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) @@ -653,6 +684,10 @@ def get_versions(): except NotThisMethod: pass - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } diff --git a/jaxkern/base.py b/jaxkern/base.py new file mode 100644 index 0000000..79a215d --- /dev/null +++ b/jaxkern/base.py @@ -0,0 +1,220 @@ +import abc +from typing import Callable, Dict, List, Optional, Sequence + +import deprecation +import jax.numpy as jnp +from jax.random import KeyArray +from jaxtyping import Array, Float +from jaxutils import PyTree + +from .computations import AbstractKernelComputation, DenseKernelComputation + + +########################################## +# Abtract classes +########################################## +class AbstractKernel(PyTree): + """ + Base kernel class""" + + def __init__( + self, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "AbstractKernel", + ) -> None: + self.compute_engine = compute_engine + self.active_dims = active_dims + self.stationary = stationary + self.spectral = spectral + self.name = name + self.ndims = 1 if not self.active_dims else len(self.active_dims) + compute_engine = self.compute_engine(kernel_fn=self.__call__) + self.gram = compute_engine.gram + self.cross_covariance = compute_engine.cross_covariance + + @abc.abstractmethod + def __call__( + self, + params: Dict, + x: Float[Array, "1 D"], + y: Float[Array, "1 D"], + ) -> Float[Array, "1"]: + """Evaluate the kernel on a pair of inputs. + + Args: + params (Dict): Parameter set for which the kernel should be evaluated on. + x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. + y (Float[Array, "1 D"]): The right hand argument of the kernel function's call + + Returns: + Float[Array, "1"]: The value of :math:`k(x, y)`. + """ + raise NotImplementedError + + def slice_input(self, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: + """Select the relevant columns of the supplied matrix to be used within the kernel's evaluation. + + Args: + x (Float[Array, "N D"]): The matrix or vector that is to be sliced. + Returns: + Float[Array, "N Q"]: A sliced form of the input matrix. + """ + return x[..., self.active_dims] + + def __add__(self, other: "AbstractKernel") -> "AbstractKernel": + """Add two kernels together. + Args: + other (AbstractKernel): The kernel to be added to the current kernel. + + Returns: + AbstractKernel: A new kernel that is the sum of the two kernels. + """ + return SumKernel(kernel_set=[self, other]) + + def __mul__(self, other: "AbstractKernel") -> "AbstractKernel": + """Multiply two kernels together. + + Args: + other (AbstractKernel): The kernel to be multiplied with the current kernel. + + Returns: + AbstractKernel: A new kernel that is the product of the two kernels. + """ + return ProductKernel(kernel_set=[self, other]) + + @property + def ard(self): + """Boolean property as to whether the kernel is isotropic or of + automatic relevance determination form. + + Returns: + bool: True if the kernel is an ARD kernel. + """ + return True if self.ndims > 1 else False + + @abc.abstractmethod + def init_params(self, key: KeyArray) -> Dict: + """A template dictionary of the kernel's parameter set. + + Args: + key (KeyArray): A PRNG key to be used for initialising + the kernel's parameters. + + Returns: + Dict: A dictionary of the kernel's parameters. + """ + raise NotImplementedError + + @deprecation.deprecated( + deprecated_in="0.0.3", + removed_in="0.1.0", + ) + def _initialise_params(self, key: KeyArray) -> Dict: + """A template dictionary of the kernel's parameter set. + + Args: + key (KeyArray): A PRNG key to be used for initialising + the kernel's parameters. + + Returns: + Dict: A dictionary of the kernel's parameters. + """ + raise NotImplementedError + + +class CombinationKernel(AbstractKernel): + """A base class for products or sums of kernels.""" + + def __init__( + self, + kernel_set: List[AbstractKernel], + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "AbstractKernel", + ) -> None: + super().__init__(compute_engine, active_dims, stationary, spectral, name) + self.kernel_set = kernel_set + name: Optional[str] = "Combination kernel" + self.combination_fn: Optional[Callable] = None + + if not all(isinstance(k, AbstractKernel) for k in self.kernel_set): + raise TypeError("can only combine Kernel instances") # pragma: no cover + + self._set_kernels(self.kernel_set) + + def _set_kernels(self, kernels: Sequence[AbstractKernel]) -> None: + """Combine multiple kernels. Based on GPFlow's Combination kernel.""" + # add kernels to a list, flattening out instances of this class therein + kernels_list: List[AbstractKernel] = [] + for k in kernels: + if isinstance(k, self.__class__): + kernels_list.extend(k.kernel_set) + else: + kernels_list.append(k) + + self.kernel_set = kernels_list + + def init_params(self, key: KeyArray) -> Dict: + """A template dictionary of the kernel's parameter set.""" + return [kernel.init_params(key) for kernel in self.kernel_set] + + def __call__( + self, + params: Dict, + x: Float[Array, "1 D"], + y: Float[Array, "1 D"], + ) -> Float[Array, "1"]: + """Evaluate combination kernel on a pair of inputs. + + Args: + params (Dict): Parameter set for which the kernel should be evaluated on. + x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. + y (Float[Array, "1 D"]): The right hand argument of the kernel function's call + + Returns: + Float[Array, "1"]: The value of :math:`k(x, y)`. + """ + return self.combination_fn( + jnp.stack([k(p, x, y) for k, p in zip(self.kernel_set, params)]) + ) + + +class SumKernel(CombinationKernel): + """A kernel that is the sum of a set of kernels.""" + + def __init__( + self, + kernel_set: List[AbstractKernel], + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Sum kernel", + ) -> None: + super().__init__( + kernel_set, compute_engine, active_dims, stationary, spectral, name + ) + self.combination_fn: Optional[Callable] = jnp.sum + + +class ProductKernel(CombinationKernel): + """A kernel that is the product of a set of kernels.""" + + def __init__( + self, + kernel_set: List[AbstractKernel], + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Product kernel", + ) -> None: + super().__init__( + kernel_set, compute_engine, active_dims, stationary, spectral, name + ) + self.combination_fn: Optional[Callable] = jnp.prod diff --git a/jaxkern/computations.py b/jaxkern/computations.py new file mode 100644 index 0000000..cfda975 --- /dev/null +++ b/jaxkern/computations.py @@ -0,0 +1,282 @@ +import abc +from typing import Callable, Dict + +import jax.numpy as jnp +from jax import vmap +from jaxlinop import ( + ConstantDiagonalLinearOperator, + DenseLinearOperator, + DiagonalLinearOperator, + LinearOperator, +) +from jaxtyping import Array, Float +from jaxutils import PyTree + + +class AbstractKernelComputation(PyTree): + """Abstract class for kernel computations.""" + + def __init__( + self, + kernel_fn: Callable[ + [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array + ] = None, + ) -> None: + self._kernel_fn = kernel_fn + + @property + def kernel_fn( + self, + ) -> Callable[[Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array]: + return self._kernel_fn + + @kernel_fn.setter + def kernel_fn( + self, + kernel_fn: Callable[[Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array], + ) -> None: + self._kernel_fn = kernel_fn + + def gram( + self, + params: Dict, + inputs: Float[Array, "N D"], + ) -> LinearOperator: + + """Compute Gram covariance operator of the kernel function. + + Args: + kernel (AbstractKernel): The kernel function to be evaluated. + params (Dict): The parameters of the kernel function. + inputs (Float[Array, "N N"]): The inputs to the kernel function. + + Returns: + LinearOperator: Gram covariance operator of the kernel function. + """ + + matrix = self.cross_covariance(params, inputs, inputs) + + return DenseLinearOperator(matrix=matrix) + + @abc.abstractmethod + def cross_covariance( + self, + params: Dict, + x: Float[Array, "N D"], + y: Float[Array, "M D"], + ) -> Float[Array, "N M"]: + """For a given kernel, compute the NxM gram matrix on an a pair + of input matrices with shape NxD and MxD. + + Args: + kernel (AbstractKernel): The kernel for which the cross-covariance + matrix should be computed for. + params (Dict): The kernel's parameter set. + x (Float[Array,"N D"]): The first input matrix. + y (Float[Array,"M D"]): The second input matrix. + + Returns: + Float[Array, "N M"]: The computed square Gram matrix. + """ + raise NotImplementedError + + def diagonal( + self, + params: Dict, + inputs: Float[Array, "N D"], + ) -> DiagonalLinearOperator: + """For a given kernel, compute the elementwise diagonal of the + NxN gram matrix on an input matrix of shape NxD. + + Args: + kernel (AbstractKernel): The kernel for which the variance + vector should be computed for. + params (Dict): The kernel's parameter set. + inputs (Float[Array, "N D"]): The input matrix. + + Returns: + LinearOperator: The computed diagonal variance entries. + """ + diag = vmap(lambda x: self._kernel_fn(params, x, x))(inputs) + + return DiagonalLinearOperator(diag=diag) + + +class DenseKernelComputation(AbstractKernelComputation): + """Dense kernel computation class. Operations with the kernel assume + a dense gram matrix structure. + """ + + def __init__( + self, + kernel_fn: Callable[ + [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array + ] = None, + ) -> None: + super().__init__(kernel_fn) + + def cross_covariance( + self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] + ) -> Float[Array, "N M"]: + """For a given kernel, compute the NxM covariance matrix on a pair of input + matrices of shape NxD and MxD. + + Args: + kernel (AbstractKernel): The kernel for which the Gram + matrix should be computed for. + params (Dict): The kernel's parameter set. + x (Float[Array,"N D"]): The input matrix. + y (Float[Array,"M D"]): The input matrix. + + Returns: + CovarianceOperator: The computed square Gram matrix. + """ + cross_cov = vmap(lambda x: vmap(lambda y: self.kernel_fn(params, x, y))(y))(x) + return cross_cov + + +class DiagonalKernelComputation(AbstractKernelComputation): + def __init__( + self, + kernel_fn: Callable[ + [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array + ] = None, + ) -> None: + super().__init__(kernel_fn) + + def gram( + self, + params: Dict, + inputs: Float[Array, "N D"], + ) -> DiagonalLinearOperator: + """For a kernel with diagonal structure, compute the NxN gram matrix on + an input matrix of shape NxD. + + Args: + kernel (AbstractKernel): The kernel for which the Gram matrix + should be computed for. + params (Dict): The kernel's parameter set. + inputs (Float[Array, "N D"]): The input matrix. + + Returns: + CovarianceOperator: The computed square Gram matrix. + """ + + diag = vmap(lambda x: self.kernel_fn(params, x, x))(inputs) + + return DiagonalLinearOperator(diag=diag) + + def cross_covariance( + self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] + ) -> Float[Array, "N M"]: + raise ValueError("Cross covariance not defined for diagonal kernels.") + + +class ConstantDiagonalKernelComputation(AbstractKernelComputation): + def __init__( + self, + kernel_fn: Callable[ + [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array + ] = None, + ) -> None: + super().__init__(kernel_fn) + + def gram( + self, + params: Dict, + inputs: Float[Array, "N D"], + ) -> ConstantDiagonalLinearOperator: + """For a kernel with diagonal structure, compute the NxN gram matrix on + an input matrix of shape NxD. + + Args: + kernel (AbstractKernel): The kernel for which the Gram matrix + should be computed for. + params (Dict): The kernel's parameter set. + inputs (Float[Array, "N D"]): The input matrix. + + Returns: + CovarianceOperator: The computed square Gram matrix. + """ + + value = self.kernel_fn(params, inputs[0], inputs[0]) + + return ConstantDiagonalLinearOperator(value=value, size=inputs.shape[0]) + + def diagonal( + self, + params: Dict, + inputs: Float[Array, "N D"], + ) -> DiagonalLinearOperator: + """For a given kernel, compute the elementwise diagonal of the + NxN gram matrix on an input matrix of shape NxD. + + Args: + kernel (AbstractKernel): The kernel for which the variance + vector should be computed for. + params (Dict): The kernel's parameter set. + inputs (Float[Array, "N D"]): The input matrix. + + Returns: + LinearOperator: The computed diagonal variance entries. + """ + + diag = vmap(lambda x: self.kernel_fn(params, x, x))(inputs) + + return DiagonalLinearOperator(diag=diag) + + def cross_covariance( + self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] + ) -> Float[Array, "N M"]: + raise ValueError("Cross covariance not defined for constant diagonal kernels.") + + +class EigenKernelComputation(AbstractKernelComputation): + def __init__( + self, + kernel_fn: Callable[ + [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array + ] = None, + ) -> None: + super().__init__(kernel_fn) + self._eigenvalues = None + self._eigenvectors = None + self._num_verticies = None + + # Define an eigenvalue setter and getter property + @property + def eigensystem(self) -> Float[Array, "N"]: + return self._eigenvalues, self._eigenvectors, self._num_verticies + + @eigensystem.setter + def eigensystem( + self, eigenvalues: Float[Array, "N"], eigenvectors: Float[Array, "N N"] + ) -> None: + self._eigenvalues = eigenvalues + self._eigenvectors = eigenvectors + + @property + def num_vertex(self) -> int: + return self._num_verticies + + @num_vertex.setter + def num_vertex(self, num_vertex: int) -> None: + self._num_verticies = num_vertex + + def _compute_S(self, params): + evals, evecs = self.eigensystem + S = jnp.power( + evals + + 2 * params["smoothness"] / params["lengthscale"] / params["lengthscale"], + -params["smoothness"], + ) + S = jnp.multiply(S, self.num_vertex / jnp.sum(S)) + S = jnp.multiply(S, params["variance"]) + return S + + def cross_covariance( + self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] + ) -> Float[Array, "N M"]: + S = self._compute_S(params=params) + matrix = self.kernel_fn(params, x, y, S=S) + return matrix diff --git a/jaxkern/kernels.py b/jaxkern/kernels.py deleted file mode 100644 index 4db5913..0000000 --- a/jaxkern/kernels.py +++ /dev/null @@ -1,1058 +0,0 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import abc -from typing import Callable, Dict, List, Optional, Sequence - -import jax -import jax.numpy as jnp -from jax import vmap -from jax.random import KeyArray -from jaxlinop import ( - ConstantDiagonalLinearOperator, - DenseLinearOperator, - DiagonalLinearOperator, - LinearOperator, -) -from jaxtyping import Array, Float -from jaxutils import PyTree - -import deprecation - - -class AbstractKernelComputation(PyTree): - """Abstract class for kernel computations.""" - - def __init__( - self, - kernel_fn: Callable[ - [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array - ] = None, - ) -> None: - self._kernel_fn = kernel_fn - - @property - def kernel_fn( - self, - ) -> Callable[[Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array]: - return self._kernel_fn - - @kernel_fn.setter - def kernel_fn( - self, - kernel_fn: Callable[[Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array], - ) -> None: - self._kernel_fn = kernel_fn - - def gram( - self, - params: Dict, - inputs: Float[Array, "N D"], - ) -> LinearOperator: - - """Compute Gram covariance operator of the kernel function. - - Args: - kernel (AbstractKernel): The kernel function to be evaluated. - params (Dict): The parameters of the kernel function. - inputs (Float[Array, "N N"]): The inputs to the kernel function. - - Returns: - LinearOperator: Gram covariance operator of the kernel function. - """ - - matrix = self.cross_covariance(params, inputs, inputs) - - return DenseLinearOperator(matrix=matrix) - - @abc.abstractmethod - def cross_covariance( - self, - params: Dict, - x: Float[Array, "N D"], - y: Float[Array, "M D"], - ) -> Float[Array, "N M"]: - """For a given kernel, compute the NxM gram matrix on an a pair - of input matrices with shape NxD and MxD. - - Args: - kernel (AbstractKernel): The kernel for which the cross-covariance - matrix should be computed for. - params (Dict): The kernel's parameter set. - x (Float[Array,"N D"]): The first input matrix. - y (Float[Array,"M D"]): The second input matrix. - - Returns: - Float[Array, "N M"]: The computed square Gram matrix. - """ - raise NotImplementedError - - def diagonal( - self, - params: Dict, - inputs: Float[Array, "N D"], - ) -> DiagonalLinearOperator: - """For a given kernel, compute the elementwise diagonal of the - NxN gram matrix on an input matrix of shape NxD. - - Args: - kernel (AbstractKernel): The kernel for which the variance - vector should be computed for. - params (Dict): The kernel's parameter set. - inputs (Float[Array, "N D"]): The input matrix. - - Returns: - LinearOperator: The computed diagonal variance entries. - """ - diag = vmap(lambda x: self._kernel_fn(params, x, x))(inputs) - - return DiagonalLinearOperator(diag=diag) - - -class DenseKernelComputation(AbstractKernelComputation): - """Dense kernel computation class. Operations with the kernel assume - a dense gram matrix structure. - """ - - def __init__( - self, - kernel_fn: Callable[ - [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array - ] = None, - ) -> None: - super().__init__(kernel_fn) - - def cross_covariance( - self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] - ) -> Float[Array, "N M"]: - """For a given kernel, compute the NxM covariance matrix on a pair of input - matrices of shape NxD and MxD. - - Args: - kernel (AbstractKernel): The kernel for which the Gram - matrix should be computed for. - params (Dict): The kernel's parameter set. - x (Float[Array,"N D"]): The input matrix. - y (Float[Array,"M D"]): The input matrix. - - Returns: - CovarianceOperator: The computed square Gram matrix. - """ - cross_cov = vmap(lambda x: vmap(lambda y: self.kernel_fn(params, x, y))(y))(x) - return cross_cov - - -class DiagonalKernelComputation(AbstractKernelComputation): - def __init__( - self, - kernel_fn: Callable[ - [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array - ] = None, - ) -> None: - super().__init__(kernel_fn) - - def gram( - self, - params: Dict, - inputs: Float[Array, "N D"], - ) -> DiagonalLinearOperator: - """For a kernel with diagonal structure, compute the NxN gram matrix on - an input matrix of shape NxD. - - Args: - kernel (AbstractKernel): The kernel for which the Gram matrix - should be computed for. - params (Dict): The kernel's parameter set. - inputs (Float[Array, "N D"]): The input matrix. - - Returns: - CovarianceOperator: The computed square Gram matrix. - """ - - diag = vmap(lambda x: self.kernel_fn(params, x, x))(inputs) - - return DiagonalLinearOperator(diag=diag) - - def cross_covariance( - self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] - ) -> Float[Array, "N M"]: - raise ValueError("Cross covariance not defined for diagonal kernels.") - - -class ConstantDiagonalKernelComputation(AbstractKernelComputation): - def __init__( - self, - kernel_fn: Callable[ - [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array - ] = None, - ) -> None: - super().__init__(kernel_fn) - - def gram( - self, - params: Dict, - inputs: Float[Array, "N D"], - ) -> ConstantDiagonalLinearOperator: - """For a kernel with diagonal structure, compute the NxN gram matrix on - an input matrix of shape NxD. - - Args: - kernel (AbstractKernel): The kernel for which the Gram matrix - should be computed for. - params (Dict): The kernel's parameter set. - inputs (Float[Array, "N D"]): The input matrix. - - Returns: - CovarianceOperator: The computed square Gram matrix. - """ - - value = self.kernel_fn(params, inputs[0], inputs[0]) - - return ConstantDiagonalLinearOperator(value=value, size=inputs.shape[0]) - - def diagonal( - self, - params: Dict, - inputs: Float[Array, "N D"], - ) -> DiagonalLinearOperator: - """For a given kernel, compute the elementwise diagonal of the - NxN gram matrix on an input matrix of shape NxD. - - Args: - kernel (AbstractKernel): The kernel for which the variance - vector should be computed for. - params (Dict): The kernel's parameter set. - inputs (Float[Array, "N D"]): The input matrix. - - Returns: - LinearOperator: The computed diagonal variance entries. - """ - - diag = vmap(lambda x: self.kernel_fn(params, x, x))(inputs) - - return DiagonalLinearOperator(diag=diag) - - def cross_covariance( - self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] - ) -> Float[Array, "N M"]: - raise ValueError("Cross covariance not defined for constant diagonal kernels.") - - -########################################## -# Abtract classes -########################################## -class AbstractKernel(PyTree): - """ - Base kernel class""" - - def __init__( - self, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "AbstractKernel", - ) -> None: - self.compute_engine = compute_engine - self.active_dims = active_dims - self.stationary = stationary - self.spectral = spectral - self.name = name - self.ndims = 1 if not self.active_dims else len(self.active_dims) - compute_engine = self.compute_engine(kernel_fn=self.__call__) - self.gram = compute_engine.gram - self.cross_covariance = compute_engine.cross_covariance - - @abc.abstractmethod - def __call__( - self, - params: Dict, - x: Float[Array, "1 D"], - y: Float[Array, "1 D"], - ) -> Float[Array, "1"]: - """Evaluate the kernel on a pair of inputs. - - Args: - params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call - - Returns: - Float[Array, "1"]: The value of :math:`k(x, y)`. - """ - raise NotImplementedError - - def slice_input(self, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: - """Select the relevant columns of the supplied matrix to be used within the kernel's evaluation. - - Args: - x (Float[Array, "N D"]): The matrix or vector that is to be sliced. - Returns: - Float[Array, "N Q"]: A sliced form of the input matrix. - """ - return x[..., self.active_dims] - - def __add__(self, other: "AbstractKernel") -> "AbstractKernel": - """Add two kernels together. - Args: - other (AbstractKernel): The kernel to be added to the current kernel. - - Returns: - AbstractKernel: A new kernel that is the sum of the two kernels. - """ - return SumKernel(kernel_set=[self, other]) - - def __mul__(self, other: "AbstractKernel") -> "AbstractKernel": - """Multiply two kernels together. - - Args: - other (AbstractKernel): The kernel to be multiplied with the current kernel. - - Returns: - AbstractKernel: A new kernel that is the product of the two kernels. - """ - return ProductKernel(kernel_set=[self, other]) - - @property - def ard(self): - """Boolean property as to whether the kernel is isotropic or of - automatic relevance determination form. - - Returns: - bool: True if the kernel is an ARD kernel. - """ - return True if self.ndims > 1 else False - - @abc.abstractmethod - def init_params(self, key: KeyArray) -> Dict: - """A template dictionary of the kernel's parameter set. - - Args: - key (KeyArray): A PRNG key to be used for initialising - the kernel's parameters. - - Returns: - Dict: A dictionary of the kernel's parameters. - """ - raise NotImplementedError - - - @deprecation.deprecated( - deprecated_in="0.0.3", - removed_in="0.1.0", - ) - def _initialise_params(self, key: KeyArray) -> Dict: - """A template dictionary of the kernel's parameter set. - - Args: - key (KeyArray): A PRNG key to be used for initialising - the kernel's parameters. - - Returns: - Dict: A dictionary of the kernel's parameters. - """ - raise NotImplementedError - - - - -class CombinationKernel(AbstractKernel): - """A base class for products or sums of kernels.""" - - def __init__( - self, - kernel_set: List[AbstractKernel], - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "AbstractKernel", - ) -> None: - super().__init__(compute_engine, active_dims, stationary, spectral, name) - self.kernel_set = kernel_set - name: Optional[str] = "Combination kernel" - self.combination_fn: Optional[Callable] = None - - if not all(isinstance(k, AbstractKernel) for k in self.kernel_set): - raise TypeError("can only combine Kernel instances") # pragma: no cover - - self._set_kernels(self.kernel_set) - - def _set_kernels(self, kernels: Sequence[AbstractKernel]) -> None: - """Combine multiple kernels. Based on GPFlow's Combination kernel.""" - # add kernels to a list, flattening out instances of this class therein - kernels_list: List[AbstractKernel] = [] - for k in kernels: - if isinstance(k, self.__class__): - kernels_list.extend(k.kernel_set) - else: - kernels_list.append(k) - - self.kernel_set = kernels_list - - def init_params(self, key: KeyArray) -> Dict: - """A template dictionary of the kernel's parameter set.""" - return [kernel.init_params(key) for kernel in self.kernel_set] - - def __call__( - self, - params: Dict, - x: Float[Array, "1 D"], - y: Float[Array, "1 D"], - ) -> Float[Array, "1"]: - """Evaluate combination kernel on a pair of inputs. - - Args: - params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call - - Returns: - Float[Array, "1"]: The value of :math:`k(x, y)`. - """ - return self.combination_fn( - jnp.stack([k(p, x, y) for k, p in zip(self.kernel_set, params)]) - ) - - -class SumKernel(CombinationKernel): - """A kernel that is the sum of a set of kernels.""" - - def __init__( - self, - kernel_set: List[AbstractKernel], - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Sum kernel", - ) -> None: - super().__init__( - kernel_set, compute_engine, active_dims, stationary, spectral, name - ) - self.combination_fn: Optional[Callable] = jnp.sum - - -class ProductKernel(CombinationKernel): - """A kernel that is the product of a set of kernels.""" - - def __init__( - self, - kernel_set: List[AbstractKernel], - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Product kernel", - ) -> None: - super().__init__( - kernel_set, compute_engine, active_dims, stationary, spectral, name - ) - self.combination_fn: Optional[Callable] = jnp.prod - - -########################################## -# Euclidean kernels -########################################## -class RBF(AbstractKernel): - """The Radial Basis Function (RBF) kernel.""" - - def __init__( - self, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Radial basis function kernel", - ) -> None: - super().__init__(compute_engine, active_dims, stationary, spectral, name) - - def __call__( - self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] - ) -> Float[Array, "1"]: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with - lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2` - - .. math:: - k(x, y) = \\sigma^2 \\exp \\Bigg( \\frac{\\lVert x - y \\rVert^2_2}{2 \\ell^2} \\Bigg) - - Args: - params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. - - Returns: - Float[Array, "1"]: The value of :math:`k(x, y)`. - """ - x = self.slice_input(x) / params["lengthscale"] - y = self.slice_input(y) / params["lengthscale"] - K = params["variance"] * jnp.exp(-0.5 * squared_distance(x, y)) - return K.squeeze() - - def init_params(self, key: KeyArray) -> Dict: - params = { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - } - return jax.tree_util.tree_map(lambda x: jnp.atleast_1d(x), params) - - -class Matern12(AbstractKernel): - """The Matérn kernel with smoothness parameter fixed at 0.5.""" - - def __init__( - self, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Matérn 1/2 kernel", - ) -> None: - super().__init__(compute_engine, active_dims, stationary, spectral, name) - - def __call__( - self, - params: Dict, - x: Float[Array, "1 D"], - y: Float[Array, "1 D"], - ) -> Float[Array, "1"]: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with - lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2` - - .. math:: - k(x, y) = \\sigma^2 \\exp \\Bigg( -\\frac{\\lvert x-y \\rvert}{2\\ell^2} \\Bigg) - - Args: - params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call - Returns: - Float[Array, "1"]: The value of :math:`k(x, y)` - """ - x = self.slice_input(x) / params["lengthscale"] - y = self.slice_input(y) / params["lengthscale"] - K = params["variance"] * jnp.exp(-euclidean_distance(x, y)) - return K.squeeze() - - def init_params(self, key: KeyArray) -> Dict: - return { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - } - - -class Matern32(AbstractKernel): - """The Matérn kernel with smoothness parameter fixed at 1.5.""" - - def __init__( - self, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Matern 3/2", - ) -> None: - super().__init__(compute_engine, active_dims, stationary, spectral, name) - - def __call__( - self, - params: Dict, - x: Float[Array, "1 D"], - y: Float[Array, "1 D"], - ) -> Float[Array, "1"]: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with - lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2` - - .. math:: - k(x, y) = \\sigma^2 \\exp \\Bigg(1+ \\frac{\\sqrt{3}\\lvert x-y \\rvert}{\\ell^2} \\Bigg)\\exp\\Bigg(-\\frac{\\sqrt{3}\\lvert x-y\\rvert}{\\ell^2} \\Bigg) - - Args: - params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. - - Returns: - Float[Array, "1"]: The value of :math:`k(x, y)`. - """ - x = self.slice_input(x) / params["lengthscale"] - y = self.slice_input(y) / params["lengthscale"] - tau = euclidean_distance(x, y) - K = ( - params["variance"] - * (1.0 + jnp.sqrt(3.0) * tau) - * jnp.exp(-jnp.sqrt(3.0) * tau) - ) - return K.squeeze() - - def init_params(self, key: KeyArray) -> Dict: - return { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - } - - -class Matern52(AbstractKernel): - """The Matérn kernel with smoothness parameter fixed at 2.5.""" - - def __init__( - self, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Matern 5/2", - ) -> None: - super().__init__(compute_engine, active_dims, stationary, spectral, name) - - def __call__( - self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] - ) -> Float[Array, "1"]: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with - lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2` - - .. math:: - k(x, y) = \\sigma^2 \\exp \\Bigg(1+ \\frac{\\sqrt{5}\\lvert x-y \\rvert}{\\ell^2} + \\frac{5\\lvert x - y \\rvert^2}{3\\ell^2} \\Bigg)\\exp\\Bigg(-\\frac{\\sqrt{5}\\lvert x-y\\rvert}{\\ell^2} \\Bigg) - - Args: - params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. - - Returns: - Float[Array, "1"]: The value of :math:`k(x, y)`. - """ - x = self.slice_input(x) / params["lengthscale"] - y = self.slice_input(y) / params["lengthscale"] - tau = euclidean_distance(x, y) - K = ( - params["variance"] - * (1.0 + jnp.sqrt(5.0) * tau + 5.0 / 3.0 * jnp.square(tau)) - * jnp.exp(-jnp.sqrt(5.0) * tau) - ) - return K.squeeze() - - def init_params(self, key: KeyArray) -> Dict: - return { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - } - - -class PoweredExponential(AbstractKernel): - """The powered exponential family of kernels. - - Key reference is Diggle and Ribeiro (2007) - "Model-based Geostatistics". - - """ - - def __init__( - self, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Powered exponential", - ) -> None: - super().__init__(compute_engine, active_dims, stationary, spectral, name) - - def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell`, :math:`\\sigma` and power :math:`\\kappa`. - - .. math:: - k(x, y) = \\sigma^2 \\exp \\Bigg( - \\Big( \\frac{\\lVert x - y \\rVert^2}{\\ell^2} \\Big)^\\kappa \\Bigg) - - Args: - x (jax.Array): The left hand argument of the kernel function's call. - y (jax.Array): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. - - Returns: - Array: The value of :math:`k(x, y)` - """ - x = self.slice_input(x) / params["lengthscale"] - y = self.slice_input(y) / params["lengthscale"] - K = params["variance"] * jnp.exp(-euclidean_distance(x, y) ** params["power"]) - return K.squeeze() - - def init_params(self, key: KeyArray) -> Dict: - return { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - "power": jnp.array([1.0]), - } - - -class Linear(AbstractKernel): - """The linear kernel.""" - - def __init__( - self, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Linear", - ) -> None: - super().__init__(compute_engine, active_dims, stationary, spectral, name) - - def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with variance parameter :math:`\\sigma` - - .. math:: - k(x, y) = \\sigma^2 x^{T}y - - Args: - x (jax.Array): The left hand argument of the kernel function's call. - y (jax.Array): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. - Returns: - Array: The value of :math:`k(x, y)` - """ - x = self.slice_input(x) - y = self.slice_input(y) - K = params["variance"] * jnp.matmul(x.T, y) - return K.squeeze() - - def init_params(self, key: KeyArray) -> Dict: - return {"variance": jnp.array([1.0])} - - -class Polynomial(AbstractKernel): - """The Polynomial kernel with variable degree.""" - - def __init__( - self, - degree: int = 1, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Polynomial", - ) -> None: - super().__init__(compute_engine, active_dims, stationary, spectral, name) - self.degree = degree - self.name = f"Polynomial Degree: {self.degree}" - - def __call__( - self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] - ) -> Float[Array, "1"]: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with shift parameter :math:`\\alpha` and variance :math:`\\sigma^2` through - - .. math:: - k(x, y) = \\Big( \\alpha + \\sigma^2 xy \\Big)^{d} - - Args: - params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call - - Returns: - Float[Array, "1"]: The value of :math:`k(x, y)`. - """ - x = self.slice_input(x).squeeze() - y = self.slice_input(y).squeeze() - K = jnp.power(params["shift"] + jnp.dot(x * params["variance"], y), self.degree) - return K.squeeze() - - def init_params(self, key: KeyArray) -> Dict: - return { - "shift": jnp.array([1.0]), - "variance": jnp.array([1.0] * self.ndims), - } - - -class White(AbstractKernel, ConstantDiagonalKernelComputation): - def __post_init__(self) -> None: - super(White, self).__post_init__() - - def __call__( - self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] - ) -> Float[Array, "1"]: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with variance :math:`\\sigma` - - .. math:: - k(x, y) = \\sigma^2 \\delta(x-y) - - Args: - params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. - - Returns: - Float[Array, "1"]: The value of :math:`k(x, y)`. - """ - K = jnp.all(jnp.equal(x, y)) * params["variance"] - return K.squeeze() - - def init_params(self, key: Float[Array, "1 D"]) -> Dict: - """Initialise the kernel parameters. - - Args: - key (Float[Array, "1 D"]): The key to initialise the parameters with. - - Returns: - Dict: The initialised parameters. - """ - return {"variance": jnp.array([1.0])} - - -class RationalQuadratic(AbstractKernel): - def __init__( - self, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Rational Quadratic", - ) -> None: - super().__init__(compute_engine, active_dims, stationary, spectral, name) - - def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell` and variance :math:`\\sigma` - - .. math:: - k(x, y) = \\sigma^2 \\exp \\Bigg( 1 + \\frac{\\lVert x - y \\rVert^2_2}{2 \\alpha \\ell^2} \\Bigg) - - Args: - x (jax.Array): The left hand argument of the kernel function's call. - y (jax.Array): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. - Returns: - Array: The value of :math:`k(x, y)` - """ - x = self.slice_input(x) / params["lengthscale"] - y = self.slice_input(y) / params["lengthscale"] - K = params["variance"] * ( - 1 + 0.5 * squared_distance(x, y) / params["alpha"] - ) ** (-params["alpha"]) - return K.squeeze() - - def init_params(self, key: KeyArray) -> dict: - return { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - "alpha": jnp.array([1.0]), - } - - -class Periodic(AbstractKernel): - """The periodic kernel. - - Key reference is MacKay 1998 - "Introduction to Gaussian processes". - """ - - def __init__( - self, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Periodic", - ) -> None: - super().__init__(compute_engine, active_dims, stationary, spectral, name) - - def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell` and variance :math:`\\sigma` - - .. math:: - k(x, y) = \\sigma^2 \\exp \\Bigg( -0.5 \\sum_{i=1}^{d} \\Bigg) - - Args: - x (jax.Array): The left hand argument of the kernel function's call. - y (jax.Array): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. - Returns: - Array: The value of :math:`k(x, y)` - """ - x = self.slice_input(x) - y = self.slice_input(y) - sine_squared = ( - jnp.sin(jnp.pi * (x - y) / params["period"]) / params["lengthscale"] - ) ** 2 - K = params["variance"] * jnp.exp(-0.5 * jnp.sum(sine_squared, axis=0)) - return K.squeeze() - - def init_params(self, key: KeyArray) -> Dict: - return { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - "period": jnp.array([1.0] * self.ndims), - } - - -########################################## -# Graph kernels -########################################## -class EigenKernelComputation(AbstractKernelComputation): - def __init__( - self, - kernel_fn: Callable[ - [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array - ] = None, - ) -> None: - super().__init__(kernel_fn) - self._eigenvalues = None - self._eigenvectors = None - self._num_verticies = None - - # Define an eigenvalue setter and getter property - @property - def eigensystem(self) -> Float[Array, "N"]: - return self._eigenvalues, self._eigenvectors, self._num_verticies - - @eigensystem.setter - def eigensystem( - self, eigenvalues: Float[Array, "N"], eigenvectors: Float[Array, "N N"] - ) -> None: - self._eigenvalues = eigenvalues - self._eigenvectors = eigenvectors - - @property - def num_vertex(self) -> int: - return self._num_verticies - - @num_vertex.setter - def num_vertex(self, num_vertex: int) -> None: - self._num_verticies = num_vertex - - def _compute_S(self, params): - evals, evecs = self.eigensystem - S = jnp.power( - evals - + 2 * params["smoothness"] / params["lengthscale"] / params["lengthscale"], - -params["smoothness"], - ) - S = jnp.multiply(S, self.num_vertex / jnp.sum(S)) - S = jnp.multiply(S, params["variance"]) - return S - - def cross_covariance( - self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] - ) -> Float[Array, "N M"]: - S = self._compute_S(params=params) - matrix = self.kernel_fn(params, x, y, S=S) - return matrix - - -class GraphKernel(AbstractKernel): - def __init__( - self, - laplacian: Float[Array, "N N"], - compute_engine: EigenKernelComputation = EigenKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Graph kernel", - ) -> None: - super().__init__(compute_engine, active_dims, stationary, spectral, name) - self.laplacian = laplacian - evals, self.evecs = jnp.linalg.eigh(self.laplacian) - self.evals = evals.reshape(-1, 1) - self.compute_engine.eigensystem = self.evals, self.evecs - self.compute_engine.num_vertex = self.laplacian.shape[0] - - def __call__( - self, - params: Dict, - x: Float[Array, "1 D"], - y: Float[Array, "1 D"], - **kwargs, - ) -> Float[Array, "1"]: - """Evaluate the graph kernel on a pair of vertices :math:`v_i, v_j`. - - Args: - params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): Index of the ith vertex. - y (Float[Array, "1 D"]): Index of the jth vertex. - - Returns: - Float[Array, "1"]: The value of :math:`k(v_i, v_j)`. - """ - S = kwargs["S"] - Kxx = (jax_gather_nd(self.evecs, x) * S[None, :]) @ jnp.transpose( - jax_gather_nd(self.evecs, y) - ) # shape (n,n) - return Kxx.squeeze() - - def init_params(self, key: KeyArray) -> Dict: - return { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - "smoothness": jnp.array([1.0]), - } - - @property - def num_vertex(self) -> int: - return self.compute_engine.num_vertex - - -def squared_distance( - x: Float[Array, "1 D"], y: Float[Array, "1 D"] -) -> Float[Array, "1"]: - """Compute the squared distance between a pair of inputs. - - Args: - x (Float[Array, "1 D"]): First input. - y (Float[Array, "1 D"]): Second input. - - Returns: - Float[Array, "1"]: The squared distance between the inputs. - """ - - return jnp.sum((x - y) ** 2) - - -def euclidean_distance( - x: Float[Array, "1 D"], y: Float[Array, "1 D"] -) -> Float[Array, "1"]: - """Compute the euclidean distance between a pair of inputs. - - Args: - x (Float[Array, "1 D"]): First input. - y (Float[Array, "1 D"]): Second input. - - Returns: - Float[Array, "1"]: The euclidean distance between the inputs. - """ - - return jnp.sqrt(jnp.maximum(squared_distance(x, y), 1e-36)) - - -def jax_gather_nd(params, indices): - tuple_indices = tuple(indices[..., i] for i in range(indices.shape[-1])) - return params[tuple_indices] - - -__all__ = [ - "AbstractKernel", - "CombinationKernel", - "SumKernel", - "ProductKernel", - "RBF", - "Matern12", - "Matern32", - "Matern52", - "Linear", - "Periodic", - "RationalQuadratic", - "Polynomial", - "White", - "GraphKernel", - "squared_distance", - "euclidean_distance", - "AbstractKernelComputation", - "DenseKernelComputation", - "DiagonalKernelComputation", -] diff --git a/jaxkern/non_euclidean.py b/jaxkern/non_euclidean.py new file mode 100644 index 0000000..ebb8231 --- /dev/null +++ b/jaxkern/non_euclidean.py @@ -0,0 +1,66 @@ +from typing import Dict, List, Optional + +import jax.numpy as jnp +from jax.random import KeyArray +from jaxtyping import Array, Float + +from .computations import EigenKernelComputation +from .nonstationary import AbstractKernel +from .utils import jax_gather_nd + + +########################################## +# Graph kernels +########################################## +class GraphKernel(AbstractKernel): + def __init__( + self, + laplacian: Float[Array, "N N"], + compute_engine: EigenKernelComputation = EigenKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Graph kernel", + ) -> None: + super().__init__( + compute_engine, active_dims, stationary, spectral, name + ) + self.laplacian = laplacian + evals, self.evecs = jnp.linalg.eigh(self.laplacian) + self.evals = evals.reshape(-1, 1) + self.compute_engine.eigensystem = self.evals, self.evecs + self.compute_engine.num_vertex = self.laplacian.shape[0] + + def __call__( + self, + params: Dict, + x: Float[Array, "1 D"], + y: Float[Array, "1 D"], + **kwargs, + ) -> Float[Array, "1"]: + """Evaluate the graph kernel on a pair of vertices :math:`v_i, v_j`. + + Args: + params (Dict): Parameter set for which the kernel should be evaluated on. + x (Float[Array, "1 D"]): Index of the ith vertex. + y (Float[Array, "1 D"]): Index of the jth vertex. + + Returns: + Float[Array, "1"]: The value of :math:`k(v_i, v_j)`. + """ + S = kwargs["S"] + Kxx = (jax_gather_nd(self.evecs, x) * S[None, :]) @ jnp.transpose( + jax_gather_nd(self.evecs, y) + ) # shape (n,n) + return Kxx.squeeze() + + def init_params(self, key: KeyArray) -> Dict: + return { + "lengthscale": jnp.array([1.0] * self.ndims), + "variance": jnp.array([1.0]), + "smoothness": jnp.array([1.0]), + } + + @property + def num_vertex(self) -> int: + return self.compute_engine.num_vertex diff --git a/jaxkern/nonstationary.py b/jaxkern/nonstationary.py new file mode 100644 index 0000000..8b3a3f1 --- /dev/null +++ b/jaxkern/nonstationary.py @@ -0,0 +1,208 @@ +# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Dict, List, Optional + +import jax +import jax.numpy as jnp +from jax.random import KeyArray +from jaxtyping import Array, Float + +from .base import AbstractKernel +from .computations import ( + AbstractKernelComputation, + ConstantDiagonalKernelComputation, + DenseKernelComputation, + DiagonalKernelComputation, +) +from .utils import euclidean_distance, squared_distance + + +########################################## +# Euclidean kernels +########################################## +class Linear(AbstractKernel): + """The linear kernel.""" + + def __init__( + self, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Linear", + ) -> None: + super().__init__( + compute_engine, active_dims, stationary, spectral, name + ) + + def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with variance parameter :math:`\\sigma` + + .. math:: + k(x, y) = \\sigma^2 x^{T}y + + Args: + x (jax.Array): The left hand argument of the kernel function's call. + y (jax.Array): The right hand argument of the kernel function's call + params (dict): Parameter set for which the kernel should be evaluated on. + Returns: + Array: The value of :math:`k(x, y)` + """ + x = self.slice_input(x) + y = self.slice_input(y) + K = params["variance"] * jnp.matmul(x.T, y) + return K.squeeze() + + def init_params(self, key: KeyArray) -> Dict: + return {"variance": jnp.array([1.0])} + + +class Polynomial(AbstractKernel): + """The Polynomial kernel with variable degree.""" + + def __init__( + self, + degree: int = 1, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Polynomial", + ) -> None: + super().__init__( + compute_engine, active_dims, stationary, spectral, name + ) + self.degree = degree + self.name = f"Polynomial Degree: {self.degree}" + + def __call__( + self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] + ) -> Float[Array, "1"]: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with shift parameter :math:`\\alpha` and variance :math:`\\sigma^2` through + + .. math:: + k(x, y) = \\Big( \\alpha + \\sigma^2 xy \\Big)^{d} + + Args: + params (Dict): Parameter set for which the kernel should be evaluated on. + x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. + y (Float[Array, "1 D"]): The right hand argument of the kernel function's call + + Returns: + Float[Array, "1"]: The value of :math:`k(x, y)`. + """ + x = self.slice_input(x).squeeze() + y = self.slice_input(y).squeeze() + K = jnp.power( + params["shift"] + jnp.dot(x * params["variance"], y), self.degree + ) + return K.squeeze() + + def init_params(self, key: KeyArray) -> Dict: + return { + "shift": jnp.array([1.0]), + "variance": jnp.array([1.0] * self.ndims), + } + + +class White(AbstractKernel, ConstantDiagonalKernelComputation): + def __post_init__(self) -> None: + super(White, self).__post_init__() + + def __call__( + self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] + ) -> Float[Array, "1"]: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with variance :math:`\\sigma` + + .. math:: + k(x, y) = \\sigma^2 \\delta(x-y) + + Args: + params (Dict): Parameter set for which the kernel should be evaluated on. + x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. + y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. + + Returns: + Float[Array, "1"]: The value of :math:`k(x, y)`. + """ + K = jnp.all(jnp.equal(x, y)) * params["variance"] + return K.squeeze() + + def init_params(self, key: Float[Array, "1 D"]) -> Dict: + """Initialise the kernel parameters. + + Args: + key (Float[Array, "1 D"]): The key to initialise the parameters with. + + Returns: + Dict: The initialised parameters. + """ + return {"variance": jnp.array([1.0])} + + +class Periodic(AbstractKernel): + """The periodic kernel. + + Key reference is MacKay 1998 - "Introduction to Gaussian processes". + """ + + def __init__( + self, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Periodic", + ) -> None: + super().__init__( + compute_engine, active_dims, stationary, spectral, name + ) + + def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell` and variance :math:`\\sigma` + + .. math:: + k(x, y) = \\sigma^2 \\exp \\Bigg( -0.5 \\sum_{i=1}^{d} \\Bigg) + + Args: + x (jax.Array): The left hand argument of the kernel function's call. + y (jax.Array): The right hand argument of the kernel function's call + params (dict): Parameter set for which the kernel should be evaluated on. + Returns: + Array: The value of :math:`k(x, y)` + """ + x = self.slice_input(x) + y = self.slice_input(y) + sine_squared = ( + jnp.sin(jnp.pi * (x - y) / params["period"]) / params["lengthscale"] + ) ** 2 + K = params["variance"] * jnp.exp(-0.5 * jnp.sum(sine_squared, axis=0)) + return K.squeeze() + + def init_params(self, key: KeyArray) -> Dict: + return { + "lengthscale": jnp.array([1.0] * self.ndims), + "variance": jnp.array([1.0]), + "period": jnp.array([1.0] * self.ndims), + } + + +__all__ = [ + "Linear", + "Periodic", + "Polynomial", + "White", +] diff --git a/jaxkern/stationary.py b/jaxkern/stationary.py new file mode 100644 index 0000000..c3f6e42 --- /dev/null +++ b/jaxkern/stationary.py @@ -0,0 +1,304 @@ +from typing import Dict, List, Optional + +import jax +import jax.numpy as jnp +from jax.random import KeyArray +from jaxtyping import Array, Float + +from .base import AbstractKernel +from .computations import ( + AbstractKernelComputation, + DenseKernelComputation, +) +from .utils import euclidean_distance, squared_distance + + +class RBF(AbstractKernel): + """The Radial Basis Function (RBF) kernel.""" + + def __init__( + self, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Radial basis function kernel", + ) -> None: + super().__init__( + compute_engine, active_dims, stationary, spectral, name + ) + + def __call__( + self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] + ) -> Float[Array, "1"]: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with + lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2` + + .. math:: + k(x, y) = \\sigma^2 \\exp \\Bigg( \\frac{\\lVert x - y \\rVert^2_2}{2 \\ell^2} \\Bigg) + + Args: + params (Dict): Parameter set for which the kernel should be evaluated on. + x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. + y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. + + Returns: + Float[Array, "1"]: The value of :math:`k(x, y)`. + """ + x = self.slice_input(x) / params["lengthscale"] + y = self.slice_input(y) / params["lengthscale"] + K = params["variance"] * jnp.exp(-0.5 * squared_distance(x, y)) + return K.squeeze() + + def init_params(self, key: KeyArray) -> Dict: + params = { + "lengthscale": jnp.array([1.0] * self.ndims), + "variance": jnp.array([1.0]), + } + return jax.tree_util.tree_map(lambda x: jnp.atleast_1d(x), params) + + +class Matern12(AbstractKernel): + """The Matérn kernel with smoothness parameter fixed at 0.5.""" + + def __init__( + self, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Matérn 1/2 kernel", + ) -> None: + super().__init__( + compute_engine, active_dims, stationary, spectral, name + ) + + def __call__( + self, + params: Dict, + x: Float[Array, "1 D"], + y: Float[Array, "1 D"], + ) -> Float[Array, "1"]: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with + lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2` + + .. math:: + k(x, y) = \\sigma^2 \\exp \\Bigg( -\\frac{\\lvert x-y \\rvert}{2\\ell^2} \\Bigg) + + Args: + params (Dict): Parameter set for which the kernel should be evaluated on. + x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. + y (Float[Array, "1 D"]): The right hand argument of the kernel function's call + Returns: + Float[Array, "1"]: The value of :math:`k(x, y)` + """ + x = self.slice_input(x) / params["lengthscale"] + y = self.slice_input(y) / params["lengthscale"] + K = params["variance"] * jnp.exp(-euclidean_distance(x, y)) + return K.squeeze() + + def init_params(self, key: KeyArray) -> Dict: + return { + "lengthscale": jnp.array([1.0] * self.ndims), + "variance": jnp.array([1.0]), + } + + +class Matern32(AbstractKernel): + """The Matérn kernel with smoothness parameter fixed at 1.5.""" + + def __init__( + self, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Matern 3/2", + ) -> None: + super().__init__( + compute_engine, active_dims, stationary, spectral, name + ) + + def __call__( + self, + params: Dict, + x: Float[Array, "1 D"], + y: Float[Array, "1 D"], + ) -> Float[Array, "1"]: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with + lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2` + + .. math:: + k(x, y) = \\sigma^2 \\exp \\Bigg(1+ \\frac{\\sqrt{3}\\lvert x-y \\rvert}{\\ell^2} \\Bigg)\\exp\\Bigg(-\\frac{\\sqrt{3}\\lvert x-y\\rvert}{\\ell^2} \\Bigg) + + Args: + params (Dict): Parameter set for which the kernel should be evaluated on. + x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. + y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. + + Returns: + Float[Array, "1"]: The value of :math:`k(x, y)`. + """ + x = self.slice_input(x) / params["lengthscale"] + y = self.slice_input(y) / params["lengthscale"] + tau = euclidean_distance(x, y) + K = ( + params["variance"] + * (1.0 + jnp.sqrt(3.0) * tau) + * jnp.exp(-jnp.sqrt(3.0) * tau) + ) + return K.squeeze() + + def init_params(self, key: KeyArray) -> Dict: + return { + "lengthscale": jnp.array([1.0] * self.ndims), + "variance": jnp.array([1.0]), + } + + +class Matern52(AbstractKernel): + """The Matérn kernel with smoothness parameter fixed at 2.5.""" + + def __init__( + self, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Matern 5/2", + ) -> None: + super().__init__( + compute_engine, active_dims, stationary, spectral, name + ) + + def __call__( + self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] + ) -> Float[Array, "1"]: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with + lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2` + + .. math:: + k(x, y) = \\sigma^2 \\exp \\Bigg(1+ \\frac{\\sqrt{5}\\lvert x-y \\rvert}{\\ell^2} + \\frac{5\\lvert x - y \\rvert^2}{3\\ell^2} \\Bigg)\\exp\\Bigg(-\\frac{\\sqrt{5}\\lvert x-y\\rvert}{\\ell^2} \\Bigg) + + Args: + params (Dict): Parameter set for which the kernel should be evaluated on. + x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. + y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. + + Returns: + Float[Array, "1"]: The value of :math:`k(x, y)`. + """ + x = self.slice_input(x) / params["lengthscale"] + y = self.slice_input(y) / params["lengthscale"] + tau = euclidean_distance(x, y) + K = ( + params["variance"] + * (1.0 + jnp.sqrt(5.0) * tau + 5.0 / 3.0 * jnp.square(tau)) + * jnp.exp(-jnp.sqrt(5.0) * tau) + ) + return K.squeeze() + + def init_params(self, key: KeyArray) -> Dict: + return { + "lengthscale": jnp.array([1.0] * self.ndims), + "variance": jnp.array([1.0]), + } + + +class PoweredExponential(AbstractKernel): + """The powered exponential family of kernels. + + Key reference is Diggle and Ribeiro (2007) - "Model-based Geostatistics". + + """ + + def __init__( + self, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Powered exponential", + ) -> None: + super().__init__( + compute_engine, active_dims, stationary, spectral, name + ) + + def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell`, :math:`\\sigma` and power :math:`\\kappa`. + + .. math:: + k(x, y) = \\sigma^2 \\exp \\Bigg( - \\Big( \\frac{\\lVert x - y \\rVert^2}{\\ell^2} \\Big)^\\kappa \\Bigg) + + Args: + x (jax.Array): The left hand argument of the kernel function's call. + y (jax.Array): The right hand argument of the kernel function's call + params (dict): Parameter set for which the kernel should be evaluated on. + + Returns: + Array: The value of :math:`k(x, y)` + """ + x = self.slice_input(x) / params["lengthscale"] + y = self.slice_input(y) / params["lengthscale"] + K = params["variance"] * jnp.exp( + -euclidean_distance(x, y) ** params["power"] + ) + return K.squeeze() + + def init_params(self, key: KeyArray) -> Dict: + return { + "lengthscale": jnp.array([1.0] * self.ndims), + "variance": jnp.array([1.0]), + "power": jnp.array([1.0]), + } + + +class RationalQuadratic(AbstractKernel): + def __init__( + self, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Rational Quadratic", + ) -> None: + super().__init__( + compute_engine, active_dims, stationary, spectral, name + ) + + def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell` and variance :math:`\\sigma` + + .. math:: + k(x, y) = \\sigma^2 \\exp \\Bigg( 1 + \\frac{\\lVert x - y \\rVert^2_2}{2 \\alpha \\ell^2} \\Bigg) + + Args: + x (jax.Array): The left hand argument of the kernel function's call. + y (jax.Array): The right hand argument of the kernel function's call + params (dict): Parameter set for which the kernel should be evaluated on. + Returns: + Array: The value of :math:`k(x, y)` + """ + x = self.slice_input(x) / params["lengthscale"] + y = self.slice_input(y) / params["lengthscale"] + K = params["variance"] * ( + 1 + 0.5 * squared_distance(x, y) / params["alpha"] + ) ** (-params["alpha"]) + return K.squeeze() + + def init_params(self, key: KeyArray) -> dict: + return { + "lengthscale": jnp.array([1.0] * self.ndims), + "variance": jnp.array([1.0]), + "alpha": jnp.array([1.0]), + } + + +__all__ = [ + "RBF", + "Matern12", + "Matern32", + "Matern52", + "RationalQuadratic", + "PoweredExponential", +] diff --git a/jaxkern/utils.py b/jaxkern/utils.py new file mode 100644 index 0000000..06fbe6c --- /dev/null +++ b/jaxkern/utils.py @@ -0,0 +1,39 @@ +import jax.numpy as jnp +from jaxtyping import Array, Float + + +def squared_distance( + x: Float[Array, "1 D"], y: Float[Array, "1 D"] +) -> Float[Array, "1"]: + """Compute the squared distance between a pair of inputs. + + Args: + x (Float[Array, "1 D"]): First input. + y (Float[Array, "1 D"]): Second input. + + Returns: + Float[Array, "1"]: The squared distance between the inputs. + """ + + return jnp.sum((x - y) ** 2) + + +def euclidean_distance( + x: Float[Array, "1 D"], y: Float[Array, "1 D"] +) -> Float[Array, "1"]: + """Compute the euclidean distance between a pair of inputs. + + Args: + x (Float[Array, "1 D"]): First input. + y (Float[Array, "1 D"]): Second input. + + Returns: + Float[Array, "1"]: The euclidean distance between the inputs. + """ + + return jnp.sqrt(jnp.maximum(squared_distance(x, y), 1e-36)) + + +def jax_gather_nd(params, indices): + tuple_indices = tuple(indices[..., i] for i in range(indices.shape[-1])) + return params[tuple_indices] diff --git a/tests/test_base.py b/tests/test_base.py new file mode 100644 index 0000000..5e8c67e --- /dev/null +++ b/tests/test_base.py @@ -0,0 +1,201 @@ +import jax.numpy as jnp +import jax.random as jr +import pytest +from jax.config import config +from jaxlinop import identity + +from jaxkern.base import ( + AbstractKernel, + CombinationKernel, + ProductKernel, + SumKernel, +) +from jaxkern.stationary import ( + RBF, + Matern12, + Matern32, + Matern52, + RationalQuadratic, +) +from jaxkern.nonstationary import Polynomial, Linear +from jax.random import KeyArray +from jaxtyping import Array, Float +from typing import Dict + + +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) +_initialise_key = jr.PRNGKey(123) +_jitter = 1e-6 + + +def test_abstract_kernel(): + # Test initialising abstract kernel raises TypeError with unimplemented __call__ and _init_params methods: + with pytest.raises(TypeError): + AbstractKernel() + + # Create a dummy kernel class with __call__ and _init_params methods implemented: + class DummyKernel(AbstractKernel): + def __call__( + self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: Dict + ) -> Float[Array, "1"]: + return x * params["test"] * y + + def init_params(self, key: KeyArray) -> Dict: + return {"test": 1.0} + + # Initialise dummy kernel class and test __call__ and _init_params methods: + dummy_kernel = DummyKernel() + assert dummy_kernel.init_params(_initialise_key) == {"test": 1.0} + assert ( + dummy_kernel(jnp.array([1.0]), jnp.array([2.0]), {"test": 2.0}) == 4.0 + ) + + +@pytest.mark.parametrize("combination_type", [SumKernel, ProductKernel]) +@pytest.mark.parametrize( + "kernel", + [RBF, RationalQuadratic, Linear, Matern12, Matern32, Matern52, Polynomial], +) +@pytest.mark.parametrize("n_kerns", [2, 3, 4]) +def test_combination_kernel( + combination_type: CombinationKernel, kernel: AbstractKernel, n_kerns: int +) -> None: + + # Create inputs + n = 20 + x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1) + + # Create list of kernels + kernel_set = [kernel() for _ in range(n_kerns)] + + # Create combination kernel + combination_kernel = combination_type(kernel_set=kernel_set) + + # Initialise default parameters + params = combination_kernel.init_params(_initialise_key) + + # Check params are a list of dictionaries + assert len(params) == n_kerns + + for p in params: + assert isinstance(p, dict) + + # Check combination kernel set + assert len(combination_kernel.kernel_set) == n_kerns + assert isinstance(combination_kernel.kernel_set, list) + assert isinstance(combination_kernel.kernel_set[0], AbstractKernel) + + # Compute gram matrix + Kxx = combination_kernel.gram(params, x) + + # Check shapes + assert Kxx.shape[0] == Kxx.shape[1] + assert Kxx.shape[1] == n + + # Check positive definiteness + Kxx += identity(n) * _jitter + eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) + assert (eigen_values > 0).all() + + +@pytest.mark.parametrize( + "k1", [RBF(), Matern12(), Matern32(), Matern52(), Polynomial()] +) +@pytest.mark.parametrize( + "k2", [RBF(), Matern12(), Matern32(), Matern52(), Polynomial()] +) +def test_sum_kern_value(k1: AbstractKernel, k2: AbstractKernel) -> None: + # Create inputs + n = 10 + x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1) + + # Create sum kernel + sum_kernel = SumKernel(kernel_set=[k1, k2]) + + # Initialise default parameters + params = sum_kernel.init_params(_initialise_key) + + # Compute gram matrix + Kxx = sum_kernel.gram(params, x) + + # NOW we do the same thing manually and check they are equal: + # Initialise default parameters + k1_params = k1.init_params(_initialise_key) + k2_params = k2.init_params(_initialise_key) + + # Compute gram matrix + Kxx_k1 = k1.gram(k1_params, x) + Kxx_k2 = k2.gram(k2_params, x) + + # Check manual and automatic gram matrices are equal + assert jnp.all(Kxx.to_dense() == Kxx_k1.to_dense() + Kxx_k2.to_dense()) + + +@pytest.mark.parametrize( + "k1", + [ + RBF(), + Matern12(), + Matern32(), + Matern52(), + Polynomial(), + Linear(), + Polynomial(), + RationalQuadratic(), + ], +) +@pytest.mark.parametrize( + "k2", + [ + RBF(), + Matern12(), + Matern32(), + Matern52(), + Polynomial(), + Linear(), + Polynomial(), + RationalQuadratic(), + ], +) +def test_prod_kern_value(k1: AbstractKernel, k2: AbstractKernel) -> None: + + # Create inputs + n = 10 + x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1) + + # Create product kernel + prod_kernel = ProductKernel(kernel_set=[k1, k2]) + + # Initialise default parameters + params = prod_kernel.init_params(_initialise_key) + + # Compute gram matrix + Kxx = prod_kernel.gram(params, x) + + # NOW we do the same thing manually and check they are equal: + + # Initialise default parameters + k1_params = k1.init_params(_initialise_key) + k2_params = k2.init_params(_initialise_key) + + # Compute gram matrix + Kxx_k1 = k1.gram(k1_params, x) + Kxx_k2 = k2.gram(k2_params, x) + + # Check manual and automatic gram matrices are equal + assert jnp.all(Kxx.to_dense() == Kxx_k1.to_dense() * Kxx_k2.to_dense()) + + +@pytest.mark.parametrize( + "kernel", + [RBF, Matern12, Matern32, Matern52, Polynomial, Linear, RationalQuadratic], +) +def test_combination_kernel_type(kernel: AbstractKernel) -> None: + prod_kern = kernel() * kernel() + assert isinstance(prod_kern, ProductKernel) + assert isinstance(prod_kern, CombinationKernel) + + add_kern = kernel() + kernel() + assert isinstance(add_kern, SumKernel) + assert isinstance(add_kern, CombinationKernel) diff --git a/tests/test_kernels.py b/tests/test_kernels.py deleted file mode 100644 index 171b8fd..0000000 --- a/tests/test_kernels.py +++ /dev/null @@ -1,596 +0,0 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - - -from itertools import permutations -from typing import Dict, List - -import jax -import jax.numpy as jnp -import jax.random as jr -import networkx as nx -import pytest -from jaxutils.parameters import initialise -from jax.config import config -from jax.random import KeyArray -from jaxlinop import LinearOperator, identity -from jaxtyping import Array, Float - -from jaxkern.kernels import ( - RBF, - AbstractKernel, - CombinationKernel, - GraphKernel, - Linear, - Matern12, - Matern32, - Matern52, - Periodic, - Polynomial, - PoweredExponential, - ProductKernel, - RationalQuadratic, - SumKernel, - euclidean_distance, -) - -# Enable Float64 for more stable matrix inversions. -config.update("jax_enable_x64", True) -_initialise_key = jr.PRNGKey(123) -_jitter = 1e-6 - - -def test_abstract_kernel(): - # Test initialising abstract kernel raises TypeError with unimplemented __call__ and _init_params methods: - with pytest.raises(TypeError): - AbstractKernel() - - # Create a dummy kernel class with __call__ and _init_params methods implemented: - class DummyKernel(AbstractKernel): - def __call__( - self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: Dict - ) -> Float[Array, "1"]: - return x * params["test"] * y - - def init_params(self, key: KeyArray) -> Dict: - return {"test": 1.0} - - # Initialise dummy kernel class and test __call__ and _init_params methods: - dummy_kernel = DummyKernel() - assert dummy_kernel.init_params(_initialise_key) == {"test": 1.0} - assert dummy_kernel(jnp.array([1.0]), jnp.array([2.0]), {"test": 2.0}) == 4.0 - - -@pytest.mark.parametrize( - "a, b, distance_to_3dp", - [ - ([1.0], [-4.0], 5.0), - ([1.0, -2.0], [-4.0, 3.0], 7.071), - ([1.0, 2.0, 3.0], [1.0, 1.0, 1.0], 2.236), - ], -) -def test_euclidean_distance( - a: List[float], b: List[float], distance_to_3dp: float -) -> None: - - # Convert lists to JAX arrays: - a: Float[Array, "D"] = jnp.array(a) - b: Float[Array, "D"] = jnp.array(b) - - # Test distance is correct to 3dp: - assert jnp.round(euclidean_distance(a, b), 3) == distance_to_3dp - - -@pytest.mark.parametrize( - "kernel", - [ - RBF(), - Matern12(), - Matern32(), - Matern52(), - Linear(), - Polynomial(), - RationalQuadratic(), - ], -) -@pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize("n", [1, 2, 10]) -def test_gram(kernel: AbstractKernel, dim: int, n: int) -> None: - - # Gram constructor static method: - kernel.gram - - # Inputs x: - x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim) - - # Default kernel parameters: - params = kernel.init_params(_initialise_key) - - # Test gram matrix: - Kxx = kernel.gram(params, x) - assert isinstance(Kxx, LinearOperator) - assert Kxx.shape == (n, n) - - -@pytest.mark.parametrize( - "kernel", - [ - RBF(), - Matern12(), - Matern32(), - Matern52(), - Linear(), - Polynomial(), - RationalQuadratic(), - ], -) -@pytest.mark.parametrize("num_a", [1, 2, 5]) -@pytest.mark.parametrize("num_b", [1, 2, 5]) -@pytest.mark.parametrize("dim", [1, 2, 5]) -def test_cross_covariance( - kernel: AbstractKernel, num_a: int, num_b: int, dim: int -) -> None: - # Inputs a, b: - a = jnp.linspace(-1.0, 1.0, num_a * dim).reshape(num_a, dim) - b = jnp.linspace(3.0, 4.0, num_b * dim).reshape(num_b, dim) - - # Default kernel parameters: - params = kernel.init_params(_initialise_key) - - # Test cross covariance, Kab: - Kab = kernel.cross_covariance(params, a, b) - assert isinstance(Kab, jnp.ndarray) - assert Kab.shape == (num_a, num_b) - - -@pytest.mark.parametrize("kernel", [RBF(), Matern12(), Matern32(), Matern52()]) -@pytest.mark.parametrize("dim", [1, 2, 5]) -def test_call(kernel: AbstractKernel, dim: int) -> None: - - # Datapoint x and datapoint y: - x = jnp.array([[1.0] * dim]) - y = jnp.array([[0.5] * dim]) - - # Defualt parameters: - params = kernel.init_params(_initialise_key) - - # Test calling gives an autocovariance value of no dimension between the inputs: - kxy = kernel(params, x, y) - - assert isinstance(kxy, jax.Array) - assert kxy.shape == () - - -@pytest.mark.parametrize("kern", [RBF, Matern12, Matern32, Matern52]) -@pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) -@pytest.mark.parametrize("n", [1, 2, 5]) -def test_pos_def( - kern: AbstractKernel, dim: int, ell: float, sigma: float, n: int -) -> None: - kern = kern(active_dims=list(range(dim))) - - # Create inputs x: - x = jr.uniform(_initialise_key, (n, dim)) - params = {"lengthscale": jnp.array([ell]), "variance": jnp.array([sigma])} - - # Test gram matrix eigenvalues are positive: - Kxx = kern.gram(params, x) - Kxx += identity(n) * _jitter - eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) - assert (eigen_values > 0.0).all() - - -@pytest.mark.parametrize("kern", [Linear, Polynomial]) -@pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize("shift", [0.0, 0.5, 2.0]) -@pytest.mark.parametrize("sigma", [0.1, 0.2, 0.5]) -@pytest.mark.parametrize("n", [1, 2, 5]) -def test_pos_def_lin_poly( - kern: AbstractKernel, dim: int, shift: float, sigma: float, n: int -) -> None: - kern = kern(active_dims=list(range(dim))) - # Gram constructor static method: - kern.gram - - # Create inputs x: - x = jr.uniform(_initialise_key, (n, dim)) - params = {"variance": jnp.array([sigma]), "shift": jnp.array([shift])} - - # Test gram matrix eigenvalues are positive: - Kxx = kern.gram(params, x) - Kxx += identity(n) * _jitter - eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) - assert (eigen_values > 0.0).all() - - -@pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) -@pytest.mark.parametrize("alpha", [0.1, 0.5, 1.0]) -@pytest.mark.parametrize("n", [1, 2, 5]) -def test_pos_def_rq(dim: int, ell: float, sigma: float, alpha: float, n: int) -> None: - kern = RationalQuadratic(active_dims=list(range(dim))) - # Gram constructor static method: - kern.gram - - # Create inputs x: - x = jr.uniform(_initialise_key, (n, dim)) - params = { - "lengthscale": jnp.array([ell]), - "variance": jnp.array([sigma]), - "alpha": jnp.array([alpha]), - } - - # Test gram matrix eigenvalues are positive: - Kxx = kern.gram(params, x) - Kxx += identity(n) * _jitter - eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) - assert (eigen_values > 0.0).all() - - -@pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) -@pytest.mark.parametrize("power", [0.1, 0.5, 1.0]) -@pytest.mark.parametrize("n", [1, 2, 5]) -def test_pos_def_power_exp( - dim: int, ell: float, sigma: float, power: float, n: int -) -> None: - kern = PoweredExponential(active_dims=list(range(dim))) - # Gram constructor static method: - kern.gram - - # Create inputs x: - x = jr.uniform(_initialise_key, (n, dim)) - params = { - "lengthscale": jnp.array([ell]), - "variance": jnp.array([sigma]), - "power": jnp.array([power]), - } - - # Test gram matrix eigenvalues are positive: - Kxx = kern.gram(params, x) - Kxx += identity(n) * _jitter - eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) - assert (eigen_values > 0.0).all() - - -@pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) -@pytest.mark.parametrize("period", [0.1, 0.5, 1.0]) -@pytest.mark.parametrize("n", [1, 2, 5]) -def test_pos_def_periodic( - dim: int, ell: float, sigma: float, period: float, n: int -) -> None: - kern = Periodic(active_dims=list(range(dim))) - # Gram constructor static method: - kern.gram - - # Create inputs x: - x = jr.uniform(_initialise_key, (n, dim)) - params = { - "lengthscale": jnp.array([ell]), - "variance": jnp.array([sigma]), - "period": jnp.array([period]), - } - - # Test gram matrix eigenvalues are positive: - Kxx = kern.gram(params, x) - Kxx += identity(n) * _jitter - eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) - assert (eigen_values > 0.0).all() - - -@pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52]) -@pytest.mark.parametrize("dim", [None, 1, 2, 5, 10]) -def test_initialisation(kernel: AbstractKernel, dim: int) -> None: - - if dim is None: - kern = kernel() - assert kern.ndims == 1 - - else: - kern = kernel(active_dims=[i for i in range(dim)]) - params = kern.init_params(_initialise_key) - - assert list(params.keys()) == ["lengthscale", "variance"] - assert all(params["lengthscale"] == jnp.array([1.0] * dim)) - assert params["variance"] == jnp.array([1.0]) - - if dim > 1: - assert kern.ard - else: - assert not kern.ard - - -@pytest.mark.parametrize( - "kernel", - [ - RBF, - Matern12, - Matern32, - Matern52, - Linear, - Polynomial, - RationalQuadratic, - PoweredExponential, - Periodic, - ], -) -def test_dtype(kernel: AbstractKernel) -> None: - parameter_state = initialise(kernel(), _initialise_key) - params, *_ = parameter_state.unpack() - for k, v in params.items(): - assert v.dtype == jnp.float64 - assert isinstance(k, str) - - -@pytest.mark.parametrize("degree", [1, 2, 3]) -@pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize("variance", [0.1, 1.0, 2.0]) -@pytest.mark.parametrize("shift", [1e-6, 0.1, 1.0]) -@pytest.mark.parametrize("n", [1, 2, 5]) -def test_polynomial( - degree: int, dim: int, variance: float, shift: float, n: int -) -> None: - - # Define inputs - x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim) - - # Define kernel - kern = Polynomial(degree=degree, active_dims=[i for i in range(dim)]) - - # Check name - assert kern.name == f"Polynomial Degree: {degree}" - - # Initialise parameters - params = kern.init_params(_initialise_key) - params["shift"] * shift - params["variance"] * variance - - # Check parameter keys - assert list(params.keys()) == ["shift", "variance"] - - # Compute gram matrix - Kxx = kern.gram(params, x) - - # Check shapes - assert Kxx.shape[0] == x.shape[0] - assert Kxx.shape[0] == Kxx.shape[1] - - # Test positive definiteness - Kxx += identity(n) * _jitter - eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) - assert (eigen_values > 0).all() - - -@pytest.mark.parametrize( - "kernel", - [RBF, Matern12, Matern32, Matern52, Linear, Polynomial, RationalQuadratic], -) -def test_active_dim(kernel: AbstractKernel) -> None: - dim_list = [0, 1, 2, 3] - perm_length = 2 - dim_pairs = list(permutations(dim_list, r=perm_length)) - n_dims = len(dim_list) - - # Generate random inputs - x = jr.normal(_initialise_key, shape=(20, n_dims)) - - for dp in dim_pairs: - # Take slice of x - slice = x[..., dp] - - # Define kernels - ad_kern = kernel(active_dims=dp) - manual_kern = kernel(active_dims=[i for i in range(perm_length)]) - - # Get initial parameters - ad_params = ad_kern.init_params(_initialise_key) - manual_params = manual_kern.init_params(_initialise_key) - - # Compute gram matrices - ad_Kxx = ad_kern.gram(ad_params, x) - manual_Kxx = manual_kern.gram(manual_params, slice) - - # Test gram matrices are equal - assert jnp.all(ad_Kxx.to_dense() == manual_Kxx.to_dense()) - - -@pytest.mark.parametrize("combination_type", [SumKernel, ProductKernel]) -@pytest.mark.parametrize( - "kernel", - [RBF, RationalQuadratic, Linear, Matern12, Matern32, Matern52, Polynomial], -) -@pytest.mark.parametrize("n_kerns", [2, 3, 4]) -def test_combination_kernel( - combination_type: CombinationKernel, kernel: AbstractKernel, n_kerns: int -) -> None: - - # Create inputs - n = 20 - x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1) - - # Create list of kernels - kernel_set = [kernel() for _ in range(n_kerns)] - - # Create combination kernel - combination_kernel = combination_type(kernel_set=kernel_set) - - # Initialise default parameters - params = combination_kernel.init_params(_initialise_key) - - # Check params are a list of dictionaries - assert len(params) == n_kerns - - for p in params: - assert isinstance(p, dict) - - # Check combination kernel set - assert len(combination_kernel.kernel_set) == n_kerns - assert isinstance(combination_kernel.kernel_set, list) - assert isinstance(combination_kernel.kernel_set[0], AbstractKernel) - - # Compute gram matrix - Kxx = combination_kernel.gram(params, x) - - # Check shapes - assert Kxx.shape[0] == Kxx.shape[1] - assert Kxx.shape[1] == n - - # Check positive definiteness - Kxx += identity(n) * _jitter - eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) - assert (eigen_values > 0).all() - - -@pytest.mark.parametrize( - "k1", [RBF(), Matern12(), Matern32(), Matern52(), Polynomial()] -) -@pytest.mark.parametrize( - "k2", [RBF(), Matern12(), Matern32(), Matern52(), Polynomial()] -) -def test_sum_kern_value(k1: AbstractKernel, k2: AbstractKernel) -> None: - # Create inputs - n = 10 - x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1) - - # Create sum kernel - sum_kernel = SumKernel(kernel_set=[k1, k2]) - - # Initialise default parameters - params = sum_kernel.init_params(_initialise_key) - - # Compute gram matrix - Kxx = sum_kernel.gram(params, x) - - # NOW we do the same thing manually and check they are equal: - # Initialise default parameters - k1_params = k1.init_params(_initialise_key) - k2_params = k2.init_params(_initialise_key) - - # Compute gram matrix - Kxx_k1 = k1.gram(k1_params, x) - Kxx_k2 = k2.gram(k2_params, x) - - # Check manual and automatic gram matrices are equal - assert jnp.all(Kxx.to_dense() == Kxx_k1.to_dense() + Kxx_k2.to_dense()) - - -@pytest.mark.parametrize( - "k1", - [ - RBF(), - Matern12(), - Matern32(), - Matern52(), - Polynomial(), - Linear(), - Polynomial(), - RationalQuadratic(), - ], -) -@pytest.mark.parametrize( - "k2", - [ - RBF(), - Matern12(), - Matern32(), - Matern52(), - Polynomial(), - Linear(), - Polynomial(), - RationalQuadratic(), - ], -) -def test_prod_kern_value(k1: AbstractKernel, k2: AbstractKernel) -> None: - - # Create inputs - n = 10 - x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1) - - # Create product kernel - prod_kernel = ProductKernel(kernel_set=[k1, k2]) - - # Initialise default parameters - params = prod_kernel.init_params(_initialise_key) - - # Compute gram matrix - Kxx = prod_kernel.gram(params, x) - - # NOW we do the same thing manually and check they are equal: - - # Initialise default parameters - k1_params = k1.init_params(_initialise_key) - k2_params = k2.init_params(_initialise_key) - - # Compute gram matrix - Kxx_k1 = k1.gram(k1_params, x) - Kxx_k2 = k2.gram(k2_params, x) - - # Check manual and automatic gram matrices are equal - assert jnp.all(Kxx.to_dense() == Kxx_k1.to_dense() * Kxx_k2.to_dense()) - - -def test_graph_kernel(): - # Create a random graph, G, and verice labels, x, - n_verticies = 20 - n_edges = 40 - G = nx.gnm_random_graph(n_verticies, n_edges, seed=123) - x = jnp.arange(n_verticies).reshape(-1, 1) - - # Compute graph laplacian - L = nx.laplacian_matrix(G).toarray() + jnp.eye(n_verticies) * 1e-12 - - # Create graph kernel - kern = GraphKernel(laplacian=L) - assert isinstance(kern, GraphKernel) - assert kern.num_vertex == n_verticies - assert kern.evals.shape == (n_verticies, 1) - assert kern.evecs.shape == (n_verticies, n_verticies) - - # Unpack kernel computation - kern.gram - - # Initialise default parameters - params = kern.init_params(_initialise_key) - assert isinstance(params, dict) - assert list(sorted(list(params.keys()))) == [ - "lengthscale", - "smoothness", - "variance", - ] - - # Compute gram matrix - Kxx = kern.gram(params, x) - assert Kxx.shape == (n_verticies, n_verticies) - - # Check positive definiteness - Kxx += identity(n_verticies) * _jitter - eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) - assert all(eigen_values > 0) - - -@pytest.mark.parametrize( - "kernel", - [RBF, Matern12, Matern32, Matern52, Polynomial, Linear, RationalQuadratic], -) -def test_combination_kernel_type(kernel: AbstractKernel) -> None: - prod_kern = kernel() * kernel() - assert isinstance(prod_kern, ProductKernel) - assert isinstance(prod_kern, CombinationKernel) - - add_kern = kernel() + kernel() - assert isinstance(add_kern, SumKernel) - assert isinstance(add_kern, CombinationKernel) diff --git a/tests/test_non_euclidean.py b/tests/test_non_euclidean.py new file mode 100644 index 0000000..cc3df37 --- /dev/null +++ b/tests/test_non_euclidean.py @@ -0,0 +1,51 @@ +import jax.numpy as jnp +import jax.random as jr +import networkx as nx +from jax.config import config +from jaxlinop import identity + +from jaxkern.non_euclidean import GraphKernel + +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) +_initialise_key = jr.PRNGKey(123) +_jitter = 1e-6 + + +def test_graph_kernel(): + # Create a random graph, G, and verice labels, x, + n_verticies = 20 + n_edges = 40 + G = nx.gnm_random_graph(n_verticies, n_edges, seed=123) + x = jnp.arange(n_verticies).reshape(-1, 1) + + # Compute graph laplacian + L = nx.laplacian_matrix(G).toarray() + jnp.eye(n_verticies) * 1e-12 + + # Create graph kernel + kern = GraphKernel(laplacian=L) + assert isinstance(kern, GraphKernel) + assert kern.num_vertex == n_verticies + assert kern.evals.shape == (n_verticies, 1) + assert kern.evecs.shape == (n_verticies, n_verticies) + + # Unpack kernel computation + kern.gram + + # Initialise default parameters + params = kern.init_params(_initialise_key) + assert isinstance(params, dict) + assert list(sorted(list(params.keys()))) == [ + "lengthscale", + "smoothness", + "variance", + ] + + # Compute gram matrix + Kxx = kern.gram(params, x) + assert Kxx.shape == (n_verticies, n_verticies) + + # Check positive definiteness + Kxx += identity(n_verticies) * _jitter + eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) + assert all(eigen_values > 0) diff --git a/tests/test_nonstationary.py b/tests/test_nonstationary.py new file mode 100644 index 0000000..50b44c6 --- /dev/null +++ b/tests/test_nonstationary.py @@ -0,0 +1,231 @@ +# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +from itertools import permutations +from typing import Dict, List + +import jax +import jax.numpy as jnp +import jax.random as jr +import pytest +from jax.config import config +from jax.random import KeyArray +from jaxlinop import LinearOperator, identity +from jaxtyping import Array, Float +from jaxutils.parameters import initialise + +from jaxkern.base import AbstractKernel +from jaxkern.nonstationary import ( + Linear, + Periodic, + Polynomial, +) + +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) +_initialise_key = jr.PRNGKey(123) +_jitter = 1e-6 + + +@pytest.mark.parametrize( + "kernel", + [ + Linear(), + Polynomial(), + ], +) +@pytest.mark.parametrize("dim", [1, 2, 5]) +@pytest.mark.parametrize("n", [1, 2, 10]) +def test_gram(kernel: AbstractKernel, dim: int, n: int) -> None: + + # Gram constructor static method: + kernel.gram + + # Inputs x: + x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim) + + # Default kernel parameters: + params = kernel.init_params(_initialise_key) + + # Test gram matrix: + Kxx = kernel.gram(params, x) + assert isinstance(Kxx, LinearOperator) + assert Kxx.shape == (n, n) + + +@pytest.mark.parametrize( + "kernel", + [ + Linear(), + Polynomial(), + ], +) +@pytest.mark.parametrize("num_a", [1, 2, 5]) +@pytest.mark.parametrize("num_b", [1, 2, 5]) +@pytest.mark.parametrize("dim", [1, 2, 5]) +def test_cross_covariance( + kernel: AbstractKernel, num_a: int, num_b: int, dim: int +) -> None: + # Inputs a, b: + a = jnp.linspace(-1.0, 1.0, num_a * dim).reshape(num_a, dim) + b = jnp.linspace(3.0, 4.0, num_b * dim).reshape(num_b, dim) + + # Default kernel parameters: + params = kernel.init_params(_initialise_key) + + # Test cross covariance, Kab: + Kab = kernel.cross_covariance(params, a, b) + assert isinstance(Kab, jnp.ndarray) + assert Kab.shape == (num_a, num_b) + + +@pytest.mark.parametrize("kern", [Linear, Polynomial]) +@pytest.mark.parametrize("dim", [1, 2, 5]) +@pytest.mark.parametrize("shift", [0.0, 0.5, 2.0]) +@pytest.mark.parametrize("sigma", [0.1, 0.2, 0.5]) +@pytest.mark.parametrize("n", [1, 2, 5]) +def test_pos_def( + kern: AbstractKernel, dim: int, shift: float, sigma: float, n: int +) -> None: + kern = kern(active_dims=list(range(dim))) + # Gram constructor static method: + kern.gram + + # Create inputs x: + x = jr.uniform(_initialise_key, (n, dim)) + params = {"variance": jnp.array([sigma]), "shift": jnp.array([shift])} + + # Test gram matrix eigenvalues are positive: + Kxx = kern.gram(params, x) + Kxx += identity(n) * _jitter + eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) + assert (eigen_values > 0.0).all() + + +@pytest.mark.parametrize("dim", [1, 2, 5]) +@pytest.mark.parametrize( + "ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)] +) +@pytest.mark.parametrize("period", [0.1, 0.5, 1.0]) +@pytest.mark.parametrize("n", [1, 2, 5]) +def test_pos_def_periodic( + dim: int, ell: float, sigma: float, period: float, n: int +) -> None: + kern = Periodic(active_dims=list(range(dim))) + # Gram constructor static method: + kern.gram + + # Create inputs x: + x = jr.uniform(_initialise_key, (n, dim)) + params = { + "lengthscale": jnp.array([ell]), + "variance": jnp.array([sigma]), + "period": jnp.array([period]), + } + + # Test gram matrix eigenvalues are positive: + Kxx = kern.gram(params, x) + Kxx += identity(n) * _jitter + eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) + assert (eigen_values > 0.0).all() + + +@pytest.mark.parametrize( + "kernel", + [ + Linear, + Polynomial, + Periodic, + ], +) +def test_dtype(kernel: AbstractKernel) -> None: + parameter_state = initialise(kernel(), _initialise_key) + params, *_ = parameter_state.unpack() + for k, v in params.items(): + assert v.dtype == jnp.float64 + assert isinstance(k, str) + + +@pytest.mark.parametrize("degree", [1, 2, 3]) +@pytest.mark.parametrize("dim", [1, 2, 5]) +@pytest.mark.parametrize("variance", [0.1, 1.0, 2.0]) +@pytest.mark.parametrize("shift", [1e-6, 0.1, 1.0]) +@pytest.mark.parametrize("n", [1, 2, 5]) +def test_polynomial( + degree: int, dim: int, variance: float, shift: float, n: int +) -> None: + + # Define inputs + x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim) + + # Define kernel + kern = Polynomial(degree=degree, active_dims=[i for i in range(dim)]) + + # Check name + assert kern.name == f"Polynomial Degree: {degree}" + + # Initialise parameters + params = kern.init_params(_initialise_key) + params["shift"] * shift + params["variance"] * variance + + # Check parameter keys + assert list(params.keys()) == ["shift", "variance"] + + # Compute gram matrix + Kxx = kern.gram(params, x) + + # Check shapes + assert Kxx.shape[0] == x.shape[0] + assert Kxx.shape[0] == Kxx.shape[1] + + # Test positive definiteness + Kxx += identity(n) * _jitter + eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) + assert (eigen_values > 0).all() + + +@pytest.mark.parametrize( + "kernel", + [Linear, Polynomial], +) +def test_active_dim(kernel: AbstractKernel) -> None: + dim_list = [0, 1, 2, 3] + perm_length = 2 + dim_pairs = list(permutations(dim_list, r=perm_length)) + n_dims = len(dim_list) + + # Generate random inputs + x = jr.normal(_initialise_key, shape=(20, n_dims)) + + for dp in dim_pairs: + # Take slice of x + slice = x[..., dp] + + # Define kernels + ad_kern = kernel(active_dims=dp) + manual_kern = kernel(active_dims=[i for i in range(perm_length)]) + + # Get initial parameters + ad_params = ad_kern.init_params(_initialise_key) + manual_params = manual_kern.init_params(_initialise_key) + + # Compute gram matrices + ad_Kxx = ad_kern.gram(ad_params, x) + manual_Kxx = manual_kern.gram(manual_params, slice) + + # Test gram matrices are equal + assert jnp.all(ad_Kxx.to_dense() == manual_Kxx.to_dense()) diff --git a/tests/test_stationary.py b/tests/test_stationary.py new file mode 100644 index 0000000..676fbde --- /dev/null +++ b/tests/test_stationary.py @@ -0,0 +1,272 @@ +# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +from itertools import permutations +from typing import Dict, List + +import jax +import jax.numpy as jnp +import jax.random as jr +import pytest +from jax.config import config +from jax.random import KeyArray +from jaxlinop import LinearOperator, identity +from jaxtyping import Array, Float +from jaxutils.parameters import initialise + +from jaxkern.base import AbstractKernel +from jaxkern.stationary import ( + RBF, + Matern12, + Matern32, + Matern52, + PoweredExponential, + RationalQuadratic, +) +from jaxkern.utils import euclidean_distance + +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) +_initialise_key = jr.PRNGKey(123) +_jitter = 1e-6 + + +@pytest.mark.parametrize( + "kernel", + [ + RBF(), + Matern12(), + Matern32(), + Matern52(), + RationalQuadratic(), + ], +) +@pytest.mark.parametrize("dim", [1, 2, 5]) +@pytest.mark.parametrize("n", [1, 2, 10]) +def test_gram(kernel: AbstractKernel, dim: int, n: int) -> None: + + # Gram constructor static method: + kernel.gram + + # Inputs x: + x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim) + + # Default kernel parameters: + params = kernel.init_params(_initialise_key) + + # Test gram matrix: + Kxx = kernel.gram(params, x) + assert isinstance(Kxx, LinearOperator) + assert Kxx.shape == (n, n) + + +@pytest.mark.parametrize( + "kernel", + [ + RBF(), + Matern12(), + Matern32(), + Matern52(), + RationalQuadratic(), + ], +) +@pytest.mark.parametrize("num_a", [1, 2, 5]) +@pytest.mark.parametrize("num_b", [1, 2, 5]) +@pytest.mark.parametrize("dim", [1, 2, 5]) +def test_cross_covariance( + kernel: AbstractKernel, num_a: int, num_b: int, dim: int +) -> None: + # Inputs a, b: + a = jnp.linspace(-1.0, 1.0, num_a * dim).reshape(num_a, dim) + b = jnp.linspace(3.0, 4.0, num_b * dim).reshape(num_b, dim) + + # Default kernel parameters: + params = kernel.init_params(_initialise_key) + + # Test cross covariance, Kab: + Kab = kernel.cross_covariance(params, a, b) + assert isinstance(Kab, jnp.ndarray) + assert Kab.shape == (num_a, num_b) + + +@pytest.mark.parametrize("kernel", [RBF(), Matern12(), Matern32(), Matern52()]) +@pytest.mark.parametrize("dim", [1, 2, 5]) +def test_call(kernel: AbstractKernel, dim: int) -> None: + + # Datapoint x and datapoint y: + x = jnp.array([[1.0] * dim]) + y = jnp.array([[0.5] * dim]) + + # Defualt parameters: + params = kernel.init_params(_initialise_key) + + # Test calling gives an autocovariance value of no dimension between the inputs: + kxy = kernel(params, x, y) + + assert isinstance(kxy, jax.Array) + assert kxy.shape == () + + +@pytest.mark.parametrize("kern", [RBF, Matern12, Matern32, Matern52]) +@pytest.mark.parametrize("dim", [1, 2, 5]) +@pytest.mark.parametrize( + "ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)] +) +@pytest.mark.parametrize("n", [1, 2, 5]) +def test_pos_def( + kern: AbstractKernel, dim: int, ell: float, sigma: float, n: int +) -> None: + kern = kern(active_dims=list(range(dim))) + + # Create inputs x: + x = jr.uniform(_initialise_key, (n, dim)) + params = {"lengthscale": jnp.array([ell]), "variance": jnp.array([sigma])} + + # Test gram matrix eigenvalues are positive: + Kxx = kern.gram(params, x) + Kxx += identity(n) * _jitter + eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) + assert (eigen_values > 0.0).all() + + +@pytest.mark.parametrize("dim", [1, 2, 5]) +@pytest.mark.parametrize( + "ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)] +) +@pytest.mark.parametrize("alpha", [0.1, 0.5, 1.0]) +@pytest.mark.parametrize("n", [1, 2, 5]) +def test_pos_def_rq( + dim: int, ell: float, sigma: float, alpha: float, n: int +) -> None: + kern = RationalQuadratic(active_dims=list(range(dim))) + # Gram constructor static method: + kern.gram + + # Create inputs x: + x = jr.uniform(_initialise_key, (n, dim)) + params = { + "lengthscale": jnp.array([ell]), + "variance": jnp.array([sigma]), + "alpha": jnp.array([alpha]), + } + + # Test gram matrix eigenvalues are positive: + Kxx = kern.gram(params, x) + Kxx += identity(n) * _jitter + eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) + assert (eigen_values > 0.0).all() + + +@pytest.mark.parametrize("dim", [1, 2, 5]) +@pytest.mark.parametrize( + "ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)] +) +@pytest.mark.parametrize("power", [0.1, 0.5, 1.0]) +@pytest.mark.parametrize("n", [1, 2, 5]) +def test_pos_def_power_exp( + dim: int, ell: float, sigma: float, power: float, n: int +) -> None: + kern = PoweredExponential(active_dims=list(range(dim))) + # Gram constructor static method: + kern.gram + + # Create inputs x: + x = jr.uniform(_initialise_key, (n, dim)) + params = { + "lengthscale": jnp.array([ell]), + "variance": jnp.array([sigma]), + "power": jnp.array([power]), + } + + # Test gram matrix eigenvalues are positive: + Kxx = kern.gram(params, x) + Kxx += identity(n) * _jitter + eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) + assert (eigen_values > 0.0).all() + + +@pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52]) +@pytest.mark.parametrize("dim", [None, 1, 2, 5, 10]) +def test_initialisation(kernel: AbstractKernel, dim: int) -> None: + + if dim is None: + kern = kernel() + assert kern.ndims == 1 + + else: + kern = kernel(active_dims=[i for i in range(dim)]) + params = kern.init_params(_initialise_key) + + assert list(params.keys()) == ["lengthscale", "variance"] + assert all(params["lengthscale"] == jnp.array([1.0] * dim)) + assert params["variance"] == jnp.array([1.0]) + + if dim > 1: + assert kern.ard + else: + assert not kern.ard + + +@pytest.mark.parametrize( + "kernel", + [ + RBF, + Matern12, + Matern32, + Matern52, + RationalQuadratic, + PoweredExponential, + ], +) +def test_dtype(kernel: AbstractKernel) -> None: + parameter_state = initialise(kernel(), _initialise_key) + params, *_ = parameter_state.unpack() + for k, v in params.items(): + assert v.dtype == jnp.float64 + assert isinstance(k, str) + + +@pytest.mark.parametrize( + "kernel", + [RBF, Matern12, Matern32, Matern52, RationalQuadratic], +) +def test_active_dim(kernel: AbstractKernel) -> None: + dim_list = [0, 1, 2, 3] + perm_length = 2 + dim_pairs = list(permutations(dim_list, r=perm_length)) + n_dims = len(dim_list) + + # Generate random inputs + x = jr.normal(_initialise_key, shape=(20, n_dims)) + + for dp in dim_pairs: + # Take slice of x + slice = x[..., dp] + + # Define kernels + ad_kern = kernel(active_dims=dp) + manual_kern = kernel(active_dims=[i for i in range(perm_length)]) + + # Get initial parameters + ad_params = ad_kern.init_params(_initialise_key) + manual_params = manual_kern.init_params(_initialise_key) + + # Compute gram matrices + ad_Kxx = ad_kern.gram(ad_params, x) + manual_Kxx = manual_kern.gram(manual_params, slice) + + # Test gram matrices are equal + assert jnp.all(ad_Kxx.to_dense() == manual_Kxx.to_dense()) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..ed316b4 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,26 @@ +from typing import List + +import jax.numpy as jnp +import pytest +from jaxtyping import Array, Float +from jaxkern.utils import euclidean_distance + + +@pytest.mark.parametrize( + "a, b, distance_to_3dp", + [ + ([1.0], [-4.0], 5.0), + ([1.0, -2.0], [-4.0, 3.0], 7.071), + ([1.0, 2.0, 3.0], [1.0, 1.0, 1.0], 2.236), + ], +) +def test_euclidean_distance( + a: List[float], b: List[float], distance_to_3dp: float +) -> None: + + # Convert lists to JAX arrays: + a: Float[Array, "D"] = jnp.array(a) + b: Float[Array, "D"] = jnp.array(b) + + # Test distance is correct to 3dp: + assert jnp.round(euclidean_distance(a, b), 3) == distance_to_3dp From b1ead5cf4c21332bb483052aa112eb125579e459 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Tue, 17 Jan 2023 21:33:00 +0000 Subject: [PATCH 2/8] Autoflake to pre-commit --- .pre-commit-config.yaml | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f4a0bc5..67f3455 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,4 +18,14 @@ repos: language: system types: [python] require_serial: true - + - id: autoflake + name: autoflake + entry: autoflake + language: python + "types": [python] + require_serial: true + args: + - "--in-place" + - "--expand-star-imports" + - "--remove-all-unused-imports" + - "--remove-unused-variables" \ No newline at end of file From e9f75ed43b89eaa3d5a9244e0d9931da1f1ad357 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Tue, 17 Jan 2023 21:38:59 +0000 Subject: [PATCH 3/8] Add black to pre-commit --- .pre-commit-config.yaml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 67f3455..8070a24 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,4 +28,8 @@ repos: - "--in-place" - "--expand-star-imports" - "--remove-all-unused-imports" - - "--remove-unused-variables" \ No newline at end of file + - "--remove-unused-variables" + - repo: https://github.com/psf/black + rev: 22.3.0 + hooks: + - id: black \ No newline at end of file From 61ecb7cbc4206bcfafb0a1ba13d9f46505f116bf Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Wed, 18 Jan 2023 08:16:17 +0000 Subject: [PATCH 4/8] Remove typed --- jaxkern/py.typed | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 jaxkern/py.typed diff --git a/jaxkern/py.typed b/jaxkern/py.typed deleted file mode 100644 index e69de29..0000000 From ffe7a1466585c34a118d85ba99f65121b3f68395 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Wed, 18 Jan 2023 08:45:36 +0000 Subject: [PATCH 5/8] Re-modulise --- jaxkern/__init__.py | 8 +- jaxkern/computations.py | 282 ---------------- jaxkern/computations/__init__.py | 13 + jaxkern/computations/base.py | 100 ++++++ jaxkern/computations/constant_diagonal.py | 68 ++++ jaxkern/computations/dense.py | 38 +++ jaxkern/computations/diagonal.py | 45 +++ jaxkern/computations/eigen.py | 56 ++++ jaxkern/non_euclidean/__init__.py | 0 .../graph.py} | 8 +- jaxkern/non_euclidean/utils.py | 3 + jaxkern/nonstationary.py | 208 ------------ jaxkern/nonstationary/__init__.py | 5 + jaxkern/nonstationary/linear.py | 50 +++ jaxkern/nonstationary/polynomial.py | 55 ++++ jaxkern/nonstationary/white.py | 44 +++ jaxkern/stationary.py | 304 ------------------ jaxkern/stationary/__init__.py | 17 + jaxkern/stationary/matern12.py | 56 ++++ jaxkern/stationary/matern32.py | 62 ++++ jaxkern/stationary/matern52.py | 59 ++++ jaxkern/stationary/periodic.py | 57 ++++ jaxkern/stationary/powered_exponential.py | 57 ++++ jaxkern/stationary/rational_quadratic.py | 52 +++ jaxkern/stationary/rbf.py | 56 ++++ jaxkern/{ => stationary}/utils.py | 5 - setup.cfg | 5 + tests/test_nonstationary.py | 40 +-- tests/test_stationary.py | 48 ++- tests/test_utils.py | 2 +- 30 files changed, 939 insertions(+), 864 deletions(-) delete mode 100644 jaxkern/computations.py create mode 100644 jaxkern/computations/__init__.py create mode 100644 jaxkern/computations/base.py create mode 100644 jaxkern/computations/constant_diagonal.py create mode 100644 jaxkern/computations/dense.py create mode 100644 jaxkern/computations/diagonal.py create mode 100644 jaxkern/computations/eigen.py create mode 100644 jaxkern/non_euclidean/__init__.py rename jaxkern/{non_euclidean.py => non_euclidean/graph.py} (91%) create mode 100644 jaxkern/non_euclidean/utils.py delete mode 100644 jaxkern/nonstationary.py create mode 100644 jaxkern/nonstationary/__init__.py create mode 100644 jaxkern/nonstationary/linear.py create mode 100644 jaxkern/nonstationary/polynomial.py create mode 100644 jaxkern/nonstationary/white.py delete mode 100644 jaxkern/stationary.py create mode 100644 jaxkern/stationary/__init__.py create mode 100644 jaxkern/stationary/matern12.py create mode 100644 jaxkern/stationary/matern32.py create mode 100644 jaxkern/stationary/matern52.py create mode 100644 jaxkern/stationary/periodic.py create mode 100644 jaxkern/stationary/powered_exponential.py create mode 100644 jaxkern/stationary/rational_quadratic.py create mode 100644 jaxkern/stationary/rbf.py rename jaxkern/{ => stationary}/utils.py (85%) diff --git a/jaxkern/__init__.py b/jaxkern/__init__.py index b1decac..554ff0e 100644 --- a/jaxkern/__init__.py +++ b/jaxkern/__init__.py @@ -6,17 +6,14 @@ DiagonalKernelComputation, EigenKernelComputation, ) -from .nonstationary import ( - Linear, - Periodic, - Polynomial, -) +from .nonstationary import Linear, Polynomial, White from .stationary import ( RBF, Matern12, Matern32, Matern52, RationalQuadratic, + Periodic, PoweredExponential, ) from .non_euclidean import GraphKernel @@ -38,6 +35,7 @@ "PoweredExponential", "Periodic", "RationalQuadratic", + "White", ] from . import _version diff --git a/jaxkern/computations.py b/jaxkern/computations.py deleted file mode 100644 index cfda975..0000000 --- a/jaxkern/computations.py +++ /dev/null @@ -1,282 +0,0 @@ -import abc -from typing import Callable, Dict - -import jax.numpy as jnp -from jax import vmap -from jaxlinop import ( - ConstantDiagonalLinearOperator, - DenseLinearOperator, - DiagonalLinearOperator, - LinearOperator, -) -from jaxtyping import Array, Float -from jaxutils import PyTree - - -class AbstractKernelComputation(PyTree): - """Abstract class for kernel computations.""" - - def __init__( - self, - kernel_fn: Callable[ - [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array - ] = None, - ) -> None: - self._kernel_fn = kernel_fn - - @property - def kernel_fn( - self, - ) -> Callable[[Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array]: - return self._kernel_fn - - @kernel_fn.setter - def kernel_fn( - self, - kernel_fn: Callable[[Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array], - ) -> None: - self._kernel_fn = kernel_fn - - def gram( - self, - params: Dict, - inputs: Float[Array, "N D"], - ) -> LinearOperator: - - """Compute Gram covariance operator of the kernel function. - - Args: - kernel (AbstractKernel): The kernel function to be evaluated. - params (Dict): The parameters of the kernel function. - inputs (Float[Array, "N N"]): The inputs to the kernel function. - - Returns: - LinearOperator: Gram covariance operator of the kernel function. - """ - - matrix = self.cross_covariance(params, inputs, inputs) - - return DenseLinearOperator(matrix=matrix) - - @abc.abstractmethod - def cross_covariance( - self, - params: Dict, - x: Float[Array, "N D"], - y: Float[Array, "M D"], - ) -> Float[Array, "N M"]: - """For a given kernel, compute the NxM gram matrix on an a pair - of input matrices with shape NxD and MxD. - - Args: - kernel (AbstractKernel): The kernel for which the cross-covariance - matrix should be computed for. - params (Dict): The kernel's parameter set. - x (Float[Array,"N D"]): The first input matrix. - y (Float[Array,"M D"]): The second input matrix. - - Returns: - Float[Array, "N M"]: The computed square Gram matrix. - """ - raise NotImplementedError - - def diagonal( - self, - params: Dict, - inputs: Float[Array, "N D"], - ) -> DiagonalLinearOperator: - """For a given kernel, compute the elementwise diagonal of the - NxN gram matrix on an input matrix of shape NxD. - - Args: - kernel (AbstractKernel): The kernel for which the variance - vector should be computed for. - params (Dict): The kernel's parameter set. - inputs (Float[Array, "N D"]): The input matrix. - - Returns: - LinearOperator: The computed diagonal variance entries. - """ - diag = vmap(lambda x: self._kernel_fn(params, x, x))(inputs) - - return DiagonalLinearOperator(diag=diag) - - -class DenseKernelComputation(AbstractKernelComputation): - """Dense kernel computation class. Operations with the kernel assume - a dense gram matrix structure. - """ - - def __init__( - self, - kernel_fn: Callable[ - [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array - ] = None, - ) -> None: - super().__init__(kernel_fn) - - def cross_covariance( - self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] - ) -> Float[Array, "N M"]: - """For a given kernel, compute the NxM covariance matrix on a pair of input - matrices of shape NxD and MxD. - - Args: - kernel (AbstractKernel): The kernel for which the Gram - matrix should be computed for. - params (Dict): The kernel's parameter set. - x (Float[Array,"N D"]): The input matrix. - y (Float[Array,"M D"]): The input matrix. - - Returns: - CovarianceOperator: The computed square Gram matrix. - """ - cross_cov = vmap(lambda x: vmap(lambda y: self.kernel_fn(params, x, y))(y))(x) - return cross_cov - - -class DiagonalKernelComputation(AbstractKernelComputation): - def __init__( - self, - kernel_fn: Callable[ - [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array - ] = None, - ) -> None: - super().__init__(kernel_fn) - - def gram( - self, - params: Dict, - inputs: Float[Array, "N D"], - ) -> DiagonalLinearOperator: - """For a kernel with diagonal structure, compute the NxN gram matrix on - an input matrix of shape NxD. - - Args: - kernel (AbstractKernel): The kernel for which the Gram matrix - should be computed for. - params (Dict): The kernel's parameter set. - inputs (Float[Array, "N D"]): The input matrix. - - Returns: - CovarianceOperator: The computed square Gram matrix. - """ - - diag = vmap(lambda x: self.kernel_fn(params, x, x))(inputs) - - return DiagonalLinearOperator(diag=diag) - - def cross_covariance( - self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] - ) -> Float[Array, "N M"]: - raise ValueError("Cross covariance not defined for diagonal kernels.") - - -class ConstantDiagonalKernelComputation(AbstractKernelComputation): - def __init__( - self, - kernel_fn: Callable[ - [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array - ] = None, - ) -> None: - super().__init__(kernel_fn) - - def gram( - self, - params: Dict, - inputs: Float[Array, "N D"], - ) -> ConstantDiagonalLinearOperator: - """For a kernel with diagonal structure, compute the NxN gram matrix on - an input matrix of shape NxD. - - Args: - kernel (AbstractKernel): The kernel for which the Gram matrix - should be computed for. - params (Dict): The kernel's parameter set. - inputs (Float[Array, "N D"]): The input matrix. - - Returns: - CovarianceOperator: The computed square Gram matrix. - """ - - value = self.kernel_fn(params, inputs[0], inputs[0]) - - return ConstantDiagonalLinearOperator(value=value, size=inputs.shape[0]) - - def diagonal( - self, - params: Dict, - inputs: Float[Array, "N D"], - ) -> DiagonalLinearOperator: - """For a given kernel, compute the elementwise diagonal of the - NxN gram matrix on an input matrix of shape NxD. - - Args: - kernel (AbstractKernel): The kernel for which the variance - vector should be computed for. - params (Dict): The kernel's parameter set. - inputs (Float[Array, "N D"]): The input matrix. - - Returns: - LinearOperator: The computed diagonal variance entries. - """ - - diag = vmap(lambda x: self.kernel_fn(params, x, x))(inputs) - - return DiagonalLinearOperator(diag=diag) - - def cross_covariance( - self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] - ) -> Float[Array, "N M"]: - raise ValueError("Cross covariance not defined for constant diagonal kernels.") - - -class EigenKernelComputation(AbstractKernelComputation): - def __init__( - self, - kernel_fn: Callable[ - [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array - ] = None, - ) -> None: - super().__init__(kernel_fn) - self._eigenvalues = None - self._eigenvectors = None - self._num_verticies = None - - # Define an eigenvalue setter and getter property - @property - def eigensystem(self) -> Float[Array, "N"]: - return self._eigenvalues, self._eigenvectors, self._num_verticies - - @eigensystem.setter - def eigensystem( - self, eigenvalues: Float[Array, "N"], eigenvectors: Float[Array, "N N"] - ) -> None: - self._eigenvalues = eigenvalues - self._eigenvectors = eigenvectors - - @property - def num_vertex(self) -> int: - return self._num_verticies - - @num_vertex.setter - def num_vertex(self, num_vertex: int) -> None: - self._num_verticies = num_vertex - - def _compute_S(self, params): - evals, evecs = self.eigensystem - S = jnp.power( - evals - + 2 * params["smoothness"] / params["lengthscale"] / params["lengthscale"], - -params["smoothness"], - ) - S = jnp.multiply(S, self.num_vertex / jnp.sum(S)) - S = jnp.multiply(S, params["variance"]) - return S - - def cross_covariance( - self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] - ) -> Float[Array, "N M"]: - S = self._compute_S(params=params) - matrix = self.kernel_fn(params, x, y, S=S) - return matrix diff --git a/jaxkern/computations/__init__.py b/jaxkern/computations/__init__.py new file mode 100644 index 0000000..7bf1605 --- /dev/null +++ b/jaxkern/computations/__init__.py @@ -0,0 +1,13 @@ +from .base import AbstractKernelComputation +from .constant_diagonal import ConstantDiagonalKernelComputation +from .dense import DenseKernelComputation +from .diagonal import DiagonalKernelComputation +from .eigen import EigenKernelComputation + +__all__ = [ + "AbstractKernelComputation", + "ConstantDiagonalKernelComputation", + "DenseKernelComputation", + "DiagonalKernelComputation", + "EigenKernelComputation", +] diff --git a/jaxkern/computations/base.py b/jaxkern/computations/base.py new file mode 100644 index 0000000..9984778 --- /dev/null +++ b/jaxkern/computations/base.py @@ -0,0 +1,100 @@ +import abc +from typing import Callable, Dict + +from jax import vmap +from jaxlinop import ( + DenseLinearOperator, + DiagonalLinearOperator, + LinearOperator, +) +from jaxtyping import Array, Float +from jaxutils import PyTree + + +class AbstractKernelComputation(PyTree): + """Abstract class for kernel computations.""" + + def __init__( + self, + kernel_fn: Callable[ + [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array + ] = None, + ) -> None: + self._kernel_fn = kernel_fn + + @property + def kernel_fn( + self, + ) -> Callable[[Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array]: + return self._kernel_fn + + @kernel_fn.setter + def kernel_fn( + self, + kernel_fn: Callable[[Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array], + ) -> None: + self._kernel_fn = kernel_fn + + def gram( + self, + params: Dict, + inputs: Float[Array, "N D"], + ) -> LinearOperator: + + """Compute Gram covariance operator of the kernel function. + + Args: + kernel (AbstractKernel): The kernel function to be evaluated. + params (Dict): The parameters of the kernel function. + inputs (Float[Array, "N N"]): The inputs to the kernel function. + + Returns: + LinearOperator: Gram covariance operator of the kernel function. + """ + + matrix = self.cross_covariance(params, inputs, inputs) + + return DenseLinearOperator(matrix=matrix) + + @abc.abstractmethod + def cross_covariance( + self, + params: Dict, + x: Float[Array, "N D"], + y: Float[Array, "M D"], + ) -> Float[Array, "N M"]: + """For a given kernel, compute the NxM gram matrix on an a pair + of input matrices with shape NxD and MxD. + + Args: + kernel (AbstractKernel): The kernel for which the cross-covariance + matrix should be computed for. + params (Dict): The kernel's parameter set. + x (Float[Array,"N D"]): The first input matrix. + y (Float[Array,"M D"]): The second input matrix. + + Returns: + Float[Array, "N M"]: The computed square Gram matrix. + """ + raise NotImplementedError + + def diagonal( + self, + params: Dict, + inputs: Float[Array, "N D"], + ) -> DiagonalLinearOperator: + """For a given kernel, compute the elementwise diagonal of the + NxN gram matrix on an input matrix of shape NxD. + + Args: + kernel (AbstractKernel): The kernel for which the variance + vector should be computed for. + params (Dict): The kernel's parameter set. + inputs (Float[Array, "N D"]): The input matrix. + + Returns: + LinearOperator: The computed diagonal variance entries. + """ + diag = vmap(lambda x: self._kernel_fn(params, x, x))(inputs) + + return DiagonalLinearOperator(diag=diag) diff --git a/jaxkern/computations/constant_diagonal.py b/jaxkern/computations/constant_diagonal.py new file mode 100644 index 0000000..f164565 --- /dev/null +++ b/jaxkern/computations/constant_diagonal.py @@ -0,0 +1,68 @@ +from typing import Callable, Dict + +from jax import vmap +from jaxlinop import ( + ConstantDiagonalLinearOperator, + DiagonalLinearOperator, +) +from jaxtyping import Array, Float +from .base import AbstractKernelComputation + + +class ConstantDiagonalKernelComputation(AbstractKernelComputation): + def __init__( + self, + kernel_fn: Callable[ + [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array + ] = None, + ) -> None: + super().__init__(kernel_fn) + + def gram( + self, + params: Dict, + inputs: Float[Array, "N D"], + ) -> ConstantDiagonalLinearOperator: + """For a kernel with diagonal structure, compute the NxN gram matrix on + an input matrix of shape NxD. + + Args: + kernel (AbstractKernel): The kernel for which the Gram matrix + should be computed for. + params (Dict): The kernel's parameter set. + inputs (Float[Array, "N D"]): The input matrix. + + Returns: + CovarianceOperator: The computed square Gram matrix. + """ + + value = self.kernel_fn(params, inputs[0], inputs[0]) + + return ConstantDiagonalLinearOperator(value=value, size=inputs.shape[0]) + + def diagonal( + self, + params: Dict, + inputs: Float[Array, "N D"], + ) -> DiagonalLinearOperator: + """For a given kernel, compute the elementwise diagonal of the + NxN gram matrix on an input matrix of shape NxD. + + Args: + kernel (AbstractKernel): The kernel for which the variance + vector should be computed for. + params (Dict): The kernel's parameter set. + inputs (Float[Array, "N D"]): The input matrix. + + Returns: + LinearOperator: The computed diagonal variance entries. + """ + + diag = vmap(lambda x: self.kernel_fn(params, x, x))(inputs) + + return DiagonalLinearOperator(diag=diag) + + def cross_covariance( + self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] + ) -> Float[Array, "N M"]: + raise ValueError("Cross covariance not defined for constant diagonal kernels.") diff --git a/jaxkern/computations/dense.py b/jaxkern/computations/dense.py new file mode 100644 index 0000000..c1f1a7c --- /dev/null +++ b/jaxkern/computations/dense.py @@ -0,0 +1,38 @@ +from typing import Callable, Dict + +from jax import vmap +from jaxtyping import Array, Float +from .base import AbstractKernelComputation + + +class DenseKernelComputation(AbstractKernelComputation): + """Dense kernel computation class. Operations with the kernel assume + a dense gram matrix structure. + """ + + def __init__( + self, + kernel_fn: Callable[ + [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array + ] = None, + ) -> None: + super().__init__(kernel_fn) + + def cross_covariance( + self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] + ) -> Float[Array, "N M"]: + """For a given kernel, compute the NxM covariance matrix on a pair of input + matrices of shape NxD and MxD. + + Args: + kernel (AbstractKernel): The kernel for which the Gram + matrix should be computed for. + params (Dict): The kernel's parameter set. + x (Float[Array,"N D"]): The input matrix. + y (Float[Array,"M D"]): The input matrix. + + Returns: + CovarianceOperator: The computed square Gram matrix. + """ + cross_cov = vmap(lambda x: vmap(lambda y: self.kernel_fn(params, x, y))(y))(x) + return cross_cov diff --git a/jaxkern/computations/diagonal.py b/jaxkern/computations/diagonal.py new file mode 100644 index 0000000..de83b7f --- /dev/null +++ b/jaxkern/computations/diagonal.py @@ -0,0 +1,45 @@ +from typing import Callable, Dict + +from jax import vmap +from jaxlinop import ( + DiagonalLinearOperator, +) +from jaxtyping import Array, Float +from .base import AbstractKernelComputation + + +class DiagonalKernelComputation(AbstractKernelComputation): + def __init__( + self, + kernel_fn: Callable[ + [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array + ] = None, + ) -> None: + super().__init__(kernel_fn) + + def gram( + self, + params: Dict, + inputs: Float[Array, "N D"], + ) -> DiagonalLinearOperator: + """For a kernel with diagonal structure, compute the NxN gram matrix on + an input matrix of shape NxD. + + Args: + kernel (AbstractKernel): The kernel for which the Gram matrix + should be computed for. + params (Dict): The kernel's parameter set. + inputs (Float[Array, "N D"]): The input matrix. + + Returns: + CovarianceOperator: The computed square Gram matrix. + """ + + diag = vmap(lambda x: self.kernel_fn(params, x, x))(inputs) + + return DiagonalLinearOperator(diag=diag) + + def cross_covariance( + self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] + ) -> Float[Array, "N M"]: + raise ValueError("Cross covariance not defined for diagonal kernels.") diff --git a/jaxkern/computations/eigen.py b/jaxkern/computations/eigen.py new file mode 100644 index 0000000..cb8d4f8 --- /dev/null +++ b/jaxkern/computations/eigen.py @@ -0,0 +1,56 @@ +from typing import Callable, Dict + +import jax.numpy as jnp +from jaxtyping import Array, Float +from .base import AbstractKernelComputation + + +class EigenKernelComputation(AbstractKernelComputation): + def __init__( + self, + kernel_fn: Callable[ + [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array + ] = None, + ) -> None: + super().__init__(kernel_fn) + self._eigenvalues = None + self._eigenvectors = None + self._num_verticies = None + + # Define an eigenvalue setter and getter property + @property + def eigensystem(self) -> Float[Array, "N"]: + return self._eigenvalues, self._eigenvectors, self._num_verticies + + @eigensystem.setter + def eigensystem( + self, eigenvalues: Float[Array, "N"], eigenvectors: Float[Array, "N N"] + ) -> None: + self._eigenvalues = eigenvalues + self._eigenvectors = eigenvectors + + @property + def num_vertex(self) -> int: + return self._num_verticies + + @num_vertex.setter + def num_vertex(self, num_vertex: int) -> None: + self._num_verticies = num_vertex + + def _compute_S(self, params): + evals, evecs = self.eigensystem + S = jnp.power( + evals + + 2 * params["smoothness"] / params["lengthscale"] / params["lengthscale"], + -params["smoothness"], + ) + S = jnp.multiply(S, self.num_vertex / jnp.sum(S)) + S = jnp.multiply(S, params["variance"]) + return S + + def cross_covariance( + self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] + ) -> Float[Array, "N M"]: + S = self._compute_S(params=params) + matrix = self.kernel_fn(params, x, y, S=S) + return matrix diff --git a/jaxkern/non_euclidean/__init__.py b/jaxkern/non_euclidean/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jaxkern/non_euclidean.py b/jaxkern/non_euclidean/graph.py similarity index 91% rename from jaxkern/non_euclidean.py rename to jaxkern/non_euclidean/graph.py index ebb8231..3167263 100644 --- a/jaxkern/non_euclidean.py +++ b/jaxkern/non_euclidean/graph.py @@ -4,8 +4,8 @@ from jax.random import KeyArray from jaxtyping import Array, Float -from .computations import EigenKernelComputation -from .nonstationary import AbstractKernel +from ..computations import EigenKernelComputation +from ..base import AbstractKernel from .utils import jax_gather_nd @@ -22,9 +22,7 @@ def __init__( spectral: Optional[bool] = False, name: Optional[str] = "Graph kernel", ) -> None: - super().__init__( - compute_engine, active_dims, stationary, spectral, name - ) + super().__init__(compute_engine, active_dims, stationary, spectral, name) self.laplacian = laplacian evals, self.evecs = jnp.linalg.eigh(self.laplacian) self.evals = evals.reshape(-1, 1) diff --git a/jaxkern/non_euclidean/utils.py b/jaxkern/non_euclidean/utils.py new file mode 100644 index 0000000..0e8d16c --- /dev/null +++ b/jaxkern/non_euclidean/utils.py @@ -0,0 +1,3 @@ +def jax_gather_nd(params, indices): + tuple_indices = tuple(indices[..., i] for i in range(indices.shape[-1])) + return params[tuple_indices] diff --git a/jaxkern/nonstationary.py b/jaxkern/nonstationary.py deleted file mode 100644 index 8b3a3f1..0000000 --- a/jaxkern/nonstationary.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from typing import Dict, List, Optional - -import jax -import jax.numpy as jnp -from jax.random import KeyArray -from jaxtyping import Array, Float - -from .base import AbstractKernel -from .computations import ( - AbstractKernelComputation, - ConstantDiagonalKernelComputation, - DenseKernelComputation, - DiagonalKernelComputation, -) -from .utils import euclidean_distance, squared_distance - - -########################################## -# Euclidean kernels -########################################## -class Linear(AbstractKernel): - """The linear kernel.""" - - def __init__( - self, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Linear", - ) -> None: - super().__init__( - compute_engine, active_dims, stationary, spectral, name - ) - - def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with variance parameter :math:`\\sigma` - - .. math:: - k(x, y) = \\sigma^2 x^{T}y - - Args: - x (jax.Array): The left hand argument of the kernel function's call. - y (jax.Array): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. - Returns: - Array: The value of :math:`k(x, y)` - """ - x = self.slice_input(x) - y = self.slice_input(y) - K = params["variance"] * jnp.matmul(x.T, y) - return K.squeeze() - - def init_params(self, key: KeyArray) -> Dict: - return {"variance": jnp.array([1.0])} - - -class Polynomial(AbstractKernel): - """The Polynomial kernel with variable degree.""" - - def __init__( - self, - degree: int = 1, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Polynomial", - ) -> None: - super().__init__( - compute_engine, active_dims, stationary, spectral, name - ) - self.degree = degree - self.name = f"Polynomial Degree: {self.degree}" - - def __call__( - self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] - ) -> Float[Array, "1"]: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with shift parameter :math:`\\alpha` and variance :math:`\\sigma^2` through - - .. math:: - k(x, y) = \\Big( \\alpha + \\sigma^2 xy \\Big)^{d} - - Args: - params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call - - Returns: - Float[Array, "1"]: The value of :math:`k(x, y)`. - """ - x = self.slice_input(x).squeeze() - y = self.slice_input(y).squeeze() - K = jnp.power( - params["shift"] + jnp.dot(x * params["variance"], y), self.degree - ) - return K.squeeze() - - def init_params(self, key: KeyArray) -> Dict: - return { - "shift": jnp.array([1.0]), - "variance": jnp.array([1.0] * self.ndims), - } - - -class White(AbstractKernel, ConstantDiagonalKernelComputation): - def __post_init__(self) -> None: - super(White, self).__post_init__() - - def __call__( - self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] - ) -> Float[Array, "1"]: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with variance :math:`\\sigma` - - .. math:: - k(x, y) = \\sigma^2 \\delta(x-y) - - Args: - params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. - - Returns: - Float[Array, "1"]: The value of :math:`k(x, y)`. - """ - K = jnp.all(jnp.equal(x, y)) * params["variance"] - return K.squeeze() - - def init_params(self, key: Float[Array, "1 D"]) -> Dict: - """Initialise the kernel parameters. - - Args: - key (Float[Array, "1 D"]): The key to initialise the parameters with. - - Returns: - Dict: The initialised parameters. - """ - return {"variance": jnp.array([1.0])} - - -class Periodic(AbstractKernel): - """The periodic kernel. - - Key reference is MacKay 1998 - "Introduction to Gaussian processes". - """ - - def __init__( - self, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Periodic", - ) -> None: - super().__init__( - compute_engine, active_dims, stationary, spectral, name - ) - - def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell` and variance :math:`\\sigma` - - .. math:: - k(x, y) = \\sigma^2 \\exp \\Bigg( -0.5 \\sum_{i=1}^{d} \\Bigg) - - Args: - x (jax.Array): The left hand argument of the kernel function's call. - y (jax.Array): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. - Returns: - Array: The value of :math:`k(x, y)` - """ - x = self.slice_input(x) - y = self.slice_input(y) - sine_squared = ( - jnp.sin(jnp.pi * (x - y) / params["period"]) / params["lengthscale"] - ) ** 2 - K = params["variance"] * jnp.exp(-0.5 * jnp.sum(sine_squared, axis=0)) - return K.squeeze() - - def init_params(self, key: KeyArray) -> Dict: - return { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - "period": jnp.array([1.0] * self.ndims), - } - - -__all__ = [ - "Linear", - "Periodic", - "Polynomial", - "White", -] diff --git a/jaxkern/nonstationary/__init__.py b/jaxkern/nonstationary/__init__.py new file mode 100644 index 0000000..20ce75e --- /dev/null +++ b/jaxkern/nonstationary/__init__.py @@ -0,0 +1,5 @@ +from .linear import Linear +from .polynomial import Polynomial +from .white import White + +__all__ = ["Linear", "Polynomial", "White"] diff --git a/jaxkern/nonstationary/linear.py b/jaxkern/nonstationary/linear.py new file mode 100644 index 0000000..3795420 --- /dev/null +++ b/jaxkern/nonstationary/linear.py @@ -0,0 +1,50 @@ +from typing import Dict, List, Optional + +import jax +import jax.numpy as jnp +from jax.random import KeyArray +from jaxtyping import Array + +from ..base import AbstractKernel +from ..computations import ( + AbstractKernelComputation, + DenseKernelComputation, +) + + +########################################## +# Euclidean kernels +########################################## +class Linear(AbstractKernel): + """The linear kernel.""" + + def __init__( + self, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Linear", + ) -> None: + super().__init__(compute_engine, active_dims, stationary, spectral, name) + + def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with variance parameter :math:`\\sigma` + + .. math:: + k(x, y) = \\sigma^2 x^{T}y + + Args: + x (jax.Array): The left hand argument of the kernel function's call. + y (jax.Array): The right hand argument of the kernel function's call + params (dict): Parameter set for which the kernel should be evaluated on. + Returns: + Array: The value of :math:`k(x, y)` + """ + x = self.slice_input(x) + y = self.slice_input(y) + K = params["variance"] * jnp.matmul(x.T, y) + return K.squeeze() + + def init_params(self, key: KeyArray) -> Dict: + return {"variance": jnp.array([1.0])} diff --git a/jaxkern/nonstationary/polynomial.py b/jaxkern/nonstationary/polynomial.py new file mode 100644 index 0000000..18210d9 --- /dev/null +++ b/jaxkern/nonstationary/polynomial.py @@ -0,0 +1,55 @@ +from typing import Dict, List, Optional + +import jax.numpy as jnp +from jax.random import KeyArray +from jaxtyping import Array, Float + +from ..base import AbstractKernel +from ..computations import ( + AbstractKernelComputation, + DenseKernelComputation, +) + + +class Polynomial(AbstractKernel): + """The Polynomial kernel with variable degree.""" + + def __init__( + self, + degree: int = 1, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Polynomial", + ) -> None: + super().__init__(compute_engine, active_dims, stationary, spectral, name) + self.degree = degree + self.name = f"Polynomial Degree: {self.degree}" + + def __call__( + self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] + ) -> Float[Array, "1"]: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with shift parameter :math:`\\alpha` and variance :math:`\\sigma^2` through + + .. math:: + k(x, y) = \\Big( \\alpha + \\sigma^2 xy \\Big)^{d} + + Args: + params (Dict): Parameter set for which the kernel should be evaluated on. + x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. + y (Float[Array, "1 D"]): The right hand argument of the kernel function's call + + Returns: + Float[Array, "1"]: The value of :math:`k(x, y)`. + """ + x = self.slice_input(x).squeeze() + y = self.slice_input(y).squeeze() + K = jnp.power(params["shift"] + jnp.dot(x * params["variance"], y), self.degree) + return K.squeeze() + + def init_params(self, key: KeyArray) -> Dict: + return { + "shift": jnp.array([1.0]), + "variance": jnp.array([1.0] * self.ndims), + } diff --git a/jaxkern/nonstationary/white.py b/jaxkern/nonstationary/white.py new file mode 100644 index 0000000..5e0e330 --- /dev/null +++ b/jaxkern/nonstationary/white.py @@ -0,0 +1,44 @@ +from typing import Dict + +import jax.numpy as jnp +from jaxtyping import Array, Float + +from ..base import AbstractKernel +from ..computations import ( + ConstantDiagonalKernelComputation, +) + + +class White(AbstractKernel, ConstantDiagonalKernelComputation): + def __post_init__(self) -> None: + super(White, self).__post_init__() + + def __call__( + self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] + ) -> Float[Array, "1"]: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with variance :math:`\\sigma` + + .. math:: + k(x, y) = \\sigma^2 \\delta(x-y) + + Args: + params (Dict): Parameter set for which the kernel should be evaluated on. + x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. + y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. + + Returns: + Float[Array, "1"]: The value of :math:`k(x, y)`. + """ + K = jnp.all(jnp.equal(x, y)) * params["variance"] + return K.squeeze() + + def init_params(self, key: Float[Array, "1 D"]) -> Dict: + """Initialise the kernel parameters. + + Args: + key (Float[Array, "1 D"]): The key to initialise the parameters with. + + Returns: + Dict: The initialised parameters. + """ + return {"variance": jnp.array([1.0])} diff --git a/jaxkern/stationary.py b/jaxkern/stationary.py deleted file mode 100644 index c3f6e42..0000000 --- a/jaxkern/stationary.py +++ /dev/null @@ -1,304 +0,0 @@ -from typing import Dict, List, Optional - -import jax -import jax.numpy as jnp -from jax.random import KeyArray -from jaxtyping import Array, Float - -from .base import AbstractKernel -from .computations import ( - AbstractKernelComputation, - DenseKernelComputation, -) -from .utils import euclidean_distance, squared_distance - - -class RBF(AbstractKernel): - """The Radial Basis Function (RBF) kernel.""" - - def __init__( - self, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Radial basis function kernel", - ) -> None: - super().__init__( - compute_engine, active_dims, stationary, spectral, name - ) - - def __call__( - self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] - ) -> Float[Array, "1"]: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with - lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2` - - .. math:: - k(x, y) = \\sigma^2 \\exp \\Bigg( \\frac{\\lVert x - y \\rVert^2_2}{2 \\ell^2} \\Bigg) - - Args: - params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. - - Returns: - Float[Array, "1"]: The value of :math:`k(x, y)`. - """ - x = self.slice_input(x) / params["lengthscale"] - y = self.slice_input(y) / params["lengthscale"] - K = params["variance"] * jnp.exp(-0.5 * squared_distance(x, y)) - return K.squeeze() - - def init_params(self, key: KeyArray) -> Dict: - params = { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - } - return jax.tree_util.tree_map(lambda x: jnp.atleast_1d(x), params) - - -class Matern12(AbstractKernel): - """The Matérn kernel with smoothness parameter fixed at 0.5.""" - - def __init__( - self, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Matérn 1/2 kernel", - ) -> None: - super().__init__( - compute_engine, active_dims, stationary, spectral, name - ) - - def __call__( - self, - params: Dict, - x: Float[Array, "1 D"], - y: Float[Array, "1 D"], - ) -> Float[Array, "1"]: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with - lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2` - - .. math:: - k(x, y) = \\sigma^2 \\exp \\Bigg( -\\frac{\\lvert x-y \\rvert}{2\\ell^2} \\Bigg) - - Args: - params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call - Returns: - Float[Array, "1"]: The value of :math:`k(x, y)` - """ - x = self.slice_input(x) / params["lengthscale"] - y = self.slice_input(y) / params["lengthscale"] - K = params["variance"] * jnp.exp(-euclidean_distance(x, y)) - return K.squeeze() - - def init_params(self, key: KeyArray) -> Dict: - return { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - } - - -class Matern32(AbstractKernel): - """The Matérn kernel with smoothness parameter fixed at 1.5.""" - - def __init__( - self, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Matern 3/2", - ) -> None: - super().__init__( - compute_engine, active_dims, stationary, spectral, name - ) - - def __call__( - self, - params: Dict, - x: Float[Array, "1 D"], - y: Float[Array, "1 D"], - ) -> Float[Array, "1"]: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with - lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2` - - .. math:: - k(x, y) = \\sigma^2 \\exp \\Bigg(1+ \\frac{\\sqrt{3}\\lvert x-y \\rvert}{\\ell^2} \\Bigg)\\exp\\Bigg(-\\frac{\\sqrt{3}\\lvert x-y\\rvert}{\\ell^2} \\Bigg) - - Args: - params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. - - Returns: - Float[Array, "1"]: The value of :math:`k(x, y)`. - """ - x = self.slice_input(x) / params["lengthscale"] - y = self.slice_input(y) / params["lengthscale"] - tau = euclidean_distance(x, y) - K = ( - params["variance"] - * (1.0 + jnp.sqrt(3.0) * tau) - * jnp.exp(-jnp.sqrt(3.0) * tau) - ) - return K.squeeze() - - def init_params(self, key: KeyArray) -> Dict: - return { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - } - - -class Matern52(AbstractKernel): - """The Matérn kernel with smoothness parameter fixed at 2.5.""" - - def __init__( - self, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Matern 5/2", - ) -> None: - super().__init__( - compute_engine, active_dims, stationary, spectral, name - ) - - def __call__( - self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] - ) -> Float[Array, "1"]: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with - lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2` - - .. math:: - k(x, y) = \\sigma^2 \\exp \\Bigg(1+ \\frac{\\sqrt{5}\\lvert x-y \\rvert}{\\ell^2} + \\frac{5\\lvert x - y \\rvert^2}{3\\ell^2} \\Bigg)\\exp\\Bigg(-\\frac{\\sqrt{5}\\lvert x-y\\rvert}{\\ell^2} \\Bigg) - - Args: - params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. - - Returns: - Float[Array, "1"]: The value of :math:`k(x, y)`. - """ - x = self.slice_input(x) / params["lengthscale"] - y = self.slice_input(y) / params["lengthscale"] - tau = euclidean_distance(x, y) - K = ( - params["variance"] - * (1.0 + jnp.sqrt(5.0) * tau + 5.0 / 3.0 * jnp.square(tau)) - * jnp.exp(-jnp.sqrt(5.0) * tau) - ) - return K.squeeze() - - def init_params(self, key: KeyArray) -> Dict: - return { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - } - - -class PoweredExponential(AbstractKernel): - """The powered exponential family of kernels. - - Key reference is Diggle and Ribeiro (2007) - "Model-based Geostatistics". - - """ - - def __init__( - self, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Powered exponential", - ) -> None: - super().__init__( - compute_engine, active_dims, stationary, spectral, name - ) - - def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell`, :math:`\\sigma` and power :math:`\\kappa`. - - .. math:: - k(x, y) = \\sigma^2 \\exp \\Bigg( - \\Big( \\frac{\\lVert x - y \\rVert^2}{\\ell^2} \\Big)^\\kappa \\Bigg) - - Args: - x (jax.Array): The left hand argument of the kernel function's call. - y (jax.Array): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. - - Returns: - Array: The value of :math:`k(x, y)` - """ - x = self.slice_input(x) / params["lengthscale"] - y = self.slice_input(y) / params["lengthscale"] - K = params["variance"] * jnp.exp( - -euclidean_distance(x, y) ** params["power"] - ) - return K.squeeze() - - def init_params(self, key: KeyArray) -> Dict: - return { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - "power": jnp.array([1.0]), - } - - -class RationalQuadratic(AbstractKernel): - def __init__( - self, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Rational Quadratic", - ) -> None: - super().__init__( - compute_engine, active_dims, stationary, spectral, name - ) - - def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell` and variance :math:`\\sigma` - - .. math:: - k(x, y) = \\sigma^2 \\exp \\Bigg( 1 + \\frac{\\lVert x - y \\rVert^2_2}{2 \\alpha \\ell^2} \\Bigg) - - Args: - x (jax.Array): The left hand argument of the kernel function's call. - y (jax.Array): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. - Returns: - Array: The value of :math:`k(x, y)` - """ - x = self.slice_input(x) / params["lengthscale"] - y = self.slice_input(y) / params["lengthscale"] - K = params["variance"] * ( - 1 + 0.5 * squared_distance(x, y) / params["alpha"] - ) ** (-params["alpha"]) - return K.squeeze() - - def init_params(self, key: KeyArray) -> dict: - return { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - "alpha": jnp.array([1.0]), - } - - -__all__ = [ - "RBF", - "Matern12", - "Matern32", - "Matern52", - "RationalQuadratic", - "PoweredExponential", -] diff --git a/jaxkern/stationary/__init__.py b/jaxkern/stationary/__init__.py new file mode 100644 index 0000000..6ba568c --- /dev/null +++ b/jaxkern/stationary/__init__.py @@ -0,0 +1,17 @@ +from .matern12 import Matern12 +from .matern32 import Matern32 +from .matern52 import Matern52 +from .periodic import Periodic +from .powered_exponential import PoweredExponential +from .rational_quadratic import RationalQuadratic +from .rbf import RBF + +__all__ = [ + "Matern12", + "Matern32", + "Matern52", + "Periodic", + "PoweredExponential", + "RationalQuadratic", + "RBF", +] diff --git a/jaxkern/stationary/matern12.py b/jaxkern/stationary/matern12.py new file mode 100644 index 0000000..ec99c05 --- /dev/null +++ b/jaxkern/stationary/matern12.py @@ -0,0 +1,56 @@ +from typing import Dict, List, Optional + +import jax.numpy as jnp +from jax.random import KeyArray +from jaxtyping import Array, Float + +from ..base import AbstractKernel +from ..computations import ( + AbstractKernelComputation, + DenseKernelComputation, +) +from .utils import euclidean_distance + + +class Matern12(AbstractKernel): + """The Matérn kernel with smoothness parameter fixed at 0.5.""" + + def __init__( + self, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Matérn 1/2 kernel", + ) -> None: + super().__init__(compute_engine, active_dims, stationary, spectral, name) + + def __call__( + self, + params: Dict, + x: Float[Array, "1 D"], + y: Float[Array, "1 D"], + ) -> Float[Array, "1"]: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with + lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2` + + .. math:: + k(x, y) = \\sigma^2 \\exp \\Bigg( -\\frac{\\lvert x-y \\rvert}{2\\ell^2} \\Bigg) + + Args: + params (Dict): Parameter set for which the kernel should be evaluated on. + x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. + y (Float[Array, "1 D"]): The right hand argument of the kernel function's call + Returns: + Float[Array, "1"]: The value of :math:`k(x, y)` + """ + x = self.slice_input(x) / params["lengthscale"] + y = self.slice_input(y) / params["lengthscale"] + K = params["variance"] * jnp.exp(-euclidean_distance(x, y)) + return K.squeeze() + + def init_params(self, key: KeyArray) -> Dict: + return { + "lengthscale": jnp.array([1.0] * self.ndims), + "variance": jnp.array([1.0]), + } diff --git a/jaxkern/stationary/matern32.py b/jaxkern/stationary/matern32.py new file mode 100644 index 0000000..bfc4aa3 --- /dev/null +++ b/jaxkern/stationary/matern32.py @@ -0,0 +1,62 @@ +from typing import Dict, List, Optional + +import jax.numpy as jnp +from jax.random import KeyArray +from jaxtyping import Array, Float + +from ..base import AbstractKernel +from ..computations import ( + AbstractKernelComputation, + DenseKernelComputation, +) +from .utils import euclidean_distance + + +class Matern32(AbstractKernel): + """The Matérn kernel with smoothness parameter fixed at 1.5.""" + + def __init__( + self, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Matern 3/2", + ) -> None: + super().__init__(compute_engine, active_dims, stationary, spectral, name) + + def __call__( + self, + params: Dict, + x: Float[Array, "1 D"], + y: Float[Array, "1 D"], + ) -> Float[Array, "1"]: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with + lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2` + + .. math:: + k(x, y) = \\sigma^2 \\exp \\Bigg(1+ \\frac{\\sqrt{3}\\lvert x-y \\rvert}{\\ell^2} \\Bigg)\\exp\\Bigg(-\\frac{\\sqrt{3}\\lvert x-y\\rvert}{\\ell^2} \\Bigg) + + Args: + params (Dict): Parameter set for which the kernel should be evaluated on. + x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. + y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. + + Returns: + Float[Array, "1"]: The value of :math:`k(x, y)`. + """ + x = self.slice_input(x) / params["lengthscale"] + y = self.slice_input(y) / params["lengthscale"] + tau = euclidean_distance(x, y) + K = ( + params["variance"] + * (1.0 + jnp.sqrt(3.0) * tau) + * jnp.exp(-jnp.sqrt(3.0) * tau) + ) + return K.squeeze() + + def init_params(self, key: KeyArray) -> Dict: + return { + "lengthscale": jnp.array([1.0] * self.ndims), + "variance": jnp.array([1.0]), + } diff --git a/jaxkern/stationary/matern52.py b/jaxkern/stationary/matern52.py new file mode 100644 index 0000000..73204cf --- /dev/null +++ b/jaxkern/stationary/matern52.py @@ -0,0 +1,59 @@ +from typing import Dict, List, Optional + +import jax.numpy as jnp +from jax.random import KeyArray +from jaxtyping import Array, Float + +from ..base import AbstractKernel +from ..computations import ( + AbstractKernelComputation, + DenseKernelComputation, +) +from .utils import euclidean_distance + + +class Matern52(AbstractKernel): + """The Matérn kernel with smoothness parameter fixed at 2.5.""" + + def __init__( + self, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Matern 5/2", + ) -> None: + super().__init__(compute_engine, active_dims, stationary, spectral, name) + + def __call__( + self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] + ) -> Float[Array, "1"]: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with + lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2` + + .. math:: + k(x, y) = \\sigma^2 \\exp \\Bigg(1+ \\frac{\\sqrt{5}\\lvert x-y \\rvert}{\\ell^2} + \\frac{5\\lvert x - y \\rvert^2}{3\\ell^2} \\Bigg)\\exp\\Bigg(-\\frac{\\sqrt{5}\\lvert x-y\\rvert}{\\ell^2} \\Bigg) + + Args: + params (Dict): Parameter set for which the kernel should be evaluated on. + x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. + y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. + + Returns: + Float[Array, "1"]: The value of :math:`k(x, y)`. + """ + x = self.slice_input(x) / params["lengthscale"] + y = self.slice_input(y) / params["lengthscale"] + tau = euclidean_distance(x, y) + K = ( + params["variance"] + * (1.0 + jnp.sqrt(5.0) * tau + 5.0 / 3.0 * jnp.square(tau)) + * jnp.exp(-jnp.sqrt(5.0) * tau) + ) + return K.squeeze() + + def init_params(self, key: KeyArray) -> Dict: + return { + "lengthscale": jnp.array([1.0] * self.ndims), + "variance": jnp.array([1.0]), + } diff --git a/jaxkern/stationary/periodic.py b/jaxkern/stationary/periodic.py new file mode 100644 index 0000000..4637dc1 --- /dev/null +++ b/jaxkern/stationary/periodic.py @@ -0,0 +1,57 @@ +from typing import Dict, List, Optional + +import jax +import jax.numpy as jnp +from jax.random import KeyArray +from jaxtyping import Array + +from ..base import AbstractKernel +from ..computations import ( + AbstractKernelComputation, + DenseKernelComputation, +) + + +class Periodic(AbstractKernel): + """The periodic kernel. + + Key reference is MacKay 1998 - "Introduction to Gaussian processes". + """ + + def __init__( + self, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Periodic", + ) -> None: + super().__init__(compute_engine, active_dims, stationary, spectral, name) + + def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell` and variance :math:`\\sigma` + + .. math:: + k(x, y) = \\sigma^2 \\exp \\Bigg( -0.5 \\sum_{i=1}^{d} \\Bigg) + + Args: + x (jax.Array): The left hand argument of the kernel function's call. + y (jax.Array): The right hand argument of the kernel function's call + params (dict): Parameter set for which the kernel should be evaluated on. + Returns: + Array: The value of :math:`k(x, y)` + """ + x = self.slice_input(x) + y = self.slice_input(y) + sine_squared = ( + jnp.sin(jnp.pi * (x - y) / params["period"]) / params["lengthscale"] + ) ** 2 + K = params["variance"] * jnp.exp(-0.5 * jnp.sum(sine_squared, axis=0)) + return K.squeeze() + + def init_params(self, key: KeyArray) -> Dict: + return { + "lengthscale": jnp.array([1.0] * self.ndims), + "variance": jnp.array([1.0]), + "period": jnp.array([1.0] * self.ndims), + } diff --git a/jaxkern/stationary/powered_exponential.py b/jaxkern/stationary/powered_exponential.py new file mode 100644 index 0000000..55295d3 --- /dev/null +++ b/jaxkern/stationary/powered_exponential.py @@ -0,0 +1,57 @@ +from typing import Dict, List, Optional + +import jax +import jax.numpy as jnp +from jax.random import KeyArray +from jaxtyping import Array + +from ..base import AbstractKernel +from ..computations import ( + AbstractKernelComputation, + DenseKernelComputation, +) +from .utils import euclidean_distance + + +class PoweredExponential(AbstractKernel): + """The powered exponential family of kernels. + + Key reference is Diggle and Ribeiro (2007) - "Model-based Geostatistics". + + """ + + def __init__( + self, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Powered exponential", + ) -> None: + super().__init__(compute_engine, active_dims, stationary, spectral, name) + + def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell`, :math:`\\sigma` and power :math:`\\kappa`. + + .. math:: + k(x, y) = \\sigma^2 \\exp \\Bigg( - \\Big( \\frac{\\lVert x - y \\rVert^2}{\\ell^2} \\Big)^\\kappa \\Bigg) + + Args: + x (jax.Array): The left hand argument of the kernel function's call. + y (jax.Array): The right hand argument of the kernel function's call + params (dict): Parameter set for which the kernel should be evaluated on. + + Returns: + Array: The value of :math:`k(x, y)` + """ + x = self.slice_input(x) / params["lengthscale"] + y = self.slice_input(y) / params["lengthscale"] + K = params["variance"] * jnp.exp(-euclidean_distance(x, y) ** params["power"]) + return K.squeeze() + + def init_params(self, key: KeyArray) -> Dict: + return { + "lengthscale": jnp.array([1.0] * self.ndims), + "variance": jnp.array([1.0]), + "power": jnp.array([1.0]), + } diff --git a/jaxkern/stationary/rational_quadratic.py b/jaxkern/stationary/rational_quadratic.py new file mode 100644 index 0000000..dcd12f0 --- /dev/null +++ b/jaxkern/stationary/rational_quadratic.py @@ -0,0 +1,52 @@ +from typing import List, Optional + +import jax +import jax.numpy as jnp +from jax.random import KeyArray +from jaxtyping import Array + +from ..base import AbstractKernel +from ..computations import ( + AbstractKernelComputation, + DenseKernelComputation, +) +from .utils import squared_distance + + +class RationalQuadratic(AbstractKernel): + def __init__( + self, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Rational Quadratic", + ) -> None: + super().__init__(compute_engine, active_dims, stationary, spectral, name) + + def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\\ell` and variance :math:`\\sigma` + + .. math:: + k(x, y) = \\sigma^2 \\exp \\Bigg( 1 + \\frac{\\lVert x - y \\rVert^2_2}{2 \\alpha \\ell^2} \\Bigg) + + Args: + x (jax.Array): The left hand argument of the kernel function's call. + y (jax.Array): The right hand argument of the kernel function's call + params (dict): Parameter set for which the kernel should be evaluated on. + Returns: + Array: The value of :math:`k(x, y)` + """ + x = self.slice_input(x) / params["lengthscale"] + y = self.slice_input(y) / params["lengthscale"] + K = params["variance"] * ( + 1 + 0.5 * squared_distance(x, y) / params["alpha"] + ) ** (-params["alpha"]) + return K.squeeze() + + def init_params(self, key: KeyArray) -> dict: + return { + "lengthscale": jnp.array([1.0] * self.ndims), + "variance": jnp.array([1.0]), + "alpha": jnp.array([1.0]), + } diff --git a/jaxkern/stationary/rbf.py b/jaxkern/stationary/rbf.py new file mode 100644 index 0000000..6e72258 --- /dev/null +++ b/jaxkern/stationary/rbf.py @@ -0,0 +1,56 @@ +from typing import Dict, List, Optional + +import jax +import jax.numpy as jnp +from jax.random import KeyArray +from jaxtyping import Array, Float + +from ..base import AbstractKernel +from ..computations import ( + AbstractKernelComputation, + DenseKernelComputation, +) +from .utils import squared_distance + + +class RBF(AbstractKernel): + """The Radial Basis Function (RBF) kernel.""" + + def __init__( + self, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Radial basis function kernel", + ) -> None: + super().__init__(compute_engine, active_dims, stationary, spectral, name) + + def __call__( + self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] + ) -> Float[Array, "1"]: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with + lengthscale parameter :math:`\\ell` and variance :math:`\\sigma^2` + + .. math:: + k(x, y) = \\sigma^2 \\exp \\Bigg( \\frac{\\lVert x - y \\rVert^2_2}{2 \\ell^2} \\Bigg) + + Args: + params (Dict): Parameter set for which the kernel should be evaluated on. + x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. + y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. + + Returns: + Float[Array, "1"]: The value of :math:`k(x, y)`. + """ + x = self.slice_input(x) / params["lengthscale"] + y = self.slice_input(y) / params["lengthscale"] + K = params["variance"] * jnp.exp(-0.5 * squared_distance(x, y)) + return K.squeeze() + + def init_params(self, key: KeyArray) -> Dict: + params = { + "lengthscale": jnp.array([1.0] * self.ndims), + "variance": jnp.array([1.0]), + } + return jax.tree_util.tree_map(lambda x: jnp.atleast_1d(x), params) diff --git a/jaxkern/utils.py b/jaxkern/stationary/utils.py similarity index 85% rename from jaxkern/utils.py rename to jaxkern/stationary/utils.py index 06fbe6c..8d2a8c0 100644 --- a/jaxkern/utils.py +++ b/jaxkern/stationary/utils.py @@ -32,8 +32,3 @@ def euclidean_distance( """ return jnp.sqrt(jnp.maximum(squared_distance(x, y), 1e-36)) - - -def jax_gather_nd(params, indices): - tuple_indices = tuple(indices[..., i] for i in range(indices.shape[-1])) - return params[tuple_indices] diff --git a/setup.cfg b/setup.cfg index a428b8e..1d85afc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -6,6 +6,11 @@ versionfile_build = jaxkern/_version.py tag_prefix = v [flake8] +max-line-length = 88 +max-complexity = 12 +ignore = E501 +select = C,E,F,W,B,B9 +extend-ignore = E203, W503, F722, F821 exclude = versioneer.py jaxkern/_version.py diff --git a/tests/test_nonstationary.py b/tests/test_nonstationary.py index 50b44c6..833089b 100644 --- a/tests/test_nonstationary.py +++ b/tests/test_nonstationary.py @@ -15,24 +15,16 @@ from itertools import permutations -from typing import Dict, List -import jax import jax.numpy as jnp import jax.random as jr import pytest from jax.config import config -from jax.random import KeyArray from jaxlinop import LinearOperator, identity -from jaxtyping import Array, Float from jaxutils.parameters import initialise from jaxkern.base import AbstractKernel -from jaxkern.nonstationary import ( - Linear, - Periodic, - Polynomial, -) +from jaxkern.nonstationary import Linear, Polynomial, White # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) @@ -115,40 +107,12 @@ def test_pos_def( assert (eigen_values > 0.0).all() -@pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize( - "ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)] -) -@pytest.mark.parametrize("period", [0.1, 0.5, 1.0]) -@pytest.mark.parametrize("n", [1, 2, 5]) -def test_pos_def_periodic( - dim: int, ell: float, sigma: float, period: float, n: int -) -> None: - kern = Periodic(active_dims=list(range(dim))) - # Gram constructor static method: - kern.gram - - # Create inputs x: - x = jr.uniform(_initialise_key, (n, dim)) - params = { - "lengthscale": jnp.array([ell]), - "variance": jnp.array([sigma]), - "period": jnp.array([period]), - } - - # Test gram matrix eigenvalues are positive: - Kxx = kern.gram(params, x) - Kxx += identity(n) * _jitter - eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) - assert (eigen_values > 0.0).all() - - @pytest.mark.parametrize( "kernel", [ Linear, Polynomial, - Periodic, + White, ], ) def test_dtype(kernel: AbstractKernel) -> None: diff --git a/tests/test_stationary.py b/tests/test_stationary.py index 676fbde..d5b0788 100644 --- a/tests/test_stationary.py +++ b/tests/test_stationary.py @@ -15,16 +15,13 @@ from itertools import permutations -from typing import Dict, List import jax import jax.numpy as jnp import jax.random as jr import pytest from jax.config import config -from jax.random import KeyArray from jaxlinop import LinearOperator, identity -from jaxtyping import Array, Float from jaxutils.parameters import initialise from jaxkern.base import AbstractKernel @@ -35,8 +32,8 @@ Matern52, PoweredExponential, RationalQuadratic, + Periodic, ) -from jaxkern.utils import euclidean_distance # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) @@ -122,9 +119,7 @@ def test_call(kernel: AbstractKernel, dim: int) -> None: @pytest.mark.parametrize("kern", [RBF, Matern12, Matern32, Matern52]) @pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize( - "ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)] -) +@pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) @pytest.mark.parametrize("n", [1, 2, 5]) def test_pos_def( kern: AbstractKernel, dim: int, ell: float, sigma: float, n: int @@ -143,14 +138,10 @@ def test_pos_def( @pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize( - "ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)] -) +@pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) @pytest.mark.parametrize("alpha", [0.1, 0.5, 1.0]) @pytest.mark.parametrize("n", [1, 2, 5]) -def test_pos_def_rq( - dim: int, ell: float, sigma: float, alpha: float, n: int -) -> None: +def test_pos_def_rq(dim: int, ell: float, sigma: float, alpha: float, n: int) -> None: kern = RationalQuadratic(active_dims=list(range(dim))) # Gram constructor static method: kern.gram @@ -171,9 +162,33 @@ def test_pos_def_rq( @pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize( - "ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)] -) +@pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) +@pytest.mark.parametrize("period", [0.1, 0.5, 1.0]) +@pytest.mark.parametrize("n", [1, 2, 5]) +def test_pos_def_periodic( + dim: int, ell: float, sigma: float, period: float, n: int +) -> None: + kern = Periodic(active_dims=list(range(dim))) + # Gram constructor static method: + kern.gram + + # Create inputs x: + x = jr.uniform(_initialise_key, (n, dim)) + params = { + "lengthscale": jnp.array([ell]), + "variance": jnp.array([sigma]), + "period": jnp.array([period]), + } + + # Test gram matrix eigenvalues are positive: + Kxx = kern.gram(params, x) + Kxx += identity(n) * _jitter + eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) + assert (eigen_values > 0.0).all() + + +@pytest.mark.parametrize("dim", [1, 2, 5]) +@pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) @pytest.mark.parametrize("power", [0.1, 0.5, 1.0]) @pytest.mark.parametrize("n", [1, 2, 5]) def test_pos_def_power_exp( @@ -228,6 +243,7 @@ def test_initialisation(kernel: AbstractKernel, dim: int) -> None: Matern32, Matern52, RationalQuadratic, + Periodic, PoweredExponential, ], ) diff --git a/tests/test_utils.py b/tests/test_utils.py index ed316b4..4765368 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,7 +3,7 @@ import jax.numpy as jnp import pytest from jaxtyping import Array, Float -from jaxkern.utils import euclidean_distance +from jaxkern.stationary.utils import euclidean_distance @pytest.mark.parametrize( From 56b4e311e966c54b5297ee1a30cf6765b5697783 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Wed, 18 Jan 2023 08:48:10 +0000 Subject: [PATCH 6/8] Fix graph kernel init --- jaxkern/non_euclidean/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/jaxkern/non_euclidean/__init__.py b/jaxkern/non_euclidean/__init__.py index e69de29..f207764 100644 --- a/jaxkern/non_euclidean/__init__.py +++ b/jaxkern/non_euclidean/__init__.py @@ -0,0 +1,3 @@ +from .graph import GraphKernel + +__all__ = ["GraphKernel"] From 1478ab7439b341776b30410b4839e986a839f49f Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Thu, 19 Jan 2023 17:35:28 +0000 Subject: [PATCH 7/8] Resolve comments --- jaxkern/non_euclidean/graph.py | 22 ++++++++++++++++--- jaxkern/non_euclidean/utils.py | 16 +++++++++++++- jaxkern/nonstationary/__init__.py | 2 +- .../{nonstationary => stationary}/white.py | 14 +++++++++--- 4 files changed, 46 insertions(+), 8 deletions(-) rename jaxkern/{nonstationary => stationary}/white.py (73%) diff --git a/jaxkern/non_euclidean/graph.py b/jaxkern/non_euclidean/graph.py index 3167263..99f409b 100644 --- a/jaxkern/non_euclidean/graph.py +++ b/jaxkern/non_euclidean/graph.py @@ -13,16 +13,26 @@ # Graph kernels ########################################## class GraphKernel(AbstractKernel): + """A Matérn graph kernel defined on the vertices of a graph. The key reference for this object is borovitskiy et. al., (2020).""" + def __init__( self, laplacian: Float[Array, "N N"], compute_engine: EigenKernelComputation = EigenKernelComputation, active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, spectral: Optional[bool] = False, - name: Optional[str] = "Graph kernel", + name: Optional[str] = "Matérn Graph kernel", ) -> None: - super().__init__(compute_engine, active_dims, stationary, spectral, name) + """Initialize a Matérn graph kernel. + + Args: + laplacian (Float[Array]): An N x N matrix representing the Laplacian matrix of a graph. + compute_engine (EigenKernelComputation, optional): The compute engine that should be used in the kernel to compute covariance matrices. Defaults to EigenKernelComputation. + active_dims (Optional[List[int]], optional): The dimensions of the input data for which the kernel should be evaluated on. Defaults to None. + stationary (Optional[bool], optional): _description_. Defaults to False. + name (Optional[str], optional): _description_. Defaults to "Graph kernel". + """ + super().__init__(compute_engine, active_dims, True, spectral, name) self.laplacian = laplacian evals, self.evecs = jnp.linalg.eigh(self.laplacian) self.evals = evals.reshape(-1, 1) @@ -53,6 +63,7 @@ def __call__( return Kxx.squeeze() def init_params(self, key: KeyArray) -> Dict: + """Initialise the lengthscale, variance and smoothness parameters of the kernel""" return { "lengthscale": jnp.array([1.0] * self.ndims), "variance": jnp.array([1.0]), @@ -61,4 +72,9 @@ def init_params(self, key: KeyArray) -> Dict: @property def num_vertex(self) -> int: + """The number of vertices within the graph. + + Returns: + int: An integer representing the number of vertices within the graph. + """ return self.compute_engine.num_vertex diff --git a/jaxkern/non_euclidean/utils.py b/jaxkern/non_euclidean/utils.py index 0e8d16c..a6ad0a3 100644 --- a/jaxkern/non_euclidean/utils.py +++ b/jaxkern/non_euclidean/utils.py @@ -1,3 +1,17 @@ -def jax_gather_nd(params, indices): +from jaxtyping import Num, Array, Int + + +def jax_gather_nd( + params: Num[Array, "N ..."], indices: Int[Array, "M"] +) -> Num[Array, "M ..."]: + """Slice a `params` array at a set of `indices`. + + Args: + params (Num[Array]): An arbitrary array with leading axes of length `N` upon which we shall slice. + indices (Float[Int]): An integer array of length M with values in the range [0, N) whose value at index `i` will be used to slice `params` at index `i`. + + Returns: + Num[Array: An arbitrary array with leading axes of length `M`. + """ tuple_indices = tuple(indices[..., i] for i in range(indices.shape[-1])) return params[tuple_indices] diff --git a/jaxkern/nonstationary/__init__.py b/jaxkern/nonstationary/__init__.py index 20ce75e..4304c1c 100644 --- a/jaxkern/nonstationary/__init__.py +++ b/jaxkern/nonstationary/__init__.py @@ -1,5 +1,5 @@ from .linear import Linear from .polynomial import Polynomial -from .white import White +from ..stationary.white import White __all__ = ["Linear", "Polynomial", "White"] diff --git a/jaxkern/nonstationary/white.py b/jaxkern/stationary/white.py similarity index 73% rename from jaxkern/nonstationary/white.py rename to jaxkern/stationary/white.py index 5e0e330..a338f14 100644 --- a/jaxkern/nonstationary/white.py +++ b/jaxkern/stationary/white.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, Optional, List import jax.numpy as jnp from jaxtyping import Array, Float @@ -6,12 +6,20 @@ from ..base import AbstractKernel from ..computations import ( ConstantDiagonalKernelComputation, + AbstractKernelComputation, ) class White(AbstractKernel, ConstantDiagonalKernelComputation): - def __post_init__(self) -> None: - super(White, self).__post_init__() + def __init__( + self, + compute_engine: AbstractKernelComputation = ConstantDiagonalKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "White Noise Kernel", + ) -> None: + super().__init__(compute_engine, active_dims, stationary, spectral, name) def __call__( self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] From 2ccf0f240909067004aff478aa2d99c8238ae44b Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Thu, 19 Jan 2023 18:15:12 +0000 Subject: [PATCH 8/8] Address White noise issues, import locations, testing and add disclaimers to files. --- jaxkern/__init__.py | 18 +++++++++++- jaxkern/base.py | 15 ++++++++++ jaxkern/computations/__init__.py | 15 ++++++++++ jaxkern/computations/base.py | 15 ++++++++++ jaxkern/computations/constant_diagonal.py | 35 +++++++++++++++++++++-- jaxkern/computations/dense.py | 15 ++++++++++ jaxkern/computations/diagonal.py | 32 ++++++++++++++++++++- jaxkern/computations/eigen.py | 15 ++++++++++ jaxkern/non_euclidean/__init__.py | 15 ++++++++++ jaxkern/non_euclidean/graph.py | 15 ++++++++++ jaxkern/non_euclidean/utils.py | 15 ++++++++++ jaxkern/nonstationary/__init__.py | 18 ++++++++++-- jaxkern/nonstationary/linear.py | 15 ++++++++++ jaxkern/nonstationary/polynomial.py | 15 ++++++++++ jaxkern/stationary/__init__.py | 17 +++++++++++ jaxkern/stationary/matern12.py | 15 ++++++++++ jaxkern/stationary/matern32.py | 15 ++++++++++ jaxkern/stationary/matern52.py | 15 ++++++++++ jaxkern/stationary/periodic.py | 15 ++++++++++ jaxkern/stationary/powered_exponential.py | 15 ++++++++++ jaxkern/stationary/rational_quadratic.py | 15 ++++++++++ jaxkern/stationary/rbf.py | 15 ++++++++++ jaxkern/stationary/utils.py | 15 ++++++++++ jaxkern/stationary/white.py | 19 ++++++++++-- tests/__init__.py | 17 ++++++++++- tests/test_base.py | 15 ++++++++++ tests/test_non_euclidean.py | 15 ++++++++++ tests/test_nonstationary.py | 6 ++-- tests/test_stationary.py | 7 +++-- tests/test_utils.py | 15 ++++++++++ 30 files changed, 469 insertions(+), 15 deletions(-) diff --git a/jaxkern/__init__.py b/jaxkern/__init__.py index 554ff0e..e3c0207 100644 --- a/jaxkern/__init__.py +++ b/jaxkern/__init__.py @@ -1,3 +1,18 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + """JaxKern.""" from .base import ProductKernel, SumKernel from .computations import ( @@ -6,7 +21,7 @@ DiagonalKernelComputation, EigenKernelComputation, ) -from .nonstationary import Linear, Polynomial, White +from .nonstationary import Linear, Polynomial from .stationary import ( RBF, Matern12, @@ -15,6 +30,7 @@ RationalQuadratic, Periodic, PoweredExponential, + White, ) from .non_euclidean import GraphKernel diff --git a/jaxkern/base.py b/jaxkern/base.py index 79a215d..20b2fbd 100644 --- a/jaxkern/base.py +++ b/jaxkern/base.py @@ -1,3 +1,18 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + import abc from typing import Callable, Dict, List, Optional, Sequence diff --git a/jaxkern/computations/__init__.py b/jaxkern/computations/__init__.py index 7bf1605..7d44edc 100644 --- a/jaxkern/computations/__init__.py +++ b/jaxkern/computations/__init__.py @@ -1,3 +1,18 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from .base import AbstractKernelComputation from .constant_diagonal import ConstantDiagonalKernelComputation from .dense import DenseKernelComputation diff --git a/jaxkern/computations/base.py b/jaxkern/computations/base.py index 9984778..4112665 100644 --- a/jaxkern/computations/base.py +++ b/jaxkern/computations/base.py @@ -1,3 +1,18 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + import abc from typing import Callable, Dict diff --git a/jaxkern/computations/constant_diagonal.py b/jaxkern/computations/constant_diagonal.py index f164565..7304571 100644 --- a/jaxkern/computations/constant_diagonal.py +++ b/jaxkern/computations/constant_diagonal.py @@ -1,4 +1,20 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from typing import Callable, Dict +import jax.numpy as jnp from jax import vmap from jaxlinop import ( @@ -38,7 +54,7 @@ def gram( value = self.kernel_fn(params, inputs[0], inputs[0]) - return ConstantDiagonalLinearOperator(value=value, size=inputs.shape[0]) + return ConstantDiagonalLinearOperator(value = jnp.atleast_1d(value), size=inputs.shape[0]) def diagonal( self, @@ -65,4 +81,19 @@ def diagonal( def cross_covariance( self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] ) -> Float[Array, "N M"]: - raise ValueError("Cross covariance not defined for constant diagonal kernels.") + """For a given kernel, compute the NxM covariance matrix on a pair of input + matrices of shape NxD and MxD. + + Args: + kernel (AbstractKernel): The kernel for which the Gram + matrix should be computed for. + params (Dict): The kernel's parameter set. + x (Float[Array,"N D"]): The input matrix. + y (Float[Array,"M D"]): The input matrix. + + Returns: + CovarianceOperator: The computed square Gram matrix. + """ + # TODO: This is currently a dense implementation. We should implement a sparse LinearOperator for non-square cross-covariance matrices. + cross_cov = vmap(lambda x: vmap(lambda y: self.kernel_fn(params, x, y))(y))(x) + return cross_cov diff --git a/jaxkern/computations/dense.py b/jaxkern/computations/dense.py index c1f1a7c..7fb74b5 100644 --- a/jaxkern/computations/dense.py +++ b/jaxkern/computations/dense.py @@ -1,3 +1,18 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from typing import Callable, Dict from jax import vmap diff --git a/jaxkern/computations/diagonal.py b/jaxkern/computations/diagonal.py index de83b7f..88cab6f 100644 --- a/jaxkern/computations/diagonal.py +++ b/jaxkern/computations/diagonal.py @@ -1,3 +1,18 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from typing import Callable, Dict from jax import vmap @@ -42,4 +57,19 @@ def gram( def cross_covariance( self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] ) -> Float[Array, "N M"]: - raise ValueError("Cross covariance not defined for diagonal kernels.") + """For a given kernel, compute the NxM covariance matrix on a pair of input + matrices of shape NxD and MxD. + + Args: + kernel (AbstractKernel): The kernel for which the Gram + matrix should be computed for. + params (Dict): The kernel's parameter set. + x (Float[Array,"N D"]): The input matrix. + y (Float[Array,"M D"]): The input matrix. + + Returns: + CovarianceOperator: The computed square Gram matrix. + """ + # TODO: This is currently a dense implementation. We should implement a sparse LinearOperator for non-square cross-covariance matrices. + cross_cov = vmap(lambda x: vmap(lambda y: self.kernel_fn(params, x, y))(y))(x) + return cross_cov diff --git a/jaxkern/computations/eigen.py b/jaxkern/computations/eigen.py index cb8d4f8..094fc8d 100644 --- a/jaxkern/computations/eigen.py +++ b/jaxkern/computations/eigen.py @@ -1,3 +1,18 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from typing import Callable, Dict import jax.numpy as jnp diff --git a/jaxkern/non_euclidean/__init__.py b/jaxkern/non_euclidean/__init__.py index f207764..c696f39 100644 --- a/jaxkern/non_euclidean/__init__.py +++ b/jaxkern/non_euclidean/__init__.py @@ -1,3 +1,18 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from .graph import GraphKernel __all__ = ["GraphKernel"] diff --git a/jaxkern/non_euclidean/graph.py b/jaxkern/non_euclidean/graph.py index 99f409b..cee9968 100644 --- a/jaxkern/non_euclidean/graph.py +++ b/jaxkern/non_euclidean/graph.py @@ -1,3 +1,18 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from typing import Dict, List, Optional import jax.numpy as jnp diff --git a/jaxkern/non_euclidean/utils.py b/jaxkern/non_euclidean/utils.py index a6ad0a3..b7c28e9 100644 --- a/jaxkern/non_euclidean/utils.py +++ b/jaxkern/non_euclidean/utils.py @@ -1,3 +1,18 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from jaxtyping import Num, Array, Int diff --git a/jaxkern/nonstationary/__init__.py b/jaxkern/nonstationary/__init__.py index 4304c1c..eeaa990 100644 --- a/jaxkern/nonstationary/__init__.py +++ b/jaxkern/nonstationary/__init__.py @@ -1,5 +1,19 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from .linear import Linear from .polynomial import Polynomial -from ..stationary.white import White -__all__ = ["Linear", "Polynomial", "White"] +__all__ = ["Linear", "Polynomial"] diff --git a/jaxkern/nonstationary/linear.py b/jaxkern/nonstationary/linear.py index 3795420..ed2270b 100644 --- a/jaxkern/nonstationary/linear.py +++ b/jaxkern/nonstationary/linear.py @@ -1,3 +1,18 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from typing import Dict, List, Optional import jax diff --git a/jaxkern/nonstationary/polynomial.py b/jaxkern/nonstationary/polynomial.py index 18210d9..3ca6241 100644 --- a/jaxkern/nonstationary/polynomial.py +++ b/jaxkern/nonstationary/polynomial.py @@ -1,3 +1,18 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from typing import Dict, List, Optional import jax.numpy as jnp diff --git a/jaxkern/stationary/__init__.py b/jaxkern/stationary/__init__.py index 6ba568c..fe48786 100644 --- a/jaxkern/stationary/__init__.py +++ b/jaxkern/stationary/__init__.py @@ -1,3 +1,18 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from .matern12 import Matern12 from .matern32 import Matern32 from .matern52 import Matern52 @@ -5,6 +20,7 @@ from .powered_exponential import PoweredExponential from .rational_quadratic import RationalQuadratic from .rbf import RBF +from .white import White __all__ = [ "Matern12", @@ -14,4 +30,5 @@ "PoweredExponential", "RationalQuadratic", "RBF", + "White", ] diff --git a/jaxkern/stationary/matern12.py b/jaxkern/stationary/matern12.py index ec99c05..7244e74 100644 --- a/jaxkern/stationary/matern12.py +++ b/jaxkern/stationary/matern12.py @@ -1,3 +1,18 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from typing import Dict, List, Optional import jax.numpy as jnp diff --git a/jaxkern/stationary/matern32.py b/jaxkern/stationary/matern32.py index bfc4aa3..d927fbf 100644 --- a/jaxkern/stationary/matern32.py +++ b/jaxkern/stationary/matern32.py @@ -1,3 +1,18 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from typing import Dict, List, Optional import jax.numpy as jnp diff --git a/jaxkern/stationary/matern52.py b/jaxkern/stationary/matern52.py index 73204cf..c130335 100644 --- a/jaxkern/stationary/matern52.py +++ b/jaxkern/stationary/matern52.py @@ -1,3 +1,18 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from typing import Dict, List, Optional import jax.numpy as jnp diff --git a/jaxkern/stationary/periodic.py b/jaxkern/stationary/periodic.py index 4637dc1..b09b08d 100644 --- a/jaxkern/stationary/periodic.py +++ b/jaxkern/stationary/periodic.py @@ -1,3 +1,18 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from typing import Dict, List, Optional import jax diff --git a/jaxkern/stationary/powered_exponential.py b/jaxkern/stationary/powered_exponential.py index 55295d3..d006e43 100644 --- a/jaxkern/stationary/powered_exponential.py +++ b/jaxkern/stationary/powered_exponential.py @@ -1,3 +1,18 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from typing import Dict, List, Optional import jax diff --git a/jaxkern/stationary/rational_quadratic.py b/jaxkern/stationary/rational_quadratic.py index dcd12f0..19f67b7 100644 --- a/jaxkern/stationary/rational_quadratic.py +++ b/jaxkern/stationary/rational_quadratic.py @@ -1,3 +1,18 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from typing import List, Optional import jax diff --git a/jaxkern/stationary/rbf.py b/jaxkern/stationary/rbf.py index 6e72258..70f8acd 100644 --- a/jaxkern/stationary/rbf.py +++ b/jaxkern/stationary/rbf.py @@ -1,3 +1,18 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from typing import Dict, List, Optional import jax diff --git a/jaxkern/stationary/utils.py b/jaxkern/stationary/utils.py index 8d2a8c0..a9a0804 100644 --- a/jaxkern/stationary/utils.py +++ b/jaxkern/stationary/utils.py @@ -1,3 +1,18 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + import jax.numpy as jnp from jaxtyping import Array, Float diff --git a/jaxkern/stationary/white.py b/jaxkern/stationary/white.py index a338f14..9c7e8ed 100644 --- a/jaxkern/stationary/white.py +++ b/jaxkern/stationary/white.py @@ -1,3 +1,18 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from typing import Dict, Optional, List import jax.numpy as jnp @@ -10,12 +25,12 @@ ) -class White(AbstractKernel, ConstantDiagonalKernelComputation): +class White(AbstractKernel): def __init__( self, compute_engine: AbstractKernelComputation = ConstantDiagonalKernelComputation, active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, + stationary: Optional[bool] = True, spectral: Optional[bool] = False, name: Optional[str] = "White Noise Kernel", ) -> None: diff --git a/tests/__init__.py b/tests/__init__.py index 35f311b..ae57fb6 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1,16 @@ -"""Test suite for the jaxkern package.""" +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Test suite for the jaxkern package.""" \ No newline at end of file diff --git a/tests/test_base.py b/tests/test_base.py index 5e8c67e..d543c6c 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1,3 +1,18 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + import jax.numpy as jnp import jax.random as jr import pytest diff --git a/tests/test_non_euclidean.py b/tests/test_non_euclidean.py index cc3df37..8f893b3 100644 --- a/tests/test_non_euclidean.py +++ b/tests/test_non_euclidean.py @@ -1,3 +1,18 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + import jax.numpy as jnp import jax.random as jr import networkx as nx diff --git a/tests/test_nonstationary.py b/tests/test_nonstationary.py index 833089b..efcae4b 100644 --- a/tests/test_nonstationary.py +++ b/tests/test_nonstationary.py @@ -1,4 +1,4 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,6 @@ # limitations under the License. # ============================================================================== - from itertools import permutations import jax.numpy as jnp @@ -24,7 +23,7 @@ from jaxutils.parameters import initialise from jaxkern.base import AbstractKernel -from jaxkern.nonstationary import Linear, Polynomial, White +from jaxkern.nonstationary import Linear, Polynomial # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) @@ -112,7 +111,6 @@ def test_pos_def( [ Linear, Polynomial, - White, ], ) def test_dtype(kernel: AbstractKernel) -> None: diff --git a/tests/test_stationary.py b/tests/test_stationary.py index d5b0788..2f1a07e 100644 --- a/tests/test_stationary.py +++ b/tests/test_stationary.py @@ -1,4 +1,4 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,6 +33,7 @@ PoweredExponential, RationalQuadratic, Periodic, + White, ) # Enable Float64 for more stable matrix inversions. @@ -49,6 +50,7 @@ Matern32(), Matern52(), RationalQuadratic(), + White(), ], ) @pytest.mark.parametrize("dim", [1, 2, 5]) @@ -78,6 +80,7 @@ def test_gram(kernel: AbstractKernel, dim: int, n: int) -> None: Matern32(), Matern52(), RationalQuadratic(), + White(), ], ) @pytest.mark.parametrize("num_a", [1, 2, 5]) @@ -99,7 +102,7 @@ def test_cross_covariance( assert Kab.shape == (num_a, num_b) -@pytest.mark.parametrize("kernel", [RBF(), Matern12(), Matern32(), Matern52()]) +@pytest.mark.parametrize("kernel", [RBF(), Matern12(), Matern32(), Matern52(), White()]) @pytest.mark.parametrize("dim", [1, 2, 5]) def test_call(kernel: AbstractKernel, dim: int) -> None: diff --git a/tests/test_utils.py b/tests/test_utils.py index 4765368..fe564cf 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,3 +1,18 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from typing import List import jax.numpy as jnp