Skip to content

Commit b6a0a5c

Browse files
sdaultonfacebook-github-bot
authored andcommitted
fix constraint handling in single objective MBM (#1771)
Summary: Pull Request resolved: #1771 X-link: meta-pytorch/botorch#1973 Currently, constraints are not used in single objective AFs in MBM due to a name mismatch between `outcome_constraints` and `constraints`. Reviewed By: SebastianAment Differential Revision: D48176978 fbshipit-source-id: 9495708002c11a874bb6b8c06327f0f4643039df
1 parent cdde361 commit b6a0a5c

File tree

3 files changed

+163
-58
lines changed

3 files changed

+163
-58
lines changed

ax/models/torch/botorch_modular/acquisition.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
optimize_acqf_discrete_local_search,
4949
optimize_acqf_mixed,
5050
)
51+
from botorch.utils.constraints import get_outcome_constraint_transforms
5152
from torch import Tensor
5253

5354

@@ -277,7 +278,9 @@ def __init__(
277278
"X_baseline": unique_Xs_observed,
278279
"X_pending": unique_Xs_pending,
279280
"objective_thresholds": objective_thresholds,
280-
"outcome_constraints": outcome_constraints,
281+
"constraints": get_outcome_constraint_transforms(
282+
outcome_constraints=outcome_constraints
283+
),
281284
"target_fidelities": search_space_digest.target_fidelities,
282285
"bounds": search_space_digest.bounds,
283286
**acqf_model_kwarg,

ax/models/torch/tests/test_acquisition.py

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from botorch.acquisition.objective import LinearMCObjective
4040
from botorch.models.gp_regression import SingleTaskGP
41+
from botorch.utils.constraints import get_outcome_constraint_transforms
4142
from botorch.utils.datasets import SupervisedDataset
4243
from botorch.utils.testing import MockPosterior
4344
from torch import Tensor
@@ -118,6 +119,9 @@ def setUp(self) -> None:
118119
torch.tensor([[1.0]], **tkwargs),
119120
torch.tensor([[0.5]], **tkwargs),
120121
)
122+
self.constraints = get_outcome_constraint_transforms(
123+
outcome_constraints=self.outcome_constraints
124+
)
121125
self.linear_constraints = None
122126
self.fixed_features = {1: 2.0}
123127
self.options = {"best_f": 0.0, "cache_root": False, "prune_baseline": False}
@@ -225,31 +229,44 @@ def test_init(
225229
self.mock_input_constructor.reset_mock()
226230
mock_botorch_acqf_class.reset_mock()
227231
self.options[Keys.SUBSET_MODEL] = False
228-
acquisition = Acquisition(
229-
surrogates={"surrogate": self.surrogate},
230-
search_space_digest=self.search_space_digest,
231-
torch_opt_config=self.torch_opt_config,
232-
botorch_acqf_class=self.botorch_acqf_class,
233-
options=self.options,
234-
)
235-
mock_subset_model.assert_not_called()
236-
# Check `get_botorch_objective_and_transform` kwargs
237-
mock_get_objective_and_transform.assert_called_once()
238-
_, ckwargs = mock_get_objective_and_transform.call_args
239-
self.assertIs(ckwargs["model"], acquisition.surrogates["surrogate"].model)
240-
self.assertIs(ckwargs["objective_weights"], self.objective_weights)
241-
self.assertIs(ckwargs["outcome_constraints"], self.outcome_constraints)
242-
self.assertTrue(torch.equal(ckwargs["X_observed"], self.X[:1]))
243-
# Check final `acqf` creation
244-
model_deps = {Keys.CURRENT_VALUE: 1.2}
245-
self.mock_input_constructor.assert_called_once()
246-
mock_botorch_acqf_class.assert_called_once()
247-
_, ckwargs = self.mock_input_constructor.call_args
248-
self.assertIs(ckwargs["model"], acquisition.surrogates["surrogate"].model)
249-
self.assertIs(ckwargs["objective"], botorch_objective)
250-
self.assertTrue(torch.equal(ckwargs["X_pending"], self.pending_observations[0]))
251-
for k, v in chain(self.options.items(), model_deps.items()):
252-
self.assertEqual(ckwargs[k], v)
232+
with mock.patch(
233+
f"{ACQUISITION_PATH}.get_outcome_constraint_transforms",
234+
return_value=self.constraints,
235+
) as mock_get_outcome_constraint_transforms:
236+
acquisition = Acquisition(
237+
surrogates={"surrogate": self.surrogate},
238+
search_space_digest=self.search_space_digest,
239+
torch_opt_config=self.torch_opt_config,
240+
botorch_acqf_class=self.botorch_acqf_class,
241+
options=self.options,
242+
)
243+
mock_subset_model.assert_not_called()
244+
# Check `get_botorch_objective_and_transform` kwargs
245+
mock_get_objective_and_transform.assert_called_once()
246+
_, ckwargs = mock_get_objective_and_transform.call_args
247+
self.assertIs(ckwargs["model"], acquisition.surrogates["surrogate"].model)
248+
self.assertIs(ckwargs["objective_weights"], self.objective_weights)
249+
self.assertIs(ckwargs["outcome_constraints"], self.outcome_constraints)
250+
self.assertTrue(torch.equal(ckwargs["X_observed"], self.X[:1]))
251+
# Check final `acqf` creation
252+
model_deps = {Keys.CURRENT_VALUE: 1.2}
253+
self.mock_input_constructor.assert_called_once()
254+
mock_botorch_acqf_class.assert_called_once()
255+
_, ckwargs = self.mock_input_constructor.call_args
256+
self.assertIs(ckwargs["model"], acquisition.surrogates["surrogate"].model)
257+
self.assertIs(ckwargs["objective"], botorch_objective)
258+
self.assertTrue(
259+
torch.equal(ckwargs["X_pending"], self.pending_observations[0])
260+
)
261+
for k, v in chain(self.options.items(), model_deps.items()):
262+
self.assertEqual(ckwargs[k], v)
263+
self.assertIs(
264+
ckwargs["constraints"],
265+
self.constraints,
266+
)
267+
mock_get_outcome_constraint_transforms.assert_called_once_with(
268+
outcome_constraints=self.outcome_constraints
269+
)
253270

254271
@mock.patch(f"{ACQUISITION_PATH}.optimize_acqf")
255272
def test_optimize(self, mock_optimize_acqf: Mock) -> None:

ax/models/torch/tests/test_model.py

Lines changed: 117 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
from ax.models.torch.botorch_modular.acquisition import Acquisition
2020
from ax.models.torch.botorch_modular.model import BoTorchModel, SurrogateSpec
2121
from ax.models.torch.botorch_modular.surrogate import Surrogate
22-
from ax.models.torch.botorch_modular.utils import choose_model_class
22+
from ax.models.torch.botorch_modular.utils import (
23+
choose_model_class,
24+
construct_acquisition_and_optimizer_options,
25+
)
2326
from ax.models.torch.utils import _filter_X_observed
2427
from ax.models.torch_base import TorchOptConfig
2528
from ax.utils.common.constants import Keys
@@ -37,11 +40,13 @@
3740
qNoisyExpectedHypervolumeImprovement,
3841
)
3942
from botorch.acquisition.multi_objective.objective import WeightedMCMultiOutputObjective
43+
from botorch.acquisition.objective import GenericMCObjective
4044
from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP
4145
from botorch.models.gp_regression import FixedNoiseGP, SingleTaskGP
4246
from botorch.models.gp_regression_fidelity import FixedNoiseMultiFidelityGP
4347
from botorch.models.model import ModelList
4448
from botorch.sampling.normal import SobolQMCNormalSampler
49+
from botorch.utils.constraints import get_outcome_constraint_transforms
4550
from botorch.utils.datasets import SupervisedDataset
4651
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
4752

