From d8393785fa86aff33c96ccc86b1f54e2ef68bad9 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 9 Oct 2024 13:57:29 +0800 Subject: [PATCH] Python API for speaker diarization. --- .github/scripts/test-python.sh | 15 +++ .github/workflows/windows-x64.yaml | 2 +- .github/workflows/windows-x86.yaml | 2 +- .../offline-speaker-diarization.py | 118 ++++++++++++++++++ ...ffline-speaker-diarization-pyannote-impl.h | 2 +- .../csrc/offline-speaker-diarization-result.h | 2 + .../csrc/offline-speaker-diarization.h | 7 +- sherpa-onnx/python/csrc/CMakeLists.txt | 2 + .../offline-speaker-diarization-result.cc | 32 +++++ .../csrc/offline-speaker-diarization-result.h | 16 +++ .../csrc/offline-speaker-diarization.cc | 92 ++++++++++++++ .../python/csrc/offline-speaker-diarization.h | 16 +++ sherpa-onnx/python/csrc/sherpa-onnx.cc | 12 +- sherpa-onnx/python/sherpa_onnx/__init__.py | 6 + 14 files changed, 315 insertions(+), 9 deletions(-) create mode 100755 python-api-examples/offline-speaker-diarization.py create mode 100644 sherpa-onnx/python/csrc/offline-speaker-diarization-result.cc create mode 100644 sherpa-onnx/python/csrc/offline-speaker-diarization-result.h create mode 100644 sherpa-onnx/python/csrc/offline-speaker-diarization.cc create mode 100644 sherpa-onnx/python/csrc/offline-speaker-diarization.h diff --git a/.github/scripts/test-python.sh b/.github/scripts/test-python.sh index de7297f2c..8c9d303b0 100755 --- a/.github/scripts/test-python.sh +++ b/.github/scripts/test-python.sh @@ -8,6 +8,21 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } +log "test offline speaker diarization" + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 +tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 +rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx + +curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav + +python3 ./python-api-examples/offline-speaker-diarization.py + +rm -rf *.wav *.onnx ./sherpa-onnx-pyannote-segmentation-3-0 + + log "test_clustering" pushd /tmp/ mkdir test-cluster diff --git a/.github/workflows/windows-x64.yaml b/.github/workflows/windows-x64.yaml index c67f3e0b5..758935918 100644 --- a/.github/workflows/windows-x64.yaml +++ b/.github/workflows/windows-x64.yaml @@ -93,7 +93,7 @@ jobs: shell: bash run: | du -h -d1 . - export PATH=$PWD/build/bin:$PATH + export PATH=$PWD/build/bin/Release:$PATH export EXE=sherpa-onnx-offline-speaker-diarization.exe .github/scripts/test-speaker-diarization.sh diff --git a/.github/workflows/windows-x86.yaml b/.github/workflows/windows-x86.yaml index 30394e90e..b9e473184 100644 --- a/.github/workflows/windows-x86.yaml +++ b/.github/workflows/windows-x86.yaml @@ -93,7 +93,7 @@ jobs: shell: bash run: | du -h -d1 . - export PATH=$PWD/build/bin:$PATH + export PATH=$PWD/build/bin/Release:$PATH export EXE=sherpa-onnx-offline-speaker-diarization.exe .github/scripts/test-speaker-diarization.sh diff --git a/python-api-examples/offline-speaker-diarization.py b/python-api-examples/offline-speaker-diarization.py new file mode 100755 index 000000000..3e3ff1618 --- /dev/null +++ b/python-api-examples/offline-speaker-diarization.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 +# Copyright (c) 2024 Xiaomi Corporation + +""" +This file shows how to use sherpa-onnx Python API for +offline/non-streaming speaker diarization. + +Usage: + +Step 1: Download a speaker segmentation model + +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models +for a list of available models. The following is an example + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + +Step 2: Download a speaker embedding extractor model + +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models +for a list of available models. The following is an example + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx + +Step 3. Download test wave files + +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models +for a list of available test wave files. The following is an example + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav + +Step 4. Run it + + python3 ./python-api-examples/offline-speaker-diarization.py + +""" +from pathlib import Path + +import sherpa_onnx +import soundfile as sf + + +def init_speaker_diarization(num_speakers: int = -1, cluster_threshold: float = 0.5): + """ + Args: + num_speakers: + If you know the actual number of speakers in the wave file, then please + specify it. Otherwise, leave it to -1 + cluster_threshold: + If num_speakers is -1, then this threshold is used for clustering. + A smaller cluster_threshold leads to more clusters, i.e., more speakers. + A larger cluster_threshold leads to fewer clusters, i.e., fewer speakers. + """ + segmentation_model = "./sherpa-onnx-pyannote-segmentation-3-0/model.onnx" + embedding_extractor_model = ( + "./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx" + ) + + config = sherpa_onnx.OfflineSpeakerDiarizationConfig( + segmentation=sherpa_onnx.OfflineSpeakerSegmentationModelConfig( + pyannote=sherpa_onnx.OfflineSpeakerSegmentationPyannoteModelConfig( + model=segmentation_model + ), + ), + embedding=sherpa_onnx.SpeakerEmbeddingExtractorConfig( + model=embedding_extractor_model + ), + clustering=sherpa_onnx.FastClusteringConfig( + num_clusters=num_speakers, threshold=cluster_threshold + ), + min_duration_on=0.3, + min_duration_off=0.5, + ) + if not config.validate(): + raise RuntimeError( + "Please check your config and make sure all required files exist" + ) + + return sherpa_onnx.OfflineSpeakerDiarization(config) + + +def progress_callback(num_processed_chunk: int, num_total_chunks: int) -> int: + progress = num_processed_chunk / num_total_chunks * 100 + print(f"Progress: {progress:.3f}%") + return 0 + + +def main(): + wave_filename = "./0-four-speakers-zh.wav" + if not Path(wave_filename).is_file(): + raise RuntimeError(f"{wave_filename} does not exist") + + audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True) + audio = audio[:, 0] # only use the first channel + + # Since we know there are 4 speakers in the above test wave file, we use + # num_speakers 4 here + sd = init_speaker_diarization(num_speakers=4) + if sample_rate != sd.sample_rate: + raise RuntimeError( + f"Expected samples rate: {sd.sample_rate}, given: {sample_rate}" + ) + + show_porgress = True + + if show_porgress: + result = sd.process(audio, callback=progress_callback).sort_by_start_time() + else: + result = sd.process(audio).sort_by_start_time() + + for r in result: + print(f"{r.start:.3f} -- {r.end:.3f} speaker_{r.speaker:02}") + # print(r) # this one is simpler + + +if __name__ == "__main__": + main() diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h index bcd0c93a4..64b087c00 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h @@ -103,7 +103,7 @@ class OfflineSpeakerDiarizationPyannoteImpl auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels); Matrix2D embeddings = ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second, - callback, callback_arg); + std::move(callback), callback_arg); std::vector cluster_labels = clustering_.Cluster( &embeddings(0, 0), embeddings.rows(), embeddings.cols()); diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-result.h b/sherpa-onnx/csrc/offline-speaker-diarization-result.h index e71d054e5..5fb144f5c 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-result.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-result.h @@ -28,6 +28,8 @@ class OfflineSpeakerDiarizationSegment { const std::string &Text() const { return text_; } float Duration() const { return end_ - start_; } + void SetText(const std::string &text) { text_ = text; } + std::string ToString() const; private: diff --git a/sherpa-onnx/csrc/offline-speaker-diarization.h b/sherpa-onnx/csrc/offline-speaker-diarization.h index ab9a440aa..e5d02c473 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization.h @@ -34,10 +34,13 @@ struct OfflineSpeakerDiarizationConfig { OfflineSpeakerDiarizationConfig( const OfflineSpeakerSegmentationModelConfig &segmentation, const SpeakerEmbeddingExtractorConfig &embedding, - const FastClusteringConfig &clustering) + const FastClusteringConfig &clustering, float min_duration_on, + float min_duration_off) : segmentation(segmentation), embedding(embedding), - clustering(clustering) {} + clustering(clustering), + min_duration_on(min_duration_on), + min_duration_off(min_duration_off) {} void Register(ParseOptions *po); bool Validate() const; diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index 7fd5efa33..2e971581a 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -62,6 +62,8 @@ endif() if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) list(APPEND srcs fast-clustering.cc + offline-speaker-diarization-result.cc + offline-speaker-diarization.cc ) endif() diff --git a/sherpa-onnx/python/csrc/offline-speaker-diarization-result.cc b/sherpa-onnx/python/csrc/offline-speaker-diarization-result.cc new file mode 100644 index 000000000..d058c26a2 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-speaker-diarization-result.cc @@ -0,0 +1,32 @@ +// sherpa-onnx/python/csrc/offline-speaker-diarization-result.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/offline-speaker-diarization-result.h" + +#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h" + +namespace sherpa_onnx { + +static void PybindOfflineSpeakerDiarizationSegment(py::module *m) { + using PyClass = OfflineSpeakerDiarizationSegment; + py::class_(*m, "OfflineSpeakerDiarizationSegment") + .def_property_readonly("start", &PyClass::Start) + .def_property_readonly("end", &PyClass::End) + .def_property_readonly("duration", &PyClass::Duration) + .def_property_readonly("speaker", &PyClass::Speaker) + .def_property("text", &PyClass::Text, &PyClass::SetText) + .def("__str__", &PyClass::ToString); +} + +void PybindOfflineSpeakerDiarizationResult(py::module *m) { + PybindOfflineSpeakerDiarizationSegment(m); + using PyClass = OfflineSpeakerDiarizationResult; + py::class_(*m, "OfflineSpeakerDiarizationResult") + .def_property_readonly("num_speakers", &PyClass::NumSpeakers) + .def_property_readonly("num_segments", &PyClass::NumSegments) + .def("sort_by_start_time", &PyClass::SortByStartTime) + .def("sort_by_speaker", &PyClass::SortBySpeaker); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-speaker-diarization-result.h b/sherpa-onnx/python/csrc/offline-speaker-diarization-result.h new file mode 100644 index 000000000..2c11e4073 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-speaker-diarization-result.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/offline-speaker-diarization-result.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOfflineSpeakerDiarizationResult(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_ diff --git a/sherpa-onnx/python/csrc/offline-speaker-diarization.cc b/sherpa-onnx/python/csrc/offline-speaker-diarization.cc new file mode 100644 index 000000000..c77979b3c --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-speaker-diarization.cc @@ -0,0 +1,92 @@ +// sherpa-onnx/python/csrc/offline-speaker-diarization.cc +// +// Copyright (c) 2024 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/offline-speaker-diarization.h" + +#include +#include + +#include "sherpa-onnx/csrc/offline-speaker-diarization.h" +#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h" +#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h" + +namespace sherpa_onnx { + +static void PybindOfflineSpeakerSegmentationPyannoteModelConfig(py::module *m) { + using PyClass = OfflineSpeakerSegmentationPyannoteModelConfig; + py::class_(*m, "OfflineSpeakerSegmentationPyannoteModelConfig") + .def(py::init<>()) + .def(py::init(), py::arg("model")) + .def_readwrite("model", &PyClass::model) + .def("__str__", &PyClass::ToString) + .def("validate", &PyClass::Validate); +} + +static void PybindOfflineSpeakerSegmentationModelConfig(py::module *m) { + PybindOfflineSpeakerSegmentationPyannoteModelConfig(m); + + using PyClass = OfflineSpeakerSegmentationModelConfig; + py::class_(*m, "OfflineSpeakerSegmentationModelConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("pyannote"), py::arg("num_threads") = 1, + py::arg("debug") = false, py::arg("provider") = "cpu") + .def_readwrite("pyannote", &PyClass::pyannote) + .def_readwrite("num_threads", &PyClass::num_threads) + .def_readwrite("debug", &PyClass::debug) + .def_readwrite("provider", &PyClass::provider) + .def("__str__", &PyClass::ToString) + .def("validate", &PyClass::Validate); +} + +static void PybindOfflineSpeakerDiarizationConfig(py::module *m) { + PybindOfflineSpeakerSegmentationModelConfig(m); + + using PyClass = OfflineSpeakerDiarizationConfig; + py::class_(*m, "OfflineSpeakerDiarizationConfig") + .def(py::init(), + py::arg("segmentation"), py::arg("embedding"), py::arg("clustering"), + py::arg("min_duration_on") = 0.3, py::arg("min_duration_off") = 0.5) + .def_readwrite("segmentation", &PyClass::segmentation) + .def_readwrite("embedding", &PyClass::embedding) + .def_readwrite("clustering", &PyClass::clustering) + .def_readwrite("min_duration_on", &PyClass::min_duration_on) + .def_readwrite("min_duration_off", &PyClass::min_duration_off) + .def("__str__", &PyClass::ToString) + .def("validate", &PyClass::Validate); +} + +void PybindOfflineSpeakerDiarization(py::module *m) { + PybindOfflineSpeakerDiarizationConfig(m); + + using PyClass = OfflineSpeakerDiarization; + py::class_(*m, "OfflineSpeakerDiarization") + .def(py::init(), + py::arg("config")) + .def_property_readonly("sample_rate", &PyClass::SampleRate) + .def( + "process", + [](const PyClass &self, const std::vector samples, + std::function callback) { + if (!callback) { + return self.Process(samples.data(), samples.size()); + } + + std::function callback_wrapper = + [callback](int32_t processed_chunks, int32_t num_chunks, + void *) -> int32_t { + callback(processed_chunks, num_chunks); + return 0; + }; + + return self.Process(samples.data(), samples.size(), + callback_wrapper); + }, + py::arg("samples"), py::arg("callback") = py::none()); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-speaker-diarization.h b/sherpa-onnx/python/csrc/offline-speaker-diarization.h new file mode 100644 index 000000000..523343062 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-speaker-diarization.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/offline-speaker-diarization.h +// +// Copyright (c) 2024 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOfflineSpeakerDiarization(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_ diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc index f668d626c..c73022f17 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -37,6 +37,8 @@ #if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1 #include "sherpa-onnx/python/csrc/fast-clustering.h" +#include "sherpa-onnx/python/csrc/offline-speaker-diarization-result.h" +#include "sherpa-onnx/python/csrc/offline-speaker-diarization.h" #endif namespace sherpa_onnx { @@ -74,14 +76,16 @@ PYBIND11_MODULE(_sherpa_onnx, m) { PybindOfflineTts(&m); #endif -#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1 - PybindFastClustering(&m); -#endif - PybindSpeakerEmbeddingExtractor(&m); PybindSpeakerEmbeddingManager(&m); PybindSpokenLanguageIdentification(&m); +#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1 + PybindFastClustering(&m); + PybindOfflineSpeakerDiarizationResult(&m); + PybindOfflineSpeakerDiarization(&m); +#endif + PybindAlsa(&m); } diff --git a/sherpa-onnx/python/sherpa_onnx/__init__.py b/sherpa-onnx/python/sherpa_onnx/__init__.py index 3568447b3..2d5e456dc 100644 --- a/sherpa-onnx/python/sherpa_onnx/__init__.py +++ b/sherpa-onnx/python/sherpa_onnx/__init__.py @@ -11,6 +11,12 @@ OfflinePunctuation, OfflinePunctuationConfig, OfflinePunctuationModelConfig, + OfflineSpeakerDiarization, + OfflineSpeakerDiarizationConfig, + OfflineSpeakerDiarizationResult, + OfflineSpeakerDiarizationSegment, + OfflineSpeakerSegmentationModelConfig, + OfflineSpeakerSegmentationPyannoteModelConfig, OfflineStream, OfflineTts, OfflineTtsConfig,