From e2f2b92b1d37c2762b5033b1e8338ac2b7fc6a55 Mon Sep 17 00:00:00 2001 From: Thomas Morris Date: Mon, 8 Jul 2024 13:45:07 -0400 Subject: [PATCH] add option to train models for inactive objectives --- src/blop/agent.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/blop/agent.py b/src/blop/agent.py index c6413e6..8f32fb9 100644 --- a/src/blop/agent.py +++ b/src/blop/agent.py @@ -72,6 +72,7 @@ def __init__( digestion_kwargs: dict = {}, verbose: bool = False, enforce_all_objectives_valid: bool = True, + model_inactive_objectives: bool = False, tolerate_acquisition_errors: bool = False, sample_center_on_init: bool = False, trigger_delay: float = 0, @@ -137,8 +138,8 @@ def __init__( self.verbose = verbose + self.model_inactive_objectives = model_inactive_objectives self.tolerate_acquisition_errors = tolerate_acquisition_errors - self.enforce_all_objectives_valid = enforce_all_objectives_valid self.train_every = train_every @@ -361,7 +362,8 @@ def tell( self._table.index = np.arange(len(self._table)) if update_models: - for obj in self.objectives(active=True): + objectives_to_model = self.objectives if self.model_inactive_objectives else self.objectives(active=True) + for obj in objectives_to_model: t0 = ttime.monotonic() cached_hypers = obj.model.state_dict() if hasattr(obj, "model") else None @@ -737,13 +739,15 @@ def _construct_model(self, obj, skew_dims=None): def _construct_all_models(self): """Construct a model for each objective.""" - for obj in self.objectives(active=True): + objectives_to_construct = self.objectives if self.model_inactive_objectives else self.objectives(active=True) + for obj in objectives_to_construct: self._construct_model(obj) def _train_all_models(self, **kwargs): """Fit all of the agent's models. All kwargs are passed to `botorch.fit.fit_gpytorch_mll`.""" t0 = ttime.monotonic() - for obj in self.objectives(active=True): + objectives_to_train = self.objectives if self.model_inactive_objectives else self.objectives(active=True) + for obj in objectives_to_train: self._train_model(obj.model) if obj.validity_conjugate_model is not None: self._train_model(obj.validity_conjugate_model)