Skip to content

Commit 2074610

Browse files
committed
Extract ensemble prediction logic to a new class to improve readability
1 parent 2635fd9 commit 2074610

6 files changed

+596
-165
lines changed

Diff for: speciesnet/ensemble.py

+15-151
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,16 @@
2020

2121
import json
2222
import time
23-
from typing import Any, Optional
23+
from typing import Any, Callable
2424

2525
from absl import logging
2626
from humanfriendly import format_timespan
2727
import PIL.ExifTags
2828
import PIL.Image
2929

3030
from speciesnet.constants import Classification
31-
from speciesnet.constants import Detection
3231
from speciesnet.constants import Failure
32+
from speciesnet.ensemble_prediction_combiner import combine_predictions_for_single_item
3333
from speciesnet.geofence_utils import geofence_animal_classification
3434
from speciesnet.geofence_utils import roll_up_labels_to_first_matching_level
3535
from speciesnet.utils import ModelInfo
@@ -44,7 +44,12 @@
4444
class SpeciesNetEnsemble:
4545
"""Ensemble component of SpeciesNet."""
4646

47-
def __init__(self, model_name: str, geofence: bool = True) -> None:
47+
def __init__(
48+
self,
49+
model_name: str,
50+
geofence: bool = True,
51+
prediction_combiner: Callable = combine_predictions_for_single_item,
52+
) -> None:
4853
"""Loads the ensemble resources.
4954
5055
Args:
@@ -62,6 +67,7 @@ def __init__(self, model_name: str, geofence: bool = True) -> None:
6267
self.enable_geofence = geofence
6368
self.taxonomy_map = self.load_taxonomy()
6469
self.geofence_map = self.load_geofence()
70+
self.prediction_combiner = prediction_combiner
6571

6672
end_time = time.time()
6773
logging.info(
@@ -101,153 +107,6 @@ def load_geofence(self):
101107
geofence_map = json.load(fp)
102108
return geofence_map
103109

104-
def _combine_predictions_for_single_item(
105-
self,
106-
classifications: dict[str, list],
107-
detections: list[dict],
108-
country: Optional[str],
109-
admin1_region: Optional[str],
110-
) -> PredictionType:
111-
"""Ensembles classifications and detections for a single image.
112-
113-
This operation leverages multiple heuristics to make the most of the classifier
114-
and the detector predictions through a complex set of decisions. It introduces
115-
various thresholds to identify humans, vehicles, blanks, animals at species
116-
level, animals at higher taxonomy levels and even unknowns.
117-
118-
Args:
119-
classifications:
120-
Dict of classification results. "classes" and "scores" are expected to
121-
be provided among the dict keys.
122-
detections:
123-
List of detection results, sorted in decreasing order of their
124-
confidence score. Each detection is expected to be a dict providing
125-
"label" and "conf" among its keys.
126-
country:
127-
Country (in ISO 3166-1 alpha-3 format) associated with predictions.
128-
Optional.
129-
admin1_region:
130-
First-level administrative division (in ISO 3166-2 format) associated
131-
with predictions. Optional.
132-
133-
Returns:
134-
A tuple of <label, score, prediction_source> describing the ensemble result.
135-
"""
136-
137-
top_classification_class = classifications["classes"][0]
138-
top_classification_score = classifications["scores"][0]
139-
top_detection_class = detections[0]["label"] if detections else Detection.ANIMAL
140-
top_detection_score = detections[0]["conf"] if detections else 0.0
141-
142-
if top_detection_class == Detection.HUMAN:
143-
# Threshold #1a: high-confidence HUMAN detections.
144-
if top_detection_score > 0.7:
145-
return Classification.HUMAN, top_detection_score, "detector"
146-
147-
# Threshold #1b: mid-confidence HUMAN detections + high-confidence
148-
# HUMAN/VEHICLE classifications.
149-
if (
150-
top_detection_score > 0.2
151-
and top_classification_class
152-
in {Classification.HUMAN, Classification.VEHICLE}
153-
and top_classification_score > 0.5
154-
):
155-
return Classification.HUMAN, top_classification_score, "classifier"
156-
157-
if top_detection_class == Detection.VEHICLE:
158-
# Threshold #2a: mid-confidence VEHICLE detections + high-confidence HUMAN
159-
# classifications.
160-
if (
161-
top_detection_score > 0.2
162-
and top_classification_class == Classification.HUMAN
163-
and top_classification_score > 0.5
164-
):
165-
return Classification.HUMAN, top_classification_score, "classifier"
166-
167-
# Threshold #2b: high-confidence VEHICLE detections.
168-
if top_detection_score > 0.7:
169-
return Classification.VEHICLE, top_detection_score, "detector"
170-
171-
# Threshold #2c: mid-confidence VEHICLE detections + high-confidence VEHICLE
172-
# classifications.
173-
if (
174-
top_detection_score > 0.2
175-
and top_classification_class == Classification.VEHICLE
176-
and top_classification_score > 0.4
177-
):
178-
return Classification.VEHICLE, top_classification_score, "classifier"
179-
180-
# Threshold #3a: high-confidence BLANK "detections" + high-confidence BLANK
181-
# classifications.
182-
if (
183-
top_detection_score < 0.2
184-
and top_classification_class == Classification.BLANK
185-
and top_classification_score > 0.5
186-
):
187-
return Classification.BLANK, top_classification_score, "classifier"
188-
189-
# Threshold #3b: extra-high-confidence BLANK classifications.
190-
if (
191-
top_classification_class == Classification.BLANK
192-
and top_classification_score > 0.99
193-
):
194-
return Classification.BLANK, top_classification_score, "classifier"
195-
196-
if top_classification_class not in {
197-
Classification.BLANK,
198-
Classification.HUMAN,
199-
Classification.VEHICLE,
200-
}:
201-
# Threshold #4a: extra-high-confidence ANIMAL classifications.
202-
if top_classification_score > 0.8:
203-
return geofence_animal_classification(
204-
labels=classifications["classes"],
205-
scores=classifications["scores"],
206-
country=country,
207-
admin1_region=admin1_region,
208-
taxonomy_map=self.taxonomy_map,
209-
geofence_map=self.geofence_map,
210-
enable_geofence=self.enable_geofence,
211-
)
212-
213-
# Threshold #4b: high-confidence ANIMAL classifications + mid-confidence
214-
# ANIMAL detections.
215-
if (
216-
top_classification_score > 0.65
217-
and top_detection_class == Detection.ANIMAL
218-
and top_detection_score > 0.2
219-
):
220-
return geofence_animal_classification(
221-
labels=classifications["classes"],
222-
scores=classifications["scores"],
223-
country=country,
224-
admin1_region=admin1_region,
225-
taxonomy_map=self.taxonomy_map,
226-
geofence_map=self.geofence_map,
227-
enable_geofence=self.enable_geofence,
228-
)
229-
230-
# Threshold #5a: high-confidence ANIMAL rollups.
231-
rollup = roll_up_labels_to_first_matching_level(
232-
labels=classifications["classes"],
233-
scores=classifications["scores"],
234-
country=country,
235-
admin1_region=admin1_region,
236-
target_taxonomy_levels=["genus", "family", "order", "class", "kingdom"],
237-
non_blank_threshold=0.65,
238-
taxonomy_map=self.taxonomy_map,
239-
geofence_map=self.geofence_map,
240-
enable_geofence=self.enable_geofence,
241-
)
242-
if rollup:
243-
return rollup
244-
245-
# Threshold #5b: mid-confidence ANIMAL detections.
246-
if top_detection_class == Detection.ANIMAL and top_detection_score > 0.5:
247-
return Classification.ANIMAL, top_detection_score, "detector"
248-
249-
return Classification.UNKNOWN, top_classification_score, "classifier"
250-
251110
def combine( # pylint: disable=too-many-positional-arguments
252111
self,
253112
filepaths: list[str],
@@ -331,11 +190,16 @@ def combine( # pylint: disable=too-many-positional-arguments
331190

332191
# Most importantly, ensemble everything into a single prediction.
333192
if classifications is not None and detections is not None:
334-
prediction, score, source = self._combine_predictions_for_single_item(
193+
prediction, score, source = self.prediction_combiner(
335194
classifications=classifications,
336195
detections=detections,
337196
country=geolocation.get("country"),
338197
admin1_region=geolocation.get("admin1_region"),
198+
taxonomy_map=self.taxonomy_map,
199+
geofence_map=self.geofence_map,
200+
enable_geofence=self.enable_geofence,
201+
geofence_fn=geofence_animal_classification,
202+
roll_up_fn=roll_up_labels_to_first_matching_level,
339203
)
340204
result["prediction"] = (
341205
prediction.value

0 commit comments

Comments
 (0)