Skip to content
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9676189
allow encoder to return running configs
dengdifan Oct 23, 2024
6b86808
add options for batch sampling
dengdifan Oct 23, 2024
06719d2
maint constant liar with nan values
dengdifan Oct 29, 2024
8b748db
Merge branch 'development' into batch_sampling_improvement
dengdifan Dec 2, 2024
2487bc3
add docs
dengdifan Dec 2, 2024
d3f4f11
tests for config selectors
dengdifan Dec 2, 2024
3c2196a
solve conflict
dengdifan Dec 19, 2024
5f3ae8e
maint doc
dengdifan Jan 8, 2025
a1f7c32
style(config_selector)
benjamc Jan 13, 2025
f68fee5
style(abstract_encoder)
benjamc Jan 13, 2025
691b1a9
Update CHANGELOG.md
benjamc Jan 13, 2025
eafef88
Merge branch 'development' into batch_sampling_improvement
benjamc Jan 13, 2025
3cf1748
refactor(config_selector): pass all args in the facades
benjamc Jan 13, 2025
62a9588
refactor(abstract_facade): fix default, add warning in docstring
benjamc Jan 13, 2025
522c671
fix(fantasize): check whether model has been trained
benjamc Jan 13, 2025
99443b4
feat(fantasize_example): add
benjamc Jan 13, 2025
6658b88
fix(config_selector): properly check whether model is trained
benjamc Jan 13, 2025
5416a8d
Merge branch 'development' into batch_sampling_improvement
benjamc Feb 27, 2025
3ff68c3
Merge remote-tracking branch 'origin/batch_sampling_improvement' into…
daphne12345 Apr 9, 2025
198d978
batch expected improvement
daphne12345 Jun 4, 2025
dce1f54
created an example
daphne12345 Jun 4, 2025
e2c81f1
adjusted the config selector to work with q_ei
daphne12345 Jun 4, 2025
b01a1f8
Merge branch 'development' into feature/batch_bo_issue_1229
daphne12345 Jun 12, 2025
07487d0
added kwargs to the maximize function
daphne12345 Jun 12, 2025
f284798
removed qei from batch_sampling again
daphne12345 Jun 12, 2025
6d4a808
Sampling from surrogate model
daphne12345 Jun 12, 2025
4b8e801
Merge branch 'feature/batch_bo_issue_1229' of https://github.com/auto…
daphne12345 Jun 12, 2025
e1c13eb
bug fix
daphne12345 Jun 12, 2025
14eb30b
Added tests for qei
daphne12345 Jun 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@

# 2.3.0

