From 94b26ff07c1b6275d1830cd2987081a0bdbedacb Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 12 Oct 2024 13:03:48 +0800 Subject: [PATCH] Android JNI support for speaker diarization (#1421) --- .../csrc/offline-speaker-diarization-impl.cc | 14 ++++++++++++++ .../csrc/offline-speaker-diarization-impl.h | 10 ++++++++++ ...ffline-speaker-diarization-pyannote-impl.h | 16 ++++++++++++++++ .../csrc/offline-speaker-diarization.cc | 6 ++++++ .../csrc/offline-speaker-diarization.h | 10 ++++++++++ ...ine-speaker-segmentation-pyannote-model.cc | 18 ++++++++++++++++++ ...line-speaker-segmentation-pyannote-model.h | 10 ++++++++++ .../sherpa-onnx-vad-microphone-offline-asr.cc | 2 +- sherpa-onnx/jni/audio-tagging.cc | 1 + sherpa-onnx/jni/keyword-spotter.cc | 2 ++ sherpa-onnx/jni/offline-punctuation.cc | 2 ++ sherpa-onnx/jni/offline-recognizer.cc | 2 ++ .../jni/offline-speaker-diarization.cc | 19 ++++++++++++++++++- sherpa-onnx/jni/offline-tts.cc | 1 + sherpa-onnx/jni/online-recognizer.cc | 1 + .../jni/speaker-embedding-extractor.cc | 1 + .../jni/spoken-language-identification.cc | 1 + sherpa-onnx/jni/voice-activity-detector.cc | 2 ++ 18 files changed, 116 insertions(+), 2 deletions(-) diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-impl.cc b/sherpa-onnx/csrc/offline-speaker-diarization-impl.cc index e41a7767a..15c3a2eb4 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-impl.cc +++ b/sherpa-onnx/csrc/offline-speaker-diarization-impl.cc @@ -23,4 +23,18 @@ OfflineSpeakerDiarizationImpl::Create( return nullptr; } +#if __ANDROID_API__ >= 9 +std::unique_ptr +OfflineSpeakerDiarizationImpl::Create( + AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config) { + if (!config.segmentation.pyannote.model.empty()) { + return std::make_unique(mgr, config); + } + + SHERPA_ONNX_LOGE("Please specify a speaker segmentation model."); + + return nullptr; +} +#endif + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-impl.h index 3aed9d72f..41f0e1e2f 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-impl.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-impl.h @@ -8,6 +8,11 @@ #include #include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + #include "sherpa-onnx/csrc/offline-speaker-diarization.h" namespace sherpa_onnx { @@ -16,6 +21,11 @@ class OfflineSpeakerDiarizationImpl { static std::unique_ptr Create( const OfflineSpeakerDiarizationConfig &config); +#if __ANDROID_API__ >= 9 + static std::unique_ptr Create( + AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config); +#endif + virtual ~OfflineSpeakerDiarizationImpl() = default; virtual int32_t SampleRate() const = 0; diff --git a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h index 0c70f0bc6..aaedc3be0 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h @@ -10,6 +10,11 @@ #include #include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + #include "Eigen/Dense" #include "sherpa-onnx/csrc/fast-clustering.h" #include "sherpa-onnx/csrc/math.h" @@ -65,6 +70,17 @@ class OfflineSpeakerDiarizationPyannoteImpl Init(); } +#if __ANDROID_API__ >= 9 + OfflineSpeakerDiarizationPyannoteImpl( + AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config) + : config_(config), + segmentation_model_(mgr, config_.segmentation), + embedding_extractor_(mgr, config_.embedding), + clustering_(std::make_unique(config_.clustering)) { + Init(); + } +#endif + int32_t SampleRate() const override { const auto &meta_data = segmentation_model_.GetModelMetaData(); diff --git a/sherpa-onnx/csrc/offline-speaker-diarization.cc b/sherpa-onnx/csrc/offline-speaker-diarization.cc index 00733bfb2..f34ea4e0e 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization.cc +++ b/sherpa-onnx/csrc/offline-speaker-diarization.cc @@ -73,6 +73,12 @@ OfflineSpeakerDiarization::OfflineSpeakerDiarization( const OfflineSpeakerDiarizationConfig &config) : impl_(OfflineSpeakerDiarizationImpl::Create(config)) {} +#if __ANDROID_API__ >= 9 +OfflineSpeakerDiarization::OfflineSpeakerDiarization( + AAssetManager *mgr, const OfflineSpeakerDiarizationConfig &config) + : impl_(OfflineSpeakerDiarizationImpl::Create(mgr, config)) {} +#endif + OfflineSpeakerDiarization::~OfflineSpeakerDiarization() = default; int32_t OfflineSpeakerDiarization::SampleRate() const { diff --git a/sherpa-onnx/csrc/offline-speaker-diarization.h b/sherpa-onnx/csrc/offline-speaker-diarization.h index 376e5f975..4a517fbb2 100644 --- a/sherpa-onnx/csrc/offline-speaker-diarization.h +++ b/sherpa-onnx/csrc/offline-speaker-diarization.h @@ -9,6 +9,11 @@ #include #include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + #include "sherpa-onnx/csrc/fast-clustering-config.h" #include "sherpa-onnx/csrc/offline-speaker-diarization-result.h" #include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h" @@ -57,6 +62,11 @@ class OfflineSpeakerDiarization { explicit OfflineSpeakerDiarization( const OfflineSpeakerDiarizationConfig &config); +#if __ANDROID_API__ >= 9 + OfflineSpeakerDiarization(AAssetManager *mgr, + const OfflineSpeakerDiarizationConfig &config); +#endif + ~OfflineSpeakerDiarization(); // Expected sample rate of the input audio samples diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc index 3f3323698..e3768dcf4 100644 --- a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc @@ -24,6 +24,17 @@ class OfflineSpeakerSegmentationPyannoteModel::Impl { Init(buf.data(), buf.size()); } +#if __ANDROID_API__ >= 9 + Impl(AAssetManager *mgr, const OfflineSpeakerSegmentationModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(mgr, config_.pyannote.model); + Init(buf.data(), buf.size()); + } +#endif + const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData() const { return meta_data_; @@ -92,6 +103,13 @@ OfflineSpeakerSegmentationPyannoteModel:: const OfflineSpeakerSegmentationModelConfig &config) : impl_(std::make_unique(config)) {} +#if __ANDROID_API__ >= 9 +OfflineSpeakerSegmentationPyannoteModel:: + OfflineSpeakerSegmentationPyannoteModel( + AAssetManager *mgr, const OfflineSpeakerSegmentationModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} +#endif + OfflineSpeakerSegmentationPyannoteModel:: ~OfflineSpeakerSegmentationPyannoteModel() = default; diff --git a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h index b504c373f..6b835763b 100644 --- a/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h +++ b/sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h @@ -6,6 +6,11 @@ #include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + #include "onnxruntime_cxx_api.h" // NOLINT #include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h" #include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h" @@ -17,6 +22,11 @@ class OfflineSpeakerSegmentationPyannoteModel { explicit OfflineSpeakerSegmentationPyannoteModel( const OfflineSpeakerSegmentationModelConfig &config); +#if __ANDROID_API__ >= 9 + OfflineSpeakerSegmentationPyannoteModel( + AAssetManager *mgr, const OfflineSpeakerSegmentationModelConfig &config); +#endif + ~OfflineSpeakerSegmentationPyannoteModel(); const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData() diff --git a/sherpa-onnx/csrc/sherpa-onnx-vad-microphone-offline-asr.cc b/sherpa-onnx/csrc/sherpa-onnx-vad-microphone-offline-asr.cc index c90c29c52..df3e250a5 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-vad-microphone-offline-asr.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-vad-microphone-offline-asr.cc @@ -211,7 +211,7 @@ to download models for offline ASR. } while (!vad->Empty()) { - auto &segment = vad->Front(); + const auto &segment = vad->Front(); auto s = recognizer.CreateStream(); s->AcceptWaveform(sample_rate, segment.samples.data(), segment.samples.size()); diff --git a/sherpa-onnx/jni/audio-tagging.cc b/sherpa-onnx/jni/audio-tagging.cc index ff8db0089..7ad6e7d53 100644 --- a/sherpa-onnx/jni/audio-tagging.cc +++ b/sherpa-onnx/jni/audio-tagging.cc @@ -70,6 +70,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_AudioTagging_newFromAsset( AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); if (!mgr) { SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + return 0; } #endif diff --git a/sherpa-onnx/jni/keyword-spotter.cc b/sherpa-onnx/jni/keyword-spotter.cc index ca0c229c2..4ac80a294 100644 --- a/sherpa-onnx/jni/keyword-spotter.cc +++ b/sherpa-onnx/jni/keyword-spotter.cc @@ -115,10 +115,12 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_KeywordSpotter_newFromAsset( AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); if (!mgr) { SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + return 0; } #endif auto config = sherpa_onnx::GetKwsConfig(env, _config); SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + auto kws = new sherpa_onnx::KeywordSpotter( #if __ANDROID_API__ >= 9 mgr, diff --git a/sherpa-onnx/jni/offline-punctuation.cc b/sherpa-onnx/jni/offline-punctuation.cc index 5056a3ac4..efe03cac0 100644 --- a/sherpa-onnx/jni/offline-punctuation.cc +++ b/sherpa-onnx/jni/offline-punctuation.cc @@ -53,10 +53,12 @@ Java_com_k2fsa_sherpa_onnx_OfflinePunctuation_newFromAsset( AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); if (!mgr) { SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + return 0; } #endif auto config = sherpa_onnx::GetOfflinePunctuationConfig(env, _config); SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + auto model = new sherpa_onnx::OfflinePunctuation( #if __ANDROID_API__ >= 9 mgr, diff --git a/sherpa-onnx/jni/offline-recognizer.cc b/sherpa-onnx/jni/offline-recognizer.cc index 8c1265bba..5e4b359b6 100644 --- a/sherpa-onnx/jni/offline-recognizer.cc +++ b/sherpa-onnx/jni/offline-recognizer.cc @@ -233,10 +233,12 @@ Java_com_k2fsa_sherpa_onnx_OfflineRecognizer_newFromAsset(JNIEnv *env, AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); if (!mgr) { SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + return 0; } #endif auto config = sherpa_onnx::GetOfflineConfig(env, _config); SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + auto model = new sherpa_onnx::OfflineRecognizer( #if __ANDROID_API__ >= 9 mgr, diff --git a/sherpa-onnx/jni/offline-speaker-diarization.cc b/sherpa-onnx/jni/offline-speaker-diarization.cc index e82962c80..ba4e14bc3 100644 --- a/sherpa-onnx/jni/offline-speaker-diarization.cc +++ b/sherpa-onnx/jni/offline-speaker-diarization.cc @@ -101,7 +101,24 @@ SHERPA_ONNX_EXTERN_C JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineSpeakerDiarization_newFromAsset( JNIEnv *env, jobject /*obj*/, jobject asset_manager, jobject _config) { - return 0; +#if __ANDROID_API__ >= 9 + AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); + if (!mgr) { + SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + return 0; + } +#endif + + auto config = sherpa_onnx::GetOfflineSpeakerDiarizationConfig(env, _config); + SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + + auto sd = new sherpa_onnx::OfflineSpeakerDiarization( +#if __ANDROID_API__ >= 9 + mgr, +#endif + config); + + return (jlong)sd; } SHERPA_ONNX_EXTERN_C diff --git a/sherpa-onnx/jni/offline-tts.cc b/sherpa-onnx/jni/offline-tts.cc index 43a93e0e0..4d67afc27 100644 --- a/sherpa-onnx/jni/offline-tts.cc +++ b/sherpa-onnx/jni/offline-tts.cc @@ -105,6 +105,7 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_OfflineTts_newFromAsset( AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); if (!mgr) { SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + return 0; } #endif auto config = sherpa_onnx::GetOfflineTtsConfig(env, _config); diff --git a/sherpa-onnx/jni/online-recognizer.cc b/sherpa-onnx/jni/online-recognizer.cc index 1793cf73b..dbe205c4e 100644 --- a/sherpa-onnx/jni/online-recognizer.cc +++ b/sherpa-onnx/jni/online-recognizer.cc @@ -267,6 +267,7 @@ Java_com_k2fsa_sherpa_onnx_OnlineRecognizer_newFromAsset(JNIEnv *env, AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); if (!mgr) { SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + return 0; } #endif auto config = sherpa_onnx::GetConfig(env, _config); diff --git a/sherpa-onnx/jni/speaker-embedding-extractor.cc b/sherpa-onnx/jni/speaker-embedding-extractor.cc index b1190bffc..33d630ee6 100644 --- a/sherpa-onnx/jni/speaker-embedding-extractor.cc +++ b/sherpa-onnx/jni/speaker-embedding-extractor.cc @@ -45,6 +45,7 @@ Java_com_k2fsa_sherpa_onnx_SpeakerEmbeddingExtractor_newFromAsset( AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); if (!mgr) { SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + return 0; } #endif auto config = sherpa_onnx::GetSpeakerEmbeddingExtractorConfig(env, _config); diff --git a/sherpa-onnx/jni/spoken-language-identification.cc b/sherpa-onnx/jni/spoken-language-identification.cc index 278c6adbf..fcb6f228a 100644 --- a/sherpa-onnx/jni/spoken-language-identification.cc +++ b/sherpa-onnx/jni/spoken-language-identification.cc @@ -62,6 +62,7 @@ Java_com_k2fsa_sherpa_onnx_SpokenLanguageIdentification_newFromAsset( AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); if (!mgr) { SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + return 0; } #endif diff --git a/sherpa-onnx/jni/voice-activity-detector.cc b/sherpa-onnx/jni/voice-activity-detector.cc index 319edd09b..a30423f70 100644 --- a/sherpa-onnx/jni/voice-activity-detector.cc +++ b/sherpa-onnx/jni/voice-activity-detector.cc @@ -71,10 +71,12 @@ JNIEXPORT jlong JNICALL Java_com_k2fsa_sherpa_onnx_Vad_newFromAsset( AAssetManager *mgr = AAssetManager_fromJava(env, asset_manager); if (!mgr) { SHERPA_ONNX_LOGE("Failed to get asset manager: %p", mgr); + return 0; } #endif auto config = sherpa_onnx::GetVadModelConfig(env, _config); SHERPA_ONNX_LOGE("config:\n%s", config.ToString().c_str()); + auto model = new sherpa_onnx::VoiceActivityDetector( #if __ANDROID_API__ >= 9 mgr,