Skip to content

Commit ed752e8

Browse files
committed
Improve validation w.r.t. batch_size and constraint type
1 parent 2c0e18e commit ed752e8

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

baybe/constraints/continuous.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ def to_botorch(
140140
141141
Raises:
142142
RuntimeError: When the constraint is an interpoint constraint but
143-
``batch_size`` is ``None``.
143+
``batch_size`` is ``None`` or when providing a value for ``batch_size``
144+
while the constraint is not an interpoint constraint.
144145
"""
145146
import torch
146147

@@ -151,6 +152,17 @@ def to_botorch(
151152
# See https://botorch.readthedocs.io/en/latest/optim.html, in particular the
152153
# docstring of ``optimize_acqf`` for details.
153154

155+
if batch_size is None and self.is_interpoint:
156+
raise RuntimeError(
157+
"No ``batch_size`` set but using interpoint constraints."
158+
"This should not happen and means that there is a bug in the code."
159+
)
160+
if batch_size is not None and not self.is_interpoint:
161+
raise RuntimeError(
162+
"A ``batch_size`` was set but the constraint is not interpoint."
163+
"This should not happen and means that there is a bug in the code."
164+
)
165+
154166
param_names = [p.name for p in parameters]
155167
coefficients: list[float]
156168
torch_indices: Tensor
@@ -162,11 +174,6 @@ def to_botorch(
162174
]
163175
coefficients = self.coefficients
164176
torch_indices = torch.tensor(param_indices)
165-
elif batch_size is None:
166-
raise RuntimeError(
167-
"No ``batch_size`` set but using interpoint constraints."
168-
"This should not happen and means that there is a bug in the code."
169-
)
170177
else:
171178
param_index_dict = {
172179
name: param_names.index(name) for name in self.parameters

0 commit comments

Comments
 (0)