Skip to content

Commit 767e7d8

Browse files
generatedunixname89002005232357facebook-github-bot
authored andcommitted
Revert D42270109: Multisect successfully blamed D42270109 for test or build failures
Summary: This diff is reverting D42270109 D42270109: [Ax] Update the model in ModelSpec.fit if all non-data args are the same by saitcakmak has been identified to be causing the following test or build failures: Tests affected: - [pts:pts_test - pts.tests.test_client.PTSClientTest: test_plot_slice](https://www.internalfb.com/intern/test/562950033661552/) Here's the Multisect link: https://www.internalfb.com/multisect/1699980 Here are the tasks that are relevant to this breakage: We're generating a revert to back out the changes in this diff, please note the backout may land if someone accepts it. Reviewed By: Balandat Differential Revision: D44089870 fbshipit-source-id: 594ac92b106ef46c878d843eda4f8f815d2132b0
1 parent 3f36e59 commit 767e7d8

File tree

9 files changed

+46
-184
lines changed

9 files changed

+46
-184
lines changed

ax/modelbridge/base.py

Lines changed: 30 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ def __init__(
9494
optimization_config: Optional[OptimizationConfig] = None,
9595
fit_out_of_design: bool = False,
9696
fit_abandoned: bool = False,
97-
fit_on_init: bool = True,
9897
) -> None:
9998
"""
10099
Applies transforms and fits model.
@@ -123,19 +122,12 @@ def __init__(
123122
fit_abandoned: Whether data for abandoned arms or trials should be
124123
included in model training data. If ``False``, only
125124
non-abandoned points are returned.
126-
fit_on_init: Whether to fit the model on initialization. This can
127-
be used to skip model fitting when a fitted model is not needed.
128-
To fit the model afterwards, use `_process_and_transform_data`
129-
to get the transformed inputs and call `_fit_if_implemented` with
130-
the transformed inputs.
131125
"""
132-
t_fit_start = time.monotonic()
126+
t_fit_start = time.time()
133127
transforms = transforms or []
134128
# pyre-ignore: Cast is a Tranform
135129
transforms: List[Type[Transform]] = [Cast] + transforms
136130

137-
self.fit_time: float = 0.0
138-
self.fit_time_since_gen: float = 0.0
139131
self._metric_names: Set[str] = set()
140132
self._training_data: List[Observation] = []
141133
self._optimization_config: Optional[OptimizationConfig] = optimization_config
@@ -146,89 +138,56 @@ def __init__(
146138
self._model_key: Optional[str] = None
147139
self._model_kwargs: Optional[Dict[str, Any]] = None
148140
self._bridge_kwargs: Optional[Dict[str, Any]] = None
149-
self._model_space: SearchSpace = search_space.clone()
141+
142+
# pyre-fixme[4]: Attribute must be annotated.
143+
self._model_space = search_space.clone()
150144
self._raw_transforms = transforms
151145
self._transform_configs: Optional[Dict[str, TConfig]] = transform_configs
152146
self._fit_out_of_design = fit_out_of_design
153147
self._fit_abandoned = fit_abandoned
154-
self._experiment_has_immutable_search_space_and_opt_config: bool = (
155-
experiment is not None and experiment.immutable_search_space_and_opt_config
156-
)
148+
imm = experiment and experiment.immutable_search_space_and_opt_config
149+
# pyre-fixme[4]: Attribute must be annotated.
150+
self._experiment_has_immutable_search_space_and_opt_config = imm
157151
if experiment is not None:
158152
if self._optimization_config is None:
159153
self._optimization_config = experiment.optimization_config
160154
self._arms_by_signature = experiment.arms_by_signature
161155

162-
# Convert Data to Observations, transform observations & search space.
163-
observations, search_space = self._process_and_transform_data(
164-
experiment=experiment, data=data
165-
)
156+
# Convert Data to Observations
157+
observations = self._prepare_observations(experiment=experiment, data=data)
166158

167-
# Set model status quo.
159+
observations_raw = self._set_training_data(
160+
observations=observations, search_space=search_space
161+
)
162+
# Set model status quo
168163
# NOTE: training data must be set before setting the status quo.
169164
self._set_status_quo(
170165
experiment=experiment,
171166
status_quo_name=status_quo_name,
172167
status_quo_features=status_quo_features,
173168
)
169+
observations, search_space = self._transform_data(
170+
observations=observations_raw,
171+
search_space=search_space,
172+
transforms=transforms,
173+
transform_configs=transform_configs,
174+
)
174175

175-
# Save model, apply terminal transform, and fit.
176+
# Save model, apply terminal transform, and fit
176177
self.model = model
177-
if fit_on_init:
178-
self._fit_if_implemented(
179-
search_space=search_space,
180-
observations=observations,
181-
time_so_far=time.monotonic() - t_fit_start,
182-
)
183-
184-
def _fit_if_implemented(
185-
self,
186-
search_space: SearchSpace,
187-
observations: List[Observation],
188-
time_so_far: float,
189-
) -> None:
190-
r"""Fits the model if `_fit` is implemented and stores fit time.
191-
192-
Args:
193-
search_space: A transformed search space for fitting the model.
194-
observations: The observations to fit the model with. These should
195-
also be transformed.
196-
time_so_far: Time spent in initializing the model up to
197-
`_fit_if_implemented` call.
198-
"""
199178
try:
200-
t_fit_start = time.monotonic()
201179
self._fit(
202-
model=self.model,
180+
model=model,
203181
search_space=search_space,
204182
observations=observations,
205183
)
206-
self.fit_time += time.monotonic() - t_fit_start + time_so_far
207-
self.fit_time_since_gen += self.fit_time
184+
# pyre-fixme[4]: Attribute must be annotated.
185+
self.fit_time = time.time() - t_fit_start
186+
# pyre-fixme[4]: Attribute must be annotated.
187+
self.fit_time_since_gen = float(self.fit_time)
208188
except NotImplementedError:
209-
pass
210-
211-
def _process_and_transform_data(
212-
self,
213-
experiment: Optional[Experiment] = None,
214-
data: Optional[Data] = None,
215-
) -> Tuple[List[Observation], SearchSpace]:
216-
r"""Processes the data into observations and returns transformed
217-
observations and the search space. This packages the following methods:
218-
* self._prepare_observations
219-
* self._set_training_data
220-
* self._transform_data
221-
"""
222-
observations = self._prepare_observations(experiment=experiment, data=data)
223-
observations_raw = self._set_training_data(
224-
observations=observations, search_space=self._model_space
225-
)
226-
return self._transform_data(
227-
observations=observations_raw,
228-
search_space=self._model_space,
229-
transforms=self._raw_transforms,
230-
transform_configs=self._transform_configs,
231-
)
189+
self.fit_time = 0.0
190+
self.fit_time_since_gen = 0.0
232191

233192
def _prepare_observations(
234193
self, experiment: Optional[Experiment], data: Optional[Data]
@@ -607,7 +566,7 @@ def update(self, new_data: Data, experiment: Experiment) -> None:
607566
`update`.
608567
experiment: Experiment, in which this data was obtained.
609568
"""
610-
t_update_start = time.monotonic()
569+
t_update_start = time.time()
611570
observations = self._prepare_observations(experiment=experiment, data=new_data)
612571
obs_raw = self._extend_training_data(observations=observations)
613572
observations, search_space = self._transform_data(
@@ -620,8 +579,8 @@ def update(self, new_data: Data, experiment: Experiment) -> None:
620579
search_space=search_space,
621580
observations=observations,
622581
)
623-
self.fit_time += time.monotonic() - t_update_start
624-
self.fit_time_since_gen += time.monotonic() - t_update_start
582+
self.fit_time += time.time() - t_update_start
583+
self.fit_time_since_gen += time.time() - t_update_start
625584

626585
def _update(
627586
self,

ax/modelbridge/map_torch.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def __init__(
6363
status_quo_features: Optional[ObservationFeatures] = None,
6464
optimization_config: Optional[OptimizationConfig] = None,
6565
fit_out_of_design: bool = False,
66-
fit_on_init: bool = True,
6766
default_model_gen_options: Optional[TConfig] = None,
6867
map_data_limit_rows_per_metric: Optional[int] = None,
6968
map_data_limit_rows_per_group: Optional[int] = None,
@@ -94,11 +93,6 @@ def __init__(
9493
the model.
9594
fit_out_of_design: If specified, all training data is returned.
9695
Otherwise, only in design points are returned.
97-
fit_on_init: Whether to fit the model on initialization. This can
98-
be used to skip model fitting when a fitted model is not needed.
99-
To fit the model afterwards, use `_process_and_transform_data`
100-
to get the transformed inputs and call `_fit_if_implemented` with
101-
the transformed inputs.
10296
default_model_gen_options: Options passed down to `model.gen(...)`.
10397
map_data_limit_rows_per_metric: Subsample the map data so that the
10498
total number of rows per metric is limited by this value.
@@ -129,7 +123,6 @@ def __init__(
129123
status_quo_features=status_quo_features,
130124
optimization_config=optimization_config,
131125
fit_out_of_design=fit_out_of_design,
132-
fit_on_init=fit_on_init,
133126
default_model_gen_options=default_model_gen_options,
134127
)
135128

ax/modelbridge/model_spec.py

Lines changed: 5 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,6 @@ class ModelSpec(Base):
6868
# stored cross validation diagnostics set in cross validate
6969
_diagnostics: Optional[CVDiagnostics] = None
7070

71-
# Stored to check if the model can be safely updated in fit.
72-
_last_fit_arg_ids: Optional[Dict[str, int]] = None
73-
7471
def __post_init__(self) -> None:
7572
self.model_kwargs = self.model_kwargs or {}
7673
self.model_gen_kwargs = self.model_gen_kwargs or {}
@@ -128,28 +125,11 @@ def fit(
128125
# adding contents of `model_kwargs` passed to this method, to
129126
# `self.model_kwargs`.
130127
combined_model_kwargs = {**(self.model_kwargs or {}), **model_kwargs}
131-
if self._fitted_model is not None and self._safe_to_update(
132-
experiment=experiment, combined_model_kwargs=combined_model_kwargs
133-
):
134-
# Update the data on the modelbridge and call `_fit`.
135-
# This will skip model fitting if the data has not changed.
136-
observations, search_space = self.fitted_model._process_and_transform_data(
137-
experiment=experiment, data=data
138-
)
139-
self.fitted_model._fit_if_implemented(
140-
search_space=search_space, observations=observations, time_so_far=0.0
141-
)
142-
143-
else:
144-
# Fit from scratch.
145-
self._fitted_model = self.model_enum(
146-
experiment=experiment,
147-
data=data,
148-
**combined_model_kwargs,
149-
)
150-
self._last_fit_arg_ids = self._get_fit_arg_ids(
151-
experiment=experiment, combined_model_kwargs=combined_model_kwargs
152-
)
128+
self._fitted_model = self.model_enum(
129+
experiment=experiment,
130+
data=data,
131+
**combined_model_kwargs,
132+
)
153133

154134
def cross_validate(
155135
self,
@@ -235,32 +215,6 @@ def copy(self) -> ModelSpec:
235215
model_cv_kwargs=deepcopy(self.model_cv_kwargs),
236216
)
237217

238-
def _safe_to_update(
239-
self,
240-
experiment: Experiment,
241-
combined_model_kwargs: Dict[str, Any],
242-
) -> bool:
243-
"""Checks if the object id of any of the non-data fit arguments has changed.
244-
245-
This is a cheap way of checking that we're attempting to re-fit the same
246-
model for the same experiment, which is a very reasonable expectation
247-
since this all happens on the same `ModelSpec` instance.
248-
"""
249-
return self._last_fit_arg_ids == self._get_fit_arg_ids(
250-
experiment=experiment, combined_model_kwargs=combined_model_kwargs
251-
)
252-
253-
def _get_fit_arg_ids(
254-
self,
255-
experiment: Experiment,
256-
combined_model_kwargs: Dict[str, Any],
257-
) -> Dict[str, int]:
258-
"""Construct a dictionary mapping arg name to object id."""
259-
return {
260-
"experiment": id(experiment),
261-
**{k: id(v) for k, v in combined_model_kwargs.items()},
262-
}
263-
264218
def _assert_fitted(self) -> None:
265219
"""Helper that verifies a model was fitted, raising an error if not"""
266220
if self._fitted_model is None:

ax/modelbridge/tests/test_base_modelbridge.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -258,15 +258,6 @@ def test_ModelBridge(
258258
modelbridge.transform_observation_features([get_observation2().features])
259259
mock_tr.assert_called_with(modelbridge, [get_observation2trans().features])
260260

261-
# Test that fit is not called when fit_on_init = False.
262-
mock_fit.reset_mock()
263-
modelbridge = ModelBridge(
264-
search_space=ss,
265-
model=Model(),
266-
fit_on_init=False,
267-
)
268-
self.assertEqual(mock_fit.call_count, 0)
269-
270261
@mock.patch(
271262
"ax.modelbridge.base.observations_from_data",
272263
autospec=True,

ax/modelbridge/tests/test_generation_strategy.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,7 @@ def setUp(self) -> None:
5959
f"{TorchModelBridge.__module__}.TorchModelBridge", spec=True
6060
)
6161
self.mock_torch_model_bridge = self.torch_model_bridge_patcher.start()
62-
mock_mb = self.mock_torch_model_bridge.return_value
63-
mock_mb.gen.return_value = self.gr
64-
mock_mb._process_and_transform_data.return_value = (None, None)
62+
self.mock_torch_model_bridge.return_value.gen.return_value = self.gr
6563

6664
# Mock out slow TS.
6765
self.discrete_model_bridge_patcher = patch(
@@ -304,7 +302,6 @@ def test_sobol_GPEI_strategy(self) -> None:
304302
"transforms": Cont_X_trans,
305303
"fit_out_of_design": False,
306304
"fit_abandoned": False,
307-
"fit_on_init": True,
308305
},
309306
)
310307
ms = g._model_state_after_gen

ax/modelbridge/tests/test_model_spec.py

Lines changed: 8 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,12 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import warnings
8-
from unittest import mock
9-
from unittest.mock import MagicMock, Mock, patch
8+
from unittest.mock import Mock, patch
109

1110
from ax.core.observation import ObservationFeatures
1211
from ax.exceptions.core import UserInputError
1312
from ax.modelbridge.factory import get_sobol
1413
from ax.modelbridge.model_spec import FactoryFunctionModelSpec, ModelSpec
15-
from ax.modelbridge.modelbridge_utils import extract_search_space_digest
1614
from ax.modelbridge.registry import Models
1715
from ax.utils.common.testutils import TestCase
1816
from ax.utils.testing.core_stubs import get_branin_experiment
@@ -41,49 +39,23 @@ def test_construct(self) -> None:
4139
with self.assertRaises(NotImplementedError):
4240
ms.update(experiment=self.experiment, new_data=self.data)
4341

44-
@fast_botorch_optimize
45-
# We can use `extract_search_space_digest` as a surrogate for executing
46-
# the full TorchModelBridge._fit.
47-
@mock.patch(
48-
"ax.modelbridge.torch.extract_search_space_digest",
49-
wraps=extract_search_space_digest,
50-
)
51-
def test_fit(self, wrapped_extract_ssd: Mock) -> None:
52-
ms = ModelSpec(model_enum=Models.GPEI)
53-
# This should fit the model as usual.
54-
ms.fit(experiment=self.experiment, data=self.data)
55-
wrapped_extract_ssd.assert_called_once()
56-
self.assertIsNotNone(ms._last_fit_arg_ids)
57-
self.assertEqual(ms._last_fit_arg_ids["experiment"], id(self.experiment))
58-
# This should skip the model fit.
59-
with mock.patch("ax.modelbridge.torch.logger") as mock_logger:
60-
ms.fit(experiment=self.experiment, data=self.data)
61-
mock_logger.info.assert_called_with(
62-
"The observations are identical to the last set of observations "
63-
"used to fit the model. Skipping model fitting."
64-
)
65-
wrapped_extract_ssd.assert_called_once()
66-
6742
def test_model_key(self) -> None:
6843
ms = ModelSpec(model_enum=Models.GPEI)
6944
self.assertEqual(ms.model_key, "GPEI")
7045

7146
@patch(f"{ModelSpec.__module__}.compute_diagnostics")
7247
@patch(f"{ModelSpec.__module__}.cross_validate", return_value=["fake-cv-result"])
73-
def test_cross_validate_with_GP_model(
74-
self, mock_cv: Mock, mock_diagnostics: Mock
75-
) -> None:
48+
# pyre-fixme[3]: Return type must be annotated.
49+
def test_cross_validate_with_GP_model(self, mock_cv: Mock, mock_diagnostics: Mock):
7650
mock_enum = Mock()
77-
fake_mb = MagicMock()
78-
fake_mb._process_and_transform_data = MagicMock(return_value=(None, None))
79-
mock_enum.return_value = fake_mb
51+
mock_enum.return_value = "fake-modelbridge"
8052
ms = ModelSpec(model_enum=mock_enum, model_cv_kwargs={"test_key": "test-value"})
8153
ms.fit(
8254
experiment=self.experiment,
8355
data=self.experiment.trials[0].fetch_data(),
8456
)
8557
cv_results, cv_diagnostics = ms.cross_validate()
86-
mock_cv.assert_called_with(model=fake_mb, test_key="test-value")
58+
mock_cv.assert_called_with(model="fake-modelbridge", test_key="test-value")
8759
mock_diagnostics.assert_called_with(["fake-cv-result"])
8860

8961
self.assertIsNotNone(cv_results)
@@ -112,14 +84,15 @@ def test_cross_validate_with_GP_model(
11284

11385
self.assertIsNotNone(cv_results)
11486
self.assertIsNotNone(cv_diagnostics)
115-
mock_cv.assert_called_with(model=fake_mb, test_key="test-value")
87+
mock_cv.assert_called_with(model="fake-modelbridge", test_key="test-value")
11688
mock_diagnostics.assert_called_with(["fake-cv-result"])
11789

11890
@patch(f"{ModelSpec.__module__}.compute_diagnostics")
11991
@patch(f"{ModelSpec.__module__}.cross_validate", side_effect=NotImplementedError)
92+
# pyre-fixme[3]: Return type must be annotated.
12093
def test_cross_validate_with_non_GP_model(
12194
self, mock_cv: Mock, mock_diagnostics: Mock
122-
) -> None:
95+
):
12396
mock_enum = Mock()
12497
mock_enum.return_value = "fake-modelbridge"
12598
ms = ModelSpec(model_enum=mock_enum, model_cv_kwargs={"test_key": "test-value"})

0 commit comments

Comments
 (0)