Skip to content

Commit 4f36140

Browse files
sdaultonfacebook-github-bot
authored andcommitted
update tests for MergeRepeatedMeasurements (#1607)
Summary: Pull Request resolved: #1607 updates per comments after landing D45558427 Reviewed By: esantorella Differential Revision: D45608725 fbshipit-source-id: 2ee32e87e575af7b8dbe5ecba6688a8104170e7f
1 parent 50b9951 commit 4f36140

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

ax/modelbridge/tests/test_merge_repeated_measurements_transform.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,16 @@ def compare_obs(
1818
test.assertEqual(obs1.data.metric_names, obs2.data.metric_names)
1919
test.assertTrue(np.array_equal(obs1.data.means, obs2.data.means))
2020
discrep = np.max(np.abs(obs1.data.covariance - obs2.data.covariance))
21-
test.assertTrue(discrep <= discrepancy_tol)
22-
test.assertTrue(obs1.features.parameters == obs2.features.parameters)
21+
test.assertLessEqual(discrep, discrepancy_tol)
22+
test.assertEqual(obs1.features.parameters, obs2.features.parameters)
2323

2424

2525
class MergeRepeatedMeasurementsTransformTest(TestCase):
2626
def testTransform(self) -> None:
2727
obs_feats1 = ObservationFeatures(parameters={"a": 0.0})
28-
with self.assertRaises(RuntimeError):
28+
with self.assertRaisesRegex(
29+
RuntimeError, "MergeRepeatedMeasurements requires observations"
30+
):
2931
# test that observations are required
3032
MergeRepeatedMeasurements()
3133
# test nan in covariance
@@ -37,7 +39,9 @@ def testTransform(self) -> None:
3739
),
3840
features=obs_feats1,
3941
)
40-
with self.assertRaises(NotImplementedError):
42+
with self.assertRaisesRegex(
43+
NotImplementedError, "All metrics must have noise observations."
44+
):
4145
MergeRepeatedMeasurements(observations=[observation])
4246
# test full covariance
4347
observation = Observation(
@@ -48,7 +52,9 @@ def testTransform(self) -> None:
4852
),
4953
features=obs_feats1,
5054
)
51-
with self.assertRaises(NotImplementedError):
55+
with self.assertRaisesRegex(
56+
NotImplementedError, "Only independent metrics are currently supported."
57+
):
5258
MergeRepeatedMeasurements(observations=[observation])
5359

5460
# test noiseless, different means
@@ -71,7 +77,11 @@ def testTransform(self) -> None:
7177
features=obs_feats1,
7278
),
7379
]
74-
with self.assertRaises(ValueError):
80+
with self.assertRaisesRegex(
81+
ValueError,
82+
"All repeated arms with noiseless measurements "
83+
"must have the same means.",
84+
):
7585
MergeRepeatedMeasurements(observations=observations)
7686
# test noiseless, same means
7787
observations = [

ax/modelbridge/transforms/merge_repeated_measurements.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def __init__(
4444
str, DefaultDict[str, DefaultDict[str, List[float]]]
4545
] = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
4646
observation_features, observation_data = separate_observations(observations)
47-
#
4847
for j, obsd in enumerate(observation_data):
4948
# This intentionally ignores the trial index
5049
key = Arm.md5hash(observation_features[j].parameters)

0 commit comments

Comments
 (0)