20
20
21
21
import json
22
22
import time
23
- from typing import Any , Optional
23
+ from typing import Any , Callable
24
24
25
25
from absl import logging
26
26
from humanfriendly import format_timespan
27
27
import PIL .ExifTags
28
28
import PIL .Image
29
29
30
30
from speciesnet .constants import Classification
31
- from speciesnet .constants import Detection
32
31
from speciesnet .constants import Failure
32
+ from speciesnet .ensemble_prediction_combiner import combine_predictions_for_single_item
33
33
from speciesnet .geofence_utils import geofence_animal_classification
34
34
from speciesnet .geofence_utils import roll_up_labels_to_first_matching_level
35
35
from speciesnet .utils import ModelInfo
44
44
class SpeciesNetEnsemble :
45
45
"""Ensemble component of SpeciesNet."""
46
46
47
- def __init__ (self , model_name : str , geofence : bool = True ) -> None :
47
+ def __init__ (
48
+ self ,
49
+ model_name : str ,
50
+ geofence : bool = True ,
51
+ prediction_combiner : Callable = combine_predictions_for_single_item ,
52
+ ) -> None :
48
53
"""Loads the ensemble resources.
49
54
50
55
Args:
@@ -62,6 +67,7 @@ def __init__(self, model_name: str, geofence: bool = True) -> None:
62
67
self .enable_geofence = geofence
63
68
self .taxonomy_map = self .load_taxonomy ()
64
69
self .geofence_map = self .load_geofence ()
70
+ self .prediction_combiner = prediction_combiner
65
71
66
72
end_time = time .time ()
67
73
logging .info (
@@ -101,153 +107,6 @@ def load_geofence(self):
101
107
geofence_map = json .load (fp )
102
108
return geofence_map
103
109
104
- def _combine_predictions_for_single_item (
105
- self ,
106
- classifications : dict [str , list ],
107
- detections : list [dict ],
108
- country : Optional [str ],
109
- admin1_region : Optional [str ],
110
- ) -> PredictionType :
111
- """Ensembles classifications and detections for a single image.
112
-
113
- This operation leverages multiple heuristics to make the most of the classifier
114
- and the detector predictions through a complex set of decisions. It introduces
115
- various thresholds to identify humans, vehicles, blanks, animals at species
116
- level, animals at higher taxonomy levels and even unknowns.
117
-
118
- Args:
119
- classifications:
120
- Dict of classification results. "classes" and "scores" are expected to
121
- be provided among the dict keys.
122
- detections:
123
- List of detection results, sorted in decreasing order of their
124
- confidence score. Each detection is expected to be a dict providing
125
- "label" and "conf" among its keys.
126
- country:
127
- Country (in ISO 3166-1 alpha-3 format) associated with predictions.
128
- Optional.
129
- admin1_region:
130
- First-level administrative division (in ISO 3166-2 format) associated
131
- with predictions. Optional.
132
-
133
- Returns:
134
- A tuple of <label, score, prediction_source> describing the ensemble result.
135
- """
136
-
137
- top_classification_class = classifications ["classes" ][0 ]
138
- top_classification_score = classifications ["scores" ][0 ]
139
- top_detection_class = detections [0 ]["label" ] if detections else Detection .ANIMAL
140
- top_detection_score = detections [0 ]["conf" ] if detections else 0.0
141
-
142
- if top_detection_class == Detection .HUMAN :
143
- # Threshold #1a: high-confidence HUMAN detections.
144
- if top_detection_score > 0.7 :
145
- return Classification .HUMAN , top_detection_score , "detector"
146
-
147
- # Threshold #1b: mid-confidence HUMAN detections + high-confidence
148
- # HUMAN/VEHICLE classifications.
149
- if (
150
- top_detection_score > 0.2
151
- and top_classification_class
152
- in {Classification .HUMAN , Classification .VEHICLE }
153
- and top_classification_score > 0.5
154
- ):
155
- return Classification .HUMAN , top_classification_score , "classifier"
156
-
157
- if top_detection_class == Detection .VEHICLE :
158
- # Threshold #2a: mid-confidence VEHICLE detections + high-confidence HUMAN
159
- # classifications.
160
- if (
161
- top_detection_score > 0.2
162
- and top_classification_class == Classification .HUMAN
163
- and top_classification_score > 0.5
164
- ):
165
- return Classification .HUMAN , top_classification_score , "classifier"
166
-
167
- # Threshold #2b: high-confidence VEHICLE detections.
168
- if top_detection_score > 0.7 :
169
- return Classification .VEHICLE , top_detection_score , "detector"
170
-
171
- # Threshold #2c: mid-confidence VEHICLE detections + high-confidence VEHICLE
172
- # classifications.
173
- if (
174
- top_detection_score > 0.2
175
- and top_classification_class == Classification .VEHICLE
176
- and top_classification_score > 0.4
177
- ):
178
- return Classification .VEHICLE , top_classification_score , "classifier"
179
-
180
- # Threshold #3a: high-confidence BLANK "detections" + high-confidence BLANK
181
- # classifications.
182
- if (
183
- top_detection_score < 0.2
184
- and top_classification_class == Classification .BLANK
185
- and top_classification_score > 0.5
186
- ):
187
- return Classification .BLANK , top_classification_score , "classifier"
188
-
189
- # Threshold #3b: extra-high-confidence BLANK classifications.
190
- if (
191
- top_classification_class == Classification .BLANK
192
- and top_classification_score > 0.99
193
- ):
194
- return Classification .BLANK , top_classification_score , "classifier"
195
-
196
- if top_classification_class not in {
197
- Classification .BLANK ,
198
- Classification .HUMAN ,
199
- Classification .VEHICLE ,
200
- }:
201
- # Threshold #4a: extra-high-confidence ANIMAL classifications.
202
- if top_classification_score > 0.8 :
203
- return geofence_animal_classification (
204
- labels = classifications ["classes" ],
205
- scores = classifications ["scores" ],
206
- country = country ,
207
- admin1_region = admin1_region ,
208
- taxonomy_map = self .taxonomy_map ,
209
- geofence_map = self .geofence_map ,
210
- enable_geofence = self .enable_geofence ,
211
- )
212
-
213
- # Threshold #4b: high-confidence ANIMAL classifications + mid-confidence
214
- # ANIMAL detections.
215
- if (
216
- top_classification_score > 0.65
217
- and top_detection_class == Detection .ANIMAL
218
- and top_detection_score > 0.2
219
- ):
220
- return geofence_animal_classification (
221
- labels = classifications ["classes" ],
222
- scores = classifications ["scores" ],
223
- country = country ,
224
- admin1_region = admin1_region ,
225
- taxonomy_map = self .taxonomy_map ,
226
- geofence_map = self .geofence_map ,
227
- enable_geofence = self .enable_geofence ,
228
- )
229
-
230
- # Threshold #5a: high-confidence ANIMAL rollups.
231
- rollup = roll_up_labels_to_first_matching_level (
232
- labels = classifications ["classes" ],
233
- scores = classifications ["scores" ],
234
- country = country ,
235
- admin1_region = admin1_region ,
236
- target_taxonomy_levels = ["genus" , "family" , "order" , "class" , "kingdom" ],
237
- non_blank_threshold = 0.65 ,
238
- taxonomy_map = self .taxonomy_map ,
239
- geofence_map = self .geofence_map ,
240
- enable_geofence = self .enable_geofence ,
241
- )
242
- if rollup :
243
- return rollup
244
-
245
- # Threshold #5b: mid-confidence ANIMAL detections.
246
- if top_detection_class == Detection .ANIMAL and top_detection_score > 0.5 :
247
- return Classification .ANIMAL , top_detection_score , "detector"
248
-
249
- return Classification .UNKNOWN , top_classification_score , "classifier"
250
-
251
110
def combine ( # pylint: disable=too-many-positional-arguments
252
111
self ,
253
112
filepaths : list [str ],
@@ -331,11 +190,16 @@ def combine( # pylint: disable=too-many-positional-arguments
331
190
332
191
# Most importantly, ensemble everything into a single prediction.
333
192
if classifications is not None and detections is not None :
334
- prediction , score , source = self ._combine_predictions_for_single_item (
193
+ prediction , score , source = self .prediction_combiner (
335
194
classifications = classifications ,
336
195
detections = detections ,
337
196
country = geolocation .get ("country" ),
338
197
admin1_region = geolocation .get ("admin1_region" ),
198
+ taxonomy_map = self .taxonomy_map ,
199
+ geofence_map = self .geofence_map ,
200
+ enable_geofence = self .enable_geofence ,
201
+ geofence_fn = geofence_animal_classification ,
202
+ roll_up_fn = roll_up_labels_to_first_matching_level ,
339
203
)
340
204
result ["prediction" ] = (
341
205
prediction .value
0 commit comments