Skip to content

Commit

Permalink
support reading rule FST for Android TTS (k2-fsa#410)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Nov 6, 2023
1 parent a4b4c77 commit 9f8ec11
Show file tree
Hide file tree
Showing 10 changed files with 143 additions and 25 deletions.
21 changes: 17 additions & 4 deletions .github/workflows/apk-tts.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ jobs:
with:
fetch-depth: 0

- name: ccache
uses: hendrikmuhs/[email protected]
with:
key: ${{ matrix.os }}-android

- name: Display NDK HOME
shell: bash
run: |
Expand Down Expand Up @@ -61,6 +66,10 @@ jobs:
- name: build APK
shell: bash
run: |
export CMAKE_CXX_COMPILER_LAUNCHER=ccache
export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH"
cmake --version
export ANDROID_NDK=$ANDROID_NDK_LATEST_HOME
./build-apk-tts.sh
Expand All @@ -70,12 +79,14 @@ jobs:
ls -lh ./apks/
du -h -d1 .
# - uses: actions/upload-artifact@v3
# with:
# name: tts-apk
# path: ./apks/*.apk
- uses: actions/upload-artifact@v3
if: false
with:
name: tts-apk
path: ./apks/*.apk

- name: Publish to huggingface
if: true
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
uses: nick-fields/retry@v2
Expand All @@ -92,7 +103,9 @@ jobs:
git clone https://huggingface.co/csukuangfj/sherpa-onnx-apk huggingface
cd huggingface
git fetch
git pull
git merge -m "merge remote" --ff origin main
mkdir -p tts
cp -v ../apks/*.apk ./tts/
Expand Down
10 changes: 10 additions & 0 deletions .github/workflows/apk.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ jobs:
- uses: actions/checkout@v4
with:
fetch-depth: 0

- name: ccache
uses: hendrikmuhs/[email protected]
with:
key: ${{ matrix.os }}-android

- name: Display NDK HOME
shell: bash
run: |
Expand All @@ -37,6 +43,10 @@ jobs:
- name: build APK
shell: bash
run: |
export CMAKE_CXX_COMPILER_LAUNCHER=ccache
export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH"
cmake --version
export ANDROID_NDK=$ANDROID_NDK_LATEST_HOME
./build-apk-vad.sh
./build-apk-two-pass.sh
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,14 @@ class MainActivity : AppCompatActivity() {
fun initTts() {
var modelDir :String?
var modelName :String?
var ruleFsts: String?

// The purpose of such a design is to make the CI test easier
// Please see
// https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/apk/generate-tts-apk-script.py
modelDir = null
modelName = null
ruleFsts = null

// Example 1:
// modelDir = "vits-vctk"
Expand All @@ -116,7 +118,12 @@ class MainActivity : AppCompatActivity() {
// modelDir = "vits-piper-en_US-lessac-medium"
// modelName = "en_US-lessac-medium.onnx"

val config = getOfflineTtsConfig(modelDir = modelDir!!, modelName = modelName!!)!!
// Example 3:
// modelDir = "vits-zh-aishell3"
// modelName = "vits-aishell3.onnx"
// ruleFsts = "vits-zh-aishell3/rule.fst"

val config = getOfflineTtsConfig(modelDir = modelDir!!, modelName = modelName!!, ruleFsts = ruleFsts ?: "")!!
tts = OfflineTts(assetManager = application.assets, config = config)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ data class OfflineTtsModelConfig(

data class OfflineTtsConfig(
var model: OfflineTtsModelConfig,
var ruleFsts: String = "",
)

class GeneratedAudio(
Expand Down Expand Up @@ -116,7 +117,7 @@ class OfflineTts(
// please refer to
// https://k2-fsa.github.io/sherpa/onnx/tts/pretrained_models/index.html
// to download models
fun getOfflineTtsConfig(modelDir: String, modelName: String): OfflineTtsConfig? {
fun getOfflineTtsConfig(modelDir: String, modelName: String, ruleFsts: String): OfflineTtsConfig? {
return OfflineTtsConfig(
model = OfflineTtsModelConfig(
vits = OfflineTtsVitsModelConfig(
Expand All @@ -125,8 +126,9 @@ fun getOfflineTtsConfig(modelDir: String, modelName: String): OfflineTtsConfig?
tokens = "$modelDir/tokens.txt"
),
numThreads = 2,
debug = false,
debug = true,
provider = "cpu",
)
),
ruleFsts=ruleFsts,
)
}
16 changes: 8 additions & 8 deletions cmake/kaldifst.cmake
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
function(download_kaldifst)
include(FetchContent)

set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.8.tar.gz")
set(kaldifst_URL2 "https://huggingface.co/csukuangfj/kaldi-hmm-gmm-cmake-deps/resolve/main/kaldifst-1.7.8.tar.gz")
set(kaldifst_HASH "SHA256=94613923568ef9a240ba1059b8b9dfe3082daad794934635d99e66248a6687b5")
set(kaldifst_URL "https://github.com/k2-fsa/kaldifst/archive/refs/tags/v1.7.9.tar.gz")
set(kaldifst_URL2 "https://huggingface.co/csukuangfj/kaldi-hmm-gmm-cmake-deps/resolve/main/kaldifst-1.7.9.tar.gz")
set(kaldifst_HASH "SHA256=8c653021491dca54c38ab659565edfab391418a79ae87099257863cd5664dd39")

# If you don't have access to the Internet,
# please pre-download kaldifst
set(possible_file_locations
$ENV{HOME}/Downloads/kaldifst-1.7.8.tar.gz
${PROJECT_SOURCE_DIR}/kaldifst-1.7.8.tar.gz
${PROJECT_BINARY_DIR}/kaldifst-1.7.8.tar.gz
/tmp/kaldifst-1.7.8.tar.gz
/star-fj/fangjun/download/github/kaldifst-1.7.8.tar.gz
$ENV{HOME}/Downloads/kaldifst-1.7.9.tar.gz
${PROJECT_SOURCE_DIR}/kaldifst-1.7.9.tar.gz
${PROJECT_BINARY_DIR}/kaldifst-1.7.9.tar.gz
/tmp/kaldifst-1.7.9.tar.gz
/star-fj/fangjun/download/github/kaldifst-1.7.9.tar.gz
)

foreach(f IN LISTS possible_file_locations)
Expand Down
8 changes: 7 additions & 1 deletion scripts/apk/build-apk-tts.sh.in
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# Inside the $ANDROID_NDK directory, you can find a binary ndk-build
# and some other files like the file "build/cmake/android.toolchain.cmake"

set -e
set -ex

log() {
# This function is from espnet
Expand Down Expand Up @@ -43,6 +43,7 @@ wget -qq https://huggingface.co/csukuangfj/$model_dir/resolve/main/$model_name
wget -qq https://huggingface.co/csukuangfj/$model_dir/resolve/main/lexicon.txt
wget -qq https://huggingface.co/csukuangfj/$model_dir/resolve/main/tokens.txt
wget -qq https://huggingface.co/csukuangfj/$model_dir/resolve/main/MODEL_CARD 2>/dev/null || true
wget -qq https://huggingface.co/csukuangfj/$model_dir/resolve/main/rule.fst 2>/dev/null || true

popd
# Now we are at the project root directory
Expand All @@ -51,6 +52,11 @@ git checkout .
pushd android/SherpaOnnxTts/app/src/main/java/com/k2fsa/sherpa/onnx
sed -i.bak s/"modelDir = null"/"modelDir = \"$model_dir\""/ ./MainActivity.kt
sed -i.bak s/"modelName = null"/"modelName = \"$model_name\""/ ./MainActivity.kt
{% if tts_model.rule_fsts %}
rule_fsts={{ tts_model.rule_fsts }}
sed -i.bak s%"ruleFsts = null"%"ruleFsts = \"$rule_fsts\""% ./MainActivity.kt
{% endif %}

git diff
popd

Expand Down
59 changes: 56 additions & 3 deletions scripts/apk/generate-tts-apk-script.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#!/usr/bin/env python3

import argparse
from dataclasses import dataclass
from typing import List, Optional

import jinja2
from typing import List
import argparse


def get_args():
Expand All @@ -29,12 +29,65 @@ class TtsModel:
model_dir: str
model_name: str
lang: str # en, zh, fr, de, etc.
rule_fsts: Optional[List[str]] = (None,)


def get_all_models() -> List[TtsModel]:
return [
# Chinese
TtsModel(
model_dir="vits-zh-aishell3",
model_name="vits-aishell3.onnx",
lang="zh",
rule_fsts="vits-zh-aishell3/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-doom",
model_name="doom.onnx",
lang="zh",
rule_fsts="vits-zh-hf-doom/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-echo",
model_name="echo.onnx",
lang="zh",
rule_fsts="vits-zh-hf-echo/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-zenyatta",
model_name="zenyatta.onnx",
lang="zh",
rule_fsts="vits-zh-hf-zenyatta/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-abyssinvoker",
model_name="abyssinvoker.onnx",
lang="zh",
rule_fsts="vits-zh-hf-abyssinvoker/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-keqing",
model_name="keqing.onnx",
lang="zh",
rule_fsts="vits-zh-hf-keqing/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-eula",
model_name="eula.onnx",
lang="zh",
rule_fsts="vits-zh-hf-eula/rule.fst",
),
TtsModel(
model_dir="vits-zh-hf-bronya",
model_name="bronya.onnx",
lang="zh",
rule_fsts="vits-zh-hf-bronya/rule.fst",
),
TtsModel(
model_dir="vits-zh-aishell3", model_name="vits-aishell3.onnx", lang="zh"
model_dir="vits-zh-hf-theresa",
model_name="theresa.onnx",
lang="zh",
rule_fsts="vits-zh-hf-theresa/rule.fst",
),
# English (US)
# fmt: off
Expand Down
14 changes: 11 additions & 3 deletions sherpa-onnx/csrc/lexicon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,14 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese(

std::vector<int64_t> ans;

auto sil = token2id_.at("sil");
auto eos = token2id_.at("eos");
int32_t sil = -1;
int32_t eos = -1;
if (token2id_.count("sil")) {
sil = token2id_.at("sil");
eos = token2id_.at("eos");
} else {
sil = 0;
}

ans.push_back(sil);

Expand All @@ -216,7 +222,9 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese(
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
}
ans.push_back(sil);
ans.push_back(eos);
if (eos != -1) {
ans.push_back(eos);
}
return ans;
}

Expand Down
16 changes: 14 additions & 2 deletions sherpa-onnx/csrc/offline-tts-vits-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@
#include <vector>

#if __ANDROID_API__ >= 9
#include <strstream>

#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif

#include "kaldifst/csrc/text-normalizer.h"
#include "sherpa-onnx/csrc/lexicon.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-tts-impl.h"
#include "sherpa-onnx/csrc/offline-tts-vits-model.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/text-utils.h"

namespace sherpa_onnx {
Expand Down Expand Up @@ -52,7 +54,17 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
model_->Punctuations(), model_->Language(), config.model.debug,
model_->IsPiper()) {
if (!config.rule_fsts.empty()) {
SHERPA_ONNX_LOGE("TODO(fangjun): Implement rule FST for Android");
std::vector<std::string> files;
SplitStringToVector(config.rule_fsts, ",", false, &files);
tn_list_.reserve(files.size());
for (const auto &f : files) {
if (config.model.debug) {
SHERPA_ONNX_LOGE("rule fst: %s", f.c_str());
}
auto buf = ReadFile(mgr, f);
std::istrstream is(buf.data(), buf.size());
tn_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(is));
}
}
}
#endif
Expand Down
7 changes: 7 additions & 0 deletions sherpa-onnx/jni/jni.cc
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,13 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
ans.model.provider = p;
env->ReleaseStringUTFChars(s, p);

// for ruleFsts
fid = env->GetFieldID(cls, "ruleFsts", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.rule_fsts = p;
env->ReleaseStringUTFChars(s, p);

return ans;
}

Expand Down

0 comments on commit 9f8ec11

Please sign in to comment.