Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
AKuederle committed Oct 18, 2024
1 parent c0d7379 commit 1632923
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions examples/validation/_03_custom_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,9 +517,8 @@ def score(
# There we have full control over the results and can calculate the precision, recall and f1-score across all
# datapoints and remove the raw matches from the final results.

def score(
pipeline: MyPipeline, datapoint: ECGExampleData, *, tolerance_s: float
):

def score(pipeline: MyPipeline, datapoint: ECGExampleData, *, tolerance_s: float):
# We use the `safe_run` wrapper instead of just run. This is always a good idea.
# We don't need to clone the pipeline here, as GridSearch will already clone the pipeline internally and `run`
# will clone it again.
Expand All @@ -537,7 +536,10 @@ def score(
"_raw": no_agg(matches),
}

def final_aggregator(agg_results: dict[str, float], raw_results: dict[str, list], pipeline: MyPipeline, dataset:ECGExampleData):

def final_aggregator(
agg_results: dict[str, float], raw_results: dict[str, list], pipeline: MyPipeline, dataset: ECGExampleData
):
# We use pop, as we don't want to have the raw matches in our final results, but this is up to you.
raw_matches = raw_results.pop("_raw")
matches = np.vstack(raw_matches)
Expand All @@ -547,7 +549,10 @@ def final_aggregator(agg_results: dict[str, float], raw_results: dict[str, list]
agg_results["per_sample__f1_score"] = f1_score
return agg_results, raw_results

agg_final_agg, raw_final_agg = Scorer(partial(score, tolerance_s=0.02), final_aggregator=final_aggregator)(pipe, example_data)

agg_final_agg, raw_final_agg = Scorer(partial(score, tolerance_s=0.02), final_aggregator=final_aggregator)(
pipe, example_data
)

agg_final_agg

Expand Down

0 comments on commit 1632923

Please sign in to comment.