Skip to content

Commit

Permalink
C API for speaker diarization (#1402)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Oct 9, 2024
1 parent 8535b1d commit d468527
Show file tree
Hide file tree
Showing 9 changed files with 418 additions and 7 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,4 @@ vits-melo-tts-zh_en
sherpa-onnx-online-punct-en-2024-08-06
*.mp4
*.mp3
sherpa-onnx-pyannote-segmentation-3-0
11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
### Supported functions

|Speech recognition| Speech synthesis | Speaker verification | Speaker identification |
|------------------|------------------|----------------------|------------------------|
| ✔️ | ✔️ | ✔️ | ✔️ |
|Speech recognition| Speech synthesis |
|------------------|------------------|
| ✔️ | ✔️ |

|Speaker identification| Speaker diarization | Speaker identification |
|----------------------|-------------------- |------------------------|
| ✔️ | ✔️ | ✔️ |

| Spoken Language identification | Audio tagging | Voice activity detection |
|--------------------------------|---------------|--------------------------|
Expand Down Expand Up @@ -47,6 +51,7 @@ This repository supports running the following functions **locally**

- Speech-to-text (i.e., ASR); both streaming and non-streaming are supported
- Text-to-speech (i.e., TTS)
- Speaker diarization
- Speaker identification
- Speaker verification
- Spoken language identification
Expand Down
5 changes: 5 additions & 0 deletions c-api-examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ if(SHERPA_ONNX_ENABLE_TTS)
target_link_libraries(offline-tts-c-api sherpa-onnx-c-api cargs)
endif()

if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
add_executable(offline-speaker-diarization-c-api offline-speaker-diarization-c-api.c)
target_link_libraries(offline-speaker-diarization-c-api sherpa-onnx-c-api)
endif()

add_executable(spoken-language-identification-c-api spoken-language-identification-c-api.c)
target_link_libraries(spoken-language-identification-c-api sherpa-onnx-c-api)

Expand Down
131 changes: 131 additions & 0 deletions c-api-examples/offline-speaker-diarization-c-api.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
// c-api-examples/offline-sepaker-diarization-c-api.c
//
// Copyright (c) 2024 Xiaomi Corporation

//
// This file demonstrates how to implement speaker diarization with
// sherpa-onnx's C API.

// clang-format off
/*
Usage:
Step 1: Download a speaker segmentation model
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available models. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
Step 2: Download a speaker embedding extractor model
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
for a list of available models. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
Step 3. Download test wave files
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available test wave files. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
Step 4. Run it
*/
// clang-format on

#include <stdio.h>
#include <string.h>

#include "sherpa-onnx/c-api/c-api.h"

static int32_t ProgressCallback(int32_t num_processed_chunks,
int32_t num_total_chunks, void *arg) {
float progress = 100.0 * num_processed_chunks / num_total_chunks;
fprintf(stderr, "progress %.2f%%\n", progress);

// the return value is currently ignored
return 0;
}

int main() {
// Please see the comments at the start of this file for how to download
// the .onnx file and .wav files below
const char *segmentation_model =
"./sherpa-onnx-pyannote-segmentation-3-0/model.onnx";

const char *embedding_extractor_model =
"./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx";

const char *wav_filename = "./0-four-speakers-zh.wav";

const SherpaOnnxWave *wave = SherpaOnnxReadWave(wav_filename);
if (wave == NULL) {
fprintf(stderr, "Failed to read %s\n", wav_filename);
return -1;
}

SherpaOnnxOfflineSpeakerDiarizationConfig config;
memset(&config, 0, sizeof(config));

config.segmentation.pyannote.model = segmentation_model;
config.embedding.model = embedding_extractor_model;

// the test wave ./0-four-speakers-zh.wav has 4 speakers, so
// we set num_clusters to 4
//
config.clustering.num_clusters = 4;
// If you don't know the number of speakers in the test wave file, please
// use
// config.clustering.threshold = 0.5; // You need to tune this threshold

const SherpaOnnxOfflineSpeakerDiarization *sd =
SherpaOnnxCreateOfflineSpeakerDiarization(&config);

if (!sd) {
fprintf(stderr, "Failed to initialize offline speaker diarization\n");
return -1;
}

if (SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(sd) !=
wave->sample_rate) {
fprintf(
stderr,
"Expected sample rate: %d. Actual sample rate from the wave file: %d\n",
SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(sd),
wave->sample_rate);
goto failed;
}

const SherpaOnnxOfflineSpeakerDiarizationResult *result =
SherpaOnnxOfflineSpeakerDiarizationProcessWithCallback(
sd, wave->samples, wave->num_samples, ProgressCallback, NULL);
if (!result) {
fprintf(stderr, "Failed to do speaker diarization");
goto failed;
}

int32_t num_segments =
SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments(result);

const SherpaOnnxOfflineSpeakerDiarizationSegment *segments =
SherpaOnnxOfflineSpeakerDiarizationResultSortByStartTime(result);

for (int32_t i = 0; i != num_segments; ++i) {
fprintf(stderr, "%.3f -- %.3f speaker_%02d\n", segments[i].start,
segments[i].end, segments[i].speaker);
}

failed:

SherpaOnnxOfflineSpeakerDiarizationDestroySegment(segments);
SherpaOnnxOfflineSpeakerDiarizationDestroyResult(result);
SherpaOnnxDestroyOfflineSpeakerDiarization(sd);
SherpaOnnxFreeWave(wave);

return 0;
}
145 changes: 145 additions & 0 deletions sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
#include "sherpa-onnx/csrc/offline-tts.h"
#endif

#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1
#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
#endif

struct SherpaOnnxOnlineRecognizer {
std::unique_ptr<sherpa_onnx::OnlineRecognizer> impl;
};
Expand Down Expand Up @@ -1670,3 +1674,144 @@ void SherpaOnnxLinearResamplerReset(SherpaOnnxLinearResampler *p) {
int32_t SherpaOnnxFileExists(const char *filename) {
return sherpa_onnx::FileExists(filename);
}

#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1

struct SherpaOnnxOfflineSpeakerDiarization {
std::unique_ptr<sherpa_onnx::OfflineSpeakerDiarization> impl;
};

struct SherpaOnnxOfflineSpeakerDiarizationResult {
sherpa_onnx::OfflineSpeakerDiarizationResult impl;
};

const SherpaOnnxOfflineSpeakerDiarization *
SherpaOnnxCreateOfflineSpeakerDiarization(
const SherpaOnnxOfflineSpeakerDiarizationConfig *config) {
sherpa_onnx::OfflineSpeakerDiarizationConfig sd_config;

sd_config.segmentation.pyannote.model =
SHERPA_ONNX_OR(config->segmentation.pyannote.model, "");
sd_config.segmentation.num_threads =
SHERPA_ONNX_OR(config->segmentation.num_threads, 1);
sd_config.segmentation.debug = config->segmentation.debug;
sd_config.segmentation.provider =
SHERPA_ONNX_OR(config->segmentation.provider, "cpu");
if (sd_config.segmentation.provider.empty()) {
sd_config.segmentation.provider = "cpu";
}

sd_config.embedding.model = SHERPA_ONNX_OR(config->embedding.model, "");
sd_config.embedding.num_threads =
SHERPA_ONNX_OR(config->embedding.num_threads, 1);
sd_config.embedding.debug = config->embedding.debug;
sd_config.embedding.provider =
SHERPA_ONNX_OR(config->embedding.provider, "cpu");
if (sd_config.embedding.provider.empty()) {
sd_config.embedding.provider = "cpu";
}

sd_config.clustering.num_clusters =
SHERPA_ONNX_OR(config->clustering.num_clusters, -1);

sd_config.clustering.threshold =
SHERPA_ONNX_OR(config->clustering.threshold, 0.5);

sd_config.min_duration_on = SHERPA_ONNX_OR(config->min_duration_on, 0.3);

sd_config.min_duration_off = SHERPA_ONNX_OR(config->min_duration_off, 0.5);

if (!sd_config.Validate()) {
SHERPA_ONNX_LOGE("Errors in config");
return nullptr;
}

SherpaOnnxOfflineSpeakerDiarization *sd =
new SherpaOnnxOfflineSpeakerDiarization;

sd->impl =
std::make_unique<sherpa_onnx::OfflineSpeakerDiarization>(sd_config);

if (sd_config.segmentation.debug || sd_config.embedding.debug) {
SHERPA_ONNX_LOGE("%s\n", sd_config.ToString().c_str());
}

return sd;
}

void SherpaOnnxDestroyOfflineSpeakerDiarization(
const SherpaOnnxOfflineSpeakerDiarization *sd) {
delete sd;
}

int32_t SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(
const SherpaOnnxOfflineSpeakerDiarization *sd) {
return sd->impl->SampleRate();
}

int32_t SherpaOnnxOfflineSpeakerDiarizationResultGetNumSpeakers(
const SherpaOnnxOfflineSpeakerDiarizationResult *r) {
return r->impl.NumSpeakers();
}

int32_t SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments(
const SherpaOnnxOfflineSpeakerDiarizationResult *r) {
return r->impl.NumSegments();
}

const SherpaOnnxOfflineSpeakerDiarizationSegment *
SherpaOnnxOfflineSpeakerDiarizationResultSortByStartTime(
const SherpaOnnxOfflineSpeakerDiarizationResult *r) {
if (r->impl.NumSegments() == 0) {
return nullptr;
}

auto segments = r->impl.SortByStartTime();

int32_t n = segments.size();
SherpaOnnxOfflineSpeakerDiarizationSegment *ans =
new SherpaOnnxOfflineSpeakerDiarizationSegment[n];

for (int32_t i = 0; i != n; ++i) {
const auto &s = segments[i];

ans[i].start = s.Start();
ans[i].end = s.End();
ans[i].speaker = s.Speaker();
}

return ans;
}

void SherpaOnnxOfflineSpeakerDiarizationDestroySegment(
const SherpaOnnxOfflineSpeakerDiarizationSegment *s) {
delete[] s;
}

const SherpaOnnxOfflineSpeakerDiarizationResult *
SherpaOnnxOfflineSpeakerDiarizationProcess(
const SherpaOnnxOfflineSpeakerDiarization *sd, const float *samples,
int32_t n) {
auto ans = new SherpaOnnxOfflineSpeakerDiarizationResult;
ans->impl = sd->impl->Process(samples, n);

return ans;
}

void SherpaOnnxOfflineSpeakerDiarizationDestroyResult(
const SherpaOnnxOfflineSpeakerDiarizationResult *r) {
delete r;
}

const SherpaOnnxOfflineSpeakerDiarizationResult *
SherpaOnnxOfflineSpeakerDiarizationProcessWithCallback(
const SherpaOnnxOfflineSpeakerDiarization *sd, const float *samples,
int32_t n, SherpaOnnxOfflineSpeakerDiarizationProgressCallback callback,
void *arg) {
auto ans = new SherpaOnnxOfflineSpeakerDiarizationResult;
ans->impl = sd->impl->Process(samples, n, callback, arg);

return ans;
}

#endif
Loading

0 comments on commit d468527

Please sign in to comment.