Skip to content

Commit

Permalink
add option to train models for inactive objectives
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Morris committed Jul 8, 2024
1 parent 4e24db4 commit e2f2b92
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/blop/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(
digestion_kwargs: dict = {},
verbose: bool = False,
enforce_all_objectives_valid: bool = True,
model_inactive_objectives: bool = False,
tolerate_acquisition_errors: bool = False,
sample_center_on_init: bool = False,
trigger_delay: float = 0,
Expand Down Expand Up @@ -137,8 +138,8 @@ def __init__(

self.verbose = verbose

self.model_inactive_objectives = model_inactive_objectives
self.tolerate_acquisition_errors = tolerate_acquisition_errors

self.enforce_all_objectives_valid = enforce_all_objectives_valid

self.train_every = train_every
Expand Down Expand Up @@ -361,7 +362,8 @@ def tell(
self._table.index = np.arange(len(self._table))

if update_models:
for obj in self.objectives(active=True):
objectives_to_model = self.objectives if self.model_inactive_objectives else self.objectives(active=True)
for obj in objectives_to_model:
t0 = ttime.monotonic()

cached_hypers = obj.model.state_dict() if hasattr(obj, "model") else None
Expand Down Expand Up @@ -737,13 +739,15 @@ def _construct_model(self, obj, skew_dims=None):

def _construct_all_models(self):
"""Construct a model for each objective."""
for obj in self.objectives(active=True):
objectives_to_construct = self.objectives if self.model_inactive_objectives else self.objectives(active=True)
for obj in objectives_to_construct:
self._construct_model(obj)

def _train_all_models(self, **kwargs):
"""Fit all of the agent's models. All kwargs are passed to `botorch.fit.fit_gpytorch_mll`."""
t0 = ttime.monotonic()
for obj in self.objectives(active=True):
objectives_to_train = self.objectives if self.model_inactive_objectives else self.objectives(active=True)
for obj in objectives_to_train:
self._train_model(obj.model)
if obj.validity_conjugate_model is not None:
self._train_model(obj.validity_conjugate_model)
Expand Down

0 comments on commit e2f2b92

Please sign in to comment.