@@ -106,11 +111,18 @@ def setUp(self) -> None:
106111
self.optimizer_options = {Keys.NUM_RESTARTS: 40, Keys.RAW_SAMPLES: 1024}
107112
self.model_gen_options = {Keys.OPTIMIZER_KWARGS: self.optimizer_options}
108113
self.objective_weights = torch.tensor([1.0], **tkwargs)
114+
self.outcome_constraints = (
115+
torch.tensor([[1.0]], **tkwargs),
116+
torch.tensor([[-5.0]], **tkwargs),
117+
)
109118
self.moo_objective_weights = torch.tensor([1.0, 1.5, 0.0], **tkwargs)
110119
self.moo_objective_thresholds = torch.tensor(
111120
[0.5, 1.5, float("nan")], **tkwargs
112121
)
113-
self.outcome_constraints = None
122+
self.moo_outcome_constraints = (
123+
torch.tensor([[1.0, 0.0, 0.0]], **tkwargs),
124+
torch.tensor([[-5.0]], **tkwargs),
125+
)
114126
self.linear_constraints = None
115127
self.fixed_features = None
116128
self.pending_observations = None
@@ -136,6 +148,7 @@ def setUp(self) -> None:
136148
self.torch_opt_config,
137149
objective_weights=self.moo_objective_weights,
138150
objective_thresholds=self.moo_objective_thresholds,
151+
outcome_constraints=self.moo_outcome_constraints,
139152
)
140153

141154
def test_init(self) -> None:
@@ -491,12 +504,9 @@ def test_cross_validate(self, mock_fit: Mock) -> None:
491504

