diff --git a/examples/validation/_03_custom_scorer.py b/examples/validation/_03_custom_scorer.py index 376271f..481b516 100644 --- a/examples/validation/_03_custom_scorer.py +++ b/examples/validation/_03_custom_scorer.py @@ -517,9 +517,8 @@ def score( # There we have full control over the results and can calculate the precision, recall and f1-score across all # datapoints and remove the raw matches from the final results. -def score( - pipeline: MyPipeline, datapoint: ECGExampleData, *, tolerance_s: float -): + +def score(pipeline: MyPipeline, datapoint: ECGExampleData, *, tolerance_s: float): # We use the `safe_run` wrapper instead of just run. This is always a good idea. # We don't need to clone the pipeline here, as GridSearch will already clone the pipeline internally and `run` # will clone it again. @@ -537,7 +536,10 @@ def score( "_raw": no_agg(matches), } -def final_aggregator(agg_results: dict[str, float], raw_results: dict[str, list], pipeline: MyPipeline, dataset:ECGExampleData): + +def final_aggregator( + agg_results: dict[str, float], raw_results: dict[str, list], pipeline: MyPipeline, dataset: ECGExampleData +): # We use pop, as we don't want to have the raw matches in our final results, but this is up to you. raw_matches = raw_results.pop("_raw") matches = np.vstack(raw_matches) @@ -547,7 +549,10 @@ def final_aggregator(agg_results: dict[str, float], raw_results: dict[str, list] agg_results["per_sample__f1_score"] = f1_score return agg_results, raw_results -agg_final_agg, raw_final_agg = Scorer(partial(score, tolerance_s=0.02), final_aggregator=final_aggregator)(pipe, example_data) + +agg_final_agg, raw_final_agg = Scorer(partial(score, tolerance_s=0.02), final_aggregator=final_aggregator)( + pipe, example_data +) agg_final_agg