19
19
from ax .models .torch .botorch_modular .acquisition import Acquisition
20
20
from ax .models .torch .botorch_modular .model import BoTorchModel , SurrogateSpec
21
21
from ax .models .torch .botorch_modular .surrogate import Surrogate
22
- from ax .models .torch .botorch_modular .utils import choose_model_class
22
+ from ax .models .torch .botorch_modular .utils import (
23
+ choose_model_class ,
24
+ construct_acquisition_and_optimizer_options ,
25
+ )
23
26
from ax .models .torch .utils import _filter_X_observed
24
27
from ax .models .torch_base import TorchOptConfig
25
28
from ax .utils .common .constants import Keys
37
40
qNoisyExpectedHypervolumeImprovement ,
38
41
)
39
42
from botorch .acquisition .multi_objective .objective import WeightedMCMultiOutputObjective
43
+ from botorch .acquisition .objective import GenericMCObjective
40
44
from botorch .models .fully_bayesian import SaasFullyBayesianSingleTaskGP
41
45
from botorch .models .gp_regression import FixedNoiseGP , SingleTaskGP
42
46
from botorch .models .gp_regression_fidelity import FixedNoiseMultiFidelityGP
43
47
from botorch .models .model import ModelList
44
48
from botorch .sampling .normal import SobolQMCNormalSampler
49
+ from botorch .utils .constraints import get_outcome_constraint_transforms
45
50
from botorch .utils .datasets import SupervisedDataset
46
51
from gpytorch .mlls .exact_marginal_log_likelihood import ExactMarginalLogLikelihood
47
52
@@ -106,11 +111,18 @@ def setUp(self) -> None:
106
111
self .optimizer_options = {Keys .NUM_RESTARTS : 40 , Keys .RAW_SAMPLES : 1024 }
107
112
self .model_gen_options = {Keys .OPTIMIZER_KWARGS : self .optimizer_options }
108
113
self .objective_weights = torch .tensor ([1.0 ], ** tkwargs )
114
+ self .outcome_constraints = (
115
+ torch .tensor ([[1.0 ]], ** tkwargs ),
116
+ torch .tensor ([[- 5.0 ]], ** tkwargs ),
117
+ )
109
118
self .moo_objective_weights = torch .tensor ([1.0 , 1.5 , 0.0 ], ** tkwargs )
110
119
self .moo_objective_thresholds = torch .tensor (
111
120
[0.5 , 1.5 , float ("nan" )], ** tkwargs
112
121
)
113
- self .outcome_constraints = None
122
+ self .moo_outcome_constraints = (
123
+ torch .tensor ([[1.0 , 0.0 , 0.0 ]], ** tkwargs ),
124
+ torch .tensor ([[- 5.0 ]], ** tkwargs ),
125
+ )
114
126
self .linear_constraints = None
115
127
self .fixed_features = None
116
128
self .pending_observations = None
@@ -136,6 +148,7 @@ def setUp(self) -> None:
136
148
self .torch_opt_config ,
137
149
objective_weights = self .moo_objective_weights ,
138
150
objective_thresholds = self .moo_objective_thresholds ,
151
+ outcome_constraints = self .moo_outcome_constraints ,
139
152
)
140
153
141
154
def test_init (self ) -> None :
@@ -491,12 +504,9 @@ def test_cross_validate(self, mock_fit: Mock) -> None:
491
504
492
505
@mock .patch (
493
506
f"{ MODEL_PATH } .construct_acquisition_and_optimizer_options" ,
494
- return_value = (
495
- ACQ_OPTIONS ,
496
- {"num_restarts" : 40 , "raw_samples" : 1024 },
497
- ),
507
+ wraps = construct_acquisition_and_optimizer_options ,
498
508
)
499
- @mock .patch (f"{ CURRENT_PATH } .Acquisition" )
509
+ @mock .patch (f"{ CURRENT_PATH } .Acquisition.optimize " )
500
510
@mock .patch (f"{ MODEL_PATH } .get_rounding_func" , return_value = "func" )
501
511
@mock .patch (f"{ MODEL_PATH } ._to_inequality_constraints" , return_value = [])
502
512
@mock .patch (
@@ -507,10 +517,18 @@ def test_gen(
507
517
mock_choose_botorch_acqf_class : Mock ,
508
518
mock_inequality_constraints : Mock ,
509
519
mock_rounding : Mock ,
510
- mock_acquisition : Mock ,
520
+ mock_optimize : Mock ,
511
521
mock_construct_options : Mock ,
512
522
) -> None :
513
- mock_acquisition .return_value .optimize .return_value = (
523
+ qEI_input_constructor = get_acqf_input_constructor (qExpectedImprovement )
524
+ mock_input_constructor = mock .MagicMock (
525
+ qEI_input_constructor , side_effect = qEI_input_constructor
526
+ )
527
+ _register_acqf_input_constructor (
528
+ acqf_cls = qExpectedImprovement ,
529
+ input_constructor = mock_input_constructor ,
530
+ )
531
+ mock_optimize .return_value = (
514
532
torch .tensor ([1.0 ]),
515
533
torch .tensor ([2.0 ]),
516
534
)
@@ -534,10 +552,60 @@ def test_gen(
534
552
)
535
553
# Add search space digest reference to make the model think it's been fit
536
554
model ._search_space_digest = self .mf_search_space_digest
537
- model .gen (
538
- n = 1 ,
539
- search_space_digest = self .mf_search_space_digest ,
540
- torch_opt_config = self .torch_opt_config ,
555
+ with mock .patch .object (
556
+ BoTorchModel ,
557
+ "_instantiate_acquisition" ,
558
+ wraps = model ._instantiate_acquisition ,
559
+ ) as mock_init_acqf :
560
+ model .gen (
561
+ n = 1 ,
562
+ search_space_digest = self .mf_search_space_digest ,
563
+ torch_opt_config = self .torch_opt_config ,
564
+ )
565
+ # Assert acquisition initialized with expected arguments
566
+ mock_init_acqf .assert_called_once_with (
567
+ search_space_digest = self .mf_search_space_digest ,
568
+ torch_opt_config = self .torch_opt_config ,
569
+ acq_options = self .acquisition_options ,
570
+ )
571
+ ckwargs = mock_input_constructor .call_args [1 ]
572
+ mock_input_constructor .assert_called_once ()
573
+ m = ckwargs ["model" ]
574
+ self .assertIsInstance (m , SingleTaskGP )
575
+ self .assertEqual (m .num_outputs , 1 )
576
+ training_data = ckwargs ["training_data" ]
577
+ self .assertIsInstance (training_data , SupervisedDataset )
578
+ self .assertTrue (torch .equal (training_data .X (), self .Xs [0 ]))
579
+ self .assertTrue (
580
+ torch .equal (
581
+ training_data .Y (),
582
+ torch .cat ([ds .Y () for ds in self .block_design_training_data ], dim = - 1 ),
583
+ )
584
+ )
585
+ self .assertIsNotNone (ckwargs ["constraints" ])
586
+
587
+ self .assertIsNone (
588
+ ckwargs ["X_pending" ],
589
+ )
590
+ self .assertIsInstance (
591
+ ckwargs .get ("objective" ),
592
+ GenericMCObjective ,
593
+ )
594
+ expected_X_baseline = _filter_X_observed (
595
+ Xs = [dataset .X () for dataset in self .block_design_training_data ],
596
+ objective_weights = self .objective_weights ,
597
+ outcome_constraints = self .outcome_constraints ,
598
+ bounds = self .search_space_digest .bounds ,
599
+ linear_constraints = self .linear_constraints ,
600
+ fixed_features = self .fixed_features ,
601
+ )
602
+ self .assertTrue (
603
+ torch .equal (
604
+ ckwargs .get ("X_baseline" ),
605
+ # pyre-fixme[6]: For 2nd param expected `Tensor` but got
606
+ # `Optional[Tensor]`.
607
+ expected_X_baseline ,
608
+ )
541
609
)
542
610
543
611
# Assert `construct_acquisition_and_optimizer_options` called with kwargs
@@ -548,16 +616,9 @@ def test_gen(
548
616
# Assert `choose_botorch_acqf_class` is called
549
617
mock_choose_botorch_acqf_class .assert_called_once ()
550
618
self .assertEqual (model ._botorch_acqf_class , qExpectedImprovement )
551
- # Assert `acquisition_class` called with kwargs
552
- mock_acquisition .assert_called_with (
553
- surrogates = {Keys .ONLY_SURROGATE : self .surrogate },
554
- botorch_acqf_class = model .botorch_acqf_class ,
555
- search_space_digest = self .mf_search_space_digest ,
556
- torch_opt_config = self .torch_opt_config ,
557
- options = self .acquisition_options ,
558
- )
619
+
559
620
# Assert `optimize` called with kwargs
560
- mock_acquisition . return_value . optimize .assert_called_with (
621
+ mock_optimize .assert_called_with (
561
622
n = 1 ,
562
623
search_space_digest = self .mf_search_space_digest ,
563
624
inequality_constraints = [],
@@ -566,6 +627,11 @@ def test_gen(
566
627
optimizer_options = self .optimizer_options ,
567
628
)
568
629
630
+ _register_acqf_input_constructor (
631
+ acqf_cls = qExpectedImprovement ,
632
+ input_constructor = qEI_input_constructor ,
633
+ )
634
+
569
635
def test_feature_importances (self ) -> None :
570
636
for botorch_model_class in [SingleTaskGP , SaasFullyBayesianSingleTaskGP ]:
571
637
surrogate = Surrogate (botorch_model_class = botorch_model_class )
@@ -813,11 +879,29 @@ def test_MOO(self, _) -> None:
813
879
self .assertIsInstance (
814
880
model .surrogates [Keys .AUTOSET_SURROGATE ].model , FixedNoiseGP
815
881
)
816
- gen_results = model .gen (
817
- n = 1 ,
818
- search_space_digest = self .mf_search_space_digest ,
819
- torch_opt_config = self .moo_torch_opt_config ,
882
+ subset_outcome_constraints = (
883
+ # model is subset since last output is not used
884
+ self .moo_outcome_constraints [0 ][:, :2 ],
885
+ self .moo_outcome_constraints [1 ],
886
+ )
887
+ constraints = get_outcome_constraint_transforms (
888
+ outcome_constraints = subset_outcome_constraints ,
820
889
)
890
+ with mock .patch (
891
+ f"{ ACQUISITION_PATH } .get_outcome_constraint_transforms" ,
892
+ # Dummy candidates and acquisition function value.
893
+ return_value = constraints ,
894
+ ) as mock_get_outcome_constraint_transforms :
895
+ gen_results = model .gen (
896
+ n = 1 ,
897
+ search_space_digest = self .mf_search_space_digest ,
898
+ torch_opt_config = self .moo_torch_opt_config ,
899
+ )
900
+ mock_get_outcome_constraint_transforms .assert_called_once ()
901
+ ckwargs = mock_get_outcome_constraint_transforms .call_args [1 ]
902
+ oc = ckwargs ["outcome_constraints" ]
903
+ self .assertTrue (torch .equal (oc [0 ], subset_outcome_constraints [0 ]))
904
+ self .assertTrue (torch .equal (oc [1 ], subset_outcome_constraints [1 ]))
821
905
ckwargs = mock_input_constructor .call_args [1 ]
822
906
self .assertIs (model .botorch_acqf_class , qNoisyExpectedHypervolumeImprovement )
823
907
mock_input_constructor .assert_called_once ()
@@ -844,9 +928,8 @@ def test_MOO(self, _) -> None:
844
928
ckwargs ["objective_thresholds" ], self .moo_objective_thresholds [:2 ]
845
929
)
846
930
)
847
- self .assertIsNone (
848
- ckwargs ["outcome_constraints" ],
849
- )
931
+ self .assertIs (ckwargs ["constraints" ], constraints )
932
+
850
933
self .assertIsNone (
851
934
ckwargs ["X_pending" ],
852
935
)
@@ -860,21 +943,21 @@ def test_MOO(self, _) -> None:
860
943
)
861
944
self .assertTrue (
862
945
torch .equal (
863
- mock_input_constructor . call_args [ 1 ] .get ("objective" ).weights ,
946
+ ckwargs .get ("objective" ).weights ,
864
947
self .moo_objective_weights [:2 ],
865
948
)
866
949
)
867
950
expected_X_baseline = _filter_X_observed (
868
951
Xs = [dataset .X () for dataset in self .moo_training_data ],
869
952
objective_weights = self .moo_objective_weights ,
870
- outcome_constraints = self .outcome_constraints ,
953
+ outcome_constraints = self .moo_outcome_constraints ,
871
954
bounds = self .search_space_digest .bounds ,
872
955
linear_constraints = self .linear_constraints ,
873
956
fixed_features = self .fixed_features ,
874
957
)
875
958
self .assertTrue (
876
959
torch .equal (
877
- mock_input_constructor . call_args [ 1 ] .get ("X_baseline" ),
960
+ ckwargs .get ("X_baseline" ),
878
961
# pyre-fixme[6]: For 2nd param expected `Tensor` but got
879
962
# `Optional[Tensor]`.
880
963
expected_X_baseline ,
@@ -936,6 +1019,8 @@ def test_MOO(self, _) -> None:
936
1019
self .assertTrue (torch .equal (obj_t [:2 ], torch .tensor ([9.9 , 3.3 ])))
937
1020
self .assertTrue (np .isnan (obj_t [2 ].item ()))
938
1021
1022
+ # test outcome constraints
1023
+
939
1024
# Avoid polluting the registry for other tests; re-register correct input
940
1025
# contructor for qNEHVI.
941
1026
_register_acqf_input_constructor (
0 commit comments