Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/nsls-ii/blop into prune
Browse files Browse the repository at this point in the history
  • Loading branch information
megha-narayanan committed Jul 31, 2024
2 parents 233575f + 76ddd5a commit 7734f59
Show file tree
Hide file tree
Showing 2 changed files with 363 additions and 363 deletions.
10 changes: 5 additions & 5 deletions src/blop/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def ask(self, acqf="qei", n=1, route=True, sequential=True, upsample=1, **acqf_k

else:
# check that all the objectives have models
if not all(hasattr(obj, "model") for obj in active_objs):
if not all(hasattr(obj, "_model") for obj in active_objs):
raise RuntimeError(
f"Can't construct non-trivial acquisition function '{acqf}' as the agent is not initialized."
)
Expand Down Expand Up @@ -369,7 +369,7 @@ def tell(
for obj in objectives_to_model:
t0 = ttime.monotonic()

cached_hypers = obj.model.state_dict() if hasattr(obj, "model") else None
cached_hypers = obj.model.state_dict() if hasattr(obj, "_model") else None
n_before_tell = obj.n_valid
self._construct_model(obj)
n_after_tell = obj.n_valid
Expand Down Expand Up @@ -540,7 +540,7 @@ def reset(self):
self._table = pd.DataFrame()

for obj in self.objectives(active=True):
if hasattr(obj, "model"):
if hasattr(obj, "_model"):
del obj._model

self.n_last_trained = 0
Expand Down Expand Up @@ -575,7 +575,7 @@ def benchmark(
def model(self):
"""A model encompassing all the fitnesses and constraints."""
active_objs = self.objectives(active=True)
if all(hasattr(obj, "model") for obj in active_objs):
if all(hasattr(obj, "_model") for obj in active_objs):
return ModelListGP(*[obj.model for obj in active_objs]) if len(active_objs) > 1 else active_objs[0].model
raise ValueError("Not all active objectives have models.")

Expand Down Expand Up @@ -691,7 +691,7 @@ def _construct_model(self, obj, skew_dims=None):

trusted = inputs_are_trusted & targets_are_trusted & ~self.pruned_mask()

obj.model = construct_single_task_model(
obj._model = construct_single_task_model(
X=train_inputs[trusted],
y=train_targets[trusted],
min_noise=obj.min_noise,
Expand Down
Loading

0 comments on commit 7734f59

Please sign in to comment.