Skip to content

Commit

Permalink
Merge branch 'fix-multi-acqfs' into targeting
Browse files Browse the repository at this point in the history
  • Loading branch information
Thomas Morris committed Oct 28, 2023
2 parents 02468ae + f0bc991 commit c43b52e
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 92 deletions.
5 changes: 3 additions & 2 deletions bloptools/bayesian/acquisition/analytic.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ def __init__(self, constraint, *args, **kwargs):

def forward(self, x):
*input_shape, _, _ = x.shape

mean, sigma = self._mean_and_sigma(x)
transformed_posterior = self.posterior_transform(self.model.posterior(x))
mean = transformed_posterior.mean.reshape(input_shape)
sigma = transformed_posterior.variance.sqrt().reshape(input_shape)

p_eff = (
0.5
Expand Down
14 changes: 3 additions & 11 deletions bloptools/bayesian/acquisition/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,16 @@ def __init__(self, constraint, beta=4, *args, **kwargs):
def forward(self, x):
*input_shape, _, _ = x.shape

posterior = self.model.posterior(x)
mean, sigma = posterior.mean, posterior.variance.sqrt()
transformed_posterior = self.posterior_transform(self.model.posterior(x))
mean = transformed_posterior.mean.reshape(input_shape)
sigma = transformed_posterior.variance.sqrt().reshape(input_shape)

p_eff = (
0.5
* (1 + torch.special.erf(self.beta.sqrt() / math.sqrt(2)))
* torch.clamp(self.constraint(x).reshape(input_shape), min=1e-6)
)

return mean.reshape(*input_shape) + sigma.reshape(*input_shape) * np.sqrt(2) * torch.special.erfinv(
2 * p_eff.reshape(*input_shape) - 1
)

posterior = self.model.posterior(x)
mean, sigma = posterior.mean, posterior.variance.sqrt()

p_eff = 0.5 * (1 + torch.special.erf(self.beta.sqrt() / math.sqrt(2))) * torch.clamp(self.constraint(x), min=1e-6)

return mean + sigma * np.sqrt(2) * torch.special.erfinv(2 * p_eff - 1)


Expand Down
2 changes: 1 addition & 1 deletion bloptools/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def agent(db):


@pytest.fixture(scope="function")
def multitask_agent(db):
def multi_agent(db):
"""
A simple agent minimizing two Styblinski-Tang functions
"""
Expand Down
20 changes: 16 additions & 4 deletions bloptools/tests/test_acq_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,24 @@


@pytest.mark.parametrize("acq_func", ["ei", "pi", "em", "ucb"])
def test_analytic_acq_funcs_single_task(agent, RE, acq_func):
RE(agent.learn("qr", n=32))
def test_analytic_acq_funcs_single_objective(agent, RE, acq_func):
RE(agent.learn("qr", n=16))
RE(agent.learn(acq_func, n=1))


@pytest.mark.parametrize("acq_func", ["qei", "qpi", "qem", "qucb"])
def test_monte_carlo_acq_funcs_single_task(agent, RE, acq_func):
RE(agent.learn("qr", n=32))
def test_monte_carlo_acq_funcs_single_objective(agent, RE, acq_func):
RE(agent.learn("qr", n=16))
RE(agent.learn(acq_func, n=4))


@pytest.mark.parametrize("acq_func", ["ei", "pi", "em", "ucb"])
def test_analytic_acq_funcs_multi_objective(multi_agent, RE, acq_func):
RE(multi_agent.learn("qr", n=16))
RE(multi_agent.learn(acq_func, n=1))


@pytest.mark.parametrize("acq_func", ["qei", "qpi", "qem", "qucb"])
def test_monte_carlo_acq_funcs_multi_objective(multi_agent, RE, acq_func):
RE(multi_agent.learn("qr", n=16))
RE(multi_agent.learn(acq_func, n=4))
7 changes: 7 additions & 0 deletions bloptools/utils/prepare_re_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ def register_handlers(db, handlers):
help="Type of RE environment.",
)

parser.add_argument(
"-f",
"--file",
dest="file",
default="",
)

args = parser.parse_args()
kwargs_re = dict(db_type=args.db_type, root_dir=args.root_dir)
ret = re_env(**kwargs_re)
Expand Down
76 changes: 2 additions & 74 deletions docs/source/tutorials/himmelblau.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,11 @@
"metadata": {},
"outputs": [],
"source": [
"from bloptools.bayesian import DOF, BrownianMotion\n",
"from bloptools.bayesian import DOF\n",
"\n",
"dofs = [\n",
" DOF(name=\"x1\", limits=(-6, 6)),\n",
" DOF(name=\"x2\", limits=(-6, 6)),\n",
" DOF(BrownianMotion(name=\"brownian1\"), read_only=True),\n",
" DOF(BrownianMotion(name=\"brownian2\"), read_only=True),\n",
"]"
]
},
Expand Down Expand Up @@ -153,16 +151,6 @@
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e964d5a5-2a4a-4403-8c06-4ad17b00cecf",
"metadata": {},
"outputs": [],
"source": [
"agent.test_inputs_grid().shape"
]
},
{
"cell_type": "markdown",
"id": "27685849",
Expand Down Expand Up @@ -219,40 +207,6 @@
"agent.plot_acquisition(acq_funcs=[\"qei\", \"pi\", \"qucb\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9c0b42dc-2df3-4ba6-b02f-569dab48db80",
"metadata": {},
"outputs": [],
"source": [
"agent.dofs.limits"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "62b24d9b-740f-45a6-9617-47796c260273",
"metadata": {},
"outputs": [],
"source": [
"self = agent\n",
"import torch\n",
"\n",
"acq_func_lower_bounds = [dof.lower_limit if not dof.read_only else dof.readback for dof in self.dofs]\n",
"acq_func_upper_bounds = [dof.upper_limit if not dof.read_only else dof.readback for dof in self.dofs]\n",
"\n",
"torch.tensor(np.vstack([acq_func_lower_bounds, acq_func_upper_bounds]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "217158c0-aa65-409c-a7ab-63bb924723ad",
"metadata": {},
"outputs": [],
"source": []
},
{
"attachments": {},
"cell_type": "markdown",
Expand All @@ -262,24 +216,6 @@
"To decide where to go, the agent will find the inputs that maximize a given acquisition function:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "16ec3c97-211b-49df-9e45-fcdd61ae98eb",
"metadata": {},
"outputs": [],
"source": [
"agent.acquisition_function_bounds"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a066da53-0cdc-429b-a588-ce22b4a599b5",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -353,16 +289,8 @@
"outputs": [],
"source": [
"agent.plot_objectives()\n",
"# print(agent.best_inputs)"
"print(agent.best_inputs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b6453b87-f864-40af-ba70-9a42960f54b9",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down

0 comments on commit c43b52e

Please sign in to comment.