Skip to content

Commit

Permalink
Add speaker diarization demo for HarmonyOS
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Dec 10, 2024
1 parent baac5df commit 351c194
Show file tree
Hide file tree
Showing 11 changed files with 366 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,170 @@ static Napi::Array OfflineSpeakerDiarizationProcessWrapper(
return ans;
}

struct SpeakerDiarizationCallbackData {
int32_t num_processed_chunks;
int32_t num_total_chunks;
};

// see
// https://github.com/nodejs/node-addon-examples/blob/main/src/6-threadsafe-function/typed_threadsafe_function/node-addon-api/clock.cc
static void InvokeJsCallback(Napi::Env env, Napi::Function callback,
Napi::Reference<Napi::Value> *context,
SpeakerDiarizationCallbackData *data) {
if (env != nullptr) {
if (callback != nullptr) {
Napi::Number num_processed_chunks =
Napi::Number::New(env, data->num_processed_chunks);
Napi::Number num_total_chunks =
Napi::Number::New(env, data->num_total_chunks);

callback.Call(context->Value(), {num_processed_chunks, num_total_chunks});
}
}
delete data;
}

using TSFN = Napi::TypedThreadSafeFunction<Napi::Reference<Napi::Value>,
SpeakerDiarizationCallbackData,
InvokeJsCallback>;

class SpeakerDiarizationProcessWorker : public Napi::AsyncWorker {
public:
SpeakerDiarizationProcessWorker(const Napi::Env &env, TSFN tsfn,
const SherpaOnnxOfflineSpeakerDiarization *sd,
std::vector<float> samples)
: tsfn_(tsfn),
Napi::AsyncWorker{env, "SpeakerDiarizationProcessAsyncWorker"},
deferred_(env),
sd_(sd),
samples_(std::move(samples)) {}

Napi::Promise Promise() { return deferred_.Promise(); }

protected:
void Execute() override {
auto callback = [](int32_t num_processed_chunks, int32_t num_total_chunks,
void *arg) -> int32_t {
auto _this = reinterpret_cast<SpeakerDiarizationProcessWorker *>(arg);

auto data = new SpeakerDiarizationCallbackData;
data->num_processed_chunks = num_processed_chunks;
data->num_total_chunks = num_total_chunks;

_this->tsfn_.NonBlockingCall(data);

return 0;
};

r_ = SherpaOnnxOfflineSpeakerDiarizationProcessWithCallback(
sd_, samples_.data(), samples_.size(), callback, this);

tsfn_.Release();
}

void OnOK() override {
Napi::Env env = deferred_.Env();

int32_t num_segments =
SherpaOnnxOfflineSpeakerDiarizationResultGetNumSegments(r_);

const SherpaOnnxOfflineSpeakerDiarizationSegment *segments =
SherpaOnnxOfflineSpeakerDiarizationResultSortByStartTime(r_);

Napi::Array ans = Napi::Array::New(env, num_segments);

for (int32_t i = 0; i != num_segments; ++i) {
Napi::Object obj = Napi::Object::New(env);

obj.Set(Napi::String::New(env, "start"), segments[i].start);
obj.Set(Napi::String::New(env, "end"), segments[i].end);
obj.Set(Napi::String::New(env, "speaker"), segments[i].speaker);

ans.Set(i, obj);
}

SherpaOnnxOfflineSpeakerDiarizationDestroySegment(segments);
SherpaOnnxOfflineSpeakerDiarizationDestroyResult(r_);

deferred_.Resolve(ans);
}

private:
TSFN tsfn_;
Napi::Promise::Deferred deferred_;
const SherpaOnnxOfflineSpeakerDiarization *sd_;
std::vector<float> samples_;
const SherpaOnnxOfflineSpeakerDiarizationResult *r_;
};

static Napi::Object OfflineSpeakerDiarizationProcessAsyncWrapper(
const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();

if (info.Length() != 3) {
std::ostringstream os;
os << "Expect only 3 arguments. Given: " << info.Length();

Napi::TypeError::New(env, os.str()).ThrowAsJavaScriptException();

return {};
}

if (!info[0].IsExternal()) {
Napi::TypeError::New(
env, "Argument 0 should be an offline speaker diarization pointer.")
.ThrowAsJavaScriptException();

return {};
}

const SherpaOnnxOfflineSpeakerDiarization *sd =
info[0].As<Napi::External<SherpaOnnxOfflineSpeakerDiarization>>().Data();

if (!info[1].IsTypedArray()) {
Napi::TypeError::New(env, "Argument 1 should be a typed array")
.ThrowAsJavaScriptException();

return {};
}

if (!info[2].IsFunction()) {
Napi::TypeError::New(env, "Argument 2 should be a function")
.ThrowAsJavaScriptException();

return {};
}

Napi::Function cb = info[2].As<Napi::Function>();

auto context =
new Napi::Reference<Napi::Value>(Napi::Persistent(info.This()));

TSFN tsfn = TSFN::New(
env,
cb, // JavaScript function called asynchronously
"SpeakerDiarizationProcessAsyncFunc", // Name
0, // Unlimited queue
1, // Only one thread will use this initially
context,
[](Napi::Env, void *, Napi::Reference<Napi::Value> *ctx) { delete ctx; });

Napi::Float32Array samples = info[1].As<Napi::Float32Array>();

#if __OHOS__
int32_t num_samples = samples.ElementLength() / sizeof(float);
#else
int32_t num_samples = samples.ElementLength();
#endif
std::vector<float> v(num_samples);
std::copy(samples.Data(), samples.Data() + num_samples, v.begin());

SpeakerDiarizationProcessWorker *worker =
new SpeakerDiarizationProcessWorker(env, tsfn, sd, v);
worker->Queue();
return worker->Promise();
}