## Features
- Improved batch sampling: Fantasize points in batch/parallel mode (#1154).

## Documentation
- Update windows install guide (#952)
- Correct intensifier for Algorithm Configuration Facade (#1162, #1165)
Expand Down
112 changes: 112 additions & 0 deletions examples/1_basics/7_0_parallelization_fantasize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""Example of using SMAC with parallelization and fantasization vs. no estimation for pending evaluations.

This example will take some time because the target function is artificially slowed down to demonstrate the effect of
fantasization. The example will plot the incumbent found by SMAC with and without fantasization.
"""
from __future__ import annotations

import numpy as np
from ConfigSpace import Configuration, ConfigurationSpace, Float

from matplotlib import pyplot as plt

from smac import BlackBoxFacade, Scenario
from smac.facade import AbstractFacade

from rich import inspect
import time

def plot_trajectory(facades: list[AbstractFacade], names: list[str]) -> None:
# Plot incumbent
cmap = plt.get_cmap("tab10")

fig = plt.figure()
axes = fig.subplots(1, 2)

for ax_i, x_axis in zip(axes, ["walltime", "trial"]):
for i, facade in enumerate(facades):
X, Y = [], []
inspect(facade.intensifier.trajectory)
for item in facade.intensifier.trajectory:
# Single-objective optimization
assert len(item.config_ids) == 1
assert len(item.costs) == 1

y = item.costs[0]
x = getattr(item, x_axis)

X.append(x)
Y.append(y)

ax_i.plot(X, Y, label=names[i], color=cmap(i))
ax_i.scatter(X, Y, marker="x", color=cmap(i))
ax_i.set_xlabel(x_axis)
ax_i.set_ylabel(facades[0].scenario.objectives)
ax_i.set_yscale("log")
ax_i.legend()

plt.show()

class Branin():
@property
def configspace(self) -> ConfigurationSpace:
# Build Configuration Space which defines all parameters and their ranges
cs = ConfigurationSpace(seed=0)

# First we create our hyperparameters
x1 = Float("x1", (-5, 10), default=0)
x2 = Float("x2", (0, 15), default=0)

# Add hyperparameters and conditions to our configspace
cs.add([x1, x2])

time.sleep(10)

return cs

def train(self, config: Configuration, seed: int) -> float:
x1 = config["x1"]
x2 = config["x2"]
a = 1.0
b = 5.1 / (4.0 * np.pi**2)
c = 5.0 / np.pi
r = 6.0
s = 10.0
t = 1.0 / (8.0 * np.pi)

cost = a * (x2 - b * x1**2 + c * x1 - r) ** 2 + s * (1 - t) * np.cos(x1) + s
regret = cost - 0.397887

return regret

if __name__ == "__main__":
seed = 345455
scenario = Scenario(n_trials=100, configspace=Branin().configspace, n_workers=4, seed=seed)
facade = BlackBoxFacade

smac_noestimation = facade(
scenario=scenario,
target_function=Branin().train,
overwrite=True,
)
smac_fantasize = facade(
scenario=scenario,
target_function=Branin().train,
config_selector=facade.get_config_selector(
scenario=scenario,
batch_sampling_estimation_strategy="kriging_believer"
),
overwrite=True,
logging_level=0
)

incumbent_noestimation = smac_noestimation.optimize()
incumbent_fantasize = smac_fantasize.optimize()

plot_trajectory(facades=[
smac_noestimation,
smac_fantasize,
], names=["No Estimation", "Fantasize"])

del smac_noestimation
del smac_fantasize
126 changes: 126 additions & 0 deletions examples/1_basics/7_1_parallelization_q_ei.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Example of using SMAC with parallelization and fantasization vs. no estimation for pending evaluations.

This example will take some time because the target function is artificially slowed down to demonstrate the effect of
fantasization. The example will plot the incumbent found by SMAC with and without fantasization.
"""
from __future__ import annotations

import numpy as np
from ConfigSpace import Configuration, ConfigurationSpace, Float

from matplotlib import pyplot as plt

from smac import BlackBoxFacade, Scenario
from smac.facade import AbstractFacade
from smac.acquisition.function.expected_improvement import QExpectedImprovement, EI
from smac.acquisition.maximizer.random_search import RandomSearch

from rich import inspect
import time

def plot_trajectory(facades: list[AbstractFacade], names: list[str]) -> None:
# Plot incumbent
cmap = plt.get_cmap("tab10")

fig = plt.figure()
axes = fig.subplots(1, 2)

for ax_i, x_axis in zip(axes, ["walltime", "trial"]):
for i, facade in enumerate(facades):
X, Y = [], []
inspect(facade.intensifier.trajectory)
for item in facade.intensifier.trajectory:
# Single-objective optimization
assert len(item.config_ids) == 1
assert len(item.costs) == 1

y = item.costs[0]
x = getattr(item, x_axis)

X.append(x)
Y.append(y)

ax_i.plot(X, Y, label=names[i], color=cmap(i))
ax_i.scatter(X, Y, marker="x", color=cmap(i))
ax_i.set_xlabel(x_axis)
ax_i.set_ylabel(facades[0].scenario.objectives)
ax_i.set_yscale("log")
ax_i.legend()

plt.show()

class Branin():
@property
def configspace(self) -> ConfigurationSpace:
# Build Configuration Space which defines all parameters and their ranges
cs = ConfigurationSpace(seed=0)

# First we create our hyperparameters
x1 = Float("x1", (-5, 10), default=0)
x2 = Float("x2", (0, 15), default=0)

# Add hyperparameters and conditions to our configspace
cs.add([x1, x2])

time.sleep(10)

return cs

def train(self, config: Configuration, seed: int) -> float:
x1 = config["x1"]
x2 = config["x2"]
a = 1.0
b = 5.1 / (4.0 * np.pi**2)
c = 5.0 / np.pi
r = 6.0
s = 10.0
t = 1.0 / (8.0 * np.pi)

cost = a * (x2 - b * x1**2 + c * x1 - r) ** 2 + s * (1 - t) * np.cos(x1) + s
regret = cost - 0.397887

return regret

if __name__ == "__main__":
seed = 345455
scenario = Scenario(n_trials=100, configspace=Branin().configspace, n_workers=4, seed=seed)
facade = BlackBoxFacade

acq_function = EI()
acq_maximizer = RandomSearch(scenario.configspace, acq_function)

smac_noestimation = facade(
scenario=scenario,
target_function=Branin().train,
overwrite=True,
acquisition_function=acq_function,
acquisition_maximizer=acq_maximizer
)

acq_function_qei = QExpectedImprovement()
acq_maximizer_qei = RandomSearch(scenario.configspace, acquisition_function=acq_function_qei)


smac_q_ei = facade(
scenario=scenario,
target_function=Branin().train,
config_selector=facade.get_config_selector(
scenario=scenario,
batch_sampling_estimation_strategy="q_ei"
),
acquisition_function = acq_function_qei,
acquisition_maximizer=acq_maximizer_qei,
overwrite=True,
logging_level=0
)

incumbent_noestimation = smac_noestimation.optimize()
incumbent_q_ei= smac_q_ei.optimize()

plot_trajectory(facades=[
smac_noestimation,
smac_q_ei,
], names=["No Estimation", "QEI"])

# del smac_noestimation
del smac_q_ei
96 changes: 96 additions & 0 deletions smac/acquisition/function/expected_improvement.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,99 @@ def calculate_f() -> np.ndarray:
raise ValueError("Expected Improvement per Second is smaller than 0 " "for at least one sample.")

return f.reshape((-1, 1))


class QExpectedImprovement(EI):
r"""
Monte Carlo approximation of q-Expected Improvement.
Approximates joint distribution with independent normals.

:math:`EI(X) := \mathbb{E}\left[ \max\{0, f(\mathbf{X^+}) - f_{t+1}(\mathbf{X}) - \xi \} \right]`,
with :math:`f(X^+)` as the best location.

Reference for q-EI


Parameters
----------
xi : float, defaults to 0.0
Controls the balance between exploration and exploitation of the
acquisition function.
log : bool, defaults to False
Whether the function values are in log-space.


Attributes
----------
_xi : float
Exploration-exloitation trade-off parameter.
_log: bool
Function values in log-space or not.
_eta : float
Current incumbent function value (best value observed so far).

"""

def __init__(self, xi: float = 0.0, n_samples: int = 128) -> None:
super(QExpectedImprovement, self).__init__(xi=xi)
self.n_samples = n_samples

@property
def name(self) -> str: # noqa: D102
return "Batch Expected Improvement"

def _compute(self, X: np.ndarray) -> np.ndarray:
"""
Compute q-EI acquisition value using Monte Carlo approximation.

Parameters
----------
X : np.ndarray [N, D]
The batch of input points to evaluate.

Returns
-------
np.ndarray [1, 1]
The q-EI value for the batch as a whole.
"""
assert self._model is not None
assert self._xi is not None

if self._eta is None:
raise ValueError(
"No current best specified. Call update(eta=<float>) to inform the acquisition function "
"about the current best value."
)

if len(X.shape) == 1:
X = X[np.newaxis, :]

m, var = self._model.predict_marginalized(X)
std = np.sqrt(var)

if np.any(std == 0.0):
logger.warning("Predicted std is 0.0 for at least one sample.")
std_copy = np.copy(std)
std[std_copy == 0.0] = 1.0 # prevent division by zero

# Monte Carlo sampling from log-normal distribution
normal_samples = np.random.normal(loc=m.T, scale=std.T, size=(self.n_samples, X.shape[0]))

if not self._log:
f_samples = normal_samples # in original (normal) space
f_min_sample = np.min(f_samples, axis=1)
improvement = np.maximum(self._eta - self._xi - f_min_sample, 0.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite sure here, but if the surrogate model is a GP, we can directly sample from the joint posterior distribution. As shown in botorch: https://botorch.readthedocs.io/en/stable/acquisition.html#botorch.acquisition.monte_carlo.qExpectedImprovement

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I replaced it with the surrogate model, if it is a GP.

else:
# In log-space, the *actual values* are exp(samples)
f_samples = np.exp(normal_samples)
f_min_sample = np.min(f_samples, axis=1)

# eta is already in log-space, so we compare to exp(eta - xi)
improvement = np.maximum(np.exp(self._eta - self._xi) - f_min_sample, 0.0)

qei = np.mean(improvement)

if qei < 0:
raise ValueError("q-Expected Improvement is smaller than 0. Should not happen.")

return np.array([[qei]])
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def maximize(
previous_configs: list[Configuration],
n_points: int | None = None,
random_design: AbstractRandomDesign | None = None,
**kwargs: Any,
) -> Iterator[Configuration]:
"""Maximize acquisition function using `_maximize`, implemented by a subclass.

Expand Down
Loading
Loading