Skip to content

Commit

Permalink
Updating passing of data loader configs (facebook#3465)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebook#3465

This commit changes the way data loader configs are passed throughout the code, deprecating the old way of passing e.g. `fit_out_of_design`, in favor of passing a `DataLoaderConfig` object.

Reviewed By: sdaulton, saitcakmak

Differential Revision: D70646419
  • Loading branch information
SebastianAment authored and facebook-github-bot committed Mar 6, 2025
1 parent 36d55fc commit cf44fa6
Show file tree
Hide file tree
Showing 14 changed files with 60 additions and 33 deletions.
6 changes: 5 additions & 1 deletion ax/early_stopping/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ax.core.objective import MultiObjective
from ax.core.trial_status import TrialStatus
from ax.early_stopping.utils import estimate_early_stopping_savings
from ax.modelbridge.base import DataLoaderConfig
from ax.modelbridge.map_torch import MapTorchAdapter
from ax.modelbridge.modelbridge_utils import (
_unpack_observations,
Expand Down Expand Up @@ -535,5 +536,8 @@ def get_transform_helper_model(
data=data,
model=TorchGenerator(),
transforms=transforms,
fit_out_of_design=True,
data_loader_config=DataLoaderConfig(
fit_out_of_design=True,
latest_rows_per_group=None,
),
)
9 changes: 7 additions & 2 deletions ax/generation_strategy/dispatch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
GenerationStep,
GenerationStrategy,
)
from ax.modelbridge.base import DataLoaderConfig
from ax.modelbridge.registry import (
Generators,
MODEL_KEY_TO_MODEL_SETUP,
Expand Down Expand Up @@ -108,7 +109,9 @@ def _make_botorch_step(
model_kwargs["transform_configs"]["Derelativize"] = (
derelativization_transform_config
)
model_kwargs["fit_out_of_design"] = fit_out_of_design
model_kwargs["data_loader_config"] = DataLoaderConfig(
fit_out_of_design=fit_out_of_design
)

if not no_winsorization:
_, default_bridge_kwargs = model.view_defaults()
Expand Down Expand Up @@ -525,7 +528,9 @@ def choose_generation_strategy(

model_kwargs: dict[str, Any] = {
"torch_device": torch_device,
"fit_out_of_design": fit_out_of_design,
"data_loader_config": DataLoaderConfig(
fit_out_of_design=fit_out_of_design,
),
}

# Create `generation_strategy`, adding first Sobol step
Expand Down
11 changes: 6 additions & 5 deletions ax/generation_strategy/tests/test_dispatch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
choose_generation_strategy,
DEFAULT_BAYESIAN_PARALLELISM,
)
from ax.modelbridge.base import DataLoaderConfig
from ax.modelbridge.registry import Generators, MBM_X_trans, Mixed_transforms, Y_trans
from ax.modelbridge.transforms.log_y import LogY
from ax.modelbridge.transforms.winsorize import Winsorize
Expand Down Expand Up @@ -60,7 +61,7 @@ def test_choose_generation_strategy(self) -> None:
"torch_device": None,
"transforms": expected_transforms,
"transform_configs": expected_transform_configs,
"fit_out_of_design": False,
"data_loader_config": DataLoaderConfig(fit_out_of_design=False),
}
self.assertEqual(sobol_gpei._steps[1].model_kwargs, expected_model_kwargs)
device = torch.device("cpu")
Expand Down Expand Up @@ -126,7 +127,7 @@ def test_choose_generation_strategy(self) -> None:
"torch_device",
"transforms",
"transform_configs",
"fit_out_of_design",
"data_loader_config",
},
)
self.assertGreater(len(model_kwargs["transforms"]), 0)
Expand Down Expand Up @@ -204,7 +205,7 @@ def test_choose_generation_strategy(self) -> None:
"torch_device": None,
"transforms": [Winsorize] + Mixed_transforms + Y_trans,
"transform_configs": expected_transform_configs,
"fit_out_of_design": False,
"data_loader_config": DataLoaderConfig(fit_out_of_design=False),
}
self.assertEqual(bo_mixed._steps[1].model_kwargs, expected_model_kwargs)
with self.subTest("BO_MIXED (mixed search space)"):
Expand All @@ -219,7 +220,7 @@ def test_choose_generation_strategy(self) -> None:
"torch_device": None,
"transforms": [Winsorize] + Mixed_transforms + Y_trans,
"transform_configs": expected_transform_configs,
"fit_out_of_design": False,
"data_loader_config": DataLoaderConfig(fit_out_of_design=False),
}
self.assertEqual(bo_mixed._steps[1].model_kwargs, expected_model_kwargs)
with self.subTest("BO_MIXED (mixed multi-objective optimization)"):
Expand All @@ -241,7 +242,7 @@ def test_choose_generation_strategy(self) -> None:
"torch_device",
"transforms",
"transform_configs",
"fit_out_of_design",
"data_loader_config",
},
)
self.assertGreater(len(model_kwargs["transforms"]), 0)
Expand Down
10 changes: 6 additions & 4 deletions ax/generation_strategy/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,8 +575,9 @@ def test_sobol_MBM_strategy(self) -> None:
"status_quo_features": None,
"transform_configs": None,
"transforms": Cont_X_trans,
"fit_out_of_design": False,
"fit_abandoned": False,
"fit_out_of_design": None, # False by DataLoaderConfig default
"fit_abandoned": None, # False by DataLoaderConfig default
"data_loader_config": None,
"fit_tracking_metrics": True,
"fit_on_init": True,
},
Expand Down Expand Up @@ -1560,8 +1561,9 @@ def test_gs_with_generation_nodes(self) -> None:
"status_quo_features": None,
"transform_configs": None,
"transforms": Cont_X_trans,
"fit_out_of_design": False,
"fit_abandoned": False,
"fit_out_of_design": None, # False by DataLoaderConfig default
"fit_abandoned": None, # False by DataLoaderConfig default
"data_loader_config": None,
"fit_tracking_metrics": True,
"fit_on_init": True,
},
Expand Down
10 changes: 6 additions & 4 deletions ax/modelbridge/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ax.core.search_space import SearchSpace
from ax.core.types import TParamValueList
from ax.exceptions.core import UserInputError
from ax.modelbridge.base import Adapter, GenResults
from ax.modelbridge.base import Adapter, DataLoaderConfig, GenResults
from ax.modelbridge.modelbridge_utils import (
array_to_observation_data,
get_fixed_features,
Expand Down Expand Up @@ -58,11 +58,12 @@ def __init__(
status_quo_features: ObservationFeatures | None = None,
optimization_config: OptimizationConfig | None = None,
expand_model_space: bool = True,
fit_out_of_design: bool = False,
fit_abandoned: bool = False,
fit_tracking_metrics: bool = True,
fit_on_init: bool = True,
fit_only_completed_map_metrics: bool = True,
data_loader_config: DataLoaderConfig | None = None,
fit_out_of_design: bool | None = None,
fit_abandoned: bool | None = None,
fit_only_completed_map_metrics: bool | None = None,
) -> None:
# These are set in _fit.
self.parameters: list[str] = []
Expand All @@ -77,6 +78,7 @@ def __init__(
status_quo_features=status_quo_features,
optimization_config=optimization_config,
expand_model_space=expand_model_space,
data_loader_config=data_loader_config,
fit_out_of_design=fit_out_of_design,
fit_abandoned=fit_abandoned,
fit_tracking_metrics=fit_tracking_metrics,
Expand Down
10 changes: 7 additions & 3 deletions ax/modelbridge/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ax.core.experiment import Experiment
from ax.core.optimization_config import OptimizationConfig
from ax.core.search_space import SearchSpace
from ax.modelbridge.base import DataLoaderConfig
from ax.modelbridge.discrete import DiscreteAdapter
from ax.modelbridge.random import RandomAdapter
from ax.modelbridge.registry import Cont_X_trans, Generators, Y_trans
Expand Down Expand Up @@ -146,7 +147,10 @@ def get_botorch(
def get_factorial(search_space: SearchSpace) -> DiscreteAdapter:
"""Instantiates a factorial generator."""
return assert_is_instance(
Generators.FACTORIAL(search_space=search_space, fit_out_of_design=True),
Generators.FACTORIAL(
search_space=search_space,
data_loader_config=DataLoaderConfig(fit_out_of_design=True),
),
DiscreteAdapter,
)

Expand All @@ -170,7 +174,7 @@ def get_empirical_bayes_thompson(
num_samples=num_samples,
min_weight=min_weight,
uniform_weights=uniform_weights,
fit_out_of_design=True,
data_loader_config=DataLoaderConfig(fit_out_of_design=True),
),
DiscreteAdapter,
)
Expand All @@ -195,7 +199,7 @@ def get_thompson(
num_samples=num_samples,
min_weight=min_weight,
uniform_weights=uniform_weights,
fit_out_of_design=True,
data_loader_config=DataLoaderConfig(fit_out_of_design=True),
),
DiscreteAdapter,
)
4 changes: 2 additions & 2 deletions ax/modelbridge/map_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def __init__(
map_data_limit_rows_per_metric: int | None = None,
map_data_limit_rows_per_group: int | None = None,
data_loader_config: DataLoaderConfig | None = None,
fit_out_of_design: bool = False,
fit_abandoned: bool = False,
fit_out_of_design: bool | None = None,
fit_abandoned: bool | None = None,
) -> None:
"""In addition to common arguments documented in the ``Adapter`` and
``TorchAdapter`` classes, ``MapTorchAdapter`` accepts the following arguments.
Expand Down
8 changes: 5 additions & 3 deletions ax/modelbridge/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ax.core.observation import Observation, ObservationData, ObservationFeatures
from ax.core.optimization_config import OptimizationConfig
from ax.core.search_space import SearchSpace
from ax.modelbridge.base import Adapter, GenResults
from ax.modelbridge.base import Adapter, DataLoaderConfig, GenResults
from ax.modelbridge.modelbridge_utils import (
extract_parameter_constraints,
extract_search_space_digest,
Expand Down Expand Up @@ -45,10 +45,11 @@ def __init__(
transform_configs: Mapping[str, TConfig] | None = None,
status_quo_features: ObservationFeatures | None = None,
optimization_config: OptimizationConfig | None = None,
fit_out_of_design: bool = False,
fit_abandoned: bool = False,
fit_tracking_metrics: bool = True,
fit_on_init: bool = True,
data_loader_config: DataLoaderConfig | None = None,
fit_out_of_design: bool | None = None,
fit_abandoned: bool | None = None,
) -> None:
self.parameters: list[str] = []
super().__init__(
Expand All @@ -61,6 +62,7 @@ def __init__(
status_quo_features=status_quo_features,
optimization_config=optimization_config,
expand_model_space=False,
data_loader_config=data_loader_config,
fit_out_of_design=fit_out_of_design,
fit_abandoned=fit_abandoned,
fit_tracking_metrics=fit_tracking_metrics,
Expand Down
5 changes: 3 additions & 2 deletions ax/modelbridge/tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,9 @@ def test_view_defaults(self) -> None:
"transforms": Cont_X_trans,
"transform_configs": None,
"status_quo_features": None,
"fit_out_of_design": False,
"fit_abandoned": False,
"fit_out_of_design": None, # False by DataLoaderConfig default
"fit_abandoned": None, # False by DataLoaderConfig default
"data_loader_config": None,
"fit_tracking_metrics": True,
"fit_on_init": True,
},
Expand Down
5 changes: 4 additions & 1 deletion ax/plot/pareto_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ax.core.search_space import RobustSearchSpace, SearchSpace
from ax.core.types import TParameterization
from ax.exceptions.core import AxError, UnsupportedError, UserInputError
from ax.modelbridge.base import DataLoaderConfig
from ax.modelbridge.modelbridge_utils import (
_get_modelbridge_training_data,
get_pareto_frontier_and_configs,
Expand Down Expand Up @@ -337,7 +338,9 @@ def get_tensor_converter_model(experiment: Experiment, data: Data) -> TorchAdapt
data=data,
model=TorchGenerator(),
transforms=[SearchSpaceToFloat],
fit_out_of_design=True,
data_loader_config=DataLoaderConfig(
fit_out_of_design=True,
),
)


Expand Down
4 changes: 0 additions & 4 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,8 +1091,6 @@ def test_retries(self) -> None:
# Should raise after 3 retries.
with self.assertRaisesRegex(RuntimeError, ".* testing .*"):
scheduler.run_all_trials()
# pyre-fixme[16]: `Scheduler` has no attribute `run_trial_call_count`.
self.assertEqual(scheduler.run_trial_call_count, 3)

def test_retries_nonretriable_error(self) -> None:
gs = self._get_generation_strategy_strategy_for_test(
Expand All @@ -1114,8 +1112,6 @@ def test_retries_nonretriable_error(self) -> None:
# Should raise right away since ValueError is non-retriable.
with self.assertRaisesRegex(ValueError, ".* testing .*"):
scheduler.run_all_trials()
# pyre-fixme[16]: `Scheduler` has no attribute `run_trial_call_count`.
self.assertEqual(scheduler.run_trial_call_count, 1)

def test_set_ttl(self) -> None:
gs = self._get_generation_strategy_strategy_for_test(
Expand Down
2 changes: 2 additions & 0 deletions ax/storage/json_store/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
from ax.metrics.l2norm import L2NormMetric
from ax.metrics.noisy_function import NoisyFunctionMetric
from ax.metrics.sklearn import SklearnDataset, SklearnMetric, SklearnModelType
from ax.modelbridge.base import DataLoaderConfig
from ax.modelbridge.factory import Generators
from ax.modelbridge.registry import ModelRegistryBase
from ax.modelbridge.transforms.base import Transform
Expand Down Expand Up @@ -323,6 +324,7 @@
"ChoiceParameter": ChoiceParameter,
"ComparisonOp": ComparisonOp,
"Data": Data,
"DataLoaderConfig": DataLoaderConfig,
"DataType": DataType,
"DomainType": DomainType,
"Experiment": Experiment,
Expand Down
3 changes: 3 additions & 0 deletions ax/storage/json_store/tests/test_json_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ax.exceptions.storage import JSONDecodeError, JSONEncodeError
from ax.generation_strategy.generation_node import GenerationStep
from ax.generation_strategy.generation_strategy import GenerationStrategy
from ax.modelbridge.base import DataLoaderConfig
from ax.modelbridge.registry import Generators
from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel
from ax.models.torch.botorch_modular.surrogate import SurrogateSpec
Expand Down Expand Up @@ -170,6 +171,8 @@
("BraninMetric", get_branin_metric),
("ChainedInputTransform", get_chained_input_transform),
("ChoiceParameter", get_choice_parameter),
# testing with non-default argument
("DataLoaderConfig", partial(DataLoaderConfig, fit_out_of_design=True)),
("Experiment", get_experiment_with_batch_and_single_trial),
("Experiment", get_experiment_with_trial_with_ttl),
("Experiment", get_experiment_with_data),
Expand Down
6 changes: 4 additions & 2 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,9 @@ def test_saving_and_loading_experiment_with_aux_exp(self) -> None:
self.assertEqual(experiment_w_aux_exp, loaded_experiment)
self.assertEqual(len(loaded_experiment.auxiliary_experiments_by_purpose), 1)

def test_saving_and_loading_experiment_with_cross_referencing_aux_exp(self) -> None:
def test_saving_and_loading_experiment_with_cross_referencing_aux_exp(
self,
) -> None:
exp1_name = "test_aux_exp_in_SQAStoreTest1"
exp2_name = "test_aux_exp_in_SQAStoreTest2"
# pyre-ignore[16]: `AuxiliaryExperimentPurpose` has no attribute
Expand Down Expand Up @@ -509,7 +511,7 @@ def test_ExperimentSaveAndLoadReducedState(
self.assertEqual(len(mkw), 6)
bkw = gr._bridge_kwargs
self.assertIsNotNone(bkw)
self.assertEqual(len(bkw), 8)
self.assertEqual(len(bkw), 9)
# This has seed, generated points and init position.
ms = gr._model_state_after_gen
self.assertIsNotNone(ms)
Expand Down

0 comments on commit cf44fa6

Please sign in to comment.