diff --git a/src/birdnetlib/analyzer.py b/src/birdnetlib/analyzer.py index 16716fb..ece315c 100644 --- a/src/birdnetlib/analyzer.py +++ b/src/birdnetlib/analyzer.py @@ -53,6 +53,7 @@ def __init__( custom_species_list=None, classifier_model_path=None, classifier_labels_path=None, + fetch_embeddings=None, ): self.name = "Analyzer" self.model_name = "BirdNET-Analyzer" @@ -71,7 +72,12 @@ def __init__( self.labels = [] self.results = [] self.custom_species_list = [] - + + if fetch_embeddings: + self.fetch_embeddings = True + else: + self.fetch_embeddings = False + self.classifier_model_path = classifier_model_path self.classifier_labels_path = classifier_labels_path self.use_custom_classifier = ( @@ -200,10 +206,15 @@ def analyze_recording(self, recording): start = 0 end = recording.sample_secs results = {} + features = [] for c in recording.chunks: if self.use_custom_classifier: pred = self.predict_with_custom_classifier(c)[0] + elif self.fetch_embeddings == True: + feature = self.predict_with_custom_classifier(c) + features.append(feature) + continue else: pred = self.predict(c)[0] @@ -226,7 +237,9 @@ def analyze_recording(self, recording): end = start + recording.sample_secs self.results = results + self.features_list = features recording.detection_list = self.detections + recording.features_list = self.features_list def load_model(self): print("load model", not self.use_custom_classifier) @@ -246,7 +259,7 @@ def load_model(self): self.input_layer_index = self.input_details[0]["index"] # Get classification output or feature embeddings - if self.use_custom_classifier: + if self.use_custom_classifier or self.fetch_embeddings: self.output_layer_index = self.output_details[0]["index"] - 1 else: self.output_layer_index = self.output_details[0]["index"] @@ -299,7 +312,9 @@ def predict_with_custom_classifier(self, sample): features = INTERPRETER.get_tensor(OUTPUT_LAYER_INDEX) feature_vector = features - + if self.fetch_embeddings: + return feature_vector + C_INTERPRETER = self.custom_interpreter C_INPUT_LAYER_INDEX = self.custom_input_layer_index C_OUTPUT_LAYER_INDEX = self.custom_output_layer_index diff --git a/src/birdnetlib/main.py b/src/birdnetlib/main.py index 338e647..e5bb7a9 100644 --- a/src/birdnetlib/main.py +++ b/src/birdnetlib/main.py @@ -29,6 +29,7 @@ def __init__( self.analyzer = analyzer self.detections_dict = {} # Old format self.detection_list = [] + self.features_list = [] self.analyzed = False self.week_48 = week_48 self.date = date @@ -94,6 +95,9 @@ def detections(self): qualified_detections.append(detection) return qualified_detections + @property + def features(self): + return self.features_list @property def as_dict(self): @@ -299,4 +303,4 @@ def as_dict(self): if __name__ == "__main__": - pass \ No newline at end of file + pass