diff --git a/src/blop/agent.py b/src/blop/agent.py index c6413e6..8f32fb9 100644 --- a/src/blop/agent.py +++ b/src/blop/agent.py @@ -72,6 +72,7 @@ def __init__( digestion_kwargs: dict = {}, verbose: bool = False, enforce_all_objectives_valid: bool = True, + model_inactive_objectives: bool = False, tolerate_acquisition_errors: bool = False, sample_center_on_init: bool = False, trigger_delay: float = 0, @@ -137,8 +138,8 @@ def __init__( self.verbose = verbose + self.model_inactive_objectives = model_inactive_objectives self.tolerate_acquisition_errors = tolerate_acquisition_errors - self.enforce_all_objectives_valid = enforce_all_objectives_valid self.train_every = train_every @@ -361,7 +362,8 @@ def tell( self._table.index = np.arange(len(self._table)) if update_models: - for obj in self.objectives(active=True): + objectives_to_model = self.objectives if self.model_inactive_objectives else self.objectives(active=True) + for obj in objectives_to_model: t0 = ttime.monotonic() cached_hypers = obj.model.state_dict() if hasattr(obj, "model") else None @@ -737,13 +739,15 @@ def _construct_model(self, obj, skew_dims=None): def _construct_all_models(self): """Construct a model for each objective.""" - for obj in self.objectives(active=True): + objectives_to_construct = self.objectives if self.model_inactive_objectives else self.objectives(active=True) + for obj in objectives_to_construct: self._construct_model(obj) def _train_all_models(self, **kwargs): """Fit all of the agent's models. All kwargs are passed to `botorch.fit.fit_gpytorch_mll`.""" t0 = ttime.monotonic() - for obj in self.objectives(active=True): + objectives_to_train = self.objectives if self.model_inactive_objectives else self.objectives(active=True) + for obj in objectives_to_train: self._train_model(obj.model) if obj.validity_conjugate_model is not None: self._train_model(obj.validity_conjugate_model)