File tree Expand file tree Collapse file tree 1 file changed +13
-11
lines changed
benchmarks/domains/regression Expand file tree Collapse file tree 1 file changed +13
-11
lines changed Original file line number Diff line number Diff line change @@ -57,6 +57,18 @@ def spearman_rho_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
57
57
return rho
58
58
59
59
60
+ REGRESSION_METRICS = [
61
+ root_mean_squared_error ,
62
+ mean_squared_error ,
63
+ r2_score ,
64
+ mean_absolute_error ,
65
+ max_error ,
66
+ explained_variance_score ,
67
+ kendall_tau_score ,
68
+ spearman_rho_score ,
69
+ ]
70
+
71
+
60
72
def run_tl_regression_benchmark (
61
73
settings : TransferLearningRegressionBenchmarkSettings ,
62
74
load_data_fn : Callable [..., pd .DataFrame ],
@@ -387,20 +399,10 @@ def _calculate_metrics(
387
399
Returns:
388
400
Dictionary with metric names as keys and metric values as values
389
401
"""
390
- regression_metrics = [
391
- root_mean_squared_error ,
392
- mean_squared_error ,
393
- r2_score ,
394
- mean_absolute_error ,
395
- max_error ,
396
- explained_variance_score ,
397
- kendall_tau_score ,
398
- spearman_rho_score ,
399
- ]
400
402
results = {}
401
403
pred_values = predictions [f"{ target_column } _mean" ].values
402
404
403
- for metric_func in regression_metrics :
405
+ for metric_func in REGRESSION_METRICS :
404
406
metric_value = metric_func (true_values , pred_values )
405
407
results [metric_func .__name__ ] = metric_value
406
408
You can’t perform that action at this time.
0 commit comments