Skip to content

Commit

Permalink
work at xpd
Browse files Browse the repository at this point in the history
  • Loading branch information
XPD Operator committed Oct 31, 2023
1 parent b89848e commit 15fa731
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
23 changes: 15 additions & 8 deletions bloptools/bayesian/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,17 @@ def __init__(

self.table = pd.DataFrame()

self.initialized = False
self._train_models = True
self.a_priori_hypers = None

self.plots = {"objectives": {}}


@property
def has_models(self):
return all([hasattr(obj, "model") for obj in self.objectives])


def tell(self, new_table=None, append=True, train=True, **kwargs):
"""
Inform the agent about new inputs and targets for the model.
Expand All @@ -112,9 +117,12 @@ def tell(self, new_table=None, append=True, train=True, **kwargs):
self.table = pd.concat([self.table, new_table]) if append else new_table
self.table.index = np.arange(len(self.table))

if len(self.table) < 2:
return

skew_dims = self.latent_dim_tuples

if self.initialized:
if self.has_models:
cached_hypers = self.hypers

inputs = self.table.loc[:, self.dofs.subset(active=True).device_names].values.astype(float)
Expand Down Expand Up @@ -173,7 +181,7 @@ def tell(self, new_table=None, append=True, train=True, **kwargs):
try:
self.train_models()
except botorch.exceptions.errors.ModelFittingError:
if self.initialized:
if self.has_models:
self._set_hypers(cached_hypers)
else:
raise RuntimeError("Could not fit model on initialization!")
Expand All @@ -197,7 +205,7 @@ def ask(self, acq_func_identifier="qei", n=1, route=True, sequential=True, **acq
print(f'finding points with acquisition function "{acq_func_name}" ...')

if acq_func_type in ["analytic", "monte_carlo"]:
if not self.initialized:
if not self.has_models:
raise RuntimeError(
f'Can\'t construct non-trivial acquisition function "{acq_func_identifier}"'
f" (the agent is not initialized!)"
Expand Down Expand Up @@ -330,7 +338,7 @@ def learn(
else:
self.tell(new_table=data)

if self.sample_center_on_init and not self.initialized:
if self.sample_center_on_init and not self.has_models:
new_table = yield from self.acquire(self.dofs.subset(active=True, read_only=False).limits.mean(axis=1))
new_table.loc[:, "acq_func"] = "sample_center_on_init"
self.tell(new_table=new_table, train=False)
Expand All @@ -344,8 +352,6 @@ def learn(
new_table.loc[:, "acq_func"] = acq_func_meta["name"]
self.tell(new_table=new_table, train=train)

self.initialized = True

def get_acquisition_function(self, acq_func_identifier, return_metadata=False):
return acquisition.get_acquisition_function(
self, acq_func_identifier=acq_func_identifier, return_metadata=return_metadata
Expand All @@ -356,7 +362,8 @@ def reset(self):
Reset the agent.
"""
self.table = pd.DataFrame()
self.initialized = False
for obj in self.objectives:
del obj.model

def benchmark(
self, output_dir="./", runs=16, n_init=64, learning_kwargs_list=[{"acq_func": "qei", "n": 4, "iterations": 16}]
Expand Down
2 changes: 1 addition & 1 deletion bloptools/bayesian/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _validate_objectives(objectives):
class Objective:
key: str
name: str = None
target: float | str = "max"
target: Union[float, str] = "max"
log: bool = False
weight: numeric = 1.0
limits: Tuple[numeric, numeric] = None
Expand Down

0 comments on commit 15fa731

Please sign in to comment.