Skip to content

Commit

Permalink
Add Go API for Keyword spotting (#1662)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Dec 31, 2024
1 parent 38d64a6 commit 49154c9
Show file tree
Hide file tree
Showing 10 changed files with 268 additions and 0 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/test-go-package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ jobs:
run: |
gcc --version
- name: Test Keyword spotting
if: matrix.os != 'windows-latest'
shell: bash
run: |
cd go-api-examples/keyword-spotting-from-file/
./run.sh
- name: Test adding punctuation
if: matrix.os != 'windows-latest'
shell: bash
Expand Down
9 changes: 9 additions & 0 deletions .github/workflows/test-go.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,15 @@ jobs:
name: ${{ matrix.os }}-libs
path: to-upload/

- name: Test Keyword spotting
shell: bash
run: |
cd scripts/go/_internal/keyword-spotting-from-file/
./run.sh
ls -lh
- name: Test non-streaming decoding files
shell: bash
run: |
Expand Down
4 changes: 4 additions & 0 deletions go-api-examples/keyword-spotting-from-file/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
module keyword-spotting-from-file

go 1.12

79 changes: 79 additions & 0 deletions go-api-examples/keyword-spotting-from-file/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package main

import (
sherpa "github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx"
"log"
)

func main() {
log.SetFlags(log.LstdFlags | log.Lmicroseconds)

config := sherpa.KeywordSpotterConfig{}

// Please download the models from
// https://github.com/k2-fsa/sherpa-onnx/releases/tag/kws-models

config.ModelConfig.Transducer.Encoder = "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/encoder-epoch-12-avg-2-chunk-16-left-64.onnx"
config.ModelConfig.Transducer.Decoder = "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/decoder-epoch-12-avg-2-chunk-16-left-64.onnx"
config.ModelConfig.Transducer.Joiner = "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/joiner-epoch-12-avg-2-chunk-16-left-64.onnx"
config.ModelConfig.Tokens = "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt"
config.KeywordsFile = "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/test_keywords.txt"
config.ModelConfig.NumThreads = 1
config.ModelConfig.Debug = 1

spotter := sherpa.NewKeywordSpotter(&config)
defer sherpa.DeleteKeywordSpotter(spotter)

wave_filename := "./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/test_wavs/3.wav"

wave := sherpa.ReadWave(wave_filename)
if wave == nil {
log.Printf("Failed to read %v\n", wave_filename)
return
}

log.Println("----------Use pre-defined keywords----------")

stream := sherpa.NewKeywordStream(spotter)
defer sherpa.DeleteOnlineStream(stream)

stream.AcceptWaveform(wave.SampleRate, wave.Samples)

for spotter.IsReady(stream) {
spotter.Decode(stream)
result := spotter.GetResult(stream)
if result.Keyword != "" {
log.Printf("Detected %v\n", result.Keyword)
}
}

log.Println("----------Use pre-defined keywords + add a new keyword----------")

stream2 := sherpa.NewKeywordStreamWithKeywords(spotter, "y ǎn y uán @演员")
defer sherpa.DeleteOnlineStream(stream2)

stream2.AcceptWaveform(wave.SampleRate, wave.Samples)

for spotter.IsReady(stream2) {
spotter.Decode(stream2)
result := spotter.GetResult(stream2)
if result.Keyword != "" {
log.Printf("Detected %v\n", result.Keyword)
}
}

log.Println("----------Use pre-defined keywords + add 2 new keywords----------")

stream3 := sherpa.NewKeywordStreamWithKeywords(spotter, "y ǎn y uán @演员/zh ī m íng @知名")
defer sherpa.DeleteOnlineStream(stream3)

stream3.AcceptWaveform(wave.SampleRate, wave.Samples)

for spotter.IsReady(stream3) {
spotter.Decode(stream3)
result := spotter.GetResult(stream3)
if result.Keyword != "" {
log.Printf("Detected %v\n", result.Keyword)
}
}
}
13 changes: 13 additions & 0 deletions go-api-examples/keyword-spotting-from-file/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/usr/bin/env bash

set -ex

if [ ! -f ./sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01/tokens.txt ]; then
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz2
tar xvf sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz2
rm sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz2
fi

go mod tidy
go build
./keyword-spotting-from-file
1 change: 1 addition & 0 deletions scripts/go/_internal/keyword-spotting-from-file/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
keyword-spotting-from-file
5 changes: 5 additions & 0 deletions scripts/go/_internal/keyword-spotting-from-file/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
module keyword-spotting-from-file

go 1.12

replace github.com/k2-fsa/sherpa-onnx-go/sherpa_onnx => ../
1 change: 1 addition & 0 deletions scripts/go/_internal/keyword-spotting-from-file/main.go
1 change: 1 addition & 0 deletions scripts/go/_internal/keyword-spotting-from-file/run.sh
148 changes: 148 additions & 0 deletions scripts/go/sherpa_onnx.go
Original file line number Diff line number Diff line change
Expand Up @@ -1385,3 +1385,151 @@ func (punc *OfflinePunctuation) AddPunct(text string) string {

return text_with_punct
}

// Configuration for the online/streaming recognizer.
type KeywordSpotterConfig struct {
FeatConfig FeatureConfig
ModelConfig OnlineModelConfig
MaxActivePaths int
KeywordsFile string
KeywordsScore float32
KeywordsThreshold float32
KeywordsBuf string
KeywordsBufSize int
}

type KeywordSpotterResult struct {
Keyword string
}

type KeywordSpotter struct {
impl *C.struct_SherpaOnnxKeywordSpotter
}

// Free the internal pointer inside the recognizer to avoid memory leak.
func DeleteKeywordSpotter(spotter *KeywordSpotter) {
C.SherpaOnnxDestroyKeywordSpotter(spotter.impl)
spotter.impl = nil
}

// The user is responsible to invoke [DeleteKeywordSpotter]() to free
// the returned spotter to avoid memory leak
func NewKeywordSpotter(config *KeywordSpotterConfig) *KeywordSpotter {
c := C.struct_SherpaOnnxKeywordSpotterConfig{}
c.feat_config.sample_rate = C.int(config.FeatConfig.SampleRate)
c.feat_config.feature_dim = C.int(config.FeatConfig.FeatureDim)

c.model_config.transducer.encoder = C.CString(config.ModelConfig.Transducer.Encoder)
defer C.free(unsafe.Pointer(c.model_config.transducer.encoder))

c.model_config.transducer.decoder = C.CString(config.ModelConfig.Transducer.Decoder)
defer C.free(unsafe.Pointer(c.model_config.transducer.decoder))

c.model_config.transducer.joiner = C.CString(config.ModelConfig.Transducer.Joiner)
defer C.free(unsafe.Pointer(c.model_config.transducer.joiner))

c.model_config.paraformer.encoder = C.CString(config.ModelConfig.Paraformer.Encoder)
defer C.free(unsafe.Pointer(c.model_config.paraformer.encoder))

c.model_config.paraformer.decoder = C.CString(config.ModelConfig.Paraformer.Decoder)
defer C.free(unsafe.Pointer(c.model_config.paraformer.decoder))

c.model_config.zipformer2_ctc.model = C.CString(config.ModelConfig.Zipformer2Ctc.Model)
defer C.free(unsafe.Pointer(c.model_config.zipformer2_ctc.model))

c.model_config.tokens = C.CString(config.ModelConfig.Tokens)
defer C.free(unsafe.Pointer(c.model_config.tokens))

c.model_config.num_threads = C.int(config.ModelConfig.NumThreads)

c.model_config.provider = C.CString(config.ModelConfig.Provider)
defer C.free(unsafe.Pointer(c.model_config.provider))

c.model_config.debug = C.int(config.ModelConfig.Debug)

c.model_config.model_type = C.CString(config.ModelConfig.ModelType)
defer C.free(unsafe.Pointer(c.model_config.model_type))

c.model_config.modeling_unit = C.CString(config.ModelConfig.ModelingUnit)
defer C.free(unsafe.Pointer(c.model_config.modeling_unit))

c.model_config.bpe_vocab = C.CString(config.ModelConfig.BpeVocab)
defer C.free(unsafe.Pointer(c.model_config.bpe_vocab))

c.model_config.tokens_buf = C.CString(config.ModelConfig.TokensBuf)
defer C.free(unsafe.Pointer(c.model_config.tokens_buf))

c.model_config.tokens_buf_size = C.int(config.ModelConfig.TokensBufSize)

c.max_active_paths = C.int(config.MaxActivePaths)

c.keywords_file = C.CString(config.KeywordsFile)
defer C.free(unsafe.Pointer(c.keywords_file))

c.keywords_score = C.float(config.KeywordsScore)

c.keywords_threshold = C.float(config.KeywordsThreshold)

c.keywords_buf = C.CString(config.KeywordsBuf)
defer C.free(unsafe.Pointer(c.keywords_buf))

c.keywords_buf_size = C.int(config.KeywordsBufSize)

spotter := &KeywordSpotter{}
spotter.impl = C.SherpaOnnxCreateKeywordSpotter(&c)

return spotter
}

// The user is responsible to invoke [DeleteOnlineStream]() to free
// the returned stream to avoid memory leak
func NewKeywordStream(spotter *KeywordSpotter) *OnlineStream {
stream := &OnlineStream{}
stream.impl = C.SherpaOnnxCreateKeywordStream(spotter.impl)
return stream
}

// The user is responsible to invoke [DeleteOnlineStream]() to free
// the returned stream to avoid memory leak
func NewKeywordStreamWithKeywords(spotter *KeywordSpotter, keywords string) *OnlineStream {
stream := &OnlineStream{}

s := C.CString(keywords)
defer C.free(unsafe.Pointer(s))

stream.impl = C.SherpaOnnxCreateKeywordStreamWithKeywords(spotter.impl, s)
return stream
}

// Check whether the stream has enough feature frames for decoding.
// Return true if this stream is ready for decoding. Return false otherwise.
//
// You will usually use it like below:
//
// for spotter.IsReady(s) {
// spotter.Decode(s)
// }
func (spotter *KeywordSpotter) IsReady(s *OnlineStream) bool {
return C.SherpaOnnxIsKeywordStreamReady(spotter.impl, s.impl) == 1
}

// Decode the stream. Before calling this function, you have to ensure
// that spotter.IsReady(s) returns true. Otherwise, you will be SAD.
//
// You usually use it like below:
//
// for spotter.IsReady(s) {
// spotter.Decode(s)
// }
func (spotter *KeywordSpotter) Decode(s *OnlineStream) {
C.SherpaOnnxDecodeKeywordStream(spotter.impl, s.impl)
}

// Get the current result of stream since the last invoke of Reset()
func (spotter *KeywordSpotter) GetResult(s *OnlineStream) *KeywordSpotterResult {
p := C.SherpaOnnxGetKeywordResult(spotter.impl, s.impl)
defer C.SherpaOnnxDestroyKeywordResult(p)
result := &KeywordSpotterResult{}
result.Keyword = C.GoString(p.keyword)
return result
}

0 comments on commit 49154c9

Please sign in to comment.