-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1224975
commit bdc66ad
Showing
15 changed files
with
686 additions
and
440 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,275 +1,148 @@ | ||
import math | ||
|
||
import gpytorch | ||
import numpy as np | ||
import torch | ||
|
||
|
||
class MultiOutputLatentKernel(gpytorch.kernels.Kernel): | ||
class LatentKernel(gpytorch.kernels.Kernel): | ||
is_stationary = True | ||
|
||
num_outputs = 1 | ||
|
||
def __init__( | ||
self, | ||
num_inputs=1, | ||
num_outputs=1, | ||
off_diag=False, | ||
diag_prior=False, | ||
off_diag=True, | ||
diag_prior=True, | ||
scale_kernel=True, | ||
**kwargs, | ||
): | ||
super(MultiOutputLatentKernel, self).__init__() | ||
super(LatentKernel, self).__init__() | ||
|
||
self.num_inputs = num_inputs | ||
self.num_outputs = num_outputs | ||
self.n_off_diag = int(num_inputs * (num_inputs - 1) / 2) | ||
|
||
self.off_diag = off_diag | ||
self.scale_kernel = scale_kernel | ||
|
||
self.nu = kwargs.get("nu", 1.5) | ||
|
||
# self.batch_shape = torch.Size([num_outputs]) | ||
# kernel_scale_constraint = gpytorch.constraints.Positive() | ||
diag_entries_constraint = gpytorch.constraints.Positive() # gpytorch.constraints.Interval(5e-1, 1e2) | ||
skew_entries_constraint = gpytorch.constraints.Interval(-1e0, 1e0) | ||
|
||
# output_scale_constraint = gpytorch.constraints.Positive() | ||
diag_params_constraint = gpytorch.constraints.Interval(1e0, 1e2) | ||
skew_params_constraint = gpytorch.constraints.Interval(-1e0, 1e0) | ||
|
||
diag_params_initial = np.sqrt(diag_params_constraint.lower_bound * diag_params_constraint.upper_bound) | ||
raw_diag_params_initial = diag_params_constraint.inverse_transform(diag_params_initial) | ||
# diag_entries_initial = np.ones() | ||
# np.sqrt(diag_entries_constraint.lower_bound * diag_entries_constraint.upper_bound) | ||
raw_diag_entries_initial = diag_entries_constraint.inverse_transform(torch.tensor(2)) | ||
|
||
self.register_parameter( | ||
name="raw_diag_params", | ||
name="raw_diag_entries", | ||
parameter=torch.nn.Parameter( | ||
raw_diag_params_initial * torch.ones(self.num_outputs, self.num_inputs).double() | ||
raw_diag_entries_initial * torch.ones(self.num_outputs, self.num_inputs).double() | ||
), | ||
) | ||
|
||
# self.register_constraint("raw_output_scale", output_scale_constraint) | ||
self.register_constraint("raw_diag_params", diag_params_constraint) | ||
self.register_constraint("raw_diag_entries", constraint=diag_entries_constraint) | ||
|
||
if diag_prior: | ||
self.register_prior( | ||
name="diag_params_prior", | ||
prior=gpytorch.priors.GammaPrior(concentration=0.5, rate=0.2), | ||
param_or_closure=lambda m: m.diag_params, | ||
setting_closure=lambda m, v: m._set_diag_params(v), | ||
name="diag_entries_prior", | ||
prior=gpytorch.priors.GammaPrior(concentration=2, rate=1), | ||
param_or_closure=lambda m: m.diag_entries, | ||
setting_closure=lambda m, v: m._set_diag_entries(v), | ||
) | ||
|
||
if self.off_diag: | ||
self.register_parameter( | ||
name="raw_skew_params", | ||
name="raw_skew_entries", | ||
parameter=torch.nn.Parameter(torch.zeros(self.num_outputs, self.n_off_diag).double()), | ||
) | ||
self.register_constraint("raw_skew_params", skew_params_constraint) | ||
|
||
@property | ||
def diag_params(self): | ||
return self.raw_diag_params_constraint.transform(self.raw_diag_params) | ||
|
||
@property | ||
def skew_params(self): | ||
return self.raw_skew_params_constraint.transform(self.raw_skew_params) | ||
|
||
@diag_params.setter | ||
def diag_params(self, value): | ||
self._set_diag_params(value) | ||
|
||
@skew_params.setter | ||
def skew_params(self, value): | ||
self._set_skew_params(value) | ||
|
||
def _set_skew_params(self, value): | ||
if not torch.is_tensor(value): | ||
value = torch.as_tensor(value).to(self.raw_skew_params) | ||
self.initialize(raw_skew_params=self.raw_skew_params_constraint.inverse_transform(value)) | ||
|
||
def _set_diag_params(self, value): | ||
if not torch.is_tensor(value): | ||
value = torch.as_tensor(value).to(self.raw_diag_params) | ||
self.initialize(raw_diag_params=self.raw_diag_params_constraint.inverse_transform(value)) | ||
|
||
@property | ||
def dimension_transform(self): | ||
# no rotations | ||
if not self.off_diag: | ||
T = torch.eye(self.num_inputs, dtype=torch.float64) | ||
|
||
# construct an orthogonal matrix. fun fact: exp(skew(N)) is the generator of SO(N) | ||
else: | ||
A = torch.zeros((self.num_outputs, self.num_inputs, self.num_inputs), dtype=torch.float64) | ||
upper_indices = np.triu_indices(self.num_inputs, k=1) | ||
for output_index in range(self.num_outputs): | ||
A[(output_index, *upper_indices)] = self.skew_params[output_index] | ||
A += -A.transpose(-1, -2) | ||
T = torch.linalg.matrix_exp(A) | ||
|
||
diagonal_transform = torch.cat([torch.diag(_values).unsqueeze(0) for _values in self.diag_params], dim=0) | ||
T = torch.matmul(diagonal_transform, T) | ||
|
||
return T | ||
|
||
def forward(self, x1, x2, diag=False, **params): | ||
# adapted from the Matern kernel | ||
mean = x1.reshape(-1, x1.size(-1)).mean(0)[(None,) * (x1.dim() - 1)] | ||
|
||
trans_x1 = torch.matmul(self.dimension_transform.unsqueeze(1), (x1 - mean).unsqueeze(-1)).squeeze(-1) | ||
trans_x2 = torch.matmul(self.dimension_transform.unsqueeze(1), (x2 - mean).unsqueeze(-1)).squeeze(-1) | ||
|
||
distance = self.covar_dist(trans_x1, trans_x2, diag=diag, **params) | ||
|
||
# if distance.shape[0] == 1: | ||
# distance = distance.squeeze(0) # this is extremely necessary | ||
|
||
exp_component = torch.exp(-math.sqrt(self.nu * 2) * distance) | ||
|
||
if self.nu == 0.5: | ||
constant_component = 1 | ||
elif self.nu == 1.5: | ||
constant_component = (math.sqrt(3) * distance).add(1) | ||
elif self.nu == 2.5: | ||
constant_component = (math.sqrt(5) * distance).add(1).add(5.0 / 3.0 * distance**2) | ||
self.register_constraint("raw_skew_entries", skew_entries_constraint) | ||
|
||
return constant_component * exp_component | ||
if self.scale_kernel: | ||
kernel_scale_constraint = gpytorch.constraints.Positive() | ||
kernel_scale_prior = gpytorch.priors.GammaPrior(concentration=2, rate=0.15) | ||
|
||
self.register_parameter( | ||
name="raw_kernel_scale", | ||
parameter=torch.nn.Parameter(torch.ones(1)), | ||
) | ||
|
||
class LatentMaternKernel(gpytorch.kernels.Kernel): | ||
def __init__( | ||
self, | ||
n_dim, | ||
off_diag=False, | ||
diagonal_prior=False, | ||
**kwargs, | ||
): | ||
super(LatentMaternKernel, self).__init__() | ||
|
||
self.n_dim = n_dim | ||
self.n_off_diag = int(n_dim * (n_dim - 1) / 2) | ||
self.off_diag = off_diag | ||
|
||
# output_scale_constraint = gpytorch.constraints.Positive() | ||
diag_params_constraint = gpytorch.constraints.Interval(1e-1, 1e2) | ||
skew_params_constraint = gpytorch.constraints.Interval(-1e0, 1e0) | ||
|
||
diag_params_initial = np.sqrt(diag_params_constraint.lower_bound * diag_params_constraint.upper_bound) | ||
raw_diag_params_initial = diag_params_constraint.inverse_transform(diag_params_initial) | ||
|
||
# self.register_parameter( | ||
# name="raw_output_scale", parameter=torch.nn.Parameter(torch.ones(*self.batch_shape, 1).double()) | ||
# ) | ||
self.register_parameter( | ||
name="raw_diag_params", | ||
parameter=torch.nn.Parameter( | ||
raw_diag_params_initial * torch.ones(*self.batch_shape, self.n_dim).double() | ||
), | ||
) | ||
|
||
# self.register_constraint("raw_output_scale", output_scale_constraint) | ||
self.register_constraint("raw_diag_params", diag_params_constraint) | ||
self.register_constraint("raw_kernel_scale", constraint=kernel_scale_constraint) | ||
|
||
if diagonal_prior: | ||
self.register_prior( | ||
name="diag_params_prior", | ||
prior=gpytorch.priors.GammaPrior(concentration=0.5, rate=0.2), | ||
param_or_closure=lambda m: m.diag_params, | ||
setting_closure=lambda m, v: m._set_diag_params(v), | ||
) | ||
|
||
if self.off_diag: | ||
self.register_parameter( | ||
name="raw_skew_params", | ||
parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, self.n_off_diag).double()), | ||
name="kernel_scale_prior", | ||
prior=kernel_scale_prior, | ||
param_or_closure=lambda m: m.kernel_scale, | ||
setting_closure=lambda m, v: m._set_kernel_scale(v), | ||
) | ||
self.register_constraint("raw_skew_params", skew_params_constraint) | ||
|
||
# @property | ||
# def output_scale(self): | ||
# return self.raw_output_scale_constraint.transform(self.raw_output_scale) | ||
@property | ||
def diag_entries(self): | ||
return self.raw_diag_entries_constraint.transform(self.raw_diag_entries) | ||
|
||
@property | ||
def diag_params(self): | ||
return self.raw_diag_params_constraint.transform(self.raw_diag_params) | ||
def skew_entries(self): | ||
return self.raw_skew_entries_constraint.transform(self.raw_skew_entries) | ||
|
||
@property | ||
def skew_params(self): | ||
return self.raw_skew_params_constraint.transform(self.raw_skew_params) | ||
def kernel_scale(self): | ||
return self.raw_kernel_scale_constraint.transform(self.raw_kernel_scale) | ||
|
||
# @output_scale.setter | ||
# def output_scale(self, value): | ||
# self._set_output_scale(value) | ||
@diag_entries.setter | ||
def diag_entries(self, value): | ||
self._set_diag_entries(value) | ||
|
||
@diag_params.setter | ||
def diag_params(self, value): | ||
self._set_diag_params(value) | ||
@skew_entries.setter | ||
def skew_entries(self, value): | ||
self._set_skew_entries(value) | ||
|
||
@skew_params.setter | ||
def skew_params(self, value): | ||
self._set_skew_params(value) | ||
@kernel_scale.setter | ||
def kernel_scale(self, value): | ||
self._set_kernel_scale(value) | ||
|
||
def _set_skew_params(self, value): | ||
def _set_diag_entries(self, value): | ||
if not torch.is_tensor(value): | ||
value = torch.as_tensor(value).to(self.raw_skew_params) | ||
self.initialize(raw_skew_params=self.raw_skew_params_constraint.inverse_transform(value)) | ||
value = torch.as_tensor(value).to(self.raw_diag_entries) | ||
self.initialize(raw_diag_entries=self.raw_diag_entries_constraint.inverse_transform(value)) | ||
|
||
# def _set_output_scale(self, value): | ||
# if not torch.is_tensor(value): | ||
# value = torch.as_tensor(value).to(self.raw_output_scale) | ||
# self.initialize(raw_output_scale=self.raw_output_scale_constraint.inverse_transform(value)) | ||
def _set_skew_entries(self, value): | ||
if not torch.is_tensor(value): | ||
value = torch.as_tensor(value).to(self.raw_skew_entries) | ||
self.initialize(raw_skew_entries=self.raw_skew_entries_constraint.inverse_transform(value)) | ||
|
||
def _set_diag_params(self, value): | ||
def _set_kernel_scale(self, value): | ||
if not torch.is_tensor(value): | ||
value = torch.as_tensor(value).to(self.raw_diag_params) | ||
self.initialize(raw_diag_params=self.raw_diag_params_constraint.inverse_transform(value)) | ||
value = torch.as_tensor(value).to(self.raw_kernel_scale) | ||
self.initialize(raw_kernel_scale=self.raw_kernel_scale_constraint.inverse_transform(value)) | ||
|
||
@property | ||
def trans_matrix(self): | ||
def output_scale(self): | ||
return self.kernel_scale.sqrt() | ||
|
||
@property | ||
def latent_dimensions(self): | ||
# no rotations | ||
if not self.off_diag: | ||
T = torch.eye(self.n_dim).double() | ||
T = torch.eye(self.num_inputs, dtype=torch.float64) | ||
|
||
# construct an orthogonal matrix. fun fact: exp(skew(N)) is the generator of SO(N) | ||
else: | ||
A = torch.zeros((self.n_dim, self.n_dim)).double() | ||
A[np.triu_indices(self.n_dim, k=1)] = self.skew_params | ||
A += -A.T | ||
A = torch.zeros((self.num_inputs, self.num_inputs)).double() | ||
A[np.triu_indices(self.num_inputs, k=1)] = self.skew_entries | ||
A += -A.transpose(-1, -2) | ||
T = torch.linalg.matrix_exp(A) | ||
|
||
T = torch.matmul(torch.diag(self.diag_params), T) | ||
diagonal_transform = torch.cat([torch.diag(entries).unsqueeze(0) for entries in self.diag_entries], dim=0) | ||
T = torch.matmul(diagonal_transform, T) | ||
|
||
return T | ||
|
||
def forward(self, x1, x2=None, diag=False, auto=False, last_dim_is_batch=False, **params): | ||
# returns the homoskedastic diagonal | ||
if diag: | ||
# return torch.square(self.output_scale[0]) * torch.ones((*self.batch_shape, *x1.shape[:-1])) | ||
return torch.ones((*self.batch_shape, *x1.shape[:-1])) | ||
|
||
# computes the autocovariance of the process at the parameters | ||
if auto: | ||
x2 = x1 | ||
|
||
# print(x1, x2) | ||
|
||
# x1 and x2 are arrays of shape (..., n_1, n_dim) and (..., n_2, n_dim) | ||
_x1, _x2 = torch.as_tensor(x1).double(), torch.as_tensor(x2).double() | ||
|
||
# dx has shape (..., n_1, n_2, n_dim) | ||
dx = _x1.unsqueeze(-2) - _x2.unsqueeze(-3) | ||
|
||
# transform coordinates with hyperparameters (this applies lengthscale and rotations) | ||
trans_dx = torch.matmul(self.trans_matrix, dx.unsqueeze(-1)) | ||
|
||
# total transformed distance. D has shape (..., n_1, n_2) | ||
d_eff = torch.sqrt(torch.matmul(trans_dx.transpose(-1, -2), trans_dx).sum((-1, -2)) + 1e-12) | ||
|
||
# Matern covariance of effective order nu=3/2. | ||
# nu=3/2 is a special case and has a concise closed-form expression | ||
# In general, this is something between an exponential (n=1/2) and a Gaussian (n=infinity) | ||
# https://en.wikipedia.org/wiki/Matern_covariance_function | ||
|
||
# C = torch.exp(-d_eff) # Matern_0.5 (exponential) | ||
C = (1 + d_eff) * torch.exp(-d_eff) # Matern_1.5 | ||
# C = (1 + d_eff + 1 / 3 * torch.square(d_eff)) * torch.exp(-d_eff) # Matern_2.5 | ||
# C = torch.exp(-0.5 * np.square(d_eff)) # Matern_infinity (RBF) | ||
def forward(self, x1, x2, diag=False, **params): | ||
# adapted from the Matern kernel | ||
mean = x1.reshape(-1, x1.size(-1)).mean(0)[(None,) * (x1.dim() - 1)] | ||
|
||
# C = torch.square(self.output_scale[0]) * torch.exp(-torch.square(d_eff)) | ||
trans_x1 = torch.matmul(self.latent_dimensions.unsqueeze(1), (x1 - mean).unsqueeze(-1)).squeeze(-1) | ||
trans_x2 = torch.matmul(self.latent_dimensions.unsqueeze(1), (x2 - mean).unsqueeze(-1)).squeeze(-1) | ||
|
||
# print(f'{diag = } {C.shape = }') | ||
distance = self.covar_dist(trans_x1, trans_x2, diag=diag, **params) | ||
|
||
return C | ||
return self.kernel_scale * (1 + distance) * torch.exp(-distance) |
Oops, something went wrong.