Skip to content

Commit 10f6342

Browse files
AVHoppfacebook-github-bot
authored andcommitted
Fix hard-coded double precision in test_functions to default dtype (#2597)
Summary: ## Motivation This PR replaces the hard-coded double precision that was used in the initialization of `test_functions/base.py` to use `torch.get_default_dtype()` instead. See #2596 for more details. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes, I have read it. Pull Request resolved: #2597 Test Plan: I ran code formatting via `ufmt` and checked the code via `pytest -ra`. All tests related to the changes here were adjusted in the second commit of the branch. Locally, two tests failed for me, but it seems to me that these are not related to the fix implemented here. If it turns out they are, I'd be more than happy to further adjust. ## Related PRs None, but #2596 is related. Reviewed By: saitcakmak Differential Revision: D65066231 Pulled By: Balandat fbshipit-source-id: 4beac1fc9a1e5094fd4806958ac2441a12506eb7
1 parent 9e44749 commit 10f6342

File tree

6 files changed

+117
-36
lines changed

6 files changed

+117
-36
lines changed

botorch/test_functions/base.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(
2929
self,
3030
noise_std: None | float | list[float] = None,
3131
negate: bool = False,
32+
dtype: torch.dtype = torch.double,
3233
) -> None:
3334
r"""Base constructor for test functions.
3435
@@ -37,6 +38,7 @@ def __init__(
3738
provided, specifies separate noise standard deviations for each
3839
objective in a multiobjective problem.
3940
negate: If True, negate the function.
41+
dtype: The dtype that is used for the bounds of the function.
4042
"""
4143
super().__init__()
4244
self.noise_std = noise_std
@@ -47,7 +49,8 @@ def __init__(
4749
f"Got {self.dim=} and {len(self._bounds)=}."
4850
)
4951
self.register_buffer(
50-
"bounds", torch.tensor(self._bounds, dtype=torch.double).transpose(-1, -2)
52+
"bounds",
53+
torch.tensor(self._bounds, dtype=dtype).transpose(-1, -2),
5154
)
5255

5356
def forward(self, X: Tensor, noise: bool = True) -> Tensor:
@@ -166,6 +169,7 @@ def __init__(
166169
self,
167170
noise_std: None | float | list[float] = None,
168171
negate: bool = False,
172+
dtype: torch.dtype = torch.double,
169173
) -> None:
170174
r"""Base constructor for multi-objective test functions.
171175
@@ -180,8 +184,8 @@ def __init__(
180184
f"If specified as a list, length of noise_std ({len(noise_std)}) "
181185
f"must match the number of objectives ({len(self._ref_point)})"
182186
)
183-
super().__init__(noise_std=noise_std, negate=negate)
184-
ref_point = torch.tensor(self._ref_point, dtype=torch.get_default_dtype())
187+
super().__init__(noise_std=noise_std, negate=negate, dtype=dtype)
188+
ref_point = torch.tensor(self._ref_point, dtype=dtype)
185189
if negate:
186190
ref_point *= -1
187191
self.register_buffer("ref_point", ref_point)

botorch/test_functions/multi_fidelity.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,19 @@ class AugmentedHartmann(SyntheticTestFunction):
7474
_optimizers = [(0.20169, 0.150011, 0.476874, 0.275332, 0.311652, 0.6573, 1.0)]
7575
_check_grad_at_opt = False
7676

77-
def __init__(self, noise_std: float | None = None, negate: bool = False) -> None:
77+
def __init__(
78+
self,
79+
noise_std: float | None = None,
80+
negate: bool = False,
81+
dtype: torch.dtype = torch.double,
82+
) -> None:
7883
r"""
7984
Args:
8085
noise_std: Standard deviation of the observation noise.
8186
negate: If True, negate the function.
87+
dtype: The dtype that is used for the bounds of the function.
8288
"""
83-
super().__init__(noise_std=noise_std, negate=negate)
89+
super().__init__(noise_std=noise_std, negate=negate, dtype=dtype)
8490
self.register_buffer("ALPHA", torch.tensor([1.0, 1.2, 3.0, 3.2]))
8591
A = [
8692
[10, 3, 17, 3.5, 1.7, 8],
@@ -126,13 +132,18 @@ class AugmentedRosenbrock(SyntheticTestFunction):
126132
_optimal_value = 0.0
127133

128134
def __init__(
129-
self, dim=3, noise_std: float | None = None, negate: bool = False
135+
self,
136+
dim=3,
137+
noise_std: float | None = None,
138+
negate: bool = False,
139+
dtype: torch.dtype = torch.double,
130140
) -> None:
131141
r"""
132142
Args:
133143
dim: The (input) dimension. Must be at least 3.
134144
noise_std: Standard deviation of the observation noise.
135145
negate: If True, negate the function.
146+
dtype: The dtype that is used for the bounds of the function.
136147
"""
137148
if dim < 3:
138149
raise ValueError(
@@ -141,7 +152,7 @@ def __init__(
141152
self.dim = dim
142153
self._bounds = [(-5.0, 10.0) for _ in range(self.dim)]
143154
self._optimizers = [tuple(1.0 for _ in range(self.dim))]
144-
super().__init__(noise_std=noise_std, negate=negate)
155+
super().__init__(noise_std=noise_std, negate=negate, dtype=dtype)
145156

146157
def evaluate_true(self, X: Tensor) -> Tensor:
147158
X_curr = X[..., :-3]

botorch/test_functions/multi_objective.py

+21-7
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,15 @@ def __init__(
119119
self,
120120
noise_std: None | float | list[float] = None,
121121
negate: bool = False,
122+
dtype: torch.dtype = torch.double,
122123
) -> None:
123124
r"""
124125
Args:
125126
noise_std: Standard deviation of the observation noise.
126127
negate: If True, negate the objectives.
128+
dtype: The dtype that is used for the bounds of the function.
127129
"""
128-
super().__init__(noise_std=noise_std, negate=negate)
130+
super().__init__(noise_std=noise_std, negate=negate, dtype=dtype)
129131
self._branin = Branin()
130132

131133
def _rescaled_branin(self, X: Tensor) -> Tensor:
@@ -179,12 +181,14 @@ def __init__(
179181
dim: int,
180182
noise_std: None | float | list[float] = None,
181183
negate: bool = False,
184+
dtype: torch.dtype = torch.double,
182185
) -> None:
183186
r"""
184187
Args:
185188
dim: The (input) dimension.
186189
noise_std: Standard deviation of the observation noise.
187190
negate: If True, negate the function.
191+
dtype: The dtype that is used for the bounds of the function.
188192
"""
189193
if dim < self._min_dim:
190194
raise ValueError(f"dim must be >= {self._min_dim}, but got dim={dim}!")
@@ -194,7 +198,7 @@ def __init__(
194198
]
195199
# max_hv is the area of the box minus the area of the curve formed by the PF.
196200
self._max_hv = self._ref_point[0] * self._ref_point[1] - self._area_under_curve
197-
super().__init__(noise_std=noise_std, negate=negate)
201+
super().__init__(noise_std=noise_std, negate=negate, dtype=dtype)
198202

199203
@abstractmethod
200204
def _h(self, X: Tensor) -> Tensor:
@@ -339,13 +343,15 @@ def __init__(
339343
num_objectives: int = 2,
340344
noise_std: None | float | list[float] = None,
341345
negate: bool = False,
346+
dtype: torch.dtype = torch.double,
342347
) -> None:
343348
r"""
344349
Args:
345350
dim: The (input) dimension of the function.
346351
num_objectives: Must be less than dim.
347352
noise_std: Standard deviation of the observation noise.
348353
negate: If True, negate the function.
354+
dtype: The dtype that is used for the bounds of the function.
349355
"""
350356
if dim <= num_objectives:
351357
raise ValueError(
@@ -356,7 +362,7 @@ def __init__(
356362
self.k = self.dim - self.num_objectives + 1
357363
self._bounds = [(0.0, 1.0) for _ in range(self.dim)]
358364
self._ref_point = [self._ref_val for _ in range(num_objectives)]
359-
super().__init__(noise_std=noise_std, negate=negate)
365+
super().__init__(noise_std=noise_std, negate=negate, dtype=dtype)
360366

361367

362368
class DTLZ1(DTLZ):
@@ -608,12 +614,14 @@ def __init__(
608614
noise_std: None | float | list[float] = None,
609615
negate: bool = False,
610616
num_objectives: int = 2,
617+
dtype: torch.dtype = torch.double,
611618
) -> None:
612619
r"""
613620
Args:
614621
noise_std: Standard deviation of the observation noise.
615622
negate: If True, negate the objectives.
616623
num_objectives: The number of objectives.
624+
dtype: The dtype that is used for the bounds of the function.
617625
"""
618626
if num_objectives not in (2, 3, 4):
619627
raise UnsupportedError("GMM only currently supports 2 to 4 objectives.")
@@ -623,7 +631,7 @@ def __init__(
623631
if num_objectives > 3:
624632
self._ref_point.append(-0.1866)
625633
self.num_objectives = num_objectives
626-
super().__init__(noise_std=noise_std, negate=negate)
634+
super().__init__(noise_std=noise_std, negate=negate, dtype=dtype)
627635
gmm_pos = torch.tensor(
628636
[
629637
[[0.2, 0.2], [0.8, 0.2], [0.5, 0.7]],
@@ -935,13 +943,15 @@ def __init__(
935943
num_objectives: int = 2,
936944
noise_std: None | float | list[float] = None,
937945
negate: bool = False,
946+
dtype: torch.dtype = torch.double,
938947
) -> None:
939948
r"""
940949
Args:
941950
dim: The (input) dimension of the function.
942951
num_objectives: Number of objectives. Must not be larger than dim.
943952
noise_std: Standard deviation of the observation noise.
944953
negate: If True, negate the function.
954+
dtype: The dtype that is used for the bounds of the function.
945955
"""
946956
if num_objectives != 2:
947957
raise NotImplementedError(
@@ -954,7 +964,7 @@ def __init__(
954964
self.num_objectives = num_objectives
955965
self.dim = dim
956966
self._bounds = [(0.0, 1.0) for _ in range(self.dim)]
957-
super().__init__(noise_std=noise_std, negate=negate)
967+
super().__init__(noise_std=noise_std, negate=negate, dtype=dtype)
958968

959969
@staticmethod
960970
def _g(X: Tensor) -> Tensor:
@@ -1246,15 +1256,17 @@ def __init__(
12461256
noise_std: None | float | list[float] = None,
12471257
constraint_noise_std: None | float | list[float] = None,
12481258
negate: bool = False,
1259+
dtype: torch.dtype = torch.double,
12491260
) -> None:
12501261
r"""
12511262
Args:
12521263
noise_std: Standard deviation of the observation noise of the objectives.
12531264
constraint_noise_std: Standard deviation of the observation noise of the
12541265
constraint.
12551266
negate: If True, negate the function.
1267+
dtype: The dtype that is used for the bounds of the function.
12561268
"""
1257-
super().__init__(noise_std=noise_std, negate=negate)
1269+
super().__init__(noise_std=noise_std, negate=negate, dtype=dtype)
12581270
con_bounds = torch.tensor(self._con_bounds, dtype=self.bounds.dtype).transpose(
12591271
-1, -2
12601272
)
@@ -1357,6 +1369,7 @@ def __init__(
13571369
noise_std: None | float | list[float] = None,
13581370
constraint_noise_std: None | float | list[float] = None,
13591371
negate: bool = False,
1372+
dtype: torch.dtype = torch.double,
13601373
) -> None:
13611374
r"""
13621375
Args:
@@ -1365,12 +1378,13 @@ def __init__(
13651378
constraint_noise_std: Standard deviation of the observation noise of the
13661379
constraints.
13671380
negate: If True, negate the function.
1381+
dtype: The dtype that is used for the bounds of the function.
13681382
"""
13691383
if dim < 2:
13701384
raise ValueError("dim must be greater than or equal to 2.")
13711385
self.dim = dim
13721386
self._bounds = [(0.0, 1.0) for _ in range(self.dim)]
1373-
super().__init__(noise_std=noise_std, negate=negate)
1387+
super().__init__(noise_std=noise_std, negate=negate, dtype=dtype)
13741388
self.constraint_noise_std = constraint_noise_std
13751389

13761390
def LA2(self, A, B, C, D, theta) -> Tensor:

botorch/test_functions/sensitivity_analysis.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,18 @@ class Ishigami(SyntheticTestFunction):
2424
"""
2525

2626
def __init__(
27-
self, b: float = 0.1, noise_std: float | None = None, negate: bool = False
27+
self,
28+
b: float = 0.1,
29+
noise_std: float | None = None,
30+
negate: bool = False,
31+
dtype: torch.dtype = torch.double,
2832
) -> None:
2933
r"""
3034
Args:
3135
b: the b constant, should be 0.1 or 0.05.
3236
noise_std: Standard deviation of the observation noise.
3337
negative: If True, negative the objective.
38+
dtype: The dtype that is used for the bounds of the function.
3439
"""
3540
self._optimizers = None
3641
if b not in (0.1, 0.05):
@@ -52,7 +57,7 @@ def __init__(
5257
self.dgsm_gradient_square = [2.8, 24.5, 11]
5358
self._bounds = [(-math.pi, math.pi) for _ in range(self.dim)]
5459
self.b = b
55-
super().__init__(noise_std=noise_std, negate=negate)
60+
super().__init__(noise_std=noise_std, negate=negate, dtype=dtype)
5661

5762
@property
5863
def _optimal_value(self) -> float:
@@ -127,13 +132,15 @@ def __init__(
127132
a: list = None,
128133
noise_std: float | None = None,
129134
negate: bool = False,
135+
dtype: torch.dtype = torch.double,
130136
) -> None:
131137
r"""
132138
Args:
133139
dim: Dimensionality of the problem. If 6, 8, or 15, will use standard a.
134140
a: a parameter, unless dim is 6, 8, or 15.
135141
noise_std: Standard deviation of observation noise.
136-
negate: Return negatie of function.
142+
negate: Return negative of function.
143+
dtype: The dtype that is used for the bounds of the function.
137144
"""
138145
self._optimizers = None
139146
self.dim = dim
@@ -163,7 +170,7 @@ def __init__(
163170
else:
164171
self.a = a
165172
self.optimal_sobol_indicies()
166-
super().__init__(noise_std=noise_std, negate=negate)
173+
super().__init__(noise_std=noise_std, negate=negate, dtype=dtype)
167174

168175
@property
169176
def _optimal_value(self) -> float:
@@ -207,11 +214,17 @@ class Morris(SyntheticTestFunction):
207214
Proposed to test sensitivity analysis methods
208215
"""
209216

210-
def __init__(self, noise_std: float | None = None, negate: bool = False) -> None:
217+
def __init__(
218+
self,
219+
noise_std: float | None = None,
220+
negate: bool = False,
221+
dtype: torch.dtype = torch.double,
222+
) -> None:
211223
r"""
212224
Args:
213225
noise_std: Standard deviation of observation noise.
214226
negate: Return negative of function.
227+
dtype: The dtype that is used for the bounds of the function.
215228
"""
216229
self._optimizers = None
217230
self.dim = 20
@@ -238,7 +251,7 @@ def __init__(self, noise_std: float | None = None, negate: bool = False) -> None
238251
0,
239252
0,
240253
]
241-
super().__init__(noise_std=noise_std, negate=negate)
254+
super().__init__(noise_std=noise_std, negate=negate, dtype=dtype)
242255

243256
@property
244257
def _optimal_value(self) -> float:

0 commit comments

Comments
 (0)