Skip to content

Commit

Permalink
Add Go API for offline punctuation models (#1434)
Browse files Browse the repository at this point in the history
It is contributed by a community user 
from [our QQ group](https://k2-fsa.github.io/sherpa/social-groups.html#qq).
  • Loading branch information
csukuangfj authored Oct 16, 2024
1 parent 77dd5f7 commit 593b967
Showing 1 changed file with 49 additions and 1 deletion.
50 changes: 49 additions & 1 deletion scripts/go/sherpa_onnx.go
Original file line number Diff line number Diff line change
Expand Up @@ -1283,7 +1283,7 @@ func (sd *OfflineSpeakerDiarization) SetConfig(config *OfflineSpeakerDiarization
c.clustering.num_clusters = C.int(config.Clustering.NumClusters)
c.clustering.threshold = C.float(config.Clustering.Threshold)

SherpaOnnxOfflineSpeakerDiarizationSetConfig(sd.impl, &c)
C.SherpaOnnxOfflineSpeakerDiarizationSetConfig(sd.impl, &c)
}

type OfflineSpeakerDiarizationSegment struct {
Expand Down Expand Up @@ -1317,3 +1317,51 @@ func (sd *OfflineSpeakerDiarization) Process(samples []float32) []OfflineSpeaker

return ans
}

// ============================================================
// For punctuation
// ============================================================
type OfflinePunctuationModelConfig struct {
Ct_transformer string
Num_threads C.int
Debug C.int // true to print debug information of the model
Provider string
}

type OfflinePunctuationConfig struct {
Model OfflinePunctuationModelConfig
}

type OfflinePunctuation struct {
impl *C.struct_SherpaOnnxOfflinePunctuation
}

func NewOfflinePunctuation(config *OfflinePunctuationConfig) *OfflinePunctuation {
cfg := C.struct_SherpaOnnxOfflinePunctuationConfig{}
cfg.model.ct_transformer = C.CString(config.Model.Ct_transformer)
defer C.free(unsafe.Pointer(cfg.model.ct_transformer))

cfg.model.num_threads = config.Model.Num_threads
cfg.model.debug = config.Model.Debug
cfg.model.provider = C.CString(config.Model.Provider)
defer C.free(unsafe.Pointer(cfg.model.provider))

punc := &OfflinePunctuation{}
punc.impl = C.SherpaOnnxCreateOfflinePunctuation(&cfg)

return punc
}

func DeleteOfflinePunc(punc *OfflinePunctuation) {
C.SherpaOnnxDestroyOfflinePunctuation(punc.impl)
punc.impl = nil
}

func (punc *OfflinePunctuation) AddPunct(text string) string {
p := C.SherpaOfflinePunctuationAddPunct(punc.impl, C.CString(text))
defer C.free(unsafe.Pointer(p))

text_with_punct := C.GoString(p)

return text_with_punct
}

0 comments on commit 593b967

Please sign in to comment.