Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

C API for speaker diarization #1402

Merged
merged 1 commit into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,4 @@ vits-melo-tts-zh_en
sherpa-onnx-online-punct-en-2024-08-06
*.mp4
*.mp3
sherpa-onnx-pyannote-segmentation-3-0
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
### Supported functions

|Speech recognition| Speech synthesis | Speaker verification | Speaker identification |
|------------------|------------------|----------------------|------------------------|
| ✔️ | ✔️ | ✔️ | ✔️ |
|Speech recognition| Speech synthesis |
|------------------|------------------|
| ✔️ | ✔️ |

|Speaker identification| Speaker diarization | Speaker identification |
|----------------------|-------------------- |------------------------|
| ✔️ | ✔️ | ✔️ |

| Spoken Language identification | Audio tagging | Voice activity detection |
|--------------------------------|---------------|--------------------------|
Expand Down Expand Up @@ -47,6 +51,7 @@ This repository supports running the following functions **locally**

- Speech-to-text (i.e., ASR); both streaming and non-streaming are supported
- Text-to-speech (i.e., TTS)
- Speaker diarization
- Speaker identification
- Speaker verification
- Spoken language identification
Expand Down
5 changes: 5 additions & 0 deletions c-api-examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ if(SHERPA_ONNX_ENABLE_TTS)
target_link_libraries(offline-tts-c-api sherpa-onnx-c-api cargs)
endif()

if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
add_executable(offline-speaker-diarization-c-api offline-speaker-diarization-c-api.c)
target_link_libraries(offline-speaker-diarization-c-api sherpa-onnx-c-api)
endif()

add_executable(spoken-language-identification-c-api spoken-language-identification-c-api.c)
target_link_libraries(spoken-language-identification-c-api sherpa-onnx-c-api)

Expand Down
131 changes: 131 additions & 0 deletions c-api-examples/offline-speaker-diarization-c-api.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
// c-api-examples/offline-sepaker-diarization-c-api.c
//
// Copyright (c) 2024 Xiaomi Corporation

//
// This file demonstrates how to implement speaker diarization with
// sherpa-onnx's C API.

// clang-format off
/*
Usage:

Step 1: Download a speaker segmentation model

Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available models. The following is an example

wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2

Step 2: Download a speaker embedding extractor model

Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
for a list of available models. The following is an example

wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx

Step 3. Download test wave files

Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available test wave files. The following is an example

wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav

Step 4. Run it

*/
// clang-format on

#include <stdio.h>
#include <string.h>

#include "sherpa-onnx/c-api/c-api.h"

static int32_t ProgressCallback(int32_t num_processed_chunks,
int32_t num_total_chunks, void *arg) {
float progress = 100.0 * num_processed_chunks / num_total_chunks;
fprintf(stderr, "progress %.2f%%\n", progress);

// the return value is currently ignored
return 0;
}

int main() {
// Please see the comments at the start of this file for how to download
// the .onnx file and .wav files below
const char *segmentation_model =
"./sherpa-onnx-pyannote-segmentation-3-0/model.onnx";

const char *embedding_extractor_model =
"./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx";

const char *wav_filename = "./0-four-speakers-zh.wav";

const SherpaOnnxWave *wave = SherpaOnnxReadWave(wav_filename);
if (wave == NULL) {
fprintf(stderr, "Failed to read %s\n", wav_filename);
return -1;
}

SherpaOnnxOfflineSpeakerDiarizationConfig config;
memset(&config, 0, sizeof(config));

config.segmentation.pyannote.model = segmentation_model;
config.embedding.model = embedding_extractor_model;

// the test wave ./0-four-speakers-zh.wav has 4 speakers, so
// we set num_clusters to 4
//
config.clustering.num_clusters = 4;
// If you don't know the number of speakers in the test wave file, please
// use
// config.clustering.threshold = 0.5; // You need to tune this threshold

const SherpaOnnxOfflineSpeakerDiarization *sd =
SherpaOnnxCreateOfflineSpeakerDiarization(&config);

if (!sd) {
fprintf(stderr, "Failed to initialize offline speaker diarization\n");
return -1;
}

if (SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(sd) !=
wave->sample_rate) {
fprintf(
stderr,
"Expected sample rate: %d. Actual sample rate from the wave file: %d\n",
SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(sd),
wave->sample_rate);
goto failed;
}

const SherpaOnnxOfflineSpeakerDiarizationResult *result =
SherpaOnnxOfflineSpeakerDiarizationProcessWithCallback(
sd, wave->samples, wave->num_samples, ProgressCallback, NULL);
if (!result) {
fprintf(stderr, "Failed to do speaker diarization");
goto failed;
}

int32_t num_segments =
SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments(result);

const SherpaOnnxOfflineSpeakerDiarizationSegment *segments =
SherpaOnnxOfflineSpeakerDiarizationResultSortByStartTime(result);

for (int32_t i = 0; i != num_segments; ++i) {
fprintf(stderr, "%.3f -- %.3f speaker_%02d\n", segments[i].start,
segments[i].end, segments[i].speaker);
}

failed:

SherpaOnnxOfflineSpeakerDiarizationDestroySegment(segments);
SherpaOnnxOfflineSpeakerDiarizationDestroyResult(result);
SherpaOnnxDestroyOfflineSpeakerDiarization(sd);
SherpaOnnxFreeWave(wave);

return 0;
}
145 changes: 145 additions & 0 deletions sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
#include "sherpa-onnx/csrc/offline-tts.h"
#endif

#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1
#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
#endif

