Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move new dispatch utils out of preview #3468

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions ax/analysis/plotly/tests/test_predicted_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ax.core.observation import ObservationFeatures
from ax.core.trial import Trial
from ax.exceptions.core import UserInputError
from ax.generation_strategy.dispatch_utils import choose_generation_strategy
from ax.generation_strategy.dispatch_utils import choose_generation_strategy_legacy
from ax.modelbridge.prediction_utils import predict_at_point
from ax.modelbridge.registry import Generators
from ax.utils.common.testutils import TestCase
Expand Down Expand Up @@ -50,7 +50,7 @@ def test_compute_for_requires_a_gs(self) -> None:
def test_compute_for_requires_trials(self) -> None:
analysis = PredictedEffectsPlot(metric_name="branin")
experiment = get_branin_experiment()
generation_strategy = choose_generation_strategy(
generation_strategy = choose_generation_strategy_legacy(
search_space=experiment.search_space,
experiment=experiment,
)
Expand All @@ -62,7 +62,7 @@ def test_compute_for_requires_trials(self) -> None:
def test_compute_for_requires_a_model_that_predicts(self) -> None:
analysis = PredictedEffectsPlot(metric_name="branin")
experiment = get_branin_experiment(with_batch=True, with_completed_batch=True)
generation_strategy = choose_generation_strategy(
generation_strategy = choose_generation_strategy_legacy(
search_space=experiment.search_space,
experiment=experiment,
)
Expand Down Expand Up @@ -311,7 +311,7 @@ def test_it_does_not_plot_abandoned_trials(self) -> None:
def test_it_works_for_non_batch_experiments(self) -> None:
# GIVEN an experiment with the default generation strategy
experiment = get_branin_experiment(with_batch=False)
generation_strategy = choose_generation_strategy(
generation_strategy = choose_generation_strategy_legacy(
search_space=experiment.search_space,
experiment=experiment,
)
Expand Down
205 changes: 204 additions & 1 deletion ax/generation_strategy/dispatch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import logging
import warnings
from enum import Enum
from math import ceil
from typing import Any, cast

Expand All @@ -16,10 +17,15 @@
from ax.core.optimization_config import OptimizationConfig
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.core.trial_status import TrialStatus
from ax.exceptions.core import UnsupportedError
from ax.generation_strategy.generation_strategy import (
GenerationNode,
GenerationStep,
GenerationStrategy,
)
from ax.generation_strategy.model_spec import GeneratorSpec
from ax.generation_strategy.transition_criterion import MinTrials
from ax.modelbridge.registry import (
Generators,
MODEL_KEY_TO_MODEL_SETUP,
Expand All @@ -30,10 +36,13 @@
from ax.models.torch.botorch_modular.model import (
BoTorchGenerator as ModularBoTorchGenerator,
)
from ax.models.torch.botorch_modular.surrogate import ModelConfig, SurrogateSpec
from ax.models.types import TConfig
from ax.models.winsorization_config import WinsorizationConfig
from ax.utils.common.deprecation import _validate_force_random_search
from ax.utils.common.logger import get_logger
from botorch.models.transforms.input import Normalize, Warp
from gpytorch.kernels.linear_kernel import LinearKernel
from pyre_extensions import none_throws


Expand All @@ -54,6 +63,200 @@
)


class GenerationMethod(Enum):
"""An enum to specify the desired candidate generation method for the experiment.
This is used in ``GenerationStrategyConfig``, along with the properties of the
experiment, to determine the generation strategy to use for candidate generation.

NOTE: New options should be rarely added to this enum. This is not intended to be
a list of generation strategies for the user to choose from. Instead, this enum
should only provide high level guidance to the underlying generation strategy
dispatch logic, which is responsible for determinining the exact details.

Available options are:
BALANCED: A balanced generation method that may utilize (per-metric) model
selection to achieve a good model accuracy. This method excludes expensive
methods, such as the fully Bayesian SAASBO model. Used by default.
FAST: A faster generation method that uses the built-in defaults from the
Modular BoTorch Model without any model selection.
RANDOM_SEARCH: Primarily intended for pure exploration experiments, this
method utilizes quasi-random Sobol sequences for candidate generation.
"""

BALANCED = "balanced"
FAST = "fast"
RANDOM_SEARCH = "random_search"


def _get_sobol_node(
initialization_budget: int | None = None,
initialization_random_seed: int | None = None,
use_existing_trials_for_initialization: bool = True,
min_observed_initialization_trials: int | None = None,
allow_exceeding_initialization_budget: bool = False,
) -> GenerationNode:
"""Constructs a Sobol node based on inputs from ``gs_config``.
The Sobol generator utilizes `initialization_random_seed` if specified.

This node always transitions to "MBM", using the following transition criteria:
- MinTrials enforcing the initialization budget.
- If the initialization budget is not specified, it defaults to 5.
- The TC will not block generation if `allow_exceeding_initialization_budget`
is set to True.
- The TC is currently not restricted to any trial statuses and will
count all trials.
- `use_existing_trials_for_initialization` controls whether trials previously
attached to the experiment are counted as part of the initialization budget.
- MinTrials enforcing the minimum number of observed initialization trials.
- If `min_observed_initialization_trials` is not specified, it defaults
to `max(1, initialization_budget // 2)`.
- The TC currently only counts trials in status COMPLETED (with data attached)
as observed trials.
- `use_existing_trials_for_initialization` controls whether trials previously
attached to the experiment are counted as part of the required number of
observed initialization trials.
"""
# Set the default options.
if initialization_budget is None:
initialization_budget = 5
if min_observed_initialization_trials is None:
min_observed_initialization_trials = max(1, initialization_budget // 2)
# Construct the transition criteria.
transition_criteria = [
MinTrials( # This represents the initialization budget.
threshold=initialization_budget,
transition_to="MBM",
block_gen_if_met=(not allow_exceeding_initialization_budget),
block_transition_if_unmet=True,
use_all_trials_in_exp=use_existing_trials_for_initialization,
),
MinTrials( # This represents minimum observed trials requirement.
threshold=min_observed_initialization_trials,
transition_to="MBM",
block_gen_if_met=False,
block_transition_if_unmet=True,
use_all_trials_in_exp=use_existing_trials_for_initialization,
only_in_statuses=[TrialStatus.COMPLETED],
count_only_trials_with_data=True,
),
]
return GenerationNode(
node_name="Sobol",
model_specs=[
GeneratorSpec(
model_enum=Generators.SOBOL,
model_kwargs={"seed": initialization_random_seed},
)
],
transition_criteria=transition_criteria,
should_deduplicate=True,
)


def _get_mbm_node(
method: GenerationMethod = GenerationMethod.FAST,
torch_device: str | None = None,
) -> GenerationNode:
"""Constructs an MBM node based on the method specified in ``gs_config``.

The ``SurrogateSpec`` takes the following form for the given method:
- BALANCED: Two model configs: one with MBM defaults, the other with
linear kernel with input warping.
- FAST: An empty model config that utilizes MBM defaults.
"""
# Construct the surrogate spec.
if method == GenerationMethod.FAST:
model_configs = [ModelConfig(name="MBM defaults")]
elif method == GenerationMethod.BALANCED:
model_configs = [
ModelConfig(name="MBM defaults"),
ModelConfig(
covar_module_class=LinearKernel,
input_transform_classes=[Warp, Normalize],
input_transform_options={"Normalize": {"center": 0.0}},
name="LinearKernel with Warp",
),
]
else:
raise UnsupportedError(f"Unsupported generation method: {method}.")

return GenerationNode(
node_name="MBM",
model_specs=[
GeneratorSpec(
model_enum=Generators.BOTORCH_MODULAR,
model_kwargs={
"surrogate_spec": SurrogateSpec(model_configs=model_configs),
"torch_device": None
if torch_device is None
else torch.device(torch_device),
},
)
],
should_deduplicate=True,
)


def choose_generation_strategy(
method: GenerationMethod = GenerationMethod.FAST,
# Initialization options
initialization_budget: int | None = None,
initialization_random_seed: int | None = None,
use_existing_trials_for_initialization: bool = True,
min_observed_initialization_trials: int | None = None,
allow_exceeding_initialization_budget: bool = False,
# Misc options
torch_device: str | None = None,
) -> GenerationStrategy:
"""Choose a generation strategy based on the properties of the experiment
and the inputs provided in ``gs_config``.

NOTE: The behavior of this function is subject to change. It will be updated to
produce best general purpose generation strategies based on benchmarking results.

Args:
gs_config: A ``GenerationStrategyConfig`` object that informs
the choice of generation strategy.

Returns:
A generation strategy.
"""
# Handle the random search case.
if method == GenerationMethod.RANDOM_SEARCH:
return GenerationStrategy(
name="QuasiRandomSearch",
nodes=[
GenerationNode(
node_name="Sobol",
model_specs=[
GeneratorSpec(
model_enum=Generators.SOBOL,
model_kwargs={"seed": initialization_random_seed},
)
],
)
],
)
# Construct the nodes.
sobol_node = _get_sobol_node(
initialization_budget=initialization_budget,
initialization_random_seed=initialization_random_seed,
use_existing_trials_for_initialization=use_existing_trials_for_initialization,
min_observed_initialization_trials=min_observed_initialization_trials,
allow_exceeding_initialization_budget=allow_exceeding_initialization_budget,
)
# Construct the MBM node.
mbm_node = _get_mbm_node(
method=method,
torch_device=torch_device,
)

return GenerationStrategy(
name=f"Sobol+MBM:{method.value}",
nodes=[sobol_node, mbm_node],
)


def _make_sobol_step(
num_trials: int = -1,
min_trials_observed: int | None = None,
Expand Down Expand Up @@ -294,7 +497,7 @@ def calculate_num_initialization_trials(
return max(ret, 5)


def choose_generation_strategy(
def choose_generation_strategy_legacy(
search_space: SearchSpace,
*,
use_batch_trials: bool = False,
Expand Down
Loading