diff --git a/.github/workflows/test-go-package.yaml b/.github/workflows/test-go-package.yaml index 9074449b4..649735699 100644 --- a/.github/workflows/test-go-package.yaml +++ b/.github/workflows/test-go-package.yaml @@ -68,6 +68,13 @@ jobs: run: | gcc --version + - name: Test Keyword spotting + if: matrix.os != 'windows-latest' + shell: bash + run: | + cd go-api-examples/keyword-spotting-from-file/ + ./run.sh + - name: Test adding punctuation if: matrix.os != 'windows-latest' shell: bash diff --git a/.github/workflows/test-go.yaml b/.github/workflows/test-go.yaml index 9c9951111..eaf561818 100644 --- a/.github/workflows/test-go.yaml +++ b/.github/workflows/test-go.yaml @@ -134,6 +134,15 @@ jobs: name: ${{ matrix.os }}-libs path: to-upload/ + - name: Test Keyword spotting + shell: bash + run: | + cd scripts/go/_internal/keyword-spotting-from-file/ + + ./run.sh + + ls -lh + - name: Test non-streaming decoding files shell: bash run: | diff --git a/go-api-examples/keyword-spotting-from-file/go.mod b/go-api-examples/keyword-spotting-from-file/go.mod new file mode 100644 index 000000000..dbd349a5e --- /dev/null +++ b/go-api-examples/keyword-spotting-from-file/go.mod @@ -0,0 +1,4 @@ +module keyword-spotting-from-file + +go 1.12 + diff --git a/go-api-examples/keyword-spotting-from-file/main.go b/go-api-examples/keyword-spotting-from-file/main.go new file mode 100644 index 000000000..cf6ffa84e --- /dev/null +++ b/go-api-examples/keyword-spotting-from-file/main.go @@ -0,0 +1,79 @@ +package main + +import ( + sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx" + "log" +) + +func main() { + log.SetFlags(log.LstdFlags | log.Lmicroseconds) + + config := sherpa.KeywordSpotterConfig{} + + // Please download the models from + // https://github.com/k2-fsa/sherpa-onnx/releases/tag/kws-models + + config.ModelConfig.Transducer.Encoder = "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.onnx" + config.ModelConfig.Transducer.Decoder = "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.onnx" + config.ModelConfig.Transducer.Joiner = "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.onnx" + config.ModelConfig.Tokens = "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt" + config.KeywordsFile = "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt" + config.ModelConfig.NumThreads = 1 + config.ModelConfig.Debug = 1 + + spotter := sherpa.NewKeywordSpotter(&config) + defer sherpa.DeleteKeywordSpotter(spotter) + + wave_filename := "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav" + + wave := sherpa.ReadWave(wave_filename) + if wave == nil { + log.Printf("Failed to read %v\n", wave_filename) + return + } + + log.Println("----------Use pre-defined keywords----------") + + stream := sherpa.NewKeywordStream(spotter) + defer sherpa.DeleteOnlineStream(stream) + + stream.AcceptWaveform(wave.SampleRate, wave.Samples) + + for spotter.IsReady(stream) { + spotter.Decode(stream) + result := spotter.GetResult(stream) + if result.Keyword != "" { + log.Printf("Detected %v\n", result.Keyword) + } + } + + log.Println("----------Use pre-defined keywords + add a new keyword----------") + + stream2 := sherpa.NewKeywordStreamWithKeywords(spotter, "y ǎn y uán @演员") + defer sherpa.DeleteOnlineStream(stream2) + + stream2.AcceptWaveform(wave.SampleRate, wave.Samples) + + for spotter.IsReady(stream2) { + spotter.Decode(stream2) + result := spotter.GetResult(stream2) + if result.Keyword != "" { + log.Printf("Detected %v\n", result.Keyword) + } + } + + log.Println("----------Use pre-defined keywords + add 2 new keywords----------") + + stream3 := sherpa.NewKeywordStreamWithKeywords(spotter, "y ǎn y uán @演员/zh ī m íng @知名") + defer sherpa.DeleteOnlineStream(stream3) + + stream3.AcceptWaveform(wave.SampleRate, wave.Samples) + + for spotter.IsReady(stream3) { + spotter.Decode(stream3) + result := spotter.GetResult(stream3) + if result.Keyword != "" { + log.Printf("Detected %v\n", result.Keyword) + } + } +} diff --git a/go-api-examples/keyword-spotting-from-file/run.sh b/go-api-examples/keyword-spotting-from-file/run.sh new file mode 100755 index 000000000..89411f47a --- /dev/null +++ b/go-api-examples/keyword-spotting-from-file/run.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash + +set -ex + +if [ ! -f ./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt ]; then + curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz2 + tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz2 + rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz2 +fi + +go mod tidy +go build +./keyword-spotting-from-file diff --git a/scripts/go/_internal/keyword-spotting-from-file/.gitignore b/scripts/go/_internal/keyword-spotting-from-file/.gitignore new file mode 100644 index 000000000..2c433c5c7 --- /dev/null +++ b/scripts/go/_internal/keyword-spotting-from-file/.gitignore @@ -0,0 +1 @@ +keyword-spotting-from-file diff --git a/scripts/go/_internal/keyword-spotting-from-file/go.mod b/scripts/go/_internal/keyword-spotting-from-file/go.mod new file mode 100644 index 000000000..9cdbc6321 --- /dev/null +++ b/scripts/go/_internal/keyword-spotting-from-file/go.mod @@ -0,0 +1,5 @@ +module keyword-spotting-from-file + +go 1.12 + +replace github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx => ../ diff --git a/scripts/go/_internal/keyword-spotting-from-file/main.go b/scripts/go/_internal/keyword-spotting-from-file/main.go new file mode 120000 index 000000000..f17d55363 --- /dev/null +++ b/scripts/go/_internal/keyword-spotting-from-file/main.go @@ -0,0 +1 @@ +../../../../go-api-examples/keyword-spotting-from-file/main.go \ No newline at end of file diff --git a/scripts/go/_internal/keyword-spotting-from-file/run.sh b/scripts/go/_internal/keyword-spotting-from-file/run.sh new file mode 120000 index 000000000..a9bb15f88 --- /dev/null +++ b/scripts/go/_internal/keyword-spotting-from-file/run.sh @@ -0,0 +1 @@ +../../../../go-api-examples/keyword-spotting-from-file/run.sh \ No newline at end of file diff --git a/scripts/go/sherpa_onnx.go b/scripts/go/sherpa_onnx.go index cde6513da..17afd32f2 100644 --- a/scripts/go/sherpa_onnx.go +++ b/scripts/go/sherpa_onnx.go @@ -1385,3 +1385,151 @@ func (punc *OfflinePunctuation) AddPunct(text string) string { return text_with_punct } + +// Configuration for the online/streaming recognizer. +type KeywordSpotterConfig struct { + FeatConfig FeatureConfig + ModelConfig OnlineModelConfig + MaxActivePaths int + KeywordsFile string + KeywordsScore float32 + KeywordsThreshold float32 + KeywordsBuf string + KeywordsBufSize int +} + +type KeywordSpotterResult struct { + Keyword string +} + +type KeywordSpotter struct { + impl *C.struct_SherpaOnnxKeywordSpotter +} + +// Free the internal pointer inside the recognizer to avoid memory leak. +func DeleteKeywordSpotter(spotter *KeywordSpotter) { + C.SherpaOnnxDestroyKeywordSpotter(spotter.impl) + spotter.impl = nil +} + +// The user is responsible to invoke [DeleteKeywordSpotter]() to free +// the returned spotter to avoid memory leak +func NewKeywordSpotter(config *KeywordSpotterConfig) *KeywordSpotter { + c := C.struct_SherpaOnnxKeywordSpotterConfig{} + c.feat_config.sample_rate = C.int(config.FeatConfig.SampleRate) + c.feat_config.feature_dim = C.int(config.FeatConfig.FeatureDim) + + c.model_config.transducer.encoder = C.CString(config.ModelConfig.Transducer.Encoder) + defer C.free(unsafe.Pointer(c.model_config.transducer.encoder)) + + c.model_config.transducer.decoder = C.CString(config.ModelConfig.Transducer.Decoder) + defer C.free(unsafe.Pointer(c.model_config.transducer.decoder)) + + c.model_config.transducer.joiner = C.CString(config.ModelConfig.Transducer.Joiner) + defer C.free(unsafe.Pointer(c.model_config.transducer.joiner)) + + c.model_config.paraformer.encoder = C.CString(config.ModelConfig.Paraformer.Encoder) + defer C.free(unsafe.Pointer(c.model_config.paraformer.encoder)) + + c.model_config.paraformer.decoder = C.CString(config.ModelConfig.Paraformer.Decoder) + defer C.free(unsafe.Pointer(c.model_config.paraformer.decoder)) + + c.model_config.zipformer2_ctc.model = C.CString(config.ModelConfig.Zipformer2Ctc.Model) + defer C.free(unsafe.Pointer(c.model_config.zipformer2_ctc.model)) + + c.model_config.tokens = C.CString(config.ModelConfig.Tokens) + defer C.free(unsafe.Pointer(c.model_config.tokens)) + + c.model_config.num_threads = C.int(config.ModelConfig.NumThreads) + + c.model_config.provider = C.CString(config.ModelConfig.Provider) + defer C.free(unsafe.Pointer(c.model_config.provider)) + + c.model_config.debug = C.int(config.ModelConfig.Debug) + + c.model_config.model_type = C.CString(config.ModelConfig.ModelType) + defer C.free(unsafe.Pointer(c.model_config.model_type)) + + c.model_config.modeling_unit = C.CString(config.ModelConfig.ModelingUnit) + defer C.free(unsafe.Pointer(c.model_config.modeling_unit)) + + c.model_config.bpe_vocab = C.CString(config.ModelConfig.BpeVocab) + defer C.free(unsafe.Pointer(c.model_config.bpe_vocab)) + + c.model_config.tokens_buf = C.CString(config.ModelConfig.TokensBuf) + defer C.free(unsafe.Pointer(c.model_config.tokens_buf)) + + c.model_config.tokens_buf_size = C.int(config.ModelConfig.TokensBufSize) + + c.max_active_paths = C.int(config.MaxActivePaths) + + c.keywords_file = C.CString(config.KeywordsFile) + defer C.free(unsafe.Pointer(c.keywords_file)) + + c.keywords_score = C.float(config.KeywordsScore) + + c.keywords_threshold = C.float(config.KeywordsThreshold) + + c.keywords_buf = C.CString(config.KeywordsBuf) + defer C.free(unsafe.Pointer(c.keywords_buf)) + + c.keywords_buf_size = C.int(config.KeywordsBufSize) + + spotter := &KeywordSpotter{} + spotter.impl = C.SherpaOnnxCreateKeywordSpotter(&c) + + return spotter +} + +// The user is responsible to invoke [DeleteOnlineStream]() to free +// the returned stream to avoid memory leak +func NewKeywordStream(spotter *KeywordSpotter) *OnlineStream { + stream := &OnlineStream{} + stream.impl = C.SherpaOnnxCreateKeywordStream(spotter.impl) + return stream +} + +// The user is responsible to invoke [DeleteOnlineStream]() to free +// the returned stream to avoid memory leak +func NewKeywordStreamWithKeywords(spotter *KeywordSpotter, keywords string) *OnlineStream { + stream := &OnlineStream{} + + s := C.CString(keywords) + defer C.free(unsafe.Pointer(s)) + + stream.impl = C.SherpaOnnxCreateKeywordStreamWithKeywords(spotter.impl, s) + return stream +} + +// Check whether the stream has enough feature frames for decoding. +// Return true if this stream is ready for decoding. Return false otherwise. +// +// You will usually use it like below: +// +// for spotter.IsReady(s) { +// spotter.Decode(s) +// } +func (spotter *KeywordSpotter) IsReady(s *OnlineStream) bool { + return C.SherpaOnnxIsKeywordStreamReady(spotter.impl, s.impl) == 1 +} + +// Decode the stream. Before calling this function, you have to ensure +// that spotter.IsReady(s) returns true. Otherwise, you will be SAD. +// +// You usually use it like below: +// +// for spotter.IsReady(s) { +// spotter.Decode(s) +// } +func (spotter *KeywordSpotter) Decode(s *OnlineStream) { + C.SherpaOnnxDecodeKeywordStream(spotter.impl, s.impl) +} + +// Get the current result of stream since the last invoke of Reset() +func (spotter *KeywordSpotter) GetResult(s *OnlineStream) *KeywordSpotterResult { + p := C.SherpaOnnxGetKeywordResult(spotter.impl, s.impl) + defer C.SherpaOnnxDestroyKeywordResult(p) + result := &KeywordSpotterResult{} + result.Keyword = C.GoString(p.keyword) + return result +}