Skip to content

Commit

Permalink
add discrete DOFs
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Morris committed Apr 3, 2024
1 parent 372292a commit bf32d09
Show file tree
Hide file tree
Showing 14 changed files with 391 additions and 103 deletions.
6 changes: 3 additions & 3 deletions docs/source/agent.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ The blop ``Agent`` takes care of the entire optimization loop, from data acquisi
from blop import DOF, Objective, Agent
dofs = [
DOF(name="x1", description="the first DOF", search_bounds=(-10, 10))
DOF(name="x2", description="another DOF", search_bounds=(-5, 5))
DOF(name="x3", description="ayet nother DOF", search_bounds=(0, 1))
DOF(name="x1", description="the first DOF", search_domain=(-10, 10))
DOF(name="x2", description="another DOF", search_domain=(-5, 5))
DOF(name="x3", description="ayet nother DOF", search_domain=(0, 1))
]
objective = [
Expand Down
4 changes: 2 additions & 2 deletions docs/source/dofs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ A degree of freedom is a variable that affects our optimization objective. We ca
from blop import DOF
dof = DOF(name="x1", description="my first DOF", search_bounds=(lower, upper))
dof = DOF(name="x1", description="my first DOF", search_domain=(lower, upper))
This will instantiate a bunch of stuff under the hood, so that our agent knows how to move things and where to search.
Typically, this will correspond to a real, physical device available in Python. In that case, we can pass the DOF an ophyd device in place of a name
Expand All @@ -16,7 +16,7 @@ Typically, this will correspond to a real, physical device available in Python.
from blop import DOF
dof = DOF(device=my_ophyd_device, description="a real piece of hardware", search_bounds=(lower, upper))
dof = DOF(device=my_ophyd_device, description="a real piece of hardware", search_domain=(lower, upper))
In this case, the agent will control the device as it sees fit, moving it between the search bounds.

Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/himmelblau.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@
"from blop import DOF\n",
"\n",
"dofs = [\n",
" DOF(name=\"x1\", search_bounds=(-6, 6)),\n",
" DOF(name=\"x2\", search_bounds=(-6, 6)),\n",
" DOF(name=\"x1\", search_domain=(-6, 6)),\n",
" DOF(name=\"x2\", search_domain=(-6, 6)),\n",
"]"
]
},
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/hyperparameters.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@
"from blop import DOF, Objective, Agent\n",
"\n",
"dofs = [\n",
" DOF(name=\"x1\", search_bounds=(-6, 6)),\n",
" DOF(name=\"x2\", search_bounds=(-6, 6)),\n",
" DOF(name=\"x1\", search_domain=(-6, 6)),\n",
" DOF(name=\"x2\", search_domain=(-6, 6)),\n",
"]\n",
"\n",
"objectives = [\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/source/tutorials/passive-dofs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@
"\n",
"\n",
"dofs = [\n",
" DOF(name=\"x1\", search_bounds=(-5.0, 5.0)),\n",
" DOF(name=\"x2\", search_bounds=(-5.0, 5.0)),\n",
" DOF(name=\"x3\", search_bounds=(-5.0, 5.0), active=False),\n",
" DOF(name=\"x1\", search_domain=(-5.0, 5.0)),\n",
" DOF(name=\"x2\", search_domain=(-5.0, 5.0)),\n",
" DOF(name=\"x3\", search_domain=(-5.0, 5.0), active=False),\n",
" DOF(device=BrownianMotion(name=\"brownian1\"), read_only=True),\n",
" DOF(device=BrownianMotion(name=\"brownian2\"), read_only=True, active=False),\n",
"]\n",
Expand Down
194 changes: 194 additions & 0 deletions scripts/gui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import asyncio

import databroker
import matplotlib as mpl
import numpy as np
from bluesky.callbacks import best_effort
from bluesky.run_engine import RunEngine
from databroker import Broker
from nicegui import ui

from blop import DOF, Agent, Objective
from blop.utils import functions

# MongoDB backend:
db = Broker.named("temp") # mongodb backend
try:
databroker.assets.utils.install_sentinels(db.reg.config, version=1)
except Exception:
pass

loop = asyncio.new_event_loop()
loop.set_debug(True)
RE = RunEngine({}, loop=loop)
RE.subscribe(db.insert)

bec = best_effort.BestEffortCallback()
RE.subscribe(bec)

bec.disable_baseline()
bec.disable_heading()
bec.disable_table()
bec.disable_plots()


dofs = [
DOF(name="x1", description="x1", search_domain=(-5.0, 5.0)),
DOF(name="x2", description="x2", search_domain=(-5.0, 5.0)),
]

objectives = [Objective(name="himmelblau", target="min")]

agent = Agent(
dofs=dofs,
objectives=objectives,
digestion=functions.himmelblau_digestion,
db=db,
verbose=True,
tolerate_acquisition_errors=False,
)

agent.acqf_index = 0

agent.acqf_number = 2


with ui.pyplot(figsize=(10, 4), dpi=160) as obj_plt:
extent = [*agent.dofs[0].search_domain, *agent.dofs[1].search_domain]

ax1 = obj_plt.fig.add_subplot(131)
ax1.set_title("Samples")
im1 = ax1.scatter([], [], cmap="magma")

ax2 = obj_plt.fig.add_subplot(132, sharex=ax1, sharey=ax1)
ax2.set_title("Posterior mean")
im2 = ax2.imshow(np.random.standard_normal(size=(32, 32)), extent=extent, cmap="magma")

