Skip to content

Commit 246b13d

Browse files
esantorellafacebook-github-bot
authored andcommitted
Make get_best_raw_objective_point_with_trial_index return a TModelPredictArm (#3753)
Summary: Pull Request resolved: #3753 In every usage of `get_best_raw_objective_point_with_trial_index`, its third return value is either ignored or changed into a TModelPredictArm. This PR * makes it return a TModelPredictArm, which makes some downstream functions more unitary. * Fixes some type annotations that were too broad (the TModelPredictArm won't be None) * Updates docstrings Reviewed By: Balandat Differential Revision: D74401853 fbshipit-source-id: bd184c2e6f0c5e673d08d6fb5278cf2029063cfc
1 parent bf42b8f commit 246b13d

File tree

2 files changed

+16
-27
lines changed

2 files changed

+16
-27
lines changed

ax/service/managed_loop.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -260,17 +260,10 @@ def get_best_point(self) -> tuple[TParameterization, TModelPredictArm | None]:
260260
return parameterizations, predictions
261261

262262
# Could not find through model, default to using raw objective.
263-
_, parameterization, values = get_best_raw_objective_point_with_trial_index(
264-
experiment=self.experiment
265-
)
266-
# For values, grab just the means to conform to TModelPredictArm format.
267-
return (
268-
parameterization,
269-
(
270-
{k: v[0] for k, v in values.items()}, # v[0] is mean
271-
{k: {k: v[1] * v[1]} for k, v in values.items()}, # v[1] is sem
272-
),
263+
_, parameterization, predict_arm = (
264+
get_best_raw_objective_point_with_trial_index(experiment=self.experiment)
273265
)
266+
return parameterization, predict_arm
274267

275268
def get_current_model(self) -> Adapter | None:
276269
"""Obtain the most recently used model in optimization."""

ax/service/utils/best_point.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,17 @@ def get_best_raw_objective_point_with_trial_index(
7676
experiment: Experiment,
7777
optimization_config: OptimizationConfig | None = None,
7878
trial_indices: Iterable[int] | None = None,
79-
) -> tuple[int, TParameterization, dict[str, tuple[float, float]]]:
79+
) -> tuple[int, TParameterization, TModelPredictArm]:
8080
"""Given an experiment, identifies the arm that had the best raw objective,
8181
based on the data fetched from the experiment.
8282
83-
Note: This function will error with an invalid configuration. If you would
83+
TModelPredictArm is of the form:
84+
({metric_name: mean}, {metric_name_1: {metric_name_2: cov_1_2}})
85+
86+
Note: This function will error with invalid inputs. If you would
8487
prefer for error logs rather than exceptions, use
85-
`get_best_by_raw_objective_with_trial_index`.
88+
`get_best_by_raw_objective_with_trial_index`, which returns None if
89+
inputs are invalid.
8690
8791
Args:
8892
experiment: Experiment, on which to identify best raw objective arm.
@@ -92,8 +96,7 @@ def get_best_raw_objective_point_with_trial_index(
9296
retrieve data from all available trials.
9397
9498
Returns:
95-
Tuple of parameterization and a mapping from metric name to a tuple of
96-
the corresponding objective mean and SEM.
99+
Tuple of trial index, parameterization, and model predictions for it.
97100
"""
98101
optimization_config = optimization_config or experiment.optimization_config
99102
if optimization_config is None:
@@ -171,8 +174,9 @@ def get_best_raw_objective_point_with_trial_index(
171174
row["metric_name"]: (row["mean"], row["sem"])
172175
for _, row in objective_rows.iterrows()
173176
}
177+
predict_arm = _raw_values_to_model_predict_arm(values=vals)
174178

175-
return best_trial_index, none_throws(best_arm).parameters, vals
179+
return best_trial_index, none_throws(best_arm).parameters, predict_arm
176180

177181

178182
def _extract_best_arm_from_gr(
@@ -320,7 +324,7 @@ def get_best_by_raw_objective_with_trial_index(
320324
experiment: Experiment,
321325
optimization_config: OptimizationConfig | None = None,
322326
trial_indices: Iterable[int] | None = None,
323-
) -> tuple[int, TParameterization, TModelPredictArm | None] | None:
327+
) -> tuple[int, TParameterization, TModelPredictArm] | None:
324328
"""Given an experiment, identifies the arm that had the best raw objective,
325329
based on the data fetched from the experiment.
326330
@@ -341,11 +345,7 @@ def get_best_by_raw_objective_with_trial_index(
341345
Tuple of trial index, parameterization, and model predictions for it.
342346
"""
343347
try:
344-
(
345-
trial_index,
346-
parameterization,
347-
values,
348-
) = get_best_raw_objective_point_with_trial_index(
348+
result = get_best_raw_objective_point_with_trial_index(
349349
experiment=experiment,
350350
optimization_config=optimization_config,
351351
trial_indices=trial_indices,
@@ -356,11 +356,7 @@ def get_best_by_raw_objective_with_trial_index(
356356
f"'{err}'. Returning None."
357357
)
358358
return None
359-
return (
360-
trial_index,
361-
parameterization,
362-
_raw_values_to_model_predict_arm(values),
363-
)
359+
return result
364360

365361

366362
def get_pareto_optimal_parameters(

0 commit comments

Comments
 (0)