From 2427d3c94fb34175a5e9719e52ca52c5cc7aa026 Mon Sep 17 00:00:00 2001 From: Joe Weiss Date: Mon, 22 Apr 2024 12:52:24 -0400 Subject: [PATCH] Add ability to extract feature embeddings (#112) * Add ability to extract embeddings * Add docs for recording.extract_embeddings() * Raise not implemented error for large recording object * Increment for 0.17.0 * Reduce test tolerance to 4 decimal points * Fix issue with LargeRecording extractions --- docs/api.md | 15 +++++ pyproject.toml | 2 +- src/birdnetlib/analyzer.py | 25 +++++++- src/birdnetlib/main.py | 31 +++++++++- tests/test_embeddings.py | 118 +++++++++++++++++++++++++++++++++++++ 5 files changed, 188 insertions(+), 3 deletions(-) create mode 100644 tests/test_embeddings.py diff --git a/docs/api.md b/docs/api.md index 14d6209..8eacfeb 100644 --- a/docs/api.md +++ b/docs/api.md @@ -67,7 +67,22 @@ recording.analyze() print(recording.detections) ``` +#### Embeddings +To extract feature embeddings instead of class predictions, use the `extract_embeddings` method. + +```python +from birdnetlib import Recording +from birdnetlib.analyzer import Analyzer + +analyzer = Analyzer() +recording = Recording( + analyzer, + "sample.mp3", +) +recording.extract_embeddings() +print(recording.embeddings) +``` ### RecordingFileObject diff --git a/pyproject.toml b/pyproject.toml index ddde791..2361283 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ exclude = [ [project] name = "birdnetlib" -version = "0.16.0" +version = "0.17.0" authors = [ { name="Joe Weiss", email="joe.weiss@gmail.com" }, ] diff --git a/src/birdnetlib/analyzer.py b/src/birdnetlib/analyzer.py index 7dba198..363b3b0 100644 --- a/src/birdnetlib/analyzer.py +++ b/src/birdnetlib/analyzer.py @@ -85,6 +85,7 @@ def __init__( self.labels = [] self.results = [] + self.embeddings = [] self.custom_species_list = [] # Set model versions. @@ -362,6 +363,22 @@ def analyze_recording(self, recording): self.results = results recording.detection_list = self.detections + def extract_embeddings_for_recording(self, recording): + print("extract_embeddings_for_recording", recording.filename) + start = 0 + end = recording.sample_secs + results = [] + for sample in recording.chunks: + data = np.array([sample], dtype="float32") + e = self._return_embeddings(data)[0].tolist() + results.append({"start_time": start, "end_time": end, "embeddings": e}) + + # Increment start and end + start += recording.sample_secs - recording.overlap + end = start + recording.sample_secs + + self.embeddings = results + def load_model(self): print("load model", not self.use_custom_classifier) # Load TFLite model and allocate tensors. @@ -420,7 +437,13 @@ def _return_embeddings(self, data): self.input_layer_index, np.array(data, dtype="float32") ) self.interpreter.invoke() - features = self.interpreter.get_tensor(self.output_layer_index) + + # Embeddings uses custom classifier output layer index. + output_layer_index = self.output_layer_index + if not self.use_custom_classifier: + output_layer_index = output_layer_index - 1 + + features = self.interpreter.get_tensor(output_layer_index) return features def predict_with_custom_classifier(self, sample): diff --git a/src/birdnetlib/main.py b/src/birdnetlib/main.py index aa38261..63bbc19 100644 --- a/src/birdnetlib/main.py +++ b/src/birdnetlib/main.py @@ -35,6 +35,8 @@ def __init__( self.detections_dict = {} # Old format self.detection_list = [] self.analyzed = False + self.embeddings_extracted = False + self.embeddings_list = [] self.week_48 = week_48 self.date = date self.sensitivity = max(0.5, min(1.0 - (sensitivity - 1.0), 1.5)) @@ -70,6 +72,28 @@ def analyze(self): self.analyzer.analyze_recording(self) self.analyzed = True + def extract_embeddings(self): + # Check that analyzer is not LargeRecordingAnalyzer + if isinstance(self.analyzer, LargeRecordingAnalyzer): + raise IncompatibleAnalyzerError( + "LargeRecordingAnalyzer can only be used with the LargeRecording class" + ) + + # Read and analyze. + self.read_audio_data() + self.analyzer.extract_embeddings_for_recording(self) + self.embeddings_list = self.analyzer.embeddings + self.embeddings_extracted = True + + @property + def embeddings(self): + if not self.embeddings_extracted: + warnings.warn( + "'extract_embeddings' method has not been called. Call .extract_embeddings() before accessing embeddings.", + AnalyzerRuntimeWarning, + ) + return self.embeddings_list + @property def detections(self): if not self.analyzed: @@ -455,6 +479,11 @@ def analyze(self): self.analyzer.analyze_recording(self) self.analyzed = True + def extract_embeddings(self): + raise NotImplementedError( + "Extraction of embeddings is not yet implemented for LargeRecordingAnalyzer. Use Analyzer if possible." + ) + def get_extract_array(self, start_sec, end_sec): # Returns ndarray trimmed for start_sec:end_sec print(start_sec, end_sec) @@ -463,7 +492,7 @@ def get_extract_array(self, start_sec, end_sec): self.path, sr=sr, mono=True, - offset=start_sec / sr, + offset=start_sec, duration=(end_sec - start_sec), ) diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py new file mode 100644 index 0000000..4c72b6e --- /dev/null +++ b/tests/test_embeddings.py @@ -0,0 +1,118 @@ +from birdnetlib import Recording, LargeRecording +from birdnetlib.analyzer import Analyzer, LargeRecordingAnalyzer + +from pprint import pprint +import pytest +import os +import tempfile +import csv +from unittest.mock import patch +import numpy as np + + +def test_embeddings(): + # Process file with command line utility, then process with python library and ensure equal commandline_results. + + lon = -120.7463 + lat = 35.4244 + week_48 = 18 + min_conf = 0.25 + input_path = os.path.join(os.path.dirname(__file__), "test_files/soundscape.wav") + + tf = tempfile.NamedTemporaryFile(suffix=".csv") + output_path = tf.name + + # Process using python script as is. + birdnet_analyzer_path = os.path.join(os.path.dirname(__file__), "BirdNET-Analyzer") + + cmd = f"python embeddings.py --i '{input_path}' --o={output_path}" + os.system(f"cd {birdnet_analyzer_path}; {cmd}") + + with open(tf.name, newline="") as tsvfile: + tsvreader = csv.reader(tsvfile, delimiter="\t") + commandline_results = [] + for row in tsvreader: + commandline_results.append( + { + "start_time": float(row[0]), + "end_time": float(row[1]), + "embeddings": [float(i) for i in row[2].split(",")], + } + ) + + # pprint(commandline_results) + assert len(commandline_results) == 40 + + analyzer = Analyzer() + recording = Recording( + analyzer, + input_path, + lat=lat, + lon=lon, + week_48=week_48, + min_conf=min_conf, + return_all_detections=True, + ) + recording.extract_embeddings() + + # Check that birdnetlib results match command line results. + assert len(recording.embeddings) == 40 + + for idx, i in enumerate(commandline_results): + # Specify the tolerance level + tolerance = 1e-4 # 4 decimal points tolerance between BirdNET and birdnetlib. + + # Assert that the arrays are almost equal within the tolerance + assert np.allclose( + i["embeddings"], recording.embeddings[idx]["embeddings"], atol=tolerance + ) + + +def test_largefile_embeddings(): + # Process file with command line utility, then process with python library and ensure equal commandline_results. + + lon = -120.7463 + lat = 35.4244 + week_48 = 18 + min_conf = 0.25 + input_path = os.path.join(os.path.dirname(__file__), "test_files/soundscape.wav") + + tf = tempfile.NamedTemporaryFile(suffix=".csv") + output_path = tf.name + + # Process using python script as is. + birdnet_analyzer_path = os.path.join(os.path.dirname(__file__), "BirdNET-Analyzer") + + cmd = f"python embeddings.py --i '{input_path}' --o={output_path}" + os.system(f"cd {birdnet_analyzer_path}; {cmd}") + + with open(tf.name, newline="") as tsvfile: + tsvreader = csv.reader(tsvfile, delimiter="\t") + commandline_results = [] + for row in tsvreader: + commandline_results.append( + { + "start_time": float(row[0]), + "end_time": float(row[1]), + "embeddings": [float(i) for i in row[2].split(",")], + } + ) + + # pprint(commandline_results) + assert len(commandline_results) == 40 + + # TODO: Implement for LargeRecording. + # Confirm that LargeRecording return not implemented. + large_analyzer = LargeRecordingAnalyzer() + recording = LargeRecording( + large_analyzer, + input_path, + lat=lat, + lon=lon, + week_48=week_48, + min_conf=min_conf, + return_all_detections=True, + ) + msg = "Extraction of embeddings is not yet implemented for LargeRecordingAnalyzer. Use Analyzer if possible." + with pytest.raises(NotImplementedError, match=msg): + recording.extract_embeddings()