Skip to content

Commit

Permalink
Add ability to extract feature embeddings (#112)
Browse files Browse the repository at this point in the history
* Add ability to extract embeddings

* Add docs for recording.extract_embeddings()

* Raise not implemented error for large recording object

* Increment for 0.17.0

* Reduce test tolerance to 4 decimal points

* Fix issue with LargeRecording extractions
  • Loading branch information
joeweiss authored Apr 22, 2024
1 parent d1e9901 commit 2427d3c
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 3 deletions.
15 changes: 15 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,22 @@ recording.analyze()
print(recording.detections)
```

#### Embeddings

To extract feature embeddings instead of class predictions, use the `extract_embeddings` method.

```python
from birdnetlib import Recording
from birdnetlib.analyzer import Analyzer

analyzer = Analyzer()
recording = Recording(
analyzer,
"sample.mp3",
)
recording.extract_embeddings()
print(recording.embeddings)
```

### RecordingFileObject

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ exclude = [

[project]
name = "birdnetlib"
version = "0.16.0"
version = "0.17.0"
authors = [
{ name="Joe Weiss", email="[email protected]" },
]
Expand Down
25 changes: 24 additions & 1 deletion src/birdnetlib/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(

self.labels = []
self.results = []
self.embeddings = []
self.custom_species_list = []

# Set model versions.
Expand Down Expand Up @@ -362,6 +363,22 @@ def analyze_recording(self, recording):
self.results = results
recording.detection_list = self.detections

def extract_embeddings_for_recording(self, recording):
print("extract_embeddings_for_recording", recording.filename)
start = 0
end = recording.sample_secs
results = []
for sample in recording.chunks:
data = np.array([sample], dtype="float32")
e = self._return_embeddings(data)[0].tolist()
results.append({"start_time": start, "end_time": end, "embeddings": e})

# Increment start and end
start += recording.sample_secs - recording.overlap
end = start + recording.sample_secs

self.embeddings = results

def load_model(self):
print("load model", not self.use_custom_classifier)
# Load TFLite model and allocate tensors.
Expand Down Expand Up @@ -420,7 +437,13 @@ def _return_embeddings(self, data):
self.input_layer_index, np.array(data, dtype="float32")
)
self.interpreter.invoke()
features = self.interpreter.get_tensor(self.output_layer_index)

# Embeddings uses custom classifier output layer index.
output_layer_index = self.output_layer_index
if not self.use_custom_classifier:
output_layer_index = output_layer_index - 1

features = self.interpreter.get_tensor(output_layer_index)
return features

def predict_with_custom_classifier(self, sample):
Expand Down
31 changes: 30 additions & 1 deletion src/birdnetlib/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def __init__(
self.detections_dict = {} # Old format
self.detection_list = []
self.analyzed = False
self.embeddings_extracted = False
self.embeddings_list = []
self.week_48 = week_48
self.date = date
self.sensitivity = max(0.5, min(1.0 - (sensitivity - 1.0), 1.5))
Expand Down Expand Up @@ -70,6 +72,28 @@ def analyze(self):
self.analyzer.analyze_recording(self)
self.analyzed = True

def extract_embeddings(self):
# Check that analyzer is not LargeRecordingAnalyzer
if isinstance(self.analyzer, LargeRecordingAnalyzer):
raise IncompatibleAnalyzerError(
"LargeRecordingAnalyzer can only be used with the LargeRecording class"
)

# Read and analyze.
self.read_audio_data()
self.analyzer.extract_embeddings_for_recording(self)
self.embeddings_list = self.analyzer.embeddings
self.embeddings_extracted = True

@property
def embeddings(self):
if not self.embeddings_extracted:
warnings.warn(
"'extract_embeddings' method has not been called. Call .extract_embeddings() before accessing embeddings.",
AnalyzerRuntimeWarning,
)
return self.embeddings_list

@property
def detections(self):
if not self.analyzed:
Expand Down Expand Up @@ -455,6 +479,11 @@ def analyze(self):
self.analyzer.analyze_recording(self)
self.analyzed = True

def extract_embeddings(self):
raise NotImplementedError(
"Extraction of embeddings is not yet implemented for LargeRecordingAnalyzer. Use Analyzer if possible."
)

def get_extract_array(self, start_sec, end_sec):
# Returns ndarray trimmed for start_sec:end_sec
print(start_sec, end_sec)
Expand All @@ -463,7 +492,7 @@ def get_extract_array(self, start_sec, end_sec):
self.path,
sr=sr,
mono=True,
offset=start_sec / sr,
offset=start_sec,
duration=(end_sec - start_sec),
)

Expand Down
118 changes: 118 additions & 0 deletions tests/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from birdnetlib import Recording, LargeRecording
from birdnetlib.analyzer import Analyzer, LargeRecordingAnalyzer

from pprint import pprint
import pytest
import os
import tempfile
import csv
from unittest.mock import patch
import numpy as np


def test_embeddings():
# Process file with command line utility, then process with python library and ensure equal commandline_results.

lon = -120.7463
lat = 35.4244
week_48 = 18
min_conf = 0.25
input_path = os.path.join(os.path.dirname(__file__), "test_files/soundscape.wav")

tf = tempfile.NamedTemporaryFile(suffix=".csv")
output_path = tf.name

# Process using python script as is.
birdnet_analyzer_path = os.path.join(os.path.dirname(__file__), "BirdNET-Analyzer")

cmd = f"python embeddings.py --i '{input_path}' --o={output_path}"
os.system(f"cd {birdnet_analyzer_path}; {cmd}")

with open(tf.name, newline="") as tsvfile:
tsvreader = csv.reader(tsvfile, delimiter="\t")
commandline_results = []
for row in tsvreader:
commandline_results.append(
{
"start_time": float(row[0]),
"end_time": float(row[1]),
"embeddings": [float(i) for i in row[2].split(",")],
}
)

# pprint(commandline_results)
assert len(commandline_results) == 40

analyzer = Analyzer()
recording = Recording(
analyzer,
input_path,
lat=lat,
lon=lon,
week_48=week_48,
min_conf=min_conf,
return_all_detections=True,
)
recording.extract_embeddings()

# Check that birdnetlib results match command line results.
assert len(recording.embeddings) == 40

for idx, i in enumerate(commandline_results):
# Specify the tolerance level
tolerance = 1e-4 # 4 decimal points tolerance between BirdNET and birdnetlib.

# Assert that the arrays are almost equal within the tolerance
assert np.allclose(
i["embeddings"], recording.embeddings[idx]["embeddings"], atol=tolerance
)


def test_largefile_embeddings():
# Process file with command line utility, then process with python library and ensure equal commandline_results.

lon = -120.7463
lat = 35.4244
week_48 = 18
min_conf = 0.25
input_path = os.path.join(os.path.dirname(__file__), "test_files/soundscape.wav")

tf = tempfile.NamedTemporaryFile(suffix=".csv")
output_path = tf.name

# Process using python script as is.
birdnet_analyzer_path = os.path.join(os.path.dirname(__file__), "BirdNET-Analyzer")

cmd = f"python embeddings.py --i '{input_path}' --o={output_path}"
os.system(f"cd {birdnet_analyzer_path}; {cmd}")

with open(tf.name, newline="") as tsvfile:
tsvreader = csv.reader(tsvfile, delimiter="\t")
commandline_results = []
for row in tsvreader:
commandline_results.append(
{
"start_time": float(row[0]),
"end_time": float(row[1]),
"embeddings": [float(i) for i in row[2].split(",")],
}
)

# pprint(commandline_results)
assert len(commandline_results) == 40

# TODO: Implement for LargeRecording.
# Confirm that LargeRecording return not implemented.
large_analyzer = LargeRecordingAnalyzer()
recording = LargeRecording(
large_analyzer,
input_path,
lat=lat,
lon=lon,
week_48=week_48,
min_conf=min_conf,
return_all_detections=True,
)
msg = "Extraction of embeddings is not yet implemented for LargeRecordingAnalyzer. Use Analyzer if possible."
with pytest.raises(NotImplementedError, match=msg):
recording.extract_embeddings()

0 comments on commit 2427d3c

Please sign in to comment.