Skip to content

Commit

Permalink
Move new dispatch utils out of preview (#3468)
Browse files Browse the repository at this point in the history
Summary:

As titled. Also refactored slightly such that we wont be importing from ax.api anywhere in the codebase. To keep our module structure easy to reason about it is very important to keep the ax.api module at the root of our dep tree.

Differential Revision: D70647193
  • Loading branch information
mpolson64 authored and facebook-github-bot committed Mar 5, 2025
1 parent 0b209d6 commit a170edf
Show file tree
Hide file tree
Showing 7 changed files with 388 additions and 392 deletions.
203 changes: 203 additions & 0 deletions ax/generation_strategy/dispatch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import logging
import warnings
from enum import Enum
from math import ceil
from typing import Any, cast

Expand All @@ -16,10 +17,15 @@
from ax.core.optimization_config import OptimizationConfig
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.core.trial_status import TrialStatus
from ax.exceptions.core import UnsupportedError
from ax.generation_strategy.generation_strategy import (
GenerationNode,
GenerationStep,
GenerationStrategy,
)
from ax.generation_strategy.model_spec import GeneratorSpec
from ax.generation_strategy.transition_criterion import MinTrials
from ax.modelbridge.registry import (
Generators,
MODEL_KEY_TO_MODEL_SETUP,
Expand All @@ -30,10 +36,13 @@
from ax.models.torch.botorch_modular.model import (
BoTorchGenerator as ModularBoTorchGenerator,
)
from ax.models.torch.botorch_modular.surrogate import ModelConfig, SurrogateSpec
from ax.models.types import TConfig
from ax.models.winsorization_config import WinsorizationConfig
from ax.utils.common.deprecation import _validate_force_random_search
from ax.utils.common.logger import get_logger
from botorch.models.transforms.input import Normalize, Warp
from gpytorch.kernels.linear_kernel import LinearKernel
from pyre_extensions import none_throws


Expand All @@ -54,6 +63,200 @@
)


class GenerationMethod(Enum):
"""An enum to specify the desired candidate generation method for the experiment.
This is used in ``GenerationStrategyConfig``, along with the properties of the
experiment, to determine the generation strategy to use for candidate generation.
NOTE: New options should be rarely added to this enum. This is not intended to be
a list of generation strategies for the user to choose from. Instead, this enum
should only provide high level guidance to the underlying generation strategy
dispatch logic, which is responsible for determinining the exact details.
Available options are:
BALANCED: A balanced generation method that may utilize (per-metric) model
selection to achieve a good model accuracy. This method excludes expensive
methods, such as the fully Bayesian SAASBO model. Used by default.
FAST: A faster generation method that uses the built-in defaults from the
Modular BoTorch Model without any model selection.
RANDOM_SEARCH: Primarily intended for pure exploration experiments, this
method utilizes quasi-random Sobol sequences for candidate generation.
"""

BALANCED = "balanced"
FAST = "fast"
RANDOM_SEARCH = "random_search"


def _get_sobol_node(
initialization_budget: int | None = None,
initialization_random_seed: int | None = None,
use_existing_trials_for_initialization: bool = True,
min_observed_initialization_trials: int | None = None,
allow_exceeding_initialization_budget: bool = False,
) -> GenerationNode:
"""Constructs a Sobol node based on inputs from ``gs_config``.
The Sobol generator utilizes `initialization_random_seed` if specified.
This node always transitions to "MBM", using the following transition criteria:
- MinTrials enforcing the initialization budget.
- If the initialization budget is not specified, it defaults to 5.
- The TC will not block generation if `allow_exceeding_initialization_budget`
is set to True.
- The TC is currently not restricted to any trial statuses and will
count all trials.
- `use_existing_trials_for_initialization` controls whether trials previously
attached to the experiment are counted as part of the initialization budget.
- MinTrials enforcing the minimum number of observed initialization trials.
- If `min_observed_initialization_trials` is not specified, it defaults
to `max(1, initialization_budget // 2)`.
- The TC currently only counts trials in status COMPLETED (with data attached)
as observed trials.
- `use_existing_trials_for_initialization` controls whether trials previously
attached to the experiment are counted as part of the required number of
observed initialization trials.
"""
# Set the default options.
if initialization_budget is None:
initialization_budget = 5
if min_observed_initialization_trials is None:
min_observed_initialization_trials = max(1, initialization_budget // 2)
# Construct the transition criteria.
transition_criteria = [
MinTrials( # This represents the initialization budget.
threshold=initialization_budget,
transition_to="MBM",
block_gen_if_met=(not allow_exceeding_initialization_budget),
block_transition_if_unmet=True,
use_all_trials_in_exp=use_existing_trials_for_initialization,
),
MinTrials( # This represents minimum observed trials requirement.
threshold=min_observed_initialization_trials,
transition_to="MBM",
block_gen_if_met=False,
block_transition_if_unmet=True,
use_all_trials_in_exp=use_existing_trials_for_initialization,
only_in_statuses=[TrialStatus.COMPLETED],
count_only_trials_with_data=True,
),
]
return GenerationNode(
node_name="Sobol",
model_specs=[
GeneratorSpec(
model_enum=Generators.SOBOL,
model_kwargs={"seed": initialization_random_seed},
)
],
transition_criteria=transition_criteria,
should_deduplicate=True,
)


def _get_mbm_node(
method: GenerationMethod = GenerationMethod.FAST,
torch_device: str | None = None,
) -> GenerationNode:
"""Constructs an MBM node based on the method specified in ``gs_config``.
The ``SurrogateSpec`` takes the following form for the given method:
- BALANCED: Two model configs: one with MBM defaults, the other with
linear kernel with input warping.
- FAST: An empty model config that utilizes MBM defaults.
"""
# Construct the surrogate spec.
if method == GenerationMethod.FAST:
model_configs = [ModelConfig(name="MBM defaults")]
elif method == GenerationMethod.BALANCED:
model_configs = [
ModelConfig(name="MBM defaults"),
ModelConfig(
covar_module_class=LinearKernel,
input_transform_classes=[Warp, Normalize],
input_transform_options={"Normalize": {"center": 0.0}},
name="LinearKernel with Warp",
),
]
else:
raise UnsupportedError(f"Unsupported generation method: {method}.")

return GenerationNode(
node_name="MBM",
model_specs=[
GeneratorSpec(
model_enum=Generators.BOTORCH_MODULAR,
model_kwargs={
"surrogate_spec": SurrogateSpec(model_configs=model_configs),
"torch_device": None
if torch_device is None
else torch.device(torch_device),
},
)
],
should_deduplicate=True,
)


def choose_generation_strategy(
method: GenerationMethod = GenerationMethod.FAST,
# Initialization options
initialization_budget: int | None = None,
initialization_random_seed: int | None = None,
use_existing_trials_for_initialization: bool = True,
min_observed_initialization_trials: int | None = None,
allow_exceeding_initialization_budget: bool = False,
# Misc options
torch_device: str | None = None,
) -> GenerationStrategy:
"""Choose a generation strategy based on the properties of the experiment
and the inputs provided in ``gs_config``.
NOTE: The behavior of this function is subject to change. It will be updated to
produce best general purpose generation strategies based on benchmarking results.
Args:
gs_config: A ``GenerationStrategyConfig`` object that informs
the choice of generation strategy.
Returns:
A generation strategy.
"""
# Handle the random search case.
if method == GenerationMethod.RANDOM_SEARCH:
return GenerationStrategy(
name="QuasiRandomSearch",
nodes=[
GenerationNode(
node_name="Sobol",
model_specs=[
GeneratorSpec(
model_enum=Generators.SOBOL,
model_kwargs={"seed": initialization_random_seed},
)
],
)
],
)
# Construct the nodes.
sobol_node = _get_sobol_node(
initialization_budget=initialization_budget,
initialization_random_seed=initialization_random_seed,
use_existing_trials_for_initialization=use_existing_trials_for_initialization,
min_observed_initialization_trials=min_observed_initialization_trials,
allow_exceeding_initialization_budget=allow_exceeding_initialization_budget,
)
# Construct the MBM node.
mbm_node = _get_mbm_node(
method=method,
torch_device=torch_device,
)

return GenerationStrategy(
name=f"Sobol+MBM:{method.value}",
nodes=[sobol_node, mbm_node],
)


def _make_sobol_step(
num_trials: int = -1,
min_trials_observed: int | None = None,
Expand Down
Loading

0 comments on commit a170edf

Please sign in to comment.