492505
@mock.patch(
493506
f"{MODEL_PATH}.construct_acquisition_and_optimizer_options",
494-
return_value=(
495-
ACQ_OPTIONS,
496-
{"num_restarts": 40, "raw_samples": 1024},
497-
),
507+
wraps=construct_acquisition_and_optimizer_options,
498508
)
499-
@mock.patch(f"{CURRENT_PATH}.Acquisition")
509+
@mock.patch(f"{CURRENT_PATH}.Acquisition.optimize")
500510
@mock.patch(f"{MODEL_PATH}.get_rounding_func", return_value="func")
501511
@mock.patch(f"{MODEL_PATH}._to_inequality_constraints", return_value=[])
502512
@mock.patch(
@@ -507,10 +517,18 @@ def test_gen(
507517
mock_choose_botorch_acqf_class: Mock,
508518
mock_inequality_constraints: Mock,
509519
mock_rounding: Mock,
510-
mock_acquisition: Mock,
520+
mock_optimize: Mock,
511521
mock_construct_options: Mock,
512522
) -> None:
513-
mock_acquisition.return_value.optimize.return_value = (
523+
qEI_input_constructor = get_acqf_input_constructor(qExpectedImprovement)
524+
mock_input_constructor = mock.MagicMock(
525+
qEI_input_constructor, side_effect=qEI_input_constructor
526+
)
527+
_register_acqf_input_constructor(
528+
acqf_cls=qExpectedImprovement,
529+
input_constructor=mock_input_constructor,
530+
)
531+
mock_optimize.return_value = (
514532
torch.tensor([1.0]),
515533
torch.tensor([2.0]),
516534
)
@@ -534,10 +552,60 @@ def test_gen(
534552
)
535553
# Add search space digest reference to make the model think it's been fit
536554
model._search_space_digest = self.mf_search_space_digest
537-
model.gen(
538-
n=1,
539-
search_space_digest=self.mf_search_space_digest,
540-
torch_opt_config=self.torch_opt_config,
555+
with mock.patch.object(
556+
BoTorchModel,
557+
"_instantiate_acquisition",
558+
wraps=model._instantiate_acquisition,
559+
) as mock_init_acqf:
560+
model.gen(
561+
n=1,
562+
search_space_digest=self.mf_search_space_digest,
563+
torch_opt_config=self.torch_opt_config,
564+
)
565+
# Assert acquisition initialized with expected arguments
566+
mock_init_acqf.assert_called_once_with(
567+
search_space_digest=self.mf_search_space_digest,
568+
torch_opt_config=self.torch_opt_config,
569+
acq_options=self.acquisition_options,
570+
)
571+
ckwargs = mock_input_constructor.call_args[1]
572+
mock_input_constructor.assert_called_once()
573+
m = ckwargs["model"]
574+
self.assertIsInstance(m, SingleTaskGP)
575+
self.assertEqual(m.num_outputs, 1)
576+
training_data = ckwargs["training_data"]
577+
self.assertIsInstance(training_data, SupervisedDataset)
578+
self.assertTrue(torch.equal(training_data.X(), self.Xs[0]))
579+
self.assertTrue(
580+
torch.equal(
581+
training_data.Y(),
582+
torch.cat([ds.Y() for ds in self.block_design_training_data], dim=-1),
583+
)
584+
)
585+
self.assertIsNotNone(ckwargs["constraints"])
586+
587+
self.assertIsNone(
588+
ckwargs["X_pending"],
589+
)
590+
self.assertIsInstance(
591+
ckwargs.get("objective"),
592+
GenericMCObjective,
593+
)
594+
expected_X_baseline = _filter_X_observed(
595+
Xs=[dataset.X() for dataset in self.block_design_training_data],
596+
objective_weights=self.objective_weights,
597+
outcome_constraints=self.outcome_constraints,
598+
bounds=self.search_space_digest.bounds,
599+
linear_constraints=self.linear_constraints,
600+
fixed_features=self.fixed_features,
601+
)
602+
self.assertTrue(
603+
torch.equal(
604+
ckwargs.get("X_baseline"),
605+
# pyre-fixme[6]: For 2nd param expected `Tensor` but got
606+
# `Optional[Tensor]`.
607+
expected_X_baseline,
608+
)
541609
)
542610

543611
# Assert `construct_acquisition_and_optimizer_options` called with kwargs
@@ -548,16 +616,9 @@ def test_gen(
548616
# Assert `choose_botorch_acqf_class` is called
549617
mock_choose_botorch_acqf_class.assert_called_once()
550618
self.assertEqual(model._botorch_acqf_class, qExpectedImprovement)
551-
# Assert `acquisition_class` called with kwargs
552-
mock_acquisition.assert_called_with(
553-
surrogates={Keys.ONLY_SURROGATE: self.surrogate},
554-
botorch_acqf_class=model.botorch_acqf_class,
555-
search_space_digest=self.mf_search_space_digest,
556-
torch_opt_config=self.torch_opt_config,
557-
options=self.acquisition_options,
558-
)
619+
559620
# Assert `optimize` called with kwargs
560-
mock_acquisition.return_value.optimize.assert_called_with(
621+
mock_optimize.assert_called_with(
561622
n=1,
562623
search_space_digest=self.mf_search_space_digest,
563624
inequality_constraints=[],
@@ -566,6 +627,11 @@ def test_gen(
566627
optimizer_options=self.optimizer_options,
567628
)
568629

630+
_register_acqf_input_constructor(
631+
acqf_cls=qExpectedImprovement,
632+
input_constructor=qEI_input_constructor,
633+
)
634+
569635
def test_feature_importances(self) -> None:
570636
for botorch_model_class in [SingleTaskGP, SaasFullyBayesianSingleTaskGP]:
571637
surrogate = Surrogate(botorch_model_class=botorch_model_class)
@@ -813,11 +879,29 @@ def test_MOO(self, _) -> None:
813879
self.assertIsInstance(
814880
model.surrogates[Keys.AUTOSET_SURROGATE].model, FixedNoiseGP
815881
)
816-
gen_results = model.gen(
817-
n=1,
818-
search_space_digest=self.mf_search_space_digest,
819-
torch_opt_config=self.moo_torch_opt_config,
882+
subset_outcome_constraints = (
883+
# model is subset since last output is not used
884+
self.moo_outcome_constraints[0][:, :2],
885+
self.moo_outcome_constraints[1],
886+
)
887+
constraints = get_outcome_constraint_transforms(
888+
outcome_constraints=subset_outcome_constraints,
820889
)
890+
with mock.patch(
891+
f"{ACQUISITION_PATH}.get_outcome_constraint_transforms",
892+
# Dummy candidates and acquisition function value.
893+
return_value=constraints,
894+
) as mock_get_outcome_constraint_transforms:
895+
gen_results = model.gen(
896+
n=1,
897+
search_space_digest=self.mf_search_space_digest,
898+
torch_opt_config=self.moo_torch_opt_config,
899+
)
900+
mock_get_outcome_constraint_transforms.assert_called_once()
901+
ckwargs = mock_get_outcome_constraint_transforms.call_args[1]
902+
oc = ckwargs["outcome_constraints"]
903+
self.assertTrue(torch.equal(oc[0], subset_outcome_constraints[0]))
904+
self.assertTrue(torch.equal(oc[1], subset_outcome_constraints[1]))
821905
ckwargs = mock_input_constructor.call_args[1]
822906
self.assertIs(model.botorch_acqf_class, qNoisyExpectedHypervolumeImprovement)
823907
mock_input_constructor.assert_called_once()
@@ -844,9 +928,8 @@ def test_MOO(self, _) -> None:
844928
ckwargs["objective_thresholds"], self.moo_objective_thresholds[:2]
845929
)
846930
)
847-
self.assertIsNone(
848-
ckwargs["outcome_constraints"],
849-
)
931+
self.assertIs(ckwargs["constraints"], constraints)
932+
850933
self.assertIsNone(
851934
ckwargs["X_pending"],
852935
)
@@ -860,21 +943,21 @@ def test_MOO(self, _) -> None:
860943
)
861944
self.assertTrue(
862945
torch.equal(
863-
mock_input_constructor.call_args[1].get("objective").weights,
946+
ckwargs.get("objective").weights,
864947
self.moo_objective_weights[:2],
865948
)
866949
)
867950
expected_X_baseline = _filter_X_observed(
868951
Xs=[dataset.X() for dataset in self.moo_training_data],
869952
objective_weights=self.moo_objective_weights,
870-
outcome_constraints=self.outcome_constraints,
953+
outcome_constraints=self.moo_outcome_constraints,
871954
bounds=self.search_space_digest.bounds,
872955
linear_constraints=self.linear_constraints,
873956
fixed_features=self.fixed_features,
874957
)
875958
self.assertTrue(
876959
torch.equal(
877-
mock_input_constructor.call_args[1].get("X_baseline"),
960+
ckwargs.get("X_baseline"),
878961
# pyre-fixme[6]: For 2nd param expected `Tensor` but got
879962
# `Optional[Tensor]`.
880963
expected_X_baseline,
@@ -936,6 +1019,8 @@ def test_MOO(self, _) -> None:
9361019
self.assertTrue(torch.equal(obj_t[:2], torch.tensor([9.9, 3.3])))
9371020
self.assertTrue(np.isnan(obj_t[2].item()))
9381021

1022+
# test outcome constraints
1023+
9391024
# Avoid polluting the registry for other tests; re-register correct input
9401025
# contructor for qNEHVI.
9411026
_register_acqf_input_constructor(

0 commit comments

Comments
 (0)