Skip to content

Commit

Permalink
better test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Morris committed Apr 24, 2024
1 parent e0f312c commit a43f7f3
Show file tree
Hide file tree
Showing 13 changed files with 236 additions and 385 deletions.
45 changes: 28 additions & 17 deletions src/blop/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def ask(self, acqf="qei", n=1, route=True, sequential=True, upsample=1, **acqf_k
# and is in the transformed model space
candidates = self.dofs(active=True).untransform(candidates).numpy()

p = self.posterior(candidates) if hasattr(self, "model") else None
# p = self.posterior(candidates) if hasattr(self, "model") else None

active_dofs = self.dofs(active=True)

Expand Down Expand Up @@ -304,7 +304,7 @@ def ask(self, acqf="qei", n=1, route=True, sequential=True, upsample=1, **acqf_k
"sequential": sequential,
"upsample": upsample,
"read_only_values": read_only_values,
"posterior": p,
# "posterior": p,
}

return res
Expand Down Expand Up @@ -426,9 +426,11 @@ def learn(
new_table = yield from self.acquire(res["points"])
new_table.loc[:, "acqf"] = res["acqf_name"]

x = {key: new_table.pop(key).tolist() for key in self.dofs.names}
y = {key: new_table.pop(key).tolist() for key in self.objectives.names}
metadata = new_table.to_dict(orient="list")
x = {key: new_table.loc[:, key].tolist() for key in self.dofs.names}
y = {key: new_table.loc[:, key].tolist() for key in self.objectives.names}
metadata = {
key: new_table.loc[:, key].tolist() for key in new_table.columns if (key not in x) and (key not in y)
}
self.tell(x=x, y=y, metadata=metadata, append=append, train=train)

def view(self, item: str = "mean", cmap: str = "turbo", max_inputs: int = 2**16):
Expand Down Expand Up @@ -510,12 +512,10 @@ def acquire(self, acquisition_inputs):

return products

def load_data(self, data_file, append=True, train=True):
def load_data(self, data_file, append=True):
new_table = pd.read_hdf(data_file, key="table")
x = {key: new_table.pop(key).tolist() for key in self.dofs.names}
y = {key: new_table.pop(key).tolist() for key in self.objectives.names}
metadata = new_table.to_dict(orient="list")
self.tell(x=x, y=y, metadata=metadata, append=append, train=train)
self.table = pd.concat([self.table, new_table]) if append else new_table
self.refresh()

def reset(self):
"""Reset the agent."""
Expand Down Expand Up @@ -557,7 +557,9 @@ def benchmark(
def model(self):
"""A model encompassing all the fitnesses and constraints."""
active_objs = self.objectives(active=True)
return ModelListGP(*[obj.model for obj in active_objs]) if len(active_objs) > 1 else active_objs[0].model
if all(hasattr(obj, "model") for obj in active_objs):
return ModelListGP(*[obj.model for obj in active_objs]) if len(active_objs) > 1 else active_objs[0].model
raise ValueError("Not all active objectives have models.")

def posterior(self, x):
"""A model encompassing all the objectives. A single GP in the single-objective case, or a model list."""
Expand All @@ -567,7 +569,7 @@ def posterior(self, x):
def fitness_model(self):
active_fitness_models = self.objectives(active=True, kind="fitness")
if len(active_fitness_models) == 0:
raise ValueError("Having no fitness objectives is unhandled.")
return GenericDeterministicModel(f=lambda x: torch.ones(x.shape[:-1]).unsqueeze(-1))
if len(active_fitness_models) == 1:
return active_fitness_models[0].model
return ModelListGP(*[obj.model for obj in active_fitness_models])
Expand All @@ -594,12 +596,21 @@ def fitness_scalarization(self, weights="default"):
return ScalarizedPosteriorTransform(weights=weights)

def scalarized_fitnesses(self, weights="default", constrained=True):
f = self.fitness_scalarization(weights=weights).evaluate(self.train_targets(active=True, kind="fitness"))
"""
Return the scalar fitness for each sample, scalarized by the weighting scheme.
If constrained=True, the points that satisfy the most constraints are automatically better than the others.
"""
fitness_objs = self.objectives(kind="fitness")
if len(fitness_objs) >= 1:
f = self.fitness_scalarization(weights=weights).evaluate(self.train_targets(active=True, kind="fitness"))
else:
f = torch.zeros(len(self.table), dtype=torch.double)
if constrained:
c = self.evaluated_constraints.all(axis=-1)
if not c.sum():
raise ValueError("There are no valid points that satisfy the constraints!")
return torch.where(c, f, -np.inf)
# how many constraints are satisfied?
c = self.evaluated_constraints.sum(axis=-1)
f = torch.where(c < c.max(), -np.inf, f)
return f

def argmax_best_f(self, weights="default"):
return int(self.scalarized_fitnesses(weights=weights, constrained=True).argmax())
Expand Down
33 changes: 4 additions & 29 deletions src/blop/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
DEFAULT_MAX_NOISE_LEVEL = 1e0

OBJ_FIELD_TYPES = {
"name": "str",
"description": "object",
"type": "str",
"target": "object",
Expand All @@ -35,14 +36,6 @@ class DuplicateNameError(ValueError):
...


def _validate_objs(objs):
names = [obj.name for obj in objs]
unique_names, counts = np.unique(names, return_counts=True)
duplicate_names = unique_names[counts > 1]
if len(duplicate_names) > 0:
raise DuplicateNameError(f"Duplicate name(s) in supplied objectives: {duplicate_names}")


domains = {"log"}


Expand Down Expand Up @@ -238,22 +231,6 @@ def targeting_constraint(self, x: torch.Tensor) -> torch.Tensor:
0.5 * (approximate_erf((b - m) / (np.sqrt(2) * sish)) - approximate_erf((a - m) / (np.sqrt(2) * sish)))[..., -1]
)

# def fitness_forward(self, y):
# f = y
# if self.log:
# f = np.log(f)
# if self.target == "min":
# f = -f
# return f

def fitness_inverse(self, f):
y = f
if self.target == "min":
y = -y
if self.log:
y = np.exp(y)
return y

@property
def is_fitness(self):
return self.target in ["min", "max"]
Expand All @@ -279,7 +256,6 @@ def fitness_prediction(self, X):

class ObjectiveList(Sequence):
def __init__(self, objectives: list = []):
_validate_objs(objectives)
self.objectives = objectives

def __call__(self, *args, **kwargs):
Expand Down Expand Up @@ -319,11 +295,11 @@ def __len__(self):

@property
def summary(self) -> pd.DataFrame:
table = pd.DataFrame(columns=list(OBJ_FIELD_TYPES.keys()), index=self.names)
table = pd.DataFrame(columns=list(OBJ_FIELD_TYPES.keys()), index=np.arange(len(self)))

for obj in self.objectives:
for index, obj in enumerate(self.objectives):
for attr, value in obj.summary.items():
table.at[obj.name, attr] = value
table.at[index, attr] = value

for attr, dtype in OBJ_FIELD_TYPES.items():
table[attr] = table[attr].astype(dtype)
Expand All @@ -337,7 +313,6 @@ def _repr_html_(self):
return self.summary.T._repr_html_()

def add(self, objective):
_validate_objs([*self.objectives, objective])
self.objectives.append(objective)

@staticmethod
Expand Down
Loading

0 comments on commit a43f7f3

Please sign in to comment.