Skip to content

Commit e79d582

Browse files
committed
Use non-sequential recommender in interpoint tests
1 parent bca9860 commit e79d582

File tree

2 files changed

+101
-6
lines changed

2 files changed

+101
-6
lines changed

tests/conftest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,19 @@ def fixture_recommender(initial_recommender, surrogate_model, acqf):
741741
)
742742

743743

744+
@pytest.fixture(name="non_sequential_recommender")
745+
def fixture_non_sequential_recommender(initial_recommender, surrogate_model, acqf):
746+
"""A recommender for non-sequential optimization."""
747+
return TwoPhaseMetaRecommender(
748+
initial_recommender=initial_recommender,
749+
recommender=BotorchRecommender(
750+
surrogate_model=surrogate_model,
751+
acquisition_function=acqf,
752+
sequential_continuous=False,
753+
),
754+
)
755+
756+
744757
@pytest.fixture(name="meta_recommender")
745758
def fixture_meta_recommender(
746759
request,

tests/constraints/test_constraints_continuous.py

Lines changed: 88 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import pytest
55
from pytest import param
66

7+
from baybe.campaign import Campaign
78
from baybe.constraints import ContinuousLinearConstraint
9+
from baybe.searchspace import SearchSpace
810
from tests.conftest import run_iterations
911

1012
TOLERANCE = 0.01
@@ -67,7 +69,21 @@ def test_inequality3(campaign, n_iterations, batch_size):
6769

6870
@pytest.mark.parametrize("parameter_names", [["Conti_finite1", "Conti_finite2"]])
6971
@pytest.mark.parametrize("constraint_names", [["InterConstraint_1"]])
70-
def test_interpoint_equality_single_parameter(campaign, n_iterations, batch_size):
72+
def test_interpoint_equality_single_parameter(
73+
non_sequential_recommender,
74+
parameters,
75+
constraints,
76+
objective,
77+
n_iterations,
78+
batch_size,
79+
):
80+
campaign = Campaign(
81+
searchspace=SearchSpace.from_product(
82+
parameters=parameters, constraints=constraints
83+
),
84+
recommender=non_sequential_recommender,
85+
objective=objective,
86+
)
7187
"""Test single parameter interpoint equality constraint."""
7288
run_iterations(campaign, n_iterations, batch_size, add_noise=False)
7389
res = campaign.measurements
@@ -79,8 +95,22 @@ def test_interpoint_equality_single_parameter(campaign, n_iterations, batch_size
7995

8096
@pytest.mark.parametrize("parameter_names", [["Conti_finite1", "Conti_finite2"]])
8197
@pytest.mark.parametrize("constraint_names", [["InterConstraint_2"]])
82-
def test_interpoint_inequality_single_parameter(campaign, n_iterations, batch_size):
98+
def test_interpoint_inequality_single_parameter(
99+
non_sequential_recommender,
100+
parameters,
101+
constraints,
102+
objective,
103+
n_iterations,
104+
batch_size,
105+
):
83106
"""Test single parameter interpoint inequality constraint."""
107+
campaign = Campaign(
108+
searchspace=SearchSpace.from_product(
109+
parameters=parameters, constraints=constraints
110+
),
111+
recommender=non_sequential_recommender,
112+
objective=objective,
113+
)
84114
run_iterations(campaign, n_iterations, batch_size, add_noise=False)
85115
res = campaign.measurements
86116

@@ -91,8 +121,22 @@ def test_interpoint_inequality_single_parameter(campaign, n_iterations, batch_si
91121

92122
@pytest.mark.parametrize("parameter_names", [["Conti_finite1", "Conti_finite2"]])
93123
@pytest.mark.parametrize("constraint_names", [["InterConstraint_3"]])
94-
def test_interpoint_equality_multiple_parameters(campaign, n_iterations, batch_size):
124+
def test_interpoint_equality_multiple_parameters(
125+
non_sequential_recommender,
126+
parameters,
127+
constraints,
128+
objective,
129+
n_iterations,
130+
batch_size,
131+
):
95132
"""Test interpoint equality constraint involving multiple parameters."""
133+
campaign = Campaign(
134+
searchspace=SearchSpace.from_product(
135+
parameters=parameters, constraints=constraints
136+
),
137+
recommender=non_sequential_recommender,
138+
objective=objective,
139+
)
96140
run_iterations(campaign, n_iterations, batch_size, add_noise=False)
97141
res = campaign.measurements
98142

@@ -106,9 +150,21 @@ def test_interpoint_equality_multiple_parameters(campaign, n_iterations, batch_s
106150
@pytest.mark.parametrize("parameter_names", [["Conti_finite1", "Conti_finite2"]])
107151
@pytest.mark.parametrize("constraint_names", [["InterConstraint_4"]])
108152
def test_geq_interpoint_inequality_multiple_parameters(
109-
campaign, n_iterations, batch_size
153+
non_sequential_recommender,
154+
parameters,
155+
constraints,
156+
objective,
157+
n_iterations,
158+
batch_size,
110159
):
111160
"""Test geq-interpoint inequality constraint involving multiple parameters."""
161+
campaign = Campaign(
162+
searchspace=SearchSpace.from_product(
163+
parameters=parameters, constraints=constraints
164+
),
165+
recommender=non_sequential_recommender,
166+
objective=objective,
167+
)
112168
run_iterations(campaign, n_iterations, batch_size, add_noise=False)
113169
res = campaign.measurements
114170

@@ -123,9 +179,21 @@ def test_geq_interpoint_inequality_multiple_parameters(
123179
@pytest.mark.parametrize("parameter_names", [["Conti_finite1", "Conti_finite2"]])
124180
@pytest.mark.parametrize("constraint_names", [["InterConstraint_5"]])
125181
def test_leq_interpoint_inequality_multiple_parameters(
126-
campaign, n_iterations, batch_size
182+
non_sequential_recommender,
183+
parameters,
184+
constraints,
185+
objective,
186+
n_iterations,
187+
batch_size,
127188
):
128189
"""Test leq-interpoint inequality constraint involving multiple parameters."""
190+
campaign = Campaign(
191+
searchspace=SearchSpace.from_product(
192+
parameters=parameters, constraints=constraints
193+
),
194+
recommender=non_sequential_recommender,
195+
objective=objective,
196+
)
129197
run_iterations(campaign, n_iterations, batch_size, add_noise=False)
130198
res = campaign.measurements
131199

@@ -140,8 +208,22 @@ def test_leq_interpoint_inequality_multiple_parameters(
140208
@pytest.mark.parametrize(
141209
"constraint_names", [["ContiConstraint_4", "InterConstraint_2"]]
142210
)
143-
def test_interpoint_normal_mix(campaign, n_iterations, batch_size):
211+
def test_interpoint_normal_mix(
212+
non_sequential_recommender,
213+
parameters,
214+
constraints,
215+
objective,
216+
n_iterations,
217+
batch_size,
218+
):
144219
"""Test mixing interpoint and normal inequality constraints."""
220+
campaign = Campaign(
221+
searchspace=SearchSpace.from_product(
222+
parameters=parameters, constraints=constraints
223+
),
224+
recommender=non_sequential_recommender,
225+
objective=objective,
226+
)
145227
run_iterations(campaign, n_iterations, batch_size, add_noise=False)
146228
res = campaign.measurements
147229

0 commit comments

Comments
 (0)