ax3 = obj_plt.fig.add_subplot(133, sharex=ax1, sharey=ax1)
ax3.set_title("Posterior error")
im3 = ax3.imshow(np.random.standard_normal(size=(32, 32)), extent=extent, cmap="magma")

data_cbar = obj_plt.fig.colorbar(mappable=im1, ax=[ax1, ax2], location="bottom", aspect=32)
err_cbar = obj_plt.fig.colorbar(mappable=im3, ax=[ax3], location="bottom", aspect=16)

for ax in [ax1, ax2, ax3]:
ax.set_xlabel(agent.dofs[0].label)
ax.set_ylabel(agent.dofs[1].label)


acqf_configs = {
0: {"name": "qr", "long_name": r"quasi-random sampling"},
1: {"name": "qei", "long_name": r"$q$-expected improvement"},
2: {"name": "qpi", "long_name": r"$q$-probability of improvement"},
3: {"name": "qucb", "long_name": r"$q$-upper confidence bound"},
}

with ui.pyplot(figsize=(10, 3), dpi=160) as acq_plt:
extent = [*agent.dofs[0].search_domain, *agent.dofs[1].search_domain]

acqf_plt_objs = {}

for iax, config in acqf_configs.items():
if iax == 0:
continue

acqf = config["name"]

acqf_plt_objs[acqf] = {}

acqf_plt_objs[acqf]["ax"] = ax = acq_plt.fig.add_subplot(1, len(acqf_configs) - 1, iax)

ax.set_title(config["long_name"])
acqf_plt_objs[acqf]["im"] = ax.imshow([[]], extent=extent, cmap="gray_r")
acqf_plt_objs[acqf]["hist"] = ax.scatter([], [])
acqf_plt_objs[acqf]["best"] = ax.scatter([], [])

ax.set_xlabel(agent.dofs[0].label)
ax.set_ylabel(agent.dofs[1].label)


acqf_button_options = {index: config["name"] for index, config in acqf_configs.items()}

v = ui.checkbox("visible", value=True)
with ui.column().bind_visibility_from(v, "value"):
ui.toggle(acqf_button_options).bind_value(agent, "acqf_index")
ui.number().bind_value(agent, "acqf_number")


def reset():
agent.reset()

print(agent.table)


def learn():
acqf_config = acqf_configs[agent.acqf_index]

acqf = acqf_config["name"]

n = int(agent.acqf_number) if acqf != "qr" else 16

ui.notify(f"sampling {n} points with acquisition function \"{acqf_config['long_name']}\"")

RE(agent.learn(acqf, n=n))

with obj_plt:
obj = agent.objectives[0]

x_samples = agent.train_inputs().detach().numpy()
y_samples = agent.train_targets(obj.name).detach().numpy()[..., 0]

x = agent.sample(method="grid", n=20000) # (n, n, 1, d)
p = obj.model.posterior(x)

m = p.mean.squeeze(-1, -2).detach().numpy()
e = p.variance.sqrt().squeeze(-1, -2).detach().numpy()

im1.set_offsets(x_samples)
im1.set_array(y_samples)
im1.set_cmap("magma")

im2.set_data(m.T[::-1])
im3.set_data(e.T[::-1])

obj_norm = mpl.colors.Normalize(vmin=np.nanmin(y_samples), vmax=np.nanmax(y_samples))
err_norm = mpl.colors.LogNorm(vmin=np.nanmin(e), vmax=np.nanmax(e))

im1.set_norm(obj_norm)
im2.set_norm(obj_norm)
im3.set_norm(err_norm)

for ax in [ax1, ax2, ax3]:
ax.set_xlim(*agent.dofs[0].search_domain)
ax.set_ylim(*agent.dofs[1].search_domain)

with acq_plt:
x = agent.sample(method="grid", n=20000) # (n, n, 1, d)
x_samples = agent.train_inputs().detach().numpy()

for acqf in acqf_plt_objs.keys():
ax = acqf_plt_objs[acqf]["ax"]

acqf_obj = getattr(agent, acqf)(x).detach().numpy()

acqf_norm = mpl.colors.Normalize(vmin=np.nanmin(acqf_obj), vmax=np.nanmax(acqf_obj))
acqf_plt_objs[acqf]["im"].set_data(acqf_obj.T[::-1])
acqf_plt_objs[acqf]["im"].set_norm(acqf_norm)

res = agent.ask(acqf, n=int(agent.acqf_number))

acqf_plt_objs[acqf]["hist"].remove()
acqf_plt_objs[acqf]["hist"] = ax.scatter(*x_samples.T, ec="b", fc="none", marker="o")

acqf_plt_objs[acqf]["best"].remove()
acqf_plt_objs[acqf]["best"] = ax.scatter(*res["points"].T, c="r", marker="x", s=64)

ax.set_xlim(*agent.dofs[0].search_domain)
ax.set_ylim(*agent.dofs[1].search_domain)


ui.button("Learn", on_click=learn)

ui.button("Reset", on_click=reset)

ui.run(port=8004)
4 changes: 2 additions & 2 deletions src/blop/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE

__version__ = version = "0.6.2.dev0"
__version_tuple__ = version_tuple = (0, 6, 2, "dev0")
__version__ = version = "0.5.1.dev48"
__version_tuple__ = version_tuple = (0, 5, 1, "dev48")
Loading

0 comments on commit bf32d09

Please sign in to comment.