From d468527f6202b61d72b8ff81532ba7c94f5b0f96 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 9 Oct 2024 17:10:03 +0800 Subject: [PATCH] C API for speaker diarization (#1402) --- .gitignore | 1 + README.md | 11 +- c-api-examples/CMakeLists.txt | 5 + .../offline-speaker-diarization-c-api.c | 131 ++++++++++++++++ sherpa-onnx/c-api/c-api.cc | 145 ++++++++++++++++++ sherpa-onnx/c-api/c-api.h | 116 +++++++++++++- sherpa-onnx/csrc/fast-clustering-config.h | 4 +- .../csrc/offline-speaker-diarization.cc | 10 ++ ...sherpa-onnx-offline-speaker-diarization.cc | 2 +- 9 files changed, 418 insertions(+), 7 deletions(-) create mode 100644 c-api-examples/offline-speaker-diarization-c-api.c diff --git a/.gitignore b/.gitignore index b0fbfae78..7e6708be4 100644 --- a/.gitignore +++ b/.gitignore @@ -120,3 +120,4 @@ vits-melo-tts-zh_en sherpa-onnx-online-punct-en-2024-08-06 *.mp4 *.mp3 +sherpa-onnx-pyannote-segmentation-3-0 diff --git a/README.md b/README.md index 890abe882..2f318d3ef 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,12 @@ ### Supported functions -|Speech recognition| Speech synthesis | Speaker verification | Speaker identification | -|------------------|------------------|----------------------|------------------------| -| ✔️ | ✔️ | ✔️ | ✔️ | +|Speech recognition| Speech synthesis | +|------------------|------------------| +| ✔️ | ✔️ | + +|Speaker identification| Speaker diarization | Speaker identification | +|----------------------|-------------------- |------------------------| +| ✔️ | ✔️ | ✔️ | | Spoken Language identification | Audio tagging | Voice activity detection | |--------------------------------|---------------|--------------------------| @@ -47,6 +51,7 @@ This repository supports running the following functions **locally** - Speech-to-text (i.e., ASR); both streaming and non-streaming are supported - Text-to-speech (i.e., TTS) + - Speaker diarization - Speaker identification - Speaker verification - Spoken language identification diff --git a/c-api-examples/CMakeLists.txt b/c-api-examples/CMakeLists.txt index 0bf526450..45ca9a156 100644 --- a/c-api-examples/CMakeLists.txt +++ b/c-api-examples/CMakeLists.txt @@ -9,6 +9,11 @@ if(SHERPA_ONNX_ENABLE_TTS) target_link_libraries(offline-tts-c-api sherpa-onnx-c-api cargs) endif() +if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION) + add_executable(offline-speaker-diarization-c-api offline-speaker-diarization-c-api.c) + target_link_libraries(offline-speaker-diarization-c-api sherpa-onnx-c-api) +endif() + add_executable(spoken-language-identification-c-api spoken-language-identification-c-api.c) target_link_libraries(spoken-language-identification-c-api sherpa-onnx-c-api) diff --git a/c-api-examples/offline-speaker-diarization-c-api.c b/c-api-examples/offline-speaker-diarization-c-api.c new file mode 100644 index 000000000..d5a17dd0b --- /dev/null +++ b/c-api-examples/offline-speaker-diarization-c-api.c @@ -0,0 +1,131 @@ +// c-api-examples/offline-sepaker-diarization-c-api.c +// +// Copyright (c) 2024 Xiaomi Corporation + +// +// This file demonstrates how to implement speaker diarization with +// sherpa-onnx's C API. + +// clang-format off +/* +Usage: + +Step 1: Download a speaker segmentation model + +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models +for a list of available models. The following is an example + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2 + +Step 2: Download a speaker embedding extractor model + +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models +for a list of available models. The following is an example + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx + +Step 3. Download test wave files + +Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models +for a list of available test wave files. The following is an example + + wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav + +Step 4. Run it + + */ +// clang-format on + +#include +#include + +#include "sherpa-onnx/c-api/c-api.h" + +static int32_t ProgressCallback(int32_t num_processed_chunks, + int32_t num_total_chunks, void *arg) { + float progress = 100.0 * num_processed_chunks / num_total_chunks; + fprintf(stderr, "progress %.2f%%\n", progress); + + // the return value is currently ignored + return 0; +} + +int main() { + // Please see the comments at the start of this file for how to download + // the .onnx file and .wav files below + const char *segmentation_model = + "./sherpa-onnx-pyannote-segmentation-3-0/model.onnx"; + + const char *embedding_extractor_model = + "./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx"; + + const char *wav_filename = "./0-four-speakers-zh.wav"; + + const SherpaOnnxWave *wave = SherpaOnnxReadWave(wav_filename); + if (wave == NULL) { + fprintf(stderr, "Failed to read %s\n", wav_filename); + return -1; + } + + SherpaOnnxOfflineSpeakerDiarizationConfig config; + memset(&config, 0, sizeof(config)); + + config.segmentation.pyannote.model = segmentation_model; + config.embedding.model = embedding_extractor_model; + + // the test wave ./0-four-speakers-zh.wav has 4 speakers, so + // we set num_clusters to 4 + // + config.clustering.num_clusters = 4; + // If you don't know the number of speakers in the test wave file, please + // use + // config.clustering.threshold = 0.5; // You need to tune this threshold + + const SherpaOnnxOfflineSpeakerDiarization *sd = + SherpaOnnxCreateOfflineSpeakerDiarization(&config); + + if (!sd) { + fprintf(stderr, "Failed to initialize offline speaker diarization\n"); + return -1; + } + + if (SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(sd) != + wave->sample_rate) { + fprintf( + stderr, + "Expected sample rate: %d. Actual sample rate from the wave file: %d\n", + SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(sd), + wave->sample_rate); + goto failed; + } + + const SherpaOnnxOfflineSpeakerDiarizationResult *result = + SherpaOnnxOfflineSpeakerDiarizationProcessWithCallback( + sd, wave->samples, wave->num_samples, ProgressCallback, NULL); + if (!result) { + fprintf(stderr, "Failed to do speaker diarization"); + goto failed; + } + + int32_t num_segments = + SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments(result); + + const SherpaOnnxOfflineSpeakerDiarizationSegment *segments = + SherpaOnnxOfflineSpeakerDiarizationResultSortByStartTime(result); + + for (int32_t i = 0; i != num_segments; ++i) { + fprintf(stderr, "%.3f -- %.3f speaker_%02d\n", segments[i].start, + segments[i].end, segments[i].speaker); + } + +failed: + + SherpaOnnxOfflineSpeakerDiarizationDestroySegment(segments); + SherpaOnnxOfflineSpeakerDiarizationDestroyResult(result); + SherpaOnnxDestroyOfflineSpeakerDiarization(sd); + SherpaOnnxFreeWave(wave); + + return 0; +} diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index 176557c75..322c4f79e 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -31,6 +31,10 @@ #include "sherpa-onnx/csrc/offline-tts.h" #endif +#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1 +#include "sherpa-onnx/csrc/offline-speaker-diarization.h" +#endif + struct SherpaOnnxOnlineRecognizer { std::unique_ptr impl; }; @@ -1670,3 +1674,144 @@ void SherpaOnnxLinearResamplerReset(SherpaOnnxLinearResampler *p) { int32_t SherpaOnnxFileExists(const char *filename) { return sherpa_onnx::FileExists(filename); } + +#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1 + +struct SherpaOnnxOfflineSpeakerDiarization { + std::unique_ptr impl; +}; + +struct SherpaOnnxOfflineSpeakerDiarizationResult { + sherpa_onnx::OfflineSpeakerDiarizationResult impl; +}; + +const SherpaOnnxOfflineSpeakerDiarization * +SherpaOnnxCreateOfflineSpeakerDiarization( + const SherpaOnnxOfflineSpeakerDiarizationConfig *config) { + sherpa_onnx::OfflineSpeakerDiarizationConfig sd_config; + + sd_config.segmentation.pyannote.model = + SHERPA_ONNX_OR(config->segmentation.pyannote.model, ""); + sd_config.segmentation.num_threads = + SHERPA_ONNX_OR(config->segmentation.num_threads, 1); + sd_config.segmentation.debug = config->segmentation.debug; + sd_config.segmentation.provider = + SHERPA_ONNX_OR(config->segmentation.provider, "cpu"); + if (sd_config.segmentation.provider.empty()) { + sd_config.segmentation.provider = "cpu"; + } + + sd_config.embedding.model = SHERPA_ONNX_OR(config->embedding.model, ""); + sd_config.embedding.num_threads = + SHERPA_ONNX_OR(config->embedding.num_threads, 1); + sd_config.embedding.debug = config->embedding.debug; + sd_config.embedding.provider = + SHERPA_ONNX_OR(config->embedding.provider, "cpu"); + if (sd_config.embedding.provider.empty()) { + sd_config.embedding.provider = "cpu"; + } + + sd_config.clustering.num_clusters = + SHERPA_ONNX_OR(config->clustering.num_clusters, -1); + + sd_config.clustering.threshold = + SHERPA_ONNX_OR(config->clustering.threshold, 0.5); + + sd_config.min_duration_on = SHERPA_ONNX_OR(config->min_duration_on, 0.3); + + sd_config.min_duration_off = SHERPA_ONNX_OR(config->min_duration_off, 0.5); + + if (!sd_config.Validate()) { + SHERPA_ONNX_LOGE("Errors in config"); + return nullptr; + } + + SherpaOnnxOfflineSpeakerDiarization *sd = + new SherpaOnnxOfflineSpeakerDiarization; + + sd->impl = + std::make_unique(sd_config); + + if (sd_config.segmentation.debug || sd_config.embedding.debug) { + SHERPA_ONNX_LOGE("%s\n", sd_config.ToString().c_str()); + } + + return sd; +} + +void SherpaOnnxDestroyOfflineSpeakerDiarization( + const SherpaOnnxOfflineSpeakerDiarization *sd) { + delete sd; +} + +int32_t SherpaOnnxOfflineSpeakerDiarizationGetSampleRate( + const SherpaOnnxOfflineSpeakerDiarization *sd) { + return sd->impl->SampleRate(); +} + +int32_t SherpaOnnxOfflineSpeakerDiarizationResultGetNumSpeakers( + const SherpaOnnxOfflineSpeakerDiarizationResult *r) { + return r->impl.NumSpeakers(); +} + +int32_t SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments( + const SherpaOnnxOfflineSpeakerDiarizationResult *r) { + return r->impl.NumSegments(); +} + +const SherpaOnnxOfflineSpeakerDiarizationSegment * +SherpaOnnxOfflineSpeakerDiarizationResultSortByStartTime( + const SherpaOnnxOfflineSpeakerDiarizationResult *r) { + if (r->impl.NumSegments() == 0) { + return nullptr; + } + + auto segments = r->impl.SortByStartTime(); + + int32_t n = segments.size(); + SherpaOnnxOfflineSpeakerDiarizationSegment *ans = + new SherpaOnnxOfflineSpeakerDiarizationSegment[n]; + + for (int32_t i = 0; i != n; ++i) { + const auto &s = segments[i]; + + ans[i].start = s.Start(); + ans[i].end = s.End(); + ans[i].speaker = s.Speaker(); + } + + return ans; +} + +void SherpaOnnxOfflineSpeakerDiarizationDestroySegment( + const SherpaOnnxOfflineSpeakerDiarizationSegment *s) { + delete[] s; +} + +const SherpaOnnxOfflineSpeakerDiarizationResult * +SherpaOnnxOfflineSpeakerDiarizationProcess( + const SherpaOnnxOfflineSpeakerDiarization *sd, const float *samples, + int32_t n) { + auto ans = new SherpaOnnxOfflineSpeakerDiarizationResult; + ans->impl = sd->impl->Process(samples, n); + + return ans; +} + +void SherpaOnnxOfflineSpeakerDiarizationDestroyResult( + const SherpaOnnxOfflineSpeakerDiarizationResult *r) { + delete r; +} + +const SherpaOnnxOfflineSpeakerDiarizationResult * +SherpaOnnxOfflineSpeakerDiarizationProcessWithCallback( + const SherpaOnnxOfflineSpeakerDiarization *sd, const float *samples, + int32_t n, SherpaOnnxOfflineSpeakerDiarizationProgressCallback callback, + void *arg) { + auto ans = new SherpaOnnxOfflineSpeakerDiarizationResult; + ans->impl = sd->impl->Process(samples, n, callback, arg); + + return ans; +} + +#endif diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 58615fe48..d378dedec 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -927,7 +927,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTts SherpaOnnxOfflineTts; SHERPA_ONNX_API SherpaOnnxOfflineTts *SherpaOnnxCreateOfflineTts( const SherpaOnnxOfflineTtsConfig *config); -// Free the pointer returned by CreateOfflineTts() +// Free the pointer returned by SherpaOnnxCreateOfflineTts() SHERPA_ONNX_API void SherpaOnnxDestroyOfflineTts(SherpaOnnxOfflineTts *tts); // Return the sample rate of the current TTS object @@ -954,6 +954,11 @@ SherpaOnnxOfflineTtsGenerateWithCallback( const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed, SherpaOnnxGeneratedAudioCallback callback); +const SherpaOnnxGeneratedAudio * +SherpaOnnxOfflineTtsGenerateWithProgressCallback( + const SherpaOnnxOfflineTts *tts, const char *text, int32_t sid, float speed, + SherpaOnnxGeneratedAudioProgressCallback callback); + // Same as SherpaOnnxGeneratedAudioCallback but you can pass an additional // `void* arg` to the callback. SHERPA_ONNX_API const SherpaOnnxGeneratedAudio * @@ -1384,6 +1389,115 @@ SHERPA_ONNX_API int32_t SherpaOnnxLinearResamplerResampleGetOutputSampleRate( // Return 1 if the file exists; return 0 if the file does not exist. SHERPA_ONNX_API int32_t SherpaOnnxFileExists(const char *filename); +// ========================================================================= +// For offline speaker diarization (i.e., non-streaming speaker diarization) +// ========================================================================= +SHERPA_ONNX_API typedef struct + SherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig { + const char *model; +} SherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig; + +SHERPA_ONNX_API typedef struct SherpaOnnxOfflineSpeakerSegmentationModelConfig { + SherpaOnnxOfflineSpeakerSegmentationPyannoteModelConfig pyannote; + int32_t num_threads; // 1 + int32_t debug; // false + const char *provider; // "cpu" +} SherpaOnnxOfflineSpeakerSegmentationModelConfig; + +SHERPA_ONNX_API typedef struct SherpaOnnxFastClusteringConfig { + // If greater than 0, then threshold is ignored. + // + // We strongly recommend that you set it if you know the number of clusters + // in advance + int32_t num_clusters; + + // distance threshold. + // + // The smaller, the more clusters it will generate. + // The larger, the fewer clusters it will generate. + float threshold; +} SherpaOnnxFastClusteringConfig; + +SHERPA_ONNX_API typedef struct SherpaOnnxOfflineSpeakerDiarizationConfig { + SherpaOnnxOfflineSpeakerSegmentationModelConfig segmentation; + SherpaOnnxSpeakerEmbeddingExtractorConfig embedding; + SherpaOnnxFastClusteringConfig clustering; + + // if a segment is less than this value, then it is discarded + float min_duration_on; // in seconds + + // if the gap between to segments of the same speaker is less than this value, + // then these two segments are merged into a single segment. + // We do this recursively. + float min_duration_off; // in seconds +} SherpaOnnxOfflineSpeakerDiarizationConfig; + +SHERPA_ONNX_API typedef struct SherpaOnnxOfflineSpeakerDiarization + SherpaOnnxOfflineSpeakerDiarization; + +// The users has to invoke SherpaOnnxDestroyOfflineSpeakerDiarization() +// to free the returned pointer to avoid memory leak +SHERPA_ONNX_API const SherpaOnnxOfflineSpeakerDiarization * +SherpaOnnxCreateOfflineSpeakerDiarization( + const SherpaOnnxOfflineSpeakerDiarizationConfig *config); + +// Free the pointer returned by SherpaOnnxCreateOfflineSpeakerDiarization() +SHERPA_ONNX_API void SherpaOnnxDestroyOfflineSpeakerDiarization( + const SherpaOnnxOfflineSpeakerDiarization *sd); + +// Expected sample rate of the input audio samples +SHERPA_ONNX_API int32_t SherpaOnnxOfflineSpeakerDiarizationGetSampleRate( + const SherpaOnnxOfflineSpeakerDiarization *sd); + +SHERPA_ONNX_API typedef struct SherpaOnnxOfflineSpeakerDiarizationResult + SherpaOnnxOfflineSpeakerDiarizationResult; + +SHERPA_ONNX_API typedef struct SherpaOnnxOfflineSpeakerDiarizationSegment { + float start; + float end; + int32_t speaker; +} SherpaOnnxOfflineSpeakerDiarizationSegment; + +SHERPA_ONNX_API int32_t SherpaOnnxOfflineSpeakerDiarizationResultGetNumSpeakers( + const SherpaOnnxOfflineSpeakerDiarizationResult *r); + +SHERPA_ONNX_API int32_t SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments( + const SherpaOnnxOfflineSpeakerDiarizationResult *r); + +// The user has to invoke SherpaOnnxOfflineSpeakerDiarizationDestroySegment() +// to free the returned pointer to avoid memory leak. +// +// The returned pointer is the start address of an array. +// Number of entries in the array equals to the value +// returned by SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments() +SHERPA_ONNX_API const SherpaOnnxOfflineSpeakerDiarizationSegment * +SherpaOnnxOfflineSpeakerDiarizationResultSortByStartTime( + const SherpaOnnxOfflineSpeakerDiarizationResult *r); + +SHERPA_ONNX_API void SherpaOnnxOfflineSpeakerDiarizationDestroySegment( + const SherpaOnnxOfflineSpeakerDiarizationSegment *s); + +typedef int32_t (*SherpaOnnxOfflineSpeakerDiarizationProgressCallback)( + int32_t num_processed_chunk, int32_t num_total_chunks, void *arg); + +// The user has to invoke SherpaOnnxOfflineSpeakerDiarizationDestroyResult() +// to free the returned pointer to avoid memory leak. +SHERPA_ONNX_API const SherpaOnnxOfflineSpeakerDiarizationResult * +SherpaOnnxOfflineSpeakerDiarizationProcess( + const SherpaOnnxOfflineSpeakerDiarization *sd, const float *samples, + int32_t n); + +// The user has to invoke SherpaOnnxOfflineSpeakerDiarizationDestroyResult() +// to free the returned pointer to avoid memory leak. +SHERPA_ONNX_API const SherpaOnnxOfflineSpeakerDiarizationResult * +SherpaOnnxOfflineSpeakerDiarizationProcessWithCallback( + const SherpaOnnxOfflineSpeakerDiarization *sd, const float *samples, + int32_t n, SherpaOnnxOfflineSpeakerDiarizationProgressCallback callback, + void *arg); + +SHERPA_ONNX_API void SherpaOnnxOfflineSpeakerDiarizationDestroyResult( + const SherpaOnnxOfflineSpeakerDiarizationResult *r); + #if defined(__GNUC__) #pragma GCC diagnostic pop #endif diff --git a/sherpa-onnx/csrc/fast-clustering-config.h b/sherpa-onnx/csrc/fast-clustering-config.h index 9b190d46b..4abf2b128 100644 --- a/sherpa-onnx/csrc/fast-clustering-config.h +++ b/sherpa-onnx/csrc/fast-clustering-config.h @@ -20,8 +20,8 @@ struct FastClusteringConfig { // distance threshold. // - // The lower, the more clusters it will generate. - // The higher, the fewer clusters it will generate. + // The smaller, the more clusters it will generate. + // The larger, the fewer clusters it will generate. float threshold = 0.5; FastClusteringConfig() = default; diff --git a/sherpa-onnx/csrc/offline-speaker-diarization.cc b/sherpa-onnx/csrc/offline-speaker-diarization.cc index aeff9b42d..4748b1cb4 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization.cc +++ b/sherpa-onnx/csrc/offline-speaker-diarization.cc @@ -43,6 +43,16 @@ bool OfflineSpeakerDiarizationConfig::Validate() const { return false; } + if (min_duration_on < 0) { + SHERPA_ONNX_LOGE("min_duration_on %.3f is negative", min_duration_on); + return false; + } + + if (min_duration_off < 0) { + SHERPA_ONNX_LOGE("min_duration_off %.3f is negative", min_duration_off); + return false; + } + return true; } diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc b/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc index 170973114..31cda85fc 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc @@ -7,7 +7,7 @@ #include "sherpa-onnx/csrc/wave-reader.h" static int32_t ProgressCallback(int32_t processed_chunks, int32_t num_chunks, - void *arg) { + void *) { float progress = 100.0 * processed_chunks / num_chunks; fprintf(stderr, "progress %.2f%%\n", progress);