Skip to content

Commit

Permalink
Merge pull request #460 from Thomas-Christie/expected-improvement
Browse files Browse the repository at this point in the history
Add expected improvement utility function
  • Loading branch information
thomaspinder authored Jul 16, 2024
2 parents fc45b9f + 8fb0f9a commit b69be96
Show file tree
Hide file tree
Showing 25 changed files with 385 additions and 206 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ repos:
types: [python]
- id: ruff
name: ruff
entry: ruff
entry: ruff check
args: ["--exit-non-zero-on-fix"]
require_serial: true
language: system
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ help: ## Display this help

##@ Formatting
black: ## Format code in-place using black.
black ${PKGROOT}/ tests/ -l 79 .
black ${PKGROOT}/ tests/ -l 88 .

isort: ## Format imports in-place using isort.
isort ${PKGROOT}/ tests/

format: ## Code styling - black, isort
black ${PKGROOT}/ tests/ -l 100 .
black ${PKGROOT}/ tests/ -l 88 .
@printf "\033[1;34mBlack passes!\033[0m\n\n"
isort ${PKGROOT}/ tests/
@printf "\033[1;34misort passes!\033[0m\n\n"
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/bayesian_optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ def obtain_log_regret_statistics(
#
# - **Expected Improvement (EI)** ([Močkus, 1974](https://link.springer.com/chapter/10.1007/3-540-07165-2_55)) - EI goes beyond PI by not only considering the
# probability of improving on the current best observed point, but also taking into
# account the \textit{magnitude} of improvement. Mathematically, this is defined as
# account the *magnitude* of improvement. Mathematically, this is defined as
# follows:
# $$
# \begin{aligned}
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/decision_making.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,8 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
# %% [markdown]

# It is worth noting that `ThompsonSampling` is not the only utility function we could use,
# since our module also provides e.g. `ProbabilityOfImprovement`,
# which was briefly discussed in [our previous introduction to Bayesian optimisation](https://docs.jaxgaussianprocesses.com/examples/bayesian_optimisation/).
# since our module also provides e.g. `ProbabilityOfImprovement`, `ExpectedImprovment`,
# which were briefly discussed in [our previous introduction to Bayesian optimisation](https://docs.jaxgaussianprocesses.com/examples/bayesian_optimisation/).


# %% [markdown]
Expand Down
26 changes: 19 additions & 7 deletions gpjax/decision_making/test_functions/continuous_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from abc import (
ABC,
abstractmethod,
)
from abc import abstractmethod
from dataclasses import dataclass

import jax.numpy as jnp
from jaxtyping import (
Array,
Float,
Num,
)
import tensorflow_probability.substrates.jax as tfp

from gpjax.dataset import Dataset
from gpjax.decision_making.search_space import ContinuousSearchSpace
from gpjax.gps import AbstractMeanFunction
from gpjax.typing import KeyArray


class AbstractContinuousTestFunction(ABC):
class AbstractContinuousTestFunction(AbstractMeanFunction):
"""
Abstract base class for continuous test functions.
Expand All @@ -43,19 +43,28 @@ class AbstractContinuousTestFunction(ABC):
minimizer: Float[Array, "1 D"]
minimum: Float[Array, "1 1"]

def generate_dataset(self, num_points: int, key: KeyArray) -> Dataset:
def generate_dataset(
self, num_points: int, key: KeyArray, obs_stddev: float = 0.0
) -> Dataset:
"""
Generate a toy dataset from the test function.
Args:
num_points (int): Number of points to sample.
key (KeyArray): JAX PRNG key.
obs_stddev (float): (Optional) standard deviation of Gaussian distributed
noise added to observations.
Returns:
Dataset: Dataset of points sampled from the test function.
"""
X = self.search_space.sample(num_points=num_points, key=key)
y = self.evaluate(X)
gaussian_noise = tfp.distributions.Normal(
jnp.zeros(num_points), obs_stddev * jnp.ones(num_points)
)
y = self.evaluate(X) + jnp.transpose(
gaussian_noise.sample(sample_shape=[1], seed=key)
)
return Dataset(X=X, y=y)

def generate_test_points(
Expand All @@ -73,6 +82,9 @@ def generate_test_points(
"""
return self.search_space.sample(num_points=num_points, key=key)

def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N 1"]:
return self.evaluate(x)

@abstractmethod
def evaluate(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]:
"""
Expand Down
16 changes: 8 additions & 8 deletions gpjax/decision_making/test_functions/non_conjugate_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@

import jax.numpy as jnp
import jax.random as jr
from jaxtyping import (
Array,
Float,
Integer,
)

from gpjax.dataset import Dataset
from gpjax.decision_making.search_space import ContinuousSearchSpace
from gpjax.typing import KeyArray
from gpjax.typing import (
Array,
Float,
Int,
KeyArray,
)


@dataclass
Expand Down Expand Up @@ -74,7 +74,7 @@ def generate_test_points(
return self.search_space.sample(num_points=num_points, key=key)

@abstractmethod
def evaluate(self, x: Float[Array, "N 1"]) -> Integer[Array, "N 1"]:
def evaluate(self, x: Float[Array, "N 1"]) -> Int[Array, "N 1"]:
"""
Evaluate the test function at a set of points. Function taken from
https://docs.jaxgaussianprocesses.com/examples/poisson/#dataset.
Expand All @@ -83,7 +83,7 @@ def evaluate(self, x: Float[Array, "N 1"]) -> Integer[Array, "N 1"]:
x (Float[Array, 'N D']): Points to evaluate the test function at.
Returns:
Integer[Array, 'N 1']: Values of the test function at the points.
Float[Array, 'N 1']: Values of the test function at the points.
"""
key = jr.key(42)
f = lambda x: 2.0 * jnp.sin(3 * x) + 0.5 * x
Expand Down
4 changes: 4 additions & 0 deletions gpjax/decision_making/utility_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
SinglePointUtilityFunction,
UtilityFunction,
)
from gpjax.decision_making.utility_functions.expected_improvement import (
ExpectedImprovement,
)
from gpjax.decision_making.utility_functions.probability_of_improvement import (
ProbabilityOfImprovement,
)
Expand All @@ -27,6 +30,7 @@
"UtilityFunction",
"AbstractUtilityFunctionBuilder",
"AbstractSinglePointUtilityFunctionBuilder",
"ExpectedImprovement",
"SinglePointUtilityFunction",
"ThompsonSampling",
"ProbabilityOfImprovement",
Expand Down
112 changes: 112 additions & 0 deletions gpjax/decision_making/utility_functions/expected_improvement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright 2024 The JaxGaussianProcesses Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from dataclasses import dataclass
from functools import partial

from beartype.typing import Mapping
import jax.numpy as jnp
import tensorflow_probability.substrates.jax as tfp

from gpjax.dataset import Dataset
from gpjax.decision_making.utility_functions.base import (
AbstractSinglePointUtilityFunctionBuilder,
SinglePointUtilityFunction,
)
from gpjax.decision_making.utils import (
OBJECTIVE,
get_best_latent_observation_val,
)
from gpjax.gps import ConjugatePosterior
from gpjax.typing import (
Array,
Float,
KeyArray,
)


@dataclass
class ExpectedImprovement(AbstractSinglePointUtilityFunctionBuilder):
"""
Expected Improvement acquisition function as introduced by [Močkus,
1974](https://link.springer.com/chapter/10.1007/3-540-07165-2_55). The "best"
incumbent value is defined as the lowest posterior mean value evaluated at the the
previously observed points. This enables the acquisition function to be utilised with noisy observations.
"""

def build_utility_function(
self,
posteriors: Mapping[str, ConjugatePosterior],
datasets: Mapping[str, Dataset],
key: KeyArray,
) -> SinglePointUtilityFunction:
r"""
Build the Expected Improvement acquisition function. This computes the expected
improvement over the "best" of the previously observed points, utilising the
posterior distribution of the surrogate model. For posterior distribution
$`f(\cdot)`$, and best incumbent value $`\eta`$, this is defined
as:
```math
\alpha_{\text{EI}}(\mathbf{x}) = \mathbb{E}\left[\max(0, \eta - f(\mathbf{x}))\right]
```
Args:
posteriors (Mapping[str, ConjugatePosterior]): Dictionary of posteriors to
used to form the utility function. One posteriors must correspond to the
`OBJECTIVE` key, as we utilise the objective posterior to form the utility
function.
datasets (Mapping[str, Dataset]): Dictionary of datasets used to form the
utility function. Keys in `datasets` should correspond to keys in
`posteriors`. One of the datasets must correspond to the `OBJECTIVE` key.
key (KeyArray): JAX PRNG key used for random number generation.
Returns:
SinglePointUtilityFunction: The Expected Improvement acquisition function to
to be *maximised* in order to decide which point to query next.
"""
self.check_objective_present(posteriors, datasets)
objective_posterior = posteriors[OBJECTIVE]
objective_dataset = datasets[OBJECTIVE]

if not isinstance(objective_posterior, ConjugatePosterior):
raise ValueError(
"Objective posterior must be a ConjugatePosterior to compute the Expected Improvement."
)

if (
objective_dataset.X is None
or objective_dataset.n == 0
or objective_dataset.y is None
):
raise ValueError("Objective dataset must contain at least one item")

eta = get_best_latent_observation_val(objective_posterior, objective_dataset)
return partial(
_expected_improvement, objective_posterior, objective_dataset, eta
)


def _expected_improvement(
objective_posterior: ConjugatePosterior,
objective_dataset: Dataset,
eta: Float[Array, ""],
x: Float[Array, "N D"],
) -> Float[Array, "N 1"]:
latent_dist = objective_posterior(x, objective_dataset)
mean = latent_dist.mean()
var = latent_dist.variance()
normal = tfp.distributions.Normal(mean, jnp.sqrt(var))
return jnp.expand_dims(
((eta - mean) * normal.cdf(eta) + var * normal.prob(eta)), -1
)
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
AbstractSinglePointUtilityFunctionBuilder,
SinglePointUtilityFunction,
)
from gpjax.decision_making.utils import OBJECTIVE
from gpjax.decision_making.utils import (
OBJECTIVE,
get_best_latent_observation_val,
)
from gpjax.gps import ConjugatePosterior
from gpjax.typing import (
Array,
Expand Down Expand Up @@ -107,14 +110,9 @@ def build_utility_function(
)

def probability_of_improvement(x_test: Num[Array, "N D"]):
# Computing the posterior mean for the training dataset
# for computing the best_y value (as the minimum
# posterior mean of the objective function)
predictive_dist_for_training = objective_posterior.predict(
objective_dataset.X, objective_dataset
best_y = get_best_latent_observation_val(
objective_posterior, objective_dataset
)
best_y = predictive_dist_for_training.mean().min()

predictive_dist = objective_posterior.predict(x_test, objective_dataset)

normal_dist = tfp.distributions.Normal(
Expand Down
11 changes: 4 additions & 7 deletions gpjax/decision_making/utility_functions/thompson_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@
SinglePointUtilityFunction,
)
from gpjax.decision_making.utils import OBJECTIVE
from gpjax.gps import (
ConjugatePosterior,
NonConjugatePosterior,
)
from gpjax.gps import ConjugatePosterior
from gpjax.typing import KeyArray


Expand Down Expand Up @@ -59,7 +56,7 @@ def __post_init__(self):

def build_utility_function(
self,
posteriors: Mapping[str, ConjugatePosterior | NonConjugatePosterior],
posteriors: Mapping[str, ConjugatePosterior],
datasets: Mapping[str, Dataset],
key: KeyArray,
) -> SinglePointUtilityFunction:
Expand All @@ -69,8 +66,8 @@ def build_utility_function(
are *maximised*.
Args:
posteriors (Mapping[str, AbstractPosterior]): Dictionary of posteriors to be
used to form the utility function. One of the posteriors must correspond
posteriors (Mapping[str, ConjugatePosterior]): Dictionary of posteriors to
be used to form the utility function. One of the posteriors must correspond
to the `OBJECTIVE` key, as we sample from the objective posterior to form
the utility function.
datasets (Mapping[str, Dataset]): Dictionary of datasets which may be used
Expand Down
14 changes: 14 additions & 0 deletions gpjax/decision_making/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
Dict,
Final,
)
import jax.numpy as jnp

from gpjax.dataset import Dataset
from gpjax.gps import AbstractPosterior
from gpjax.typing import (
Array,
Float,
Expand Down Expand Up @@ -48,3 +50,15 @@ def build_function_evaluator(
dictionary of datasets storing the evaluated points.
"""
return lambda x: {tag: Dataset(x, f(x)) for tag, f in functions.items()}


def get_best_latent_observation_val(
posterior: AbstractPosterior, dataset: Dataset
) -> Float[Array, ""]:
"""
Takes a posterior and dataset and returns the best (latent) function value in the
dataset, corresponding to the minimum of the posterior mean value evaluated at
locations in the dataset. In the noiseless case, this corresponds to the minimum
value in the dataset.
"""
return jnp.min(posterior(dataset.X, dataset).mean())
Loading

0 comments on commit b69be96

Please sign in to comment.