Skip to content

Commit d7d9291

Browse files
committed
Define regression metrics at the module level
1 parent 25fd899 commit d7d9291

File tree

1 file changed

+13
-11
lines changed
  • benchmarks/domains/regression

1 file changed

+13
-11
lines changed

benchmarks/domains/regression/base.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,18 @@ def spearman_rho_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
5757
return rho
5858

5959

60+
REGRESSION_METRICS = [
61+
root_mean_squared_error,
62+
mean_squared_error,
63+
r2_score,
64+
mean_absolute_error,
65+
max_error,
66+
explained_variance_score,
67+
kendall_tau_score,
68+
spearman_rho_score,
69+
]
70+
71+
6072
def run_tl_regression_benchmark(
6173
settings: TransferLearningRegressionBenchmarkSettings,
6274
load_data_fn: Callable[..., pd.DataFrame],
@@ -387,20 +399,10 @@ def _calculate_metrics(
387399
Returns:
388400
Dictionary with metric names as keys and metric values as values
389401
"""
390-
regression_metrics = [
391-
root_mean_squared_error,
392-
mean_squared_error,
393-
r2_score,
394-
mean_absolute_error,
395-
max_error,
396-
explained_variance_score,
397-
kendall_tau_score,
398-
spearman_rho_score,
399-
]
400402
results = {}
401403
pred_values = predictions[f"{target_column}_mean"].values
402404

403-
for metric_func in regression_metrics:
405+
for metric_func in REGRESSION_METRICS:
404406
metric_value = metric_func(true_values, pred_values)
405407
results[metric_func.__name__] = metric_value
406408

0 commit comments

Comments
 (0)