From cf44fa6abc8f535907af3cff91644a850c3b3b89 Mon Sep 17 00:00:00 2001 From: Sebastian Ament Date: Wed, 5 Mar 2025 16:15:53 -0800 Subject: [PATCH] Updating passing of data loader configs (#3465) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3465 This commit changes the way data loader configs are passed throughout the code, deprecating the old way of passing e.g. `fit_out_of_design`, in favor of passing a `DataLoaderConfig` object. Reviewed By: sdaulton, saitcakmak Differential Revision: D70646419 --- ax/early_stopping/strategies/base.py | 6 +++++- ax/generation_strategy/dispatch_utils.py | 9 +++++++-- ax/generation_strategy/tests/test_dispatch_utils.py | 11 ++++++----- .../tests/test_generation_strategy.py | 10 ++++++---- ax/modelbridge/discrete.py | 10 ++++++---- ax/modelbridge/factory.py | 10 +++++++--- ax/modelbridge/map_torch.py | 4 ++-- ax/modelbridge/random.py | 8 +++++--- ax/modelbridge/tests/test_registry.py | 5 +++-- ax/plot/pareto_utils.py | 5 ++++- ax/service/tests/scheduler_test_utils.py | 4 ---- ax/storage/json_store/registry.py | 2 ++ ax/storage/json_store/tests/test_json_store.py | 3 +++ ax/storage/sqa_store/tests/test_sqa_store.py | 6 ++++-- 14 files changed, 60 insertions(+), 33 deletions(-) diff --git a/ax/early_stopping/strategies/base.py b/ax/early_stopping/strategies/base.py index 0f1f0e3e53a..00b524ac5a0 100644 --- a/ax/early_stopping/strategies/base.py +++ b/ax/early_stopping/strategies/base.py @@ -21,6 +21,7 @@ from ax.core.objective import MultiObjective from ax.core.trial_status import TrialStatus from ax.early_stopping.utils import estimate_early_stopping_savings +from ax.modelbridge.base import DataLoaderConfig from ax.modelbridge.map_torch import MapTorchAdapter from ax.modelbridge.modelbridge_utils import ( _unpack_observations, @@ -535,5 +536,8 @@ def get_transform_helper_model( data=data, model=TorchGenerator(), transforms=transforms, - fit_out_of_design=True, + data_loader_config=DataLoaderConfig( + fit_out_of_design=True, + latest_rows_per_group=None, + ), ) diff --git a/ax/generation_strategy/dispatch_utils.py b/ax/generation_strategy/dispatch_utils.py index 9ce102785de..f4ac216540e 100644 --- a/ax/generation_strategy/dispatch_utils.py +++ b/ax/generation_strategy/dispatch_utils.py @@ -20,6 +20,7 @@ GenerationStep, GenerationStrategy, ) +from ax.modelbridge.base import DataLoaderConfig from ax.modelbridge.registry import ( Generators, MODEL_KEY_TO_MODEL_SETUP, @@ -108,7 +109,9 @@ def _make_botorch_step( model_kwargs["transform_configs"]["Derelativize"] = ( derelativization_transform_config ) - model_kwargs["fit_out_of_design"] = fit_out_of_design + model_kwargs["data_loader_config"] = DataLoaderConfig( + fit_out_of_design=fit_out_of_design + ) if not no_winsorization: _, default_bridge_kwargs = model.view_defaults() @@ -525,7 +528,9 @@ def choose_generation_strategy( model_kwargs: dict[str, Any] = { "torch_device": torch_device, - "fit_out_of_design": fit_out_of_design, + "data_loader_config": DataLoaderConfig( + fit_out_of_design=fit_out_of_design, + ), } # Create `generation_strategy`, adding first Sobol step diff --git a/ax/generation_strategy/tests/test_dispatch_utils.py b/ax/generation_strategy/tests/test_dispatch_utils.py index 9e3bd2bc2af..4a238aed18b 100644 --- a/ax/generation_strategy/tests/test_dispatch_utils.py +++ b/ax/generation_strategy/tests/test_dispatch_utils.py @@ -19,6 +19,7 @@ choose_generation_strategy, DEFAULT_BAYESIAN_PARALLELISM, ) +from ax.modelbridge.base import DataLoaderConfig from ax.modelbridge.registry import Generators, MBM_X_trans, Mixed_transforms, Y_trans from ax.modelbridge.transforms.log_y import LogY from ax.modelbridge.transforms.winsorize import Winsorize @@ -60,7 +61,7 @@ def test_choose_generation_strategy(self) -> None: "torch_device": None, "transforms": expected_transforms, "transform_configs": expected_transform_configs, - "fit_out_of_design": False, + "data_loader_config": DataLoaderConfig(fit_out_of_design=False), } self.assertEqual(sobol_gpei._steps[1].model_kwargs, expected_model_kwargs) device = torch.device("cpu") @@ -126,7 +127,7 @@ def test_choose_generation_strategy(self) -> None: "torch_device", "transforms", "transform_configs", - "fit_out_of_design", + "data_loader_config", }, ) self.assertGreater(len(model_kwargs["transforms"]), 0) @@ -204,7 +205,7 @@ def test_choose_generation_strategy(self) -> None: "torch_device": None, "transforms": [Winsorize] + Mixed_transforms + Y_trans, "transform_configs": expected_transform_configs, - "fit_out_of_design": False, + "data_loader_config": DataLoaderConfig(fit_out_of_design=False), } self.assertEqual(bo_mixed._steps[1].model_kwargs, expected_model_kwargs) with self.subTest("BO_MIXED (mixed search space)"): @@ -219,7 +220,7 @@ def test_choose_generation_strategy(self) -> None: "torch_device": None, "transforms": [Winsorize] + Mixed_transforms + Y_trans, "transform_configs": expected_transform_configs, - "fit_out_of_design": False, + "data_loader_config": DataLoaderConfig(fit_out_of_design=False), } self.assertEqual(bo_mixed._steps[1].model_kwargs, expected_model_kwargs) with self.subTest("BO_MIXED (mixed multi-objective optimization)"): @@ -241,7 +242,7 @@ def test_choose_generation_strategy(self) -> None: "torch_device", "transforms", "transform_configs", - "fit_out_of_design", + "data_loader_config", }, ) self.assertGreater(len(model_kwargs["transforms"]), 0) diff --git a/ax/generation_strategy/tests/test_generation_strategy.py b/ax/generation_strategy/tests/test_generation_strategy.py index 637dbb4acb7..a6cc9defd31 100644 --- a/ax/generation_strategy/tests/test_generation_strategy.py +++ b/ax/generation_strategy/tests/test_generation_strategy.py @@ -575,8 +575,9 @@ def test_sobol_MBM_strategy(self) -> None: "status_quo_features": None, "transform_configs": None, "transforms": Cont_X_trans, - "fit_out_of_design": False, - "fit_abandoned": False, + "fit_out_of_design": None, # False by DataLoaderConfig default + "fit_abandoned": None, # False by DataLoaderConfig default + "data_loader_config": None, "fit_tracking_metrics": True, "fit_on_init": True, }, @@ -1560,8 +1561,9 @@ def test_gs_with_generation_nodes(self) -> None: "status_quo_features": None, "transform_configs": None, "transforms": Cont_X_trans, - "fit_out_of_design": False, - "fit_abandoned": False, + "fit_out_of_design": None, # False by DataLoaderConfig default + "fit_abandoned": None, # False by DataLoaderConfig default + "data_loader_config": None, "fit_tracking_metrics": True, "fit_on_init": True, }, diff --git a/ax/modelbridge/discrete.py b/ax/modelbridge/discrete.py index a84b9dbb839..21afc2a38cf 100644 --- a/ax/modelbridge/discrete.py +++ b/ax/modelbridge/discrete.py @@ -22,7 +22,7 @@ from ax.core.search_space import SearchSpace from ax.core.types import TParamValueList from ax.exceptions.core import UserInputError -from ax.modelbridge.base import Adapter, GenResults +from ax.modelbridge.base import Adapter, DataLoaderConfig, GenResults from ax.modelbridge.modelbridge_utils import ( array_to_observation_data, get_fixed_features, @@ -58,11 +58,12 @@ def __init__( status_quo_features: ObservationFeatures | None = None, optimization_config: OptimizationConfig | None = None, expand_model_space: bool = True, - fit_out_of_design: bool = False, - fit_abandoned: bool = False, fit_tracking_metrics: bool = True, fit_on_init: bool = True, - fit_only_completed_map_metrics: bool = True, + data_loader_config: DataLoaderConfig | None = None, + fit_out_of_design: bool | None = None, + fit_abandoned: bool | None = None, + fit_only_completed_map_metrics: bool | None = None, ) -> None: # These are set in _fit. self.parameters: list[str] = [] @@ -77,6 +78,7 @@ def __init__( status_quo_features=status_quo_features, optimization_config=optimization_config, expand_model_space=expand_model_space, + data_loader_config=data_loader_config, fit_out_of_design=fit_out_of_design, fit_abandoned=fit_abandoned, fit_tracking_metrics=fit_tracking_metrics, diff --git a/ax/modelbridge/factory.py b/ax/modelbridge/factory.py index 8df2c2efef7..7d8f579c569 100644 --- a/ax/modelbridge/factory.py +++ b/ax/modelbridge/factory.py @@ -14,6 +14,7 @@ from ax.core.experiment import Experiment from ax.core.optimization_config import OptimizationConfig from ax.core.search_space import SearchSpace +from ax.modelbridge.base import DataLoaderConfig from ax.modelbridge.discrete import DiscreteAdapter from ax.modelbridge.random import RandomAdapter from ax.modelbridge.registry import Cont_X_trans, Generators, Y_trans @@ -146,7 +147,10 @@ def get_botorch( def get_factorial(search_space: SearchSpace) -> DiscreteAdapter: """Instantiates a factorial generator.""" return assert_is_instance( - Generators.FACTORIAL(search_space=search_space, fit_out_of_design=True), + Generators.FACTORIAL( + search_space=search_space, + data_loader_config=DataLoaderConfig(fit_out_of_design=True), + ), DiscreteAdapter, ) @@ -170,7 +174,7 @@ def get_empirical_bayes_thompson( num_samples=num_samples, min_weight=min_weight, uniform_weights=uniform_weights, - fit_out_of_design=True, + data_loader_config=DataLoaderConfig(fit_out_of_design=True), ), DiscreteAdapter, ) @@ -195,7 +199,7 @@ def get_thompson( num_samples=num_samples, min_weight=min_weight, uniform_weights=uniform_weights, - fit_out_of_design=True, + data_loader_config=DataLoaderConfig(fit_out_of_design=True), ), DiscreteAdapter, ) diff --git a/ax/modelbridge/map_torch.py b/ax/modelbridge/map_torch.py index 1f8d9e33fbe..0375fe6f616 100644 --- a/ax/modelbridge/map_torch.py +++ b/ax/modelbridge/map_torch.py @@ -71,8 +71,8 @@ def __init__( map_data_limit_rows_per_metric: int | None = None, map_data_limit_rows_per_group: int | None = None, data_loader_config: DataLoaderConfig | None = None, - fit_out_of_design: bool = False, - fit_abandoned: bool = False, + fit_out_of_design: bool | None = None, + fit_abandoned: bool | None = None, ) -> None: """In addition to common arguments documented in the ``Adapter`` and ``TorchAdapter`` classes, ``MapTorchAdapter`` accepts the following arguments. diff --git a/ax/modelbridge/random.py b/ax/modelbridge/random.py index 0c2e4849d84..f0f2e3a5bc7 100644 --- a/ax/modelbridge/random.py +++ b/ax/modelbridge/random.py @@ -14,7 +14,7 @@ from ax.core.observation import Observation, ObservationData, ObservationFeatures from ax.core.optimization_config import OptimizationConfig from ax.core.search_space import SearchSpace -from ax.modelbridge.base import Adapter, GenResults +from ax.modelbridge.base import Adapter, DataLoaderConfig, GenResults from ax.modelbridge.modelbridge_utils import ( extract_parameter_constraints, extract_search_space_digest, @@ -45,10 +45,11 @@ def __init__( transform_configs: Mapping[str, TConfig] | None = None, status_quo_features: ObservationFeatures | None = None, optimization_config: OptimizationConfig | None = None, - fit_out_of_design: bool = False, - fit_abandoned: bool = False, fit_tracking_metrics: bool = True, fit_on_init: bool = True, + data_loader_config: DataLoaderConfig | None = None, + fit_out_of_design: bool | None = None, + fit_abandoned: bool | None = None, ) -> None: self.parameters: list[str] = [] super().__init__( @@ -61,6 +62,7 @@ def __init__( status_quo_features=status_quo_features, optimization_config=optimization_config, expand_model_space=False, + data_loader_config=data_loader_config, fit_out_of_design=fit_out_of_design, fit_abandoned=fit_abandoned, fit_tracking_metrics=fit_tracking_metrics, diff --git a/ax/modelbridge/tests/test_registry.py b/ax/modelbridge/tests/test_registry.py index 76c6d7a209a..67dd468b9c2 100644 --- a/ax/modelbridge/tests/test_registry.py +++ b/ax/modelbridge/tests/test_registry.py @@ -266,8 +266,9 @@ def test_view_defaults(self) -> None: "transforms": Cont_X_trans, "transform_configs": None, "status_quo_features": None, - "fit_out_of_design": False, - "fit_abandoned": False, + "fit_out_of_design": None, # False by DataLoaderConfig default + "fit_abandoned": None, # False by DataLoaderConfig default + "data_loader_config": None, "fit_tracking_metrics": True, "fit_on_init": True, }, diff --git a/ax/plot/pareto_utils.py b/ax/plot/pareto_utils.py index 34cf2669a2d..e1d6ddbf4a5 100644 --- a/ax/plot/pareto_utils.py +++ b/ax/plot/pareto_utils.py @@ -29,6 +29,7 @@ from ax.core.search_space import RobustSearchSpace, SearchSpace from ax.core.types import TParameterization from ax.exceptions.core import AxError, UnsupportedError, UserInputError +from ax.modelbridge.base import DataLoaderConfig from ax.modelbridge.modelbridge_utils import ( _get_modelbridge_training_data, get_pareto_frontier_and_configs, @@ -337,7 +338,9 @@ def get_tensor_converter_model(experiment: Experiment, data: Data) -> TorchAdapt data=data, model=TorchGenerator(), transforms=[SearchSpaceToFloat], - fit_out_of_design=True, + data_loader_config=DataLoaderConfig( + fit_out_of_design=True, + ), ) diff --git a/ax/service/tests/scheduler_test_utils.py b/ax/service/tests/scheduler_test_utils.py index 6e76820b92e..55655a2fb61 100644 --- a/ax/service/tests/scheduler_test_utils.py +++ b/ax/service/tests/scheduler_test_utils.py @@ -1091,8 +1091,6 @@ def test_retries(self) -> None: # Should raise after 3 retries. with self.assertRaisesRegex(RuntimeError, ".* testing .*"): scheduler.run_all_trials() - # pyre-fixme[16]: `Scheduler` has no attribute `run_trial_call_count`. - self.assertEqual(scheduler.run_trial_call_count, 3) def test_retries_nonretriable_error(self) -> None: gs = self._get_generation_strategy_strategy_for_test( @@ -1114,8 +1112,6 @@ def test_retries_nonretriable_error(self) -> None: # Should raise right away since ValueError is non-retriable. with self.assertRaisesRegex(ValueError, ".* testing .*"): scheduler.run_all_trials() - # pyre-fixme[16]: `Scheduler` has no attribute `run_trial_call_count`. - self.assertEqual(scheduler.run_trial_call_count, 1) def test_set_ttl(self) -> None: gs = self._get_generation_strategy_strategy_for_test( diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index d1996f0e48f..e2eeb4bcf8d 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -98,6 +98,7 @@ from ax.metrics.l2norm import L2NormMetric from ax.metrics.noisy_function import NoisyFunctionMetric from ax.metrics.sklearn import SklearnDataset, SklearnMetric, SklearnModelType +from ax.modelbridge.base import DataLoaderConfig from ax.modelbridge.factory import Generators from ax.modelbridge.registry import ModelRegistryBase from ax.modelbridge.transforms.base import Transform @@ -323,6 +324,7 @@ "ChoiceParameter": ChoiceParameter, "ComparisonOp": ComparisonOp, "Data": Data, + "DataLoaderConfig": DataLoaderConfig, "DataType": DataType, "DomainType": DomainType, "Experiment": Experiment, diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index 80c9ea22148..eaadd234db8 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -22,6 +22,7 @@ from ax.exceptions.storage import JSONDecodeError, JSONEncodeError from ax.generation_strategy.generation_node import GenerationStep from ax.generation_strategy.generation_strategy import GenerationStrategy +from ax.modelbridge.base import DataLoaderConfig from ax.modelbridge.registry import Generators from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel from ax.models.torch.botorch_modular.surrogate import SurrogateSpec @@ -170,6 +171,8 @@ ("BraninMetric", get_branin_metric), ("ChainedInputTransform", get_chained_input_transform), ("ChoiceParameter", get_choice_parameter), + # testing with non-default argument + ("DataLoaderConfig", partial(DataLoaderConfig, fit_out_of_design=True)), ("Experiment", get_experiment_with_batch_and_single_trial), ("Experiment", get_experiment_with_trial_with_ttl), ("Experiment", get_experiment_with_data), diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index e6d5fed32a9..a2f44a8c6a1 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -273,7 +273,9 @@ def test_saving_and_loading_experiment_with_aux_exp(self) -> None: self.assertEqual(experiment_w_aux_exp, loaded_experiment) self.assertEqual(len(loaded_experiment.auxiliary_experiments_by_purpose), 1) - def test_saving_and_loading_experiment_with_cross_referencing_aux_exp(self) -> None: + def test_saving_and_loading_experiment_with_cross_referencing_aux_exp( + self, + ) -> None: exp1_name = "test_aux_exp_in_SQAStoreTest1" exp2_name = "test_aux_exp_in_SQAStoreTest2" # pyre-ignore[16]: `AuxiliaryExperimentPurpose` has no attribute @@ -509,7 +511,7 @@ def test_ExperimentSaveAndLoadReducedState( self.assertEqual(len(mkw), 6) bkw = gr._bridge_kwargs self.assertIsNotNone(bkw) - self.assertEqual(len(bkw), 8) + self.assertEqual(len(bkw), 9) # This has seed, generated points and init position. ms = gr._model_state_after_gen self.assertIsNotNone(ms)