Skip to content

Commit

Permalink
Merge pull request NSLS-II#41 from thomaswmorris/reuse-hypers
Browse files Browse the repository at this point in the history
Reuse hyperparameters
  • Loading branch information
mrakitin authored Aug 8, 2023
2 parents a48ec10 + 82db973 commit be84a7b
Show file tree
Hide file tree
Showing 20 changed files with 592 additions and 360 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ repos:
rev: 23.1.0
hooks:
- id: black
language_version: python3.10
language_version: python3
- id: black-jupyter
language_version: python3.10
language_version: python3
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
hooks:
Expand Down
458 changes: 306 additions & 152 deletions bloptools/bayesian/__init__.py

Large diffs are not rendered by default.

149 changes: 89 additions & 60 deletions bloptools/bayesian/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,73 @@

class LatentKernel(gpytorch.kernels.Kernel):
is_stationary = True

num_outputs = 1
batch_inverse_lengthscale = 1e6

def __init__(
self,
num_inputs=1,
off_diag=True,
skew_dims=True,
diag_prior=True,
scale_kernel=True,
scale_output=True,
**kwargs,
):
super(LatentKernel, self).__init__()

self.num_inputs = num_inputs
self.n_off_diag = int(num_inputs * (num_inputs - 1) / 2)

self.off_diag = off_diag
self.scale_kernel = scale_kernel
self.scale_output = scale_output

self.nu = kwargs.get("nu", 1.5)
self.batch_dimension = kwargs.get("batch_dimension", None)

if type(skew_dims) is bool:
if skew_dims:
self.skew_dims = [torch.arange(self.num_inputs)]
else:
self.skew_dims = [torch.arange(0)]
elif hasattr(skew_dims, "__iter__"):
self.skew_dims = [torch.tensor(np.atleast_1d(skew_group)) for skew_group in skew_dims]
else:
raise ValueError('arg "skew_dims" must be True, False, or an iterable of tuples of ints.')

# if not all([len(skew_group) >= 2 for skew_group in self.skew_dims]):
# raise ValueError("must have at least two dims per skew group")
skewed_dims = [dim for skew_group in self.skew_dims for dim in skew_group]
if not len(set(skewed_dims)) == len(skewed_dims):
raise ValueError("values in skew_dims must be unique")
if not max(skewed_dims) < self.num_inputs:
raise ValueError("invalud dimension index in skew_dims")

skew_group_submatrix_indices = []
for dim in range(self.num_outputs):
for skew_group in self.skew_dims:
j, k = skew_group[torch.triu_indices(len(skew_group), len(skew_group), 1)].unsqueeze(1)
i = dim * torch.ones(j.shape).long()
skew_group_submatrix_indices.append(torch.cat((i, j, k), dim=0))

self.diag_matrix_indices = tuple(
[
torch.kron(torch.arange(self.num_outputs), torch.ones(self.num_inputs)).long(),
*2 * [torch.arange(self.num_inputs).repeat(self.num_outputs)],
]
)

# kernel_scale_constraint = gpytorch.constraints.Positive()
diag_entries_constraint = gpytorch.constraints.Positive() # gpytorch.constraints.Interval(5e-1, 1e2)
skew_entries_constraint = gpytorch.constraints.Interval(-1e0, 1e0)
self.skew_matrix_indices = (
tuple(torch.cat(skew_group_submatrix_indices, dim=1))
if len(skew_group_submatrix_indices) > 0
else tuple([[], []])
)

# diag_entries_initial = np.ones()
# np.sqrt(diag_entries_constraint.lower_bound * diag_entries_constraint.upper_bound)
raw_diag_entries_initial = diag_entries_constraint.inverse_transform(torch.tensor(2))
self.n_skew_entries = len(self.skew_matrix_indices[0])

self.register_parameter(
name="raw_diag_entries",
parameter=torch.nn.Parameter(raw_diag_entries_initial * torch.ones(self.num_outputs, self.num_inputs).double()),
diag_entries_constraint = gpytorch.constraints.Positive()
raw_diag_entries_initial = (
diag_entries_constraint.inverse_transform(torch.tensor(1e-1))
* torch.ones(self.num_outputs, self.num_inputs).double()
)
self.register_constraint("raw_diag_entries", constraint=diag_entries_constraint)

self.register_parameter(name="raw_diag_entries", parameter=torch.nn.Parameter(raw_diag_entries_initial))
self.register_constraint(param_name="raw_diag_entries", constraint=diag_entries_constraint)

if diag_prior:
self.register_prior(
Expand All @@ -48,29 +81,28 @@ def __init__(
setting_closure=lambda m, v: m._set_diag_entries(v),
)

if self.off_diag:
self.register_parameter(
name="raw_skew_entries",
parameter=torch.nn.Parameter(torch.zeros(self.num_outputs, self.n_off_diag).double()),
)
self.register_constraint("raw_skew_entries", skew_entries_constraint)
if self.n_skew_entries > 0:
skew_entries_constraint = gpytorch.constraints.Interval(-1e0, 1e0)
skew_entries_initial = torch.zeros((self.num_outputs, self.n_skew_entries), dtype=torch.float64)
self.register_parameter(name="raw_skew_entries", parameter=torch.nn.Parameter(skew_entries_initial))
self.register_constraint(param_name="raw_skew_entries", constraint=skew_entries_constraint)

if self.scale_kernel:
kernel_scale_constraint = gpytorch.constraints.Positive()
kernel_scale_prior = gpytorch.priors.GammaPrior(concentration=2, rate=0.15)
if self.scale_output:
outputscale_constraint = gpytorch.constraints.Positive()
outputscale_prior = gpytorch.priors.GammaPrior(concentration=2, rate=0.15)

self.register_parameter(
name="raw_kernel_scale",
name="raw_outputscale",
parameter=torch.nn.Parameter(torch.ones(1)),
)

self.register_constraint("raw_kernel_scale", constraint=kernel_scale_constraint)
self.register_constraint("raw_outputscale", constraint=outputscale_constraint)

self.register_prior(
name="kernel_scale_prior",
prior=kernel_scale_prior,
param_or_closure=lambda m: m.kernel_scale,
setting_closure=lambda m, v: m._set_kernel_scale(v),
name="outputscale_prior",
prior=outputscale_prior,
param_or_closure=lambda m: m.outputscale,
setting_closure=lambda m, v: m._set_outputscale(v),
)

@property
Expand All @@ -82,8 +114,8 @@ def skew_entries(self):
return self.raw_skew_entries_constraint.transform(self.raw_skew_entries)

@property
def kernel_scale(self):
return self.raw_kernel_scale_constraint.transform(self.raw_kernel_scale)
def outputscale(self):
return self.raw_outputscale_constraint.transform(self.raw_outputscale)

@diag_entries.setter
def diag_entries(self, value):
Expand All @@ -93,9 +125,9 @@ def diag_entries(self, value):
def skew_entries(self, value):
self._set_skew_entries(value)

@kernel_scale.setter
def kernel_scale(self, value):
self._set_kernel_scale(value)
@outputscale.setter
def outputscale(self, value):
self._set_outputscale(value)

def _set_diag_entries(self, value):
if not torch.is_tensor(value):
Expand All @@ -107,40 +139,37 @@ def _set_skew_entries(self, value):
value = torch.as_tensor(value).to(self.raw_skew_entries)
self.initialize(raw_skew_entries=self.raw_skew_entries_constraint.inverse_transform(value))

def _set_kernel_scale(self, value):
def _set_outputscale(self, value):
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.raw_kernel_scale)
self.initialize(raw_kernel_scale=self.raw_kernel_scale_constraint.inverse_transform(value))
value = torch.as_tensor(value).to(self.raw_outputscale)
self.initialize(raw_outputscale=self.raw_outputscale_constraint.inverse_transform(value))

@property
def output_scale(self):
return self.kernel_scale.sqrt()
def skew_matrix(self):
S = torch.zeros((self.num_outputs, self.num_inputs, self.num_inputs), dtype=torch.float64)
if self.n_skew_entries > 0:
# to construct an orthogonal matrix. fun fact: exp(skew(N)) is the generator of SO(N)
S[self.skew_matrix_indices] = self.skew_entries
S += -S.transpose(-1, -2)
return torch.linalg.matrix_exp(S)

@property
def latent_dimensions(self):
# no rotations
if not self.off_diag:
T = torch.eye(self.num_inputs, dtype=torch.float64)

# construct an orthogonal matrix. fun fact: exp(skew(N)) is the generator of SO(N)
else:
A = torch.zeros((self.num_inputs, self.num_inputs)).double()
A[np.triu_indices(self.num_inputs, k=1)] = self.skew_entries
A += -A.transpose(-1, -2)
T = torch.linalg.matrix_exp(A)

diagonal_transform = torch.cat([torch.diag(entries).unsqueeze(0) for entries in self.diag_entries], dim=0)
T = torch.matmul(diagonal_transform, T)
def diag_matrix(self):
D = torch.zeros((self.num_outputs, self.num_inputs, self.num_inputs), dtype=torch.float64)
D[self.diag_matrix_indices] = self.diag_entries.ravel()
return D

return T
@property
def latent_transform(self):
return torch.matmul(self.diag_matrix, self.skew_matrix)

def forward(self, x1, x2, diag=False, **params):
# adapted from the Matern kernel
mean = x1.reshape(-1, x1.size(-1)).mean(0)[(None,) * (x1.dim() - 1)]

trans_x1 = torch.matmul(self.latent_dimensions.unsqueeze(1), (x1 - mean).unsqueeze(-1)).squeeze(-1)
trans_x2 = torch.matmul(self.latent_dimensions.unsqueeze(1), (x2 - mean).unsqueeze(-1)).squeeze(-1)
trans_x1 = torch.matmul(self.latent_transform.unsqueeze(1), (x1 - mean).unsqueeze(-1)).squeeze(-1)
trans_x2 = torch.matmul(self.latent_transform.unsqueeze(1), (x2 - mean).unsqueeze(-1)).squeeze(-1)

distance = self.covar_dist(trans_x1, trans_x2, diag=diag, **params)

return self.kernel_scale * (1 + distance) * torch.exp(-distance)
return (self.outputscale if self.scale_output else 1.0) * (1 + distance) * torch.exp(-distance)
82 changes: 12 additions & 70 deletions bloptools/bayesian/models.py
Original file line number Diff line number Diff line change
@@ -1,99 +1,41 @@
import botorch
import gpytorch
import torch
from botorch.models.gpytorch import GPyTorchModel
from gpytorch.models import ExactGP

from . import kernels


class LatentDirichletClassifier(botorch.models.gp_regression.SingleTaskGP):
def __init__(self, train_inputs, train_targets, *args, **kwargs):
class LatentGP(botorch.models.gp_regression.SingleTaskGP):
def __init__(self, train_inputs, train_targets, skew_dims=True, *args, **kwargs):
super().__init__(train_inputs, train_targets, *args, **kwargs)

self.mean_module = gpytorch.means.ConstantMean()
self.mean_module = gpytorch.means.ConstantMean(constant_prior=gpytorch.priors.NormalPrior(loc=0, scale=1))

self.covar_module = kernels.LatentKernel(
num_inputs=train_inputs.shape[-1],
num_outputs=train_targets.shape[-1],
off_diag=True,
skew_dims=skew_dims,
diag_prior=True,
scale_output=True,
scale=True,
**kwargs
)

def log_prob(self, x, n_samples=256):
*input_shape, n_dim = x.shape
samples = self.posterior(x.reshape(-1, n_dim)).sample(torch.Size((n_samples,))).exp()
return torch.log((samples / samples.sum(-1, keepdim=True)).mean(0)[:, 1]).reshape(*input_shape, 1)


class LatentGP(botorch.models.gp_regression.SingleTaskGP):
def __init__(self, train_inputs, train_targets, *args, **kwargs):
class LatentDirichletClassifier(botorch.models.gp_regression.SingleTaskGP):
def __init__(self, train_inputs, train_targets, skew_dims=True, *args, **kwargs):
super().__init__(train_inputs, train_targets, *args, **kwargs)

self.mean_module = gpytorch.means.ConstantMean(constant_prior=gpytorch.priors.NormalPrior(loc=0, scale=1))

self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = kernels.LatentKernel(
num_inputs=train_inputs.shape[-1],
num_outputs=train_targets.shape[-1],
off_diag=True,
skew_dims=skew_dims,
diag_prior=True,
scale_output=True,
scale=True,
**kwargs
)


class OldBoTorchSingleTaskGP(ExactGP, GPyTorchModel):
def __init__(self, train_inputs, train_targets, likelihood):
super(OldBoTorchSingleTaskGP, self).__init__(train_inputs, train_targets, likelihood)
self.mean_module = gpytorch.means.ConstantMean()
self.covar_module = gpytorch.kernels.ScaleKernel(
kernels.LatentMaternKernel(n_dim=train_inputs.shape[-1], off_diag=True, diagonal_prior=True)
)

def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)

return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


class BoTorchMultiTaskGP(ExactGP, GPyTorchModel):
_num_outputs = 1 # to inform GPyTorchModel API

def __init__(self, train_inputs, train_targets, likelihood):
self._num_outputs = train_targets.shape[-1]

super(BoTorchMultiTaskGP, self).__init__(train_inputs, train_targets, likelihood)
self.mean_module = gpytorch.means.MultitaskMean(gpytorch.means.ConstantMean(), num_tasks=self._num_outputs)
self.covar_module = gpytorch.kernels.MultitaskKernel(
kernels.LatentMaternKernel(n_dim=train_inputs.shape[-1], off_diag=True, diagonal_prior=True),
num_tasks=self._num_outputs,
rank=1,
)

def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultitaskMultivariateNormal(mean_x, covar_x)


class OldBoTorchDirichletClassifier(gpytorch.models.ExactGP, botorch.models.gpytorch.GPyTorchModel):
_num_outputs = 1 # to inform GPyTorchModel API

def __init__(self, train_inputs, train_targets, likelihood):
super(OldBoTorchDirichletClassifier, self).__init__(train_inputs, train_targets, likelihood)
self.mean_module = gpytorch.means.ConstantMean(batch_shape=len(train_targets.unique()))
self.covar_module = gpytorch.kernels.ScaleKernel(
kernels.LatentMaternKernel(n_dim=train_inputs.shape[-1], off_diag=False, diagonal_prior=False)
)

def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

def log_prob(self, x, n_samples=256):
*input_shape, n_dim = x.shape
samples = self.posterior(x.reshape(-1, n_dim)).sample(torch.Size((n_samples,))).exp()
return torch.log((samples / samples.sum(-3, keepdim=True)).mean(0)[1]).reshape(*input_shape)
return torch.log((samples / samples.sum(-1, keepdim=True)).mean(0)[:, 1]).reshape(*input_shape, 1)
24 changes: 23 additions & 1 deletion bloptools/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
from ophyd import Device, Signal, SignalRO


def dummy_dof(name):
return Signal(name=name, value=0.0)


def dummy_dofs(n=2):
return [Signal(name=f"x{i+1}", value=0) for i in range(n)]
return [dummy_dof(name=f"x{i+1}") for i in range(n)]


def get_dummy_device(name="dofs", n=2):
Expand All @@ -22,8 +26,26 @@ def get_dummy_device(name="dofs", n=2):


class TimeReadback(SignalRO):
"""
Returns the current timestamp.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def get(self):
return ttime.time()


class ConstantReadback(SignalRO):
"""
Returns a constant every time you read it (more useful than you'd think).
"""

def __init__(self, constant=1, *args, **kwargs):
super().__init__(*args, **kwargs)

self.constant = constant

def get(self):
return self.constant
Loading

0 comments on commit be84a7b

Please sign in to comment.