Skip to content

Commit 50b9951

Browse files
sdaultonfacebook-github-bot
authored andcommitted
add transform for merging repeated measurements (#1606)
Summary: Pull Request resolved: #1606 Use inverse-variance weighting to merge repeated observations (e.g. across different trials) for a given arm. This ignores the trial_index and assumes stationarity. Reviewed By: Balandat Differential Revision: D45558427 fbshipit-source-id: b58d27c32b14e7612d7773010544074ac7c346bd
1 parent 4cecfd8 commit 50b9951

File tree

3 files changed

+302
-0
lines changed

3 files changed

+302
-0
lines changed
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import numpy as np
8+
from ax.core.observation import Observation, ObservationData, ObservationFeatures
9+
from ax.modelbridge.transforms.merge_repeated_measurements import (
10+
MergeRepeatedMeasurements,
11+
)
12+
from ax.utils.common.testutils import TestCase
13+
14+
15+
def compare_obs(
16+
test: TestCase, obs1: Observation, obs2: Observation, discrepancy_tol: float = 1e-8
17+
) -> None:
18+
test.assertEqual(obs1.data.metric_names, obs2.data.metric_names)
19+
test.assertTrue(np.array_equal(obs1.data.means, obs2.data.means))
20+
discrep = np.max(np.abs(obs1.data.covariance - obs2.data.covariance))
21+
test.assertTrue(discrep <= discrepancy_tol)
22+
test.assertTrue(obs1.features.parameters == obs2.features.parameters)
23+
24+
25+
class MergeRepeatedMeasurementsTransformTest(TestCase):
26+
def testTransform(self) -> None:
27+
obs_feats1 = ObservationFeatures(parameters={"a": 0.0})
28+
with self.assertRaises(RuntimeError):
29+
# test that observations are required
30+
MergeRepeatedMeasurements()
31+
# test nan in covariance
32+
observation = Observation(
33+
data=ObservationData(
34+
metric_names=["m1"],
35+
means=np.array([1.0]),
36+
covariance=np.array([[float("nan")]]),
37+
),
38+
features=obs_feats1,
39+
)
40+
with self.assertRaises(NotImplementedError):
41+
MergeRepeatedMeasurements(observations=[observation])
42+
# test full covariance
43+
observation = Observation(
44+
data=ObservationData(
45+
metric_names=["m1", "m2"],
46+
means=np.array([1.0, 1.0]),
47+
covariance=np.ones((2, 2)),
48+
),
49+
features=obs_feats1,
50+
)
51+
with self.assertRaises(NotImplementedError):
52+
MergeRepeatedMeasurements(observations=[observation])
53+
54+
# test noiseless, different means
55+
zero_covar = np.zeros((1, 1))
56+
observations = [
57+
Observation(
58+
data=ObservationData(
59+
metric_names=["m1"],
60+
means=np.array([1.0]),
61+
covariance=zero_covar,
62+
),
63+
features=obs_feats1,
64+
),
65+
Observation(
66+
data=ObservationData(
67+
metric_names=["m1"],
68+
means=np.array([2.0]),
69+
covariance=zero_covar,
70+
),
71+
features=obs_feats1,
72+
),
73+
]
74+
with self.assertRaises(ValueError):
75+
MergeRepeatedMeasurements(observations=observations)
76+
# test noiseless, same means
77+
observations = [
78+
Observation(
79+
data=ObservationData(
80+
metric_names=["m1"],
81+
means=np.array([1.0]),
82+
covariance=zero_covar,
83+
),
84+
features=obs_feats1,
85+
),
86+
Observation(
87+
data=ObservationData(
88+
metric_names=["m1"],
89+
means=np.array([1.0]),
90+
covariance=zero_covar,
91+
),
92+
features=obs_feats1,
93+
),
94+
Observation(
95+
data=ObservationData(
96+
metric_names=["m1"],
97+
means=np.array([2.0]),
98+
covariance=zero_covar,
99+
),
100+
features=ObservationFeatures(parameters={"a": 2.0}),
101+
),
102+
]
103+
t = MergeRepeatedMeasurements(observations=observations)
104+
expected_obs = observations[-2:]
105+
transformed_obs = t.transform_observations(observations)
106+
for i in (0, 1):
107+
compare_obs(
108+
test=self,
109+
obs1=expected_obs[i],
110+
obs2=transformed_obs[i],
111+
discrepancy_tol=0.0,
112+
)
113+
114+
# basic test
115+
obs_feat1 = ObservationFeatures(parameters={"a": 0.0, "b": 1.0})
116+
obs1 = Observation(
117+
data=ObservationData(
118+
metric_names=["m1", "m2"],
119+
means=np.array([1.0, 2.0]),
120+
covariance=np.array(
121+
[
122+
[1.0, 0.0],
123+
[0.0, 2.0],
124+
]
125+
),
126+
),
127+
features=obs_feat1,
128+
)
129+
obs2 = Observation(
130+
data=ObservationData(
131+
metric_names=["m1", "m2"],
132+
means=np.array([1.0, 1.0]),
133+
covariance=np.array(
134+
[
135+
[1.0, 0.0],
136+
[0.0, 3.0],
137+
]
138+
),
139+
),
140+
features=obs_feat1,
141+
)
142+
# different arm
143+
obs3 = Observation(
144+
data=ObservationData(
145+
metric_names=["m1", "m2"],
146+
means=np.array([3.0, 1.0]),
147+
covariance=np.array(
148+
[
149+
[4.0, 0.0],
150+
[0.0, 5.0],
151+
]
152+
),
153+
),
154+
features=ObservationFeatures(parameters={"a": 1.0, "b": 0.0}),
155+
)
156+
expected_obs = Observation(
157+
data=ObservationData(
158+
metric_names=["m1", "m2"],
159+
means=np.array([1.0, 1.6]),
160+
covariance=np.array([[0.5, 0.0], [0.0, 1.2]]),
161+
),
162+
features=obs_feat1,
163+
)
164+
observations = [obs1, obs2, obs3]
165+
t = MergeRepeatedMeasurements(observations=observations)
166+
observations2 = t.transform_observations(observations)
167+
compare_obs(
168+
test=self, obs1=expected_obs, obs2=observations2[0], discrepancy_tol=1e-8
169+
)
170+
compare_obs(test=self, obs1=obs3, obs2=observations2[1], discrepancy_tol=0.0)
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from __future__ import annotations
8+
9+
from collections import defaultdict
10+
from typing import DefaultDict, Dict, List, Optional
11+
12+
import numpy as np
13+
from ax.core.arm import Arm
14+
from ax.core.observation import Observation, ObservationData, separate_observations
15+
from ax.core.search_space import SearchSpace
16+
from ax.modelbridge.base import ModelBridge
17+
from ax.modelbridge.transforms.base import Transform
18+
from ax.models.types import TConfig
19+
20+
21+
class MergeRepeatedMeasurements(Transform):
22+
"""Merge repeated measurements for to obtain one observation per arm.
23+
24+
Repeated measurements are merged via inverse variance weighting (e.g. over
25+
different trials). This intentionally ignores the trial index and assumes
26+
stationarity.
27+
28+
TODO: Support inverse variance weighting correlated outcomes (full covariance).
29+
30+
Note: this is not reversible.
31+
"""
32+
33+
def __init__(
34+
self,
35+
search_space: Optional[SearchSpace] = None,
36+
observations: Optional[List[Observation]] = None,
37+
modelbridge: Optional[ModelBridge] = None,
38+
config: Optional[TConfig] = None,
39+
) -> None:
40+
if observations is None:
41+
raise RuntimeError("MergeRepeatedMeasurements requires observations")
42+
# create a mapping of arm_key -> {metric_name: {means: [], vars: []}}
43+
arm_to_multi_obs: DefaultDict[
44+
str, DefaultDict[str, DefaultDict[str, List[float]]]
45+
] = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
46+
observation_features, observation_data = separate_observations(observations)
47+
#
48+
for j, obsd in enumerate(observation_data):
49+
# This intentionally ignores the trial index
50+
key = Arm.md5hash(observation_features[j].parameters)
51+
# TODO: support inverse variance weighting for multivariate distributions
52+
# (full covariance)
53+
diag = np.diag(np.diag(obsd.covariance))
54+
if np.any(np.isnan(obsd.covariance)):
55+
raise NotImplementedError("All metrics must have noise observations.")
56+
elif ~np.all(obsd.covariance == diag):
57+
raise NotImplementedError(
58+
"Only independent metrics are currently supported."
59+
)
60+
for i, m in enumerate(obsd.metric_names):
61+
arm_to_multi_obs[key][m]["means"].append(obsd.means[i])
62+
arm_to_multi_obs[key][m]["vars"].append(obsd.covariance[i, i])
63+
64+
self.arm_to_merged: DefaultDict[str, Dict[str, Dict[str, float]]] = defaultdict(
65+
dict
66+
)
67+
for k, metric_dict in arm_to_multi_obs.items():
68+
for m, v in metric_dict.items():
69+
# inverse variance weighting
70+
var = np.array(v["vars"])
71+
means = np.array(v["means"])
72+
noiseless = var == 0
73+
if np.any(noiseless):
74+
noiseless_means = means[noiseless]
75+
if (noiseless_means.shape[0] > 1) and (
76+
not np.all(noiseless_means[1:] == noiseless_means[0])
77+
):
78+
raise ValueError(
79+
"All repeated arms with noiseless measurements "
80+
"must have the same means."
81+
)
82+
self.arm_to_merged[k][m] = {
83+
"mean": noiseless_means[0],
84+
"var": 0.0,
85+
}
86+
else:
87+
inv_var = 1 / np.array(var)
88+
inv_sum_inv_var = 1 / np.sum(inv_var)
89+
weights = inv_var * inv_sum_inv_var
90+
self.arm_to_merged[k][m] = {
91+
"mean": np.sum(means * weights),
92+
"var": inv_sum_inv_var,
93+
}
94+
95+
def transform_observations(
96+
self,
97+
observations: List[Observation],
98+
) -> List[Observation]:
99+
# Transform observations
100+
new_observations = []
101+
observation_features, observation_data = separate_observations(observations)
102+
for j, obsd in enumerate(observation_data):
103+
key = Arm.md5hash(observation_features[j].parameters)
104+
# pop to ensure that the resulting observations list has one
105+
# observation per unique arm
106+
metric_dict = self.arm_to_merged.pop(key, None)
107+
if metric_dict is None:
108+
continue
109+
merged_means = np.zeros(len(obsd.metric_names))
110+
merged_covariance = np.zeros(
111+
(len(obsd.metric_names), len(obsd.metric_names))
112+
)
113+
for i, m in enumerate(obsd.metric_names):
114+
merged_metric = metric_dict[m]
115+
merged_means[i] = merged_metric["mean"]
116+
merged_covariance[i, i] = merged_metric["var"]
117+
new_obsd = ObservationData(
118+
metric_names=obsd.metric_names,
119+
means=merged_means,
120+
covariance=merged_covariance,
121+
)
122+
new_obs = Observation(features=observation_features[j], data=new_obsd)
123+
new_observations.append(new_obs)
124+
return new_observations

sphinx/source/modelbridge.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,14 @@ Transforms
259259
:undoc-members:
260260
:show-inheritance:
261261

262+
`ax.modelbridge.transforms.merge_repeated_measurements`
263+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
264+
265+
.. automodule:: ax.modelbridge.transforms.merge_repeated_measurements
266+
:members:
267+
:undoc-members:
268+
:show-inheritance:
269+
262270
`ax.modelbridge.transforms.metrics_as_task`
263271
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
264272

0 commit comments

Comments
 (0)