Skip to content
This repository has been archived by the owner on May 11, 2023. It is now read-only.

Commit

Permalink
Resolve comments
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaspinder committed Jan 19, 2023
1 parent 56b4e31 commit 1478ab7
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 8 deletions.
22 changes: 19 additions & 3 deletions jaxkern/non_euclidean/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,26 @@
# Graph kernels
##########################################
class GraphKernel(AbstractKernel):
"""A Matérn graph kernel defined on the vertices of a graph. The key reference for this object is borovitskiy et. al., (2020)."""

def __init__(
self,
laplacian: Float[Array, "N N"],
compute_engine: EigenKernelComputation = EigenKernelComputation,
active_dims: Optional[List[int]] = None,
stationary: Optional[bool] = False,
spectral: Optional[bool] = False,
name: Optional[str] = "Graph kernel",
name: Optional[str] = "Matérn Graph kernel",
) -> None:
super().__init__(compute_engine, active_dims, stationary, spectral, name)
"""Initialize a Matérn graph kernel.
Args:
laplacian (Float[Array]): An N x N matrix representing the Laplacian matrix of a graph.
compute_engine (EigenKernelComputation, optional): The compute engine that should be used in the kernel to compute covariance matrices. Defaults to EigenKernelComputation.
active_dims (Optional[List[int]], optional): The dimensions of the input data for which the kernel should be evaluated on. Defaults to None.
stationary (Optional[bool], optional): _description_. Defaults to False.
name (Optional[str], optional): _description_. Defaults to "Graph kernel".
"""
super().__init__(compute_engine, active_dims, True, spectral, name)
self.laplacian = laplacian
evals, self.evecs = jnp.linalg.eigh(self.laplacian)
self.evals = evals.reshape(-1, 1)
Expand Down Expand Up @@ -53,6 +63,7 @@ def __call__(
return Kxx.squeeze()

def init_params(self, key: KeyArray) -> Dict:
"""Initialise the lengthscale, variance and smoothness parameters of the kernel"""
return {
"lengthscale": jnp.array([1.0] * self.ndims),
"variance": jnp.array([1.0]),
Expand All @@ -61,4 +72,9 @@ def init_params(self, key: KeyArray) -> Dict:

@property
def num_vertex(self) -> int:
"""The number of vertices within the graph.
Returns:
int: An integer representing the number of vertices within the graph.
"""
return self.compute_engine.num_vertex
16 changes: 15 additions & 1 deletion jaxkern/non_euclidean/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
def jax_gather_nd(params, indices):
from jaxtyping import Num, Array, Int


def jax_gather_nd(
params: Num[Array, "N ..."], indices: Int[Array, "M"]
) -> Num[Array, "M ..."]:
"""Slice a `params` array at a set of `indices`.
Args:
params (Num[Array]): An arbitrary array with leading axes of length `N` upon which we shall slice.
indices (Float[Int]): An integer array of length M with values in the range [0, N) whose value at index `i` will be used to slice `params` at index `i`.
Returns:
Num[Array: An arbitrary array with leading axes of length `M`.
"""
tuple_indices = tuple(indices[..., i] for i in range(indices.shape[-1]))
return params[tuple_indices]
2 changes: 1 addition & 1 deletion jaxkern/nonstationary/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .linear import Linear
from .polynomial import Polynomial
from .white import White
from ..stationary.white import White

__all__ = ["Linear", "Polynomial", "White"]
14 changes: 11 additions & 3 deletions jaxkern/nonstationary/white.py → jaxkern/stationary/white.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
from typing import Dict
from typing import Dict, Optional, List

import jax.numpy as jnp
from jaxtyping import Array, Float

from ..base import AbstractKernel
from ..computations import (
ConstantDiagonalKernelComputation,
AbstractKernelComputation,
)


class White(AbstractKernel, ConstantDiagonalKernelComputation):
def __post_init__(self) -> None:
super(White, self).__post_init__()
def __init__(
self,
compute_engine: AbstractKernelComputation = ConstantDiagonalKernelComputation,
active_dims: Optional[List[int]] = None,
stationary: Optional[bool] = False,
spectral: Optional[bool] = False,
name: Optional[str] = "White Noise Kernel",
) -> None:
super().__init__(compute_engine, active_dims, stationary, spectral, name)

def __call__(
self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"]
Expand Down

0 comments on commit 1478ab7

Please sign in to comment.