static void OfflineSpeakerDiarizationSetConfigWrapper(
const Napi::CallbackInfo &info) {
Napi::Env env = info.Env();
Expand Down Expand Up @@ -313,7 +477,7 @@ static void OfflineSpeakerDiarizationSetConfigWrapper(
return;
}

Napi::Object o = info[0].As<Napi::Object>();
Napi::Object o = info[1].As<Napi::Object>();

SherpaOnnxOfflineSpeakerDiarizationConfig c;
memset(&c, 0, sizeof(c));
Expand All @@ -334,6 +498,10 @@ void InitNonStreamingSpeakerDiarization(Napi::Env env, Napi::Object exports) {
Napi::String::New(env, "offlineSpeakerDiarizationProcess"),
Napi::Function::New(env, OfflineSpeakerDiarizationProcessWrapper));

exports.Set(
Napi::String::New(env, "offlineSpeakerDiarizationProcessAsync"),
Napi::Function::New(env, OfflineSpeakerDiarizationProcessAsyncWrapper));

exports.Set(
Napi::String::New(env, "offlineSpeakerDiarizationSetConfig"),
Napi::Function::New(env, OfflineSpeakerDiarizationSetConfigWrapper));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,9 @@ struct TtsCallbackData {

// see
// https://github.com/nodejs/node-addon-examples/blob/main/src/6-threadsafe-function/typed_threadsafe_function/node-addon-api/clock.cc
void InvokeJsCallback(Napi::Env env, Napi::Function callback,
Napi::Reference<Napi::Value> *context,
TtsCallbackData *data) {
static void InvokeJsCallback(Napi::Env env, Napi::Function callback,
Napi::Reference<Napi::Value> *context,
TtsCallbackData *data) {
if (env != nullptr) {
if (callback != nullptr) {
Napi::ArrayBuffer arrayBuffer =
Expand Down Expand Up @@ -580,7 +580,6 @@ static Napi::Object OfflineTtsGenerateAsyncWrapper(
context,
[](Napi::Env, void *, Napi::Reference<Napi::Value> *ctx) { delete ctx; });

const SherpaOnnxGeneratedAudio *audio;
TtsGenerateWorker *worker = new TtsGenerateWorker(
env, tsfn, tts, text, speed, sid, enable_external_buffer);
worker->Queue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,6 @@ export const speakerEmbeddingManagerGetAllSpeakers: (handle: object) => Array<st

export const createOfflineSpeakerDiarization: (config: object, mgr?: object) => object;
export const getOfflineSpeakerDiarizationSampleRate: (handle: object) => number;
export const offlineSpeakerDiarizationProcess: (handle: object, samples: Float32Array) => object;
export const offlineSpeakerDiarizationProcess: (handle: object, input: object) => object;
export const offlineSpeakerDiarizationProcessAsync: (handle: object, input: object, callback: object) => object;
export const offlineSpeakerDiarizationSetConfig: (handle: object, config: object) => void;
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import {
createOfflineSpeakerDiarization,
getOfflineSpeakerDiarizationSampleRate,
offlineSpeakerDiarizationProcess,
offlineSpeakerDiarizationProcessAsync,
offlineSpeakerDiarizationSetConfig,
} from 'libsherpa_onnx.so';

Expand Down Expand Up @@ -32,9 +33,12 @@ export class OfflineSpeakerDiarizationConfig {
}

export class OfflineSpeakerDiarizationSegment {
public start: number = 0; // in seconds
public end: number = 0; // in seconds
public speaker: number = 0; // ID of the speaker; count from 0
// in seconds
public start: number = 0;
// in seconds
public end: number = 0;
// ID of the speaker; count from 0
public speaker: number = 0;
}

export class OfflineSpeakerDiarization {
Expand Down Expand Up @@ -67,6 +71,12 @@ export class OfflineSpeakerDiarization {
return offlineSpeakerDiarizationProcess(this.handle, samples) as OfflineSpeakerDiarizationSegment[];
}

processAsync(samples: Float32Array, callback: (numProcessedChunks: number,
numTotalChunks: number) => void): Promise<OfflineSpeakerDiarizationSegment[]> {
return offlineSpeakerDiarizationProcessAsync(this.handle, samples,
callback) as Promise<OfflineSpeakerDiarizationSegment[]>;
}

setConfig(config: OfflineSpeakerDiarizationConfig) {
offlineSpeakerDiarizationSetConfig(this.handle, config);
this.config.clustering = config.clustering;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
"main": "",
"author": "",
"license": "",
"dependencies": {}
"dependencies": {
"sherpa_onnx": "1.10.33"
}
}

Loading

0 comments on commit 351c194

Please sign in to comment.