Skip to content
This repository has been archived by the owner on May 11, 2023. It is now read-only.

Commit

Permalink
Address White noise issues, import locations, testing and add disclai…
Browse files Browse the repository at this point in the history
…mers to files.
  • Loading branch information
daniel-dodd committed Jan 19, 2023
1 parent 1478ab7 commit 2ccf0f2
Show file tree
Hide file tree
Showing 30 changed files with 469 additions and 15 deletions.
18 changes: 17 additions & 1 deletion jaxkern/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""JaxKern."""
from .base import ProductKernel, SumKernel
from .computations import (
Expand All @@ -6,7 +21,7 @@
DiagonalKernelComputation,
EigenKernelComputation,
)
from .nonstationary import Linear, Polynomial, White
from .nonstationary import Linear, Polynomial
from .stationary import (
RBF,
Matern12,
Expand All @@ -15,6 +30,7 @@
RationalQuadratic,
Periodic,
PoweredExponential,
White,
)
from .non_euclidean import GraphKernel

Expand Down
15 changes: 15 additions & 0 deletions jaxkern/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import abc
from typing import Callable, Dict, List, Optional, Sequence

Expand Down
15 changes: 15 additions & 0 deletions jaxkern/computations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from .base import AbstractKernelComputation
from .constant_diagonal import ConstantDiagonalKernelComputation
from .dense import DenseKernelComputation
Expand Down
15 changes: 15 additions & 0 deletions jaxkern/computations/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import abc
from typing import Callable, Dict

Expand Down
35 changes: 33 additions & 2 deletions jaxkern/computations/constant_diagonal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,20 @@
# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Callable, Dict
import jax.numpy as jnp

from jax import vmap
from jaxlinop import (
Expand Down Expand Up @@ -38,7 +54,7 @@ def gram(

value = self.kernel_fn(params, inputs[0], inputs[0])

return ConstantDiagonalLinearOperator(value=value, size=inputs.shape[0])
return ConstantDiagonalLinearOperator(value = jnp.atleast_1d(value), size=inputs.shape[0])

def diagonal(
self,
Expand All @@ -65,4 +81,19 @@ def diagonal(
def cross_covariance(
self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"]
) -> Float[Array, "N M"]:
raise ValueError("Cross covariance not defined for constant diagonal kernels.")
"""For a given kernel, compute the NxM covariance matrix on a pair of input
matrices of shape NxD and MxD.
Args:
kernel (AbstractKernel): The kernel for which the Gram
matrix should be computed for.
params (Dict): The kernel's parameter set.
x (Float[Array,"N D"]): The input matrix.
y (Float[Array,"M D"]): The input matrix.
Returns:
CovarianceOperator: The computed square Gram matrix.
"""
# TODO: This is currently a dense implementation. We should implement a sparse LinearOperator for non-square cross-covariance matrices.
cross_cov = vmap(lambda x: vmap(lambda y: self.kernel_fn(params, x, y))(y))(x)
return cross_cov
15 changes: 15 additions & 0 deletions jaxkern/computations/dense.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Callable, Dict

from jax import vmap
Expand Down
32 changes: 31 additions & 1 deletion jaxkern/computations/diagonal.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Callable, Dict

from jax import vmap
Expand Down Expand Up @@ -42,4 +57,19 @@ def gram(
def cross_covariance(
self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"]
) -> Float[Array, "N M"]:
raise ValueError("Cross covariance not defined for diagonal kernels.")
"""For a given kernel, compute the NxM covariance matrix on a pair of input
matrices of shape NxD and MxD.
Args:
kernel (AbstractKernel): The kernel for which the Gram
matrix should be computed for.
params (Dict): The kernel's parameter set.
x (Float[Array,"N D"]): The input matrix.
y (Float[Array,"M D"]): The input matrix.
Returns:
CovarianceOperator: The computed square Gram matrix.
"""
# TODO: This is currently a dense implementation. We should implement a sparse LinearOperator for non-square cross-covariance matrices.
cross_cov = vmap(lambda x: vmap(lambda y: self.kernel_fn(params, x, y))(y))(x)
return cross_cov
15 changes: 15 additions & 0 deletions jaxkern/computations/eigen.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Callable, Dict

import jax.numpy as jnp
Expand Down
15 changes: 15 additions & 0 deletions jaxkern/non_euclidean/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from .graph import GraphKernel

__all__ = ["GraphKernel"]
15 changes: 15 additions & 0 deletions jaxkern/non_euclidean/graph.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Dict, List, Optional

import jax.numpy as jnp
Expand Down
15 changes: 15 additions & 0 deletions jaxkern/non_euclidean/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from jaxtyping import Num, Array, Int


Expand Down
18 changes: 16 additions & 2 deletions jaxkern/nonstationary/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from .linear import Linear
from .polynomial import Polynomial
from ..stationary.white import White

__all__ = ["Linear", "Polynomial", "White"]
__all__ = ["Linear", "Polynomial"]
15 changes: 15 additions & 0 deletions jaxkern/nonstationary/linear.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Dict, List, Optional

import jax
Expand Down
15 changes: 15 additions & 0 deletions jaxkern/nonstationary/polynomial.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Dict, List, Optional

import jax.numpy as jnp
Expand Down
17 changes: 17 additions & 0 deletions jaxkern/stationary/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,26 @@
# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from .matern12 import Matern12
from .matern32 import Matern32
from .matern52 import Matern52
from .periodic import Periodic
from .powered_exponential import PoweredExponential
from .rational_quadratic import RationalQuadratic
from .rbf import RBF
from .white import White

__all__ = [
"Matern12",
Expand All @@ -14,4 +30,5 @@
"PoweredExponential",
"RationalQuadratic",
"RBF",
"White",
]
15 changes: 15 additions & 0 deletions jaxkern/stationary/matern12.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Dict, List, Optional

import jax.numpy as jnp
Expand Down
Loading

0 comments on commit 2ccf0f2

Please sign in to comment.