struct SherpaOnnxOnlineRecognizer {
std::unique_ptr<sherpa_onnx::OnlineRecognizer> impl;
};
Expand Down Expand Up @@ -1670,3 +1674,144 @@ void SherpaOnnxLinearResamplerReset(SherpaOnnxLinearResampler *p) {
int32_t SherpaOnnxFileExists(const char *filename) {
return sherpa_onnx::FileExists(filename);
}

#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1

struct SherpaOnnxOfflineSpeakerDiarization {
std::unique_ptr<sherpa_onnx::OfflineSpeakerDiarization> impl;
};

struct SherpaOnnxOfflineSpeakerDiarizationResult {
sherpa_onnx::OfflineSpeakerDiarizationResult impl;
};

const SherpaOnnxOfflineSpeakerDiarization *
SherpaOnnxCreateOfflineSpeakerDiarization(
const SherpaOnnxOfflineSpeakerDiarizationConfig *config) {
sherpa_onnx::OfflineSpeakerDiarizationConfig sd_config;

sd_config.segmentation.pyannote.model =
SHERPA_ONNX_OR(config->segmentation.pyannote.model, "");
sd_config.segmentation.num_threads =
SHERPA_ONNX_OR(config->segmentation.num_threads, 1);
sd_config.segmentation.debug = config->segmentation.debug;
sd_config.segmentation.provider =
SHERPA_ONNX_OR(config->segmentation.provider, "cpu");
if (sd_config.segmentation.provider.empty()) {
sd_config.segmentation.provider = "cpu";
}

sd_config.embedding.model = SHERPA_ONNX_OR(config->embedding.model, "");
sd_config.embedding.num_threads =
SHERPA_ONNX_OR(config->embedding.num_threads, 1);
sd_config.embedding.debug = config->embedding.debug;
sd_config.embedding.provider =
SHERPA_ONNX_OR(config->embedding.provider, "cpu");
if (sd_config.embedding.provider.empty()) {
sd_config.embedding.provider = "cpu";
}

sd_config.clustering.num_clusters =
SHERPA_ONNX_OR(config->clustering.num_clusters, -1);

sd_config.clustering.threshold =
SHERPA_ONNX_OR(config->clustering.threshold, 0.5);

sd_config.min_duration_on = SHERPA_ONNX_OR(config->min_duration_on, 0.3);

sd_config.min_duration_off = SHERPA_ONNX_OR(config->min_duration_off, 0.5);

if (!sd_config.Validate()) {
SHERPA_ONNX_LOGE("Errors in config");
return nullptr;
}

SherpaOnnxOfflineSpeakerDiarization *sd =
new SherpaOnnxOfflineSpeakerDiarization;

sd->impl =
std::make_unique<sherpa_onnx::OfflineSpeakerDiarization>(sd_config);

if (sd_config.segmentation.debug || sd_config.embedding.debug) {
SHERPA_ONNX_LOGE("%s\n", sd_config.ToString().c_str());
}

return sd;
}

void SherpaOnnxDestroyOfflineSpeakerDiarization(
const SherpaOnnxOfflineSpeakerDiarization *sd) {
delete sd;
}

int32_t SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(
const SherpaOnnxOfflineSpeakerDiarization *sd) {
return sd->impl->SampleRate();
}

int32_t SherpaOnnxOfflineSpeakerDiarizationResultGetNumSpeakers(
const SherpaOnnxOfflineSpeakerDiarizationResult *r) {
return r->impl.NumSpeakers();
}

int32_t SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments(
const SherpaOnnxOfflineSpeakerDiarizationResult *r) {
return r->impl.NumSegments();
}

const SherpaOnnxOfflineSpeakerDiarizationSegment *
SherpaOnnxOfflineSpeakerDiarizationResultSortByStartTime(
const SherpaOnnxOfflineSpeakerDiarizationResult *r) {
if (r->impl.NumSegments() == 0) {
return nullptr;
}

auto segments = r->impl.SortByStartTime();

int32_t n = segments.size();
SherpaOnnxOfflineSpeakerDiarizationSegment *ans =
new SherpaOnnxOfflineSpeakerDiarizationSegment[n];

for (int32_t i = 0; i != n; ++i) {
const auto &s = segments[i];

ans[i].start = s.Start();
ans[i].end = s.End();
ans[i].speaker = s.Speaker();
}

return ans;
}

void SherpaOnnxOfflineSpeakerDiarizationDestroySegment(
const SherpaOnnxOfflineSpeakerDiarizationSegment *s) {
delete[] s;
}

const SherpaOnnxOfflineSpeakerDiarizationResult *
SherpaOnnxOfflineSpeakerDiarizationProcess(
const SherpaOnnxOfflineSpeakerDiarization *sd, const float *samples,
int32_t n) {
auto ans = new SherpaOnnxOfflineSpeakerDiarizationResult;
ans->impl = sd->impl->Process(samples, n);

return ans;
}

void SherpaOnnxOfflineSpeakerDiarizationDestroyResult(
const SherpaOnnxOfflineSpeakerDiarizationResult *r) {
delete r;
}

const SherpaOnnxOfflineSpeakerDiarizationResult *
SherpaOnnxOfflineSpeakerDiarizationProcessWithCallback(
const SherpaOnnxOfflineSpeakerDiarization *sd, const float *samples,
int32_t n, SherpaOnnxOfflineSpeakerDiarizationProgressCallback callback,
void *arg) {
auto ans = new SherpaOnnxOfflineSpeakerDiarizationResult;
ans->impl = sd->impl->Process(samples, n, callback, arg);

return ans;
}

#endif
Loading
Loading