Skip to content

Commit

Permalink
construct models independently of the agent
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Morris committed Jul 16, 2024
1 parent e2f2b92 commit 6adc90e
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 29 deletions.
41 changes: 12 additions & 29 deletions src/blop/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from . import plotting, utils
from .bayesian import acquisition, models
from .bayesian.acquisition import _construct_acqf, parse_acqf_identifier
from .bayesian.models import construct_single_task_model, train_model

# from .bayesian.transforms import TargetingPosteriorTransform
from .digestion import default_digestion_function
Expand Down Expand Up @@ -256,7 +257,7 @@ def ask(self, acqf="qei", n=1, route=True, sequential=True, upsample=1, **acqf_k
for obj in active_objs:
if obj.model_dofs != set(active_dofs.names):
self._construct_model(obj)
self._train_model(obj.model)
train_model(obj.model)

if acqf_config["type"] == "analytic" and n > 1:
raise ValueError("Can't generate multiple design points for analytic acquisition functions.")
Expand Down Expand Up @@ -377,12 +378,12 @@ def tell(
if len(obj.model.train_targets) >= 4:
if train:
t0 = ttime.monotonic()
self._train_model(obj.model)
train_model(obj.model)
if self.verbose:
print(f"trained model '{obj.name}' in {1e3*(ttime.monotonic() - t0):.00f} ms")

else:
self._train_model(obj.model, hypers=cached_hypers)
train_model(obj.model, hypers=cached_hypers)

def learn(
self,
Expand Down Expand Up @@ -673,30 +674,13 @@ def all_objectives_valid(self):
"""A mask of whether all objectives are valid for each data point."""
return ~torch.isnan(self.scalarized_fitnesses())

def _train_model(self, model, hypers=None, **kwargs):
"""Fit all of the agent's models. All kwargs are passed to `botorch.fit.fit_gpytorch_mll`."""
if hypers is not None:
model.load_state_dict(hypers)
else:
botorch.fit.fit_gpytorch_mll(gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model), **kwargs)
model.trained = True

def _construct_model(self, obj, skew_dims=None):
"""
Construct an untrained model for an objective.
"""

skew_dims = skew_dims if skew_dims is not None else self._latent_dim_tuples(obj.name)

likelihood = gpytorch.likelihoods.GaussianLikelihood(
noise_constraint=gpytorch.constraints.Interval(
torch.tensor(obj.min_noise),
torch.tensor(obj.max_noise),
),
)

outcome_transform = botorch.models.transforms.outcome.Standardize(m=1) # , batch_shape=torch.Size((1,)))

train_inputs = self.train_inputs(active=True)
train_targets = self.train_targets()[obj.name].unsqueeze(-1)

Expand All @@ -705,13 +689,12 @@ def _construct_model(self, obj, skew_dims=None):

trusted = inputs_are_trusted & targets_are_trusted

obj.model = models.LatentGP(
train_inputs=train_inputs[trusted],
train_targets=train_targets[trusted],
likelihood=likelihood,
skew_dims=skew_dims,
input_transform=self.input_normalization,
outcome_transform=outcome_transform,
obj.model = construct_single_task_model(
X=train_inputs[trusted],
y=train_targets[trusted],
min_noise=obj.min_noise,
max_noise=obj.max_noise,
skew_dims=self._latent_dim_tuples()[obj.name],
)

obj.model_dofs = set(self.dofs(active=True).names) # if these change, retrain the model on self.ask()
Expand Down Expand Up @@ -748,9 +731,9 @@ def _train_all_models(self, **kwargs):
t0 = ttime.monotonic()
objectives_to_train = self.objectives if self.model_inactive_objectives else self.objectives(active=True)
for obj in objectives_to_train:
self._train_model(obj.model)
train_model(obj.model)
if obj.validity_conjugate_model is not None:
self._train_model(obj.validity_conjugate_model)
train_model(obj.validity_conjugate_model)

if self.verbose:
print(f"trained models in {ttime.monotonic() - t0:.01f} seconds")
Expand Down
40 changes: 40 additions & 0 deletions src/blop/bayesian/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,50 @@
import botorch
import gpytorch
import torch
from botorch.models.gp_regression import SingleTaskGP

from . import kernels


def train_model(model, hypers=None, **kwargs):
"""Fit all of the agent's models. All kwargs are passed to `botorch.fit.fit_gpytorch_mll`."""
if hypers is not None:
model.load_state_dict(hypers)
else:
botorch.fit.fit_gpytorch_mll(gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model), **kwargs)
model.trained = True


def construct_single_task_model(X, y, skew_dims=None, min_noise=1e-6, max_noise=1e0):
"""
Construct an untrained model for an objective.
"""

likelihood = gpytorch.likelihoods.GaussianLikelihood(
noise_constraint=gpytorch.constraints.Interval(
torch.tensor(min_noise),
torch.tensor(max_noise),
),
)

input_transform = botorch.models.transforms.input.Normalize(d=X.shape[-1])
outcome_transform = botorch.models.transforms.outcome.Standardize(m=1) # , batch_shape=torch.Size((1,)))

if not X.isfinite().all():
raise ValueError("'X' must not contain points that are inf or NaN.")
if not y.isfinite().all():
raise ValueError("'y' must not contain points that are inf or NaN.")

return LatentGP(
train_inputs=X,
train_targets=y,
likelihood=likelihood,
skew_dims=skew_dims,
input_transform=input_transform,
outcome_transform=outcome_transform,
)


class LatentGP(SingleTaskGP):
def __init__(self, train_inputs, train_targets, skew_dims=True, *args, **kwargs):
super().__init__(train_inputs, train_targets, *args, **kwargs)
Expand Down

0 comments on commit 6adc90e

Please sign in to comment.