From 76ddd5a64db4a7eddd99c5d1bd3a92823b2ccb23 Mon Sep 17 00:00:00 2001 From: Thomas Morris Date: Wed, 31 Jul 2024 12:38:36 -0400 Subject: [PATCH] make objective models read-only --- src/blop/agent.py | 14 +++++++------- src/blop/objectives.py | 4 ++++ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/blop/agent.py b/src/blop/agent.py index 93f8bfd..4626794 100644 --- a/src/blop/agent.py +++ b/src/blop/agent.py @@ -248,7 +248,7 @@ def ask(self, acqf="qei", n=1, route=True, sequential=True, upsample=1, **acqf_k else: # check that all the objectives have models - if not all(hasattr(obj, "model") for obj in active_objs): + if not all(hasattr(obj, "_model") for obj in active_objs): raise RuntimeError( f"Can't construct non-trivial acquisition function '{acqf}' as the agent is not initialized." ) @@ -367,7 +367,7 @@ def tell( for obj in objectives_to_model: t0 = ttime.monotonic() - cached_hypers = obj.model.state_dict() if hasattr(obj, "model") else None + cached_hypers = obj.model.state_dict() if hasattr(obj, "_model") else None n_before_tell = obj.n_valid self._construct_model(obj) n_after_tell = obj.n_valid @@ -538,8 +538,8 @@ def reset(self): self._table = pd.DataFrame() for obj in self.objectives(active=True): - if hasattr(obj, "model"): - del obj.model + if hasattr(obj, "_model"): + del obj._model self.n_last_trained = 0 @@ -573,7 +573,7 @@ def benchmark( def model(self): """A model encompassing all the fitnesses and constraints.""" active_objs = self.objectives(active=True) - if all(hasattr(obj, "model") for obj in active_objs): + if all(hasattr(obj, "_model") for obj in active_objs): return ModelListGP(*[obj.model for obj in active_objs]) if len(active_objs) > 1 else active_objs[0].model raise ValueError("Not all active objectives have models.") @@ -689,7 +689,7 @@ def _construct_model(self, obj, skew_dims=None): trusted = inputs_are_trusted & targets_are_trusted - obj.model = construct_single_task_model( + obj._model = construct_single_task_model( X=train_inputs[trusted], y=train_targets[trusted], min_noise=obj.min_noise, @@ -731,7 +731,7 @@ def _train_all_models(self, **kwargs): t0 = ttime.monotonic() objectives_to_train = self.objectives if self.model_inactive_objectives else self.objectives(active=True) for obj in objectives_to_train: - train_model(obj.model) + train_model(obj._model) if obj.validity_conjugate_model is not None: train_model(obj.validity_conjugate_model) diff --git a/src/blop/objectives.py b/src/blop/objectives.py index fb783bf..54c8bd1 100644 --- a/src/blop/objectives.py +++ b/src/blop/objectives.py @@ -254,6 +254,10 @@ def fitness_prediction(self, X): if isinstance(self.target, tuple): return self.targeting_constraint(X).log().clamp(min=-16) + @property + def model(self): + return self._model.eval() + class ObjectiveList(Sequence): def __init__(self, objectives: list = []):