Skip to content

Commit 54bc504

Browse files
authored
Add Python API example for CED audio tagging. (#793)
1 parent c1608b3 commit 54bc504

File tree

2 files changed

+122
-3
lines changed

2 files changed

+122
-3
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
#!/usr/bin/env python3
2+
3+
"""
4+
This script shows how to use audio tagging Python APIs to tag a file.
5+
6+
Please read the code to download the required model files and test wave file.
7+
"""
8+
9+
import logging
10+
import time
11+
from pathlib import Path
12+
13+
import numpy as np
14+
import sherpa_onnx
15+
import soundfile as sf
16+
17+
18+
def read_test_wave():
19+
# Please download the model files and test wave files from
20+
# https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models
21+
test_wave = "./sherpa-onnx-ced-mini-audio-tagging-2024-04-19/test_wavs/6.wav"
22+
23+
if not Path(test_wave).is_file():
24+
raise ValueError(
25+
f"Please download {test_wave} from "
26+
"https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
27+
)
28+
29+
# See https://python-soundfile.readthedocs.io/en/0.11.0/#soundfile.read
30+
data, sample_rate = sf.read(
31+
test_wave,
32+
always_2d=True,
33+
dtype="float32",
34+
)
35+
data = data[:, 0] # use only the first channel
36+
samples = np.ascontiguousarray(data)
37+
38+
# samples is a 1-d array of dtype float32
39+
# sample_rate is a scalar
40+
return samples, sample_rate
41+
42+
43+
def create_audio_tagger():
44+
# Please download the model files and test wave files from
45+
# https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models
46+
model_file = "./sherpa-onnx-ced-mini-audio-tagging-2024-04-19/model.int8.onnx"
47+
label_file = (
48+
"./sherpa-onnx-ced-mini-audio-tagging-2024-04-19/class_labels_indices.csv"
49+
)
50+
51+
if not Path(model_file).is_file():
52+
raise ValueError(
53+
f"Please download {model_file} from "
54+
"https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
55+
)
56+
57+
if not Path(label_file).is_file():
58+
raise ValueError(
59+
f"Please download {label_file} from "
60+
"https://github.com/k2-fsa/sherpa-onnx/releases/tag/audio-tagging-models"
61+
)
62+
63+
config = sherpa_onnx.AudioTaggingConfig(
64+
model=sherpa_onnx.AudioTaggingModelConfig(
65+
ced=model_file,
66+
num_threads=1,
67+
debug=True,
68+
provider="cpu",
69+
),
70+
labels=label_file,
71+
top_k=5,
72+
)
73+
if not config.validate():
74+
raise ValueError(f"Please check the config: {config}")
75+
76+
print(config)
77+
78+
return sherpa_onnx.AudioTagging(config)
79+
80+
81+
def main():
82+
logging.info("Create audio tagger")
83+
audio_tagger = create_audio_tagger()
84+
85+
logging.info("Read test wave")
86+
samples, sample_rate = read_test_wave()
87+
88+
logging.info("Computing")
89+
90+
start_time = time.time()
91+
92+
stream = audio_tagger.create_stream()
93+
stream.accept_waveform(sample_rate=sample_rate, waveform=samples)
94+
result = audio_tagger.compute(stream)
95+
end_time = time.time()
96+
97+
elapsed_seconds = end_time - start_time
98+
audio_duration = len(samples) / sample_rate
99+
100+
real_time_factor = elapsed_seconds / audio_duration
101+
logging.info(f"Elapsed seconds: {elapsed_seconds:.3f}")
102+
logging.info(f"Audio duration in seconds: {audio_duration:.3f}")
103+
logging.info(
104+
f"RTF: {elapsed_seconds:.3f}/{audio_duration:.3f} = {real_time_factor:.3f}"
105+
)
106+
107+
s = "\n"
108+
for i, e in enumerate(result):
109+
s += f"{i}: {e}\n"
110+
111+
logging.info(s)
112+
113+
114+
if __name__ == "__main__":
115+
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
116+
117+
logging.basicConfig(format=formatter, level=logging.INFO)
118+
119+
main()

sherpa-onnx/python/csrc/audio-tagging.cc

+3-3
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ static void PybindAudioTaggingModelConfig(py::module *m) {
2929
.def(py::init<>())
3030
.def(py::init<const OfflineZipformerAudioTaggingModelConfig &,
3131
const std::string &, int32_t, bool, const std::string &>(),
32-
py::arg("zipformer"), py::arg("ced") = "",
33-
py::arg("num_threads") = 1, py::arg("debug") = false,
34-
py::arg("provider") = "cpu")
32+
py::arg("zipformer") = OfflineZipformerAudioTaggingModelConfig{},
33+
py::arg("ced") = "", py::arg("num_threads") = 1,
34+
py::arg("debug") = false, py::arg("provider") = "cpu")
3535
.def_readwrite("zipformer", &PyClass::zipformer)
3636
.def_readwrite("num_threads", &PyClass::num_threads)
3737
.def_readwrite("debug", &PyClass::debug)

0 commit comments

Comments
 (0)