diff --git a/scripts/go/sherpa_onnx.go b/scripts/go/sherpa_onnx.go index 8055380c6..45bf714bb 100644 --- a/scripts/go/sherpa_onnx.go +++ b/scripts/go/sherpa_onnx.go @@ -1283,7 +1283,7 @@ func (sd *OfflineSpeakerDiarization) SetConfig(config *OfflineSpeakerDiarization c.clustering.num_clusters = C.int(config.Clustering.NumClusters) c.clustering.threshold = C.float(config.Clustering.Threshold) - SherpaOnnxOfflineSpeakerDiarizationSetConfig(sd.impl, &c) + C.SherpaOnnxOfflineSpeakerDiarizationSetConfig(sd.impl, &c) } type OfflineSpeakerDiarizationSegment struct { @@ -1317,3 +1317,51 @@ func (sd *OfflineSpeakerDiarization) Process(samples []float32) []OfflineSpeaker return ans } + +// ============================================================ +// For punctuation +// ============================================================ +type OfflinePunctuationModelConfig struct { + Ct_transformer string + Num_threads C.int + Debug C.int // true to print debug information of the model + Provider string +} + +type OfflinePunctuationConfig struct { + Model OfflinePunctuationModelConfig +} + +type OfflinePunctuation struct { + impl *C.struct_SherpaOnnxOfflinePunctuation +} + +func NewOfflinePunctuation(config *OfflinePunctuationConfig) *OfflinePunctuation { + cfg := C.struct_SherpaOnnxOfflinePunctuationConfig{} + cfg.model.ct_transformer = C.CString(config.Model.Ct_transformer) + defer C.free(unsafe.Pointer(cfg.model.ct_transformer)) + + cfg.model.num_threads = config.Model.Num_threads + cfg.model.debug = config.Model.Debug + cfg.model.provider = C.CString(config.Model.Provider) + defer C.free(unsafe.Pointer(cfg.model.provider)) + + punc := &OfflinePunctuation{} + punc.impl = C.SherpaOnnxCreateOfflinePunctuation(&cfg) + + return punc +} + +func DeleteOfflinePunc(punc *OfflinePunctuation) { + C.SherpaOnnxDestroyOfflinePunctuation(punc.impl) + punc.impl = nil +} + +func (punc *OfflinePunctuation) AddPunct(text string) string { + p := C.SherpaOfflinePunctuationAddPunct(punc.impl, C.CString(text)) + defer C.free(unsafe.Pointer(p)) + + text_with_punct := C.GoString(p) + + return text_with_punct +}