Skip to content

Commit

Permalink
make sure DOF bounds are cast to floats
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Morris committed Oct 12, 2023
1 parent 1bbb801 commit 39a579f
Show file tree
Hide file tree
Showing 15 changed files with 248 additions and 199 deletions.
4 changes: 2 additions & 2 deletions bloptools/bayesian/acquisition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,11 @@ def get_acquisition_function(agent, acq_func_identifier="qei", return_metadata=T
acq_func_meta = {"name": acq_func_name, "args": {"beta": beta}}

elif acq_func_name == "expected_mean":
acq_func, _ = get_acquisition_function(agent, acq_func_identifier="ucb", beta=0, return_metadata=False)
acq_func = get_acquisition_function(agent, acq_func_identifier="ucb", beta=0, return_metadata=False)
acq_func_meta = {"name": acq_func_name, "args": {}}

elif acq_func_name == "monte_carlo_expected_mean":
acq_func, _ = get_acquisition_function(agent, acq_func_identifier="qucb", beta=0, return_metadata=False)
acq_func = get_acquisition_function(agent, acq_func_identifier="qucb", beta=0, return_metadata=False)
acq_func_meta = {"name": acq_func_name, "args": {}}

return (acq_func, acq_func_meta) if return_metadata else acq_func
95 changes: 22 additions & 73 deletions bloptools/bayesian/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@

MAX_TEST_INPUTS = 2**11

os.environ["KMP_DUPLICATE_LIB_OK"] = "True"


class Agent:
def __init__(
Expand Down Expand Up @@ -126,8 +124,8 @@ def tell(self, new_table=None, append=True, train=True, **kwargs):
if not train_index.sum() >= 2:
raise ValueError("There must be at least two valid data points per objective!")

train_inputs = torch.tensor(inputs[train_index]).double()
train_targets = torch.tensor(targets[train_index]).double().unsqueeze(-1) # .unsqueeze(0)
train_inputs = torch.tensor(inputs[train_index], dtype=torch.double)
train_targets = torch.tensor(targets[train_index], dtype=torch.double).unsqueeze(-1) # .unsqueeze(0)

# for constructing the log normal noise prior
# target_snr = 2e2
Expand Down Expand Up @@ -212,7 +210,7 @@ def ask(self, acq_func_identifier="qei", n=1, route=True, sequential=True, **acq

candidates, _ = botorch.optim.optimize_acqf(
acq_function=acq_func,
bounds=self._active_bounds_torch,
bounds=self.acquisition_function_bounds,
q=n,
sequential=sequential,
num_restarts=NUM_RESTARTS,
Expand Down Expand Up @@ -392,22 +390,6 @@ def _get_objective_targets(self, i):

return targets

# @property
# def acquisition_dofs(self):
# """
# Returns the acquisition DOFs, which are the DOFs to optimize over (that is, active and not read-only).
# """
# return self.dofs.subset(active=True, read_only=False)

# @property
# def acquisition_dof_limits(self):
# """
# Returns the acquisition limits, which are the ranges optimize over (that is, active and not read-only).
# This has shape (n_acq_dofs, 2).
# """
# acq_dofs = self.dofs.subset(active=True, read_only=False)
# return np.c_[acq_dofs.summary.lower_limit.values, acq_dofs.summary.upper_limit.values]

@property
def n_objs(self):
"""
Expand Down Expand Up @@ -441,17 +423,19 @@ def target_names(self):

def test_inputs_grid(self, max_inputs=MAX_TEST_INPUTS):
"""
Returns a (n_side, ..., n_side, 1, n_active_dof) grid of test_inputs
Returns a (n_side, ..., n_side, 1, n_active_dof) grid of test_inputs.
n_side is 1 if a dof is read-only
"""
n_acq_dofs = len(self.dofs.subset(active=True, read_only=False))
n_side = int(np.power(max_inputs, n_acq_dofs**-1))
n_settable_acq_func_dofs = len(self.dofs.subset(active=True, read_only=False))
n_side_settable = int(np.power(max_inputs, n_settable_acq_func_dofs**-1))
n_sides = [1 if dof.read_only else n_side_settable for dof in self.dofs.subset(active=True)]
return torch.cat(
[
tensor.unsqueeze(-1)
for tensor in torch.meshgrid(
*[
torch.linspace(dof.lower_limit, dof.upper_limit, n_side) if not dof.read_only else dof.readback
for dof in self.dofs.subset(active=True)
torch.linspace(lower_limit, upper_limit, n_side)
for (lower_limit, upper_limit), n_side in zip(self.dofs.subset(active=True).limits, n_sides)
],
indexing="ij",
)
Expand All @@ -463,10 +447,18 @@ def test_inputs(self, n=MAX_TEST_INPUTS):
"""
Returns a (n, 1, n_active_dof) grid of test_inputs
"""
return utils.sobol_sampler(self._active_bounds_torch, n=n)
return utils.sobol_sampler(self.acquisition_function_bounds, n=n)

@property
def _active_bounds_torch(self):
def acquisition_function_bounds(self):
"""
Returns a (2, n_active_dof) array of bounds for the acquisition function
"""
acq_func_lower_bounds = [dof.lower_limit if not dof.read_only else dof.readback for dof in self.dofs]
acq_func_upper_bounds = [dof.upper_limit if not dof.read_only else dof.readback for dof in self.dofs]

return torch.tensor(np.vstack([acq_func_lower_bounds, acq_func_upper_bounds]), dtype=torch.double)

return torch.tensor(
[dof.limits if not dof.read_only else tuple(2 * [dof.readback]) for dof in self.dofs.subset(active=True)]
).T
Expand Down Expand Up @@ -644,48 +636,5 @@ def plot_validity(self, **kwargs):
else:
plotting._plot_valid_many_dofs(self, **kwargs)

# def plot_history(self, x_key="index", show_all_objs=False):
# x = getattr(self.table, x_key).values

# num_obj_plots = 1
# if show_all_objs:
# num_obj_plots = self.n_objs + 1

# self.n_objs + 1 if self.n_objs > 1 else 1

# hist_fig, hist_axes = plt.subplots(
# num_obj_plots, 1, figsize=(6, 4 * num_obj_plots), sharex=True, constrained_layout=True, dpi=200
# )
# hist_axes = np.atleast_1d(hist_axes)

# unique_strategies, acq_func_index, acq_func_inverse = np.unique(
# self.table.acq_func, return_index=True, return_inverse=True
# )

# sample_colors = np.array(DEFAULT_COLOR_LIST)[acq_func_inverse]

# if show_all_objs:
# for obj_index, obj in enumerate(self.objectives):
# y = self.table.loc[:, f"{obj.key}_fitness"].values
# hist_axes[obj_index].scatter(x, y, c=sample_colors)
# hist_axes[obj_index].plot(x, y, lw=5e-1, c="k")
# hist_axes[obj_index].set_ylabel(obj.key)

# y = self.scalarized_objectives

# cummax_y = np.array([np.nanmax(y[: i + 1]) for i in range(len(y))])

# hist_axes[-1].scatter(x, y, c=sample_colors)
# hist_axes[-1].plot(x, y, lw=5e-1, c="k")

# hist_axes[-1].plot(x, cummax_y, lw=5e-1, c="k", ls=":")

# hist_axes[-1].set_ylabel("total_fitness")
# hist_axes[-1].set_xlabel(x_key)

# handles = []
# for i_acq_func, acq_func in enumerate(unique_strategies):
# # i_acq_func = np.argsort(acq_func_index)[i_handle]
# handles.append(Patch(color=DEFAULT_COLOR_LIST[i_acq_func], label=acq_func))
# legend = hist_axes[0].legend(handles=handles, fontsize=8)
# legend.set_title("acquisition function")
def plot_history(self, **kwargs):
plotting._plot_history(self, **kwargs)
6 changes: 3 additions & 3 deletions bloptools/bayesian/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class DOF:
def __init__(
self,
device: Signal = None,
limits: Tuple[numeric, numeric] = (-10.0, 10.0),
limits: Tuple[float, float] = (-10.0, 10.0),
name: str = None,
units: str = None,
read_only: bool = None,
Expand All @@ -53,11 +53,11 @@ def __init__(

@property
def lower_limit(self):
return self.limits[0]
return float(self.limits[0])

@property
def upper_limit(self):
return self.limits[1]
return float(self.limits[1])

@property
def readback(self):
Expand Down
6 changes: 4 additions & 2 deletions bloptools/bayesian/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

numeric = Union[float, int]

DEFAULT_MINIMUM_SNR = 1e1
DEFAULT_MINIMUM_SNR = 2e1
OBJ_FIELDS = ["name", "key", "limits", "weight", "minimize", "log"]


Expand Down Expand Up @@ -59,7 +59,7 @@ def summary(self):
return series

def __repr__(self):
return self.params.__repr__()
return self.summary.__repr__()

@property
def has_model(self):
Expand Down Expand Up @@ -88,6 +88,8 @@ def summary(self):
for attr in ["minimize", "log"]:
summary[attr] = summary[attr].astype(bool)

return summary

def __repr__(self):
return self.summary.__repr__()

Expand Down
91 changes: 49 additions & 42 deletions bloptools/bayesian/plotting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.patches import Patch

from . import acquisition

Expand Down Expand Up @@ -194,7 +195,7 @@ def _plot_acq_one_dof(agent, acq_funcs, lw=1e0, **kwargs):
acq_func, acq_func_meta = acquisition.get_acquisition_function(agent, acq_func_identifier)
test_acqf = acq_func(test_inputs).detach().numpy()

agent.acq_axes[iacq_func].plot(test_inputs.squeeze(), test_acqf, lw=lw, color=color)
agent.acq_axes[iacq_func].plot(test_inputs.squeeze(-2), test_acqf, lw=lw, color=color)

agent.acq_axes[iacq_func].set_xlim(*x_dof.limits)
agent.acq_axes[iacq_func].set_xlabel(x_dof.label)
Expand Down Expand Up @@ -269,7 +270,7 @@ def _plot_valid_one_dof(agent, size=16, lw=1e0):
constraint = agent.constraint(test_inputs)[..., 0]

agent.valid_ax.scatter(x_values, agent.all_objectives_valid, s=size)
agent.valid_ax.plot(test_inputs.squeeze(), constraint, lw=lw)
agent.valid_ax.plot(test_inputs.squeeze(-2), constraint, lw=lw)
agent.valid_ax.set_xlim(*x_dof.limits)


Expand Down Expand Up @@ -316,45 +317,51 @@ def _plot_valid_many_dofs(agent, axes=[0, 1], shading="nearest", cmap=DEFAULT_CO
ax.set_xlim(*x_dof.limits)
ax.set_ylim(*y_dof.limits)

# data_ax = agent.valid_axes[0].scatter(
# *agent.acquisition_inputs.values.T[:2],
# c=agent.all_objectives_valid,
# s=size,
# vmin=0,
# vmax=1,
# cmap=cmap,
# )

# x = agent.test_inputs_grid().squeeze() if gridded else agent.test_inputs(n=MAX_TEST_INPUTS)
# *input_shape, input_dim = x.shape
# constraint = agent.classifier.probabilities(x.reshape(-1, 1, input_dim))[..., -1].reshape(input_shape)

# if gridded:
# agent.valid_axes[1].pcolormesh(
# x[..., 0].detach().numpy(),
# x[..., 1].detach().numpy(),
# constraint.detach().numpy(),
# shading=shading,
# cmap=cmap,
# vmin=0,
# vmax=1,
# )

# # agent.acq_fig.colorbar(obj_ax, ax=agent.valid_axes[iacq_func], location="bottom", aspect=32, shrink=0.8)

# else:
# # agent.valid_axes.set_title(acq_func_meta["name"])
# agent.valid_axes[1].scatter(
# x.detach().numpy()[..., axes[0]],
# x.detach().numpy()[..., axes[1]],
# c=constraint.detach().numpy(),
# )

# agent.valid_fig.colorbar(data_ax, ax=agent.valid_axes[:2], location="bottom", aspect=32, shrink=0.8)

# for ax in agent.valid_axes.ravel():
# ax.set_xlim(*agent.dofs.subset(active=True, read_only=False)[axes[0]].limits)
# ax.set_ylim(*agent.dofs.subset(active=True, read_only=False)[axes[1]].limits)

def _plot_history(agent, x_key="index", show_all_objs=False):
x = getattr(agent.table, x_key).values

num_obj_plots = 1
if show_all_objs:
num_obj_plots = agent.n_objs + 1

agent.n_objs + 1 if agent.n_objs > 1 else 1

hist_fig, hist_axes = plt.subplots(
num_obj_plots, 1, figsize=(6, 4 * num_obj_plots), sharex=True, constrained_layout=True, dpi=200
)
hist_axes = np.atleast_1d(hist_axes)

unique_strategies, acq_func_index, acq_func_inverse = np.unique(
agent.table.acq_func, return_index=True, return_inverse=True
)

sample_colors = np.array(DEFAULT_COLOR_LIST)[acq_func_inverse]

if show_all_objs:
for obj_index, obj in enumerate(agent.objectives):
y = agent.table.loc[:, f"{obj.key}_fitness"].values
hist_axes[obj_index].scatter(x, y, c=sample_colors)
hist_axes[obj_index].plot(x, y, lw=5e-1, c="k")
hist_axes[obj_index].set_ylabel(obj.key)

y = agent.scalarized_objectives

cummax_y = np.array([np.nanmax(y[: i + 1]) for i in range(len(y))])

hist_axes[-1].scatter(x, y, c=sample_colors)
hist_axes[-1].plot(x, y, lw=5e-1, c="k")

hist_axes[-1].plot(x, cummax_y, lw=5e-1, c="k", ls=":")

hist_axes[-1].set_ylabel("total_fitness")
hist_axes[-1].set_xlabel(x_key)

handles = []
for i_acq_func, acq_func in enumerate(unique_strategies):
handles.append(Patch(color=DEFAULT_COLOR_LIST[i_acq_func], label=acq_func))
legend = hist_axes[0].legend(handles=handles, fontsize=8)
legend.set_title("acquisition function")


def inspect_beam(agent, index, border=None):
Expand Down
25 changes: 0 additions & 25 deletions bloptools/tasks.py

This file was deleted.

4 changes: 2 additions & 2 deletions bloptools/tests/test_passive_dofs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_passive_dofs(RE, db):
DOF(name="x1", limits=(-5.0, 5.0)),
DOF(name="x2", limits=(-5.0, 5.0)),
DOF(BrownianMotion(name="brownian1"), read_only=True),
DOF(BrownianMotion(name="brownian1"), read_only=True),
DOF(BrownianMotion(name="brownian2"), read_only=True),
]

objectives = [
Expand All @@ -28,6 +28,6 @@ def test_passive_dofs(RE, db):

RE(agent.learn("qr", n=32))

agent.plot_tasks()
agent.plot_objectives()
agent.plot_acquisition()
agent.plot_validity()
2 changes: 1 addition & 1 deletion bloptools/tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
def test_plots(RE, agent):
RE(agent.learn("qr", n=32))

agent.plot_tasks()
agent.plot_objectives()
agent.plot_acquisition()
agent.plot_validity()
agent.plot_history()
Loading

0 comments on commit 39a579f

Please sign in to comment.