Skip to content

Commit 0b6946c

Browse files
esantorellafacebook-github-bot
authored andcommitted
Remove _get_best_row_for_scalarized_objective and _get_best_row_for_single_objective (#3751)
Summary: Pull Request resolved: #3751 * Remove `_get_best_row_for_scalarized_objective` and `_get_best_row_for_single_objective`, enabled by cutting `get_best_raw_objective_point_with_trial_index` over to use `get_trace_by_arm_pull_from_data` * Add more careful no-data handling for `get_trace_by_arm_pull_from_data`, including for empty data and observations with only some metrics observed Note: `get_best_raw_objective_point_with_trial_index` and `get_trace` are now more similar because they both use `get_trace_by_arm_pull_from_data`. However, they differ in a few aspects such as how they treat multi-objective configs and out-of-design points. I think it would be good if these were closer together, but probably not low-ROI at the moment. Reviewed By: Balandat Differential Revision: D74282536 fbshipit-source-id: 4612cab52ad4e53af85efb54d3d0c18568462b0d
1 parent 4bbc415 commit 0b6946c

File tree

3 files changed

+99
-85
lines changed

3 files changed

+99
-85
lines changed

ax/core/tests/test_objective.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def setUp(self) -> None:
3939

4040
def test_Init(self) -> None:
4141
with self.assertRaisesRegex(UserInputError, "does not specify"):
42-
(Objective(metric=self.metrics["m1"]),)
42+
Objective(metric=self.metrics["m1"])
4343
with self.assertRaisesRegex(
4444
UserInputError, "doesn't match the specified optimization direction"
4545
):

ax/service/tests/test_best_point_utils.py

Lines changed: 58 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,46 @@ def test_get_hypervolume_trace_of_outcomes_multi_objective(self) -> None:
177177
self.assertEqual(hvs, [0.0, 2.0, 0.0, 2.0])
178178

179179
def test_get_trace_by_arm_pull_from_data(self) -> None:
180+
objective = Objective(metric=Metric("m1"), minimize=False)
181+
optimzation_config = OptimizationConfig(
182+
objective=objective,
183+
outcome_constraints=[
184+
OutcomeConstraint(
185+
Metric("m2"), op=ComparisonOp.GEQ, bound=0.0, relative=False
186+
)
187+
],
188+
)
189+
190+
with self.subTest("No data"):
191+
df = pd.DataFrame()
192+
result = get_trace_by_arm_pull_from_data(
193+
df=df, optimization_config=optimzation_config
194+
)
195+
self.assertTrue(
196+
result.equals(
197+
pd.DataFrame(columns=["trial_index", "arm_name", "value"])
198+
)
199+
)
200+
201+
with self.subTest("Data for wrong metric"):
202+
df = pd.DataFrame.from_records(
203+
data=[
204+
{
205+
"trial_index": 0,
206+
"arm_name": "0_0",
207+
"metric_name": "wrong",
208+
"mean": 1.0,
209+
"sem": None,
210+
},
211+
]
212+
)
213+
with self.assertRaisesRegex(
214+
ValueError, "Some metrics are not present for all trials and arms"
215+
):
216+
result = get_trace_by_arm_pull_from_data(
217+
df=df, optimization_config=optimzation_config
218+
)
219+
180220
df = pd.DataFrame.from_records(
181221
data=[
182222
{"trial_index": 0, "arm_name": "0_0", "metric_name": "m1", "mean": 1.0},
@@ -200,9 +240,8 @@ def test_get_trace_by_arm_pull_from_data(self) -> None:
200240
]
201241
).assign(sem=None)
202242

203-
objective = Objective(metric=Metric("m1"), minimize=False)
204243
with self.subTest("Relative optimization config not supported"):
205-
optimization_config = OptimizationConfig(
244+
rel_optimization_config = OptimizationConfig(
206245
objective=objective,
207246
outcome_constraints=[
208247
OutcomeConstraint(
@@ -217,17 +256,9 @@ def test_get_trace_by_arm_pull_from_data(self) -> None:
217256
ValueError, "Relativized optimization config not supported"
218257
):
219258
get_trace_by_arm_pull_from_data(
220-
df=df, optimization_config=optimization_config
259+
df=df, optimization_config=rel_optimization_config
221260
)
222261

223-
optimzation_config = OptimizationConfig(
224-
objective=objective,
225-
outcome_constraints=[
226-
OutcomeConstraint(
227-
Metric("m2"), op=ComparisonOp.GEQ, bound=0.0, relative=False
228-
)
229-
],
230-
)
231262
with self.subTest("Single objective, cumulative"):
232263
result = get_trace_by_arm_pull_from_data(
233264
df=df, optimization_config=optimzation_config, use_cumulative_best=True
@@ -382,13 +413,17 @@ def test_best_raw_objective_point(self) -> None:
382413
generator_run=GeneratorRun(arms=[Arm(parameters={"x1": 5.0, "x2": 5.0})])
383414
).run().complete()
384415
exp.fetch_data()
385-
# pyre-fixme[16]: Optional type has no attribute `clone`.
386-
opt_conf = exp.optimization_config.clone()
387-
opt_conf.objective.metric._name = "not_branin"
388-
with self.assertRaisesRegex(ValueError, "No data has been logged"):
389-
get_best_raw_objective_point_with_trial_index(
390-
experiment=exp, optimization_config=opt_conf
416+
417+
with self.subTest("Data present but not for needed metrics"):
418+
opt_conf = OptimizationConfig(
419+
objective=Objective(metric=get_branin_metric(name="not_branin"))
391420
)
421+
with self.assertRaisesRegex(
422+
ValueError, "Some metrics are not present for all trials and arms"
423+
):
424+
get_best_raw_objective_point_with_trial_index(
425+
experiment=exp, optimization_config=opt_conf
426+
)
392427

393428
# Test constraints work as expected.
394429
observations = [[1.0, 2.0], [3.0, 4.0], [-5.0, -6.0]]
@@ -446,13 +481,13 @@ def test_best_raw_objective_point_unsatisfiable_relative(self) -> None:
446481
opt_conf.outcome_constraints[0].relative = True
447482
opt_conf.outcome_constraints[0].bound = 9999
448483

449-
with self.assertLogs(logger=best_point_logger, level="WARN") as lg:
484+
with self.assertRaisesRegex(
485+
DataRequiredError,
486+
"Optimization config has relative constraint, but model was not fit"
487+
" with status quo.",
488+
):
450489
get_best_raw_objective_point_with_trial_index(
451-
exp, optimization_config=opt_conf
452-
)
453-
self.assertTrue(
454-
any("No status quo provided" in warning for warning in lg.output),
455-
msg=lg.output,
490+
experiment=exp, optimization_config=opt_conf
456491
)
457492

458493
exp.status_quo = exp.trials[0].arms[0]

ax/service/utils/best_point.py

Lines changed: 40 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ def get_best_raw_objective_point_with_trial_index(
8080
"""Given an experiment, identifies the arm that had the best raw objective,
8181
based on the data fetched from the experiment.
8282
83+
Note: This function will error with an invalid configuration. If you would
84+
prefer for error logs rather than exceptions, use
85+
`get_best_by_raw_objective_with_trial_index`.
86+
8387
Args:
8488
experiment: Experiment, on which to identify best raw objective arm.
8589
optimization_config: Optimization config to use in place of the one stored
@@ -108,15 +112,10 @@ def get_best_raw_objective_point_with_trial_index(
108112
if dat.df.empty:
109113
raise ValueError("Cannot identify best point if experiment contains no data.")
110114
if any(oc.relative for oc in optimization_config.all_constraints):
111-
if experiment.status_quo is not None:
112-
optimization_config = derelativize_opt_config(
113-
optimization_config=optimization_config,
114-
experiment=experiment,
115-
)
116-
else:
117-
logger.warning(
118-
"No status quo provided; relative constraints will be ignored."
119-
)
115+
optimization_config = derelativize_opt_config(
116+
optimization_config=optimization_config,
117+
experiment=experiment,
118+
)
120119

121120
# Only COMPLETED trials should be considered when identifying the best point
122121
completed_indices = {
@@ -146,15 +145,22 @@ def get_best_raw_objective_point_with_trial_index(
146145
raise ValueError("No feasible points are in the search space.")
147146

148147
in_design_df = feasible_df.loc[is_in_design]
148+
value_by_arm_pull = get_trace_by_arm_pull_from_data(
149+
df=in_design_df,
150+
optimization_config=optimization_config,
151+
use_cumulative_best=False,
152+
)
149153

150-
objective = optimization_config.objective
151-
best_row_helper = (
152-
_get_best_row_for_scalarized_objective
153-
if isinstance(objective, ScalarizedObjective)
154-
else _get_best_row_for_single_objective
154+
maximize = isinstance(optimization_config.objective, MultiObjective) or (
155+
not optimization_config.objective.minimize
156+
)
157+
best_row_idx = (
158+
value_by_arm_pull["value"].idxmax()
159+
if maximize
160+
else value_by_arm_pull["value"].idxmin()
155161
)
156-
# pyre-ignore Incompatible parameter type [6]
157-
best_row = best_row_helper(df=in_design_df, objective=objective)
162+
best_row = value_by_arm_pull.loc[best_row_idx]
163+
158164
best_arm = experiment.arms_by_name[best_row["arm_name"]]
159165
best_trial_index = int(best_row["trial_index"])
160166
objective_rows = dat.df.loc[
@@ -321,6 +327,9 @@ def get_best_by_raw_objective_with_trial_index(
321327
TModelPredictArm is of the form:
322328
({metric_name: mean}, {metric_name_1: {metric_name_2: cov_1_2}})
323329
330+
This is a version of `get_best_raw_objective_point_with_trial_index` that
331+
logs errors rather than letting exceptions be raised.
332+
324333
Args:
325334
experiment: Experiment, on which to identify best raw objective arm.
326335
optimization_config: Optimization config to use in place of the one stored
@@ -468,51 +477,6 @@ def get_pareto_optimal_parameters(
468477
return res
469478

470479

471-
# NOTE: This function will be removed in the next PR.
472-
def _get_best_row_for_scalarized_objective(
473-
df: pd.DataFrame,
474-
objective: ScalarizedObjective,
475-
) -> pd.Series:
476-
df = df.copy()
477-
# First, add a weight column, setting 0.0 if the metric is not part
478-
# of the objective
479-
metric_to_weight = {
480-
m.name: objective.weights[i] for i, m in enumerate(objective.metrics)
481-
}
482-
df["weight"] = df["metric_name"].apply(lambda x: metric_to_weight.get(x) or 0.0)
483-
# Now, calculate the weighted linear combination via groupby,
484-
# filtering out NaN for missing data
485-
df["weighted_mean"] = df["mean"] * df["weight"]
486-
groupby_df = (
487-
df[["arm_name", "trial_index", "weighted_mean"]]
488-
.groupby(["arm_name", "trial_index"], as_index=False)
489-
.sum(min_count=1)
490-
.dropna()
491-
)
492-
if groupby_df.empty:
493-
raise ValueError("No data has been logged for scalarized objective.")
494-
return (
495-
groupby_df.loc[groupby_df["weighted_mean"].idxmin()]
496-
if objective.minimize
497-
else groupby_df.loc[groupby_df["weighted_mean"].idxmax()]
498-
)
499-
500-
501-
# NOTE: This function will be removed in the next PR.
502-
def _get_best_row_for_single_objective(
503-
df: pd.DataFrame, objective: Objective
504-
) -> pd.Series:
505-
objective_name = objective.metric.name
506-
objective_rows = df.loc[df["metric_name"] == objective_name]
507-
if objective_rows.empty:
508-
raise ValueError(f'No data has been logged for objective "{objective_name}".')
509-
return (
510-
objective_rows.loc[objective_rows["mean"].idxmin()]
511-
if objective.minimize
512-
else objective_rows.loc[objective_rows["mean"].idxmax()]
513-
)
514-
515-
516480
def _is_row_feasible(
517481
df: pd.DataFrame,
518482
optimization_config: OptimizationConfig,
@@ -779,6 +743,11 @@ def get_trace_by_arm_pull_from_data(
779743
"`Derelativize` the optimization config, or use `get_trace`."
780744
)
781745

746+
empty_result = pd.DataFrame(columns=["trial_index", "arm_name", "value"])
747+
748+
if len(df) == 0:
749+
return empty_result
750+
782751
# reshape data to wide, using only the metrics in the optimization config
783752
metrics = list(optimization_config.metrics.keys())
784753

@@ -793,6 +762,16 @@ def get_trace_by_arm_pull_from_data(
793762
.set_index(["trial_index", "arm_name", "metric_name"])["mean"]
794763
.unstack(level="metric_name")
795764
)
765+
missing_metrics = [
766+
m for m in metrics if m not in df_wide.columns or df_wide[m].isnull().any()
767+
]
768+
if len(missing_metrics) > 0:
769+
raise ValueError(
770+
"Some metrics are not present for all trials and arms. The "
771+
f"following are missing: {missing_metrics}."
772+
)
773+
if len(df_wide) == 0:
774+
return empty_result
796775
df_wide["feasible"] = df.groupby(["trial_index", "arm_name"])["row_feasible"].all()
797776
df_wide.reset_index(inplace=True)
798777

0 commit comments

Comments
 (0)