diff --git a/fastembed.go b/fastembed.go index e3b6e20..362074e 100644 --- a/fastembed.go +++ b/fastembed.go @@ -148,7 +148,7 @@ func (f *FlagEmbedding) onnxEmbed(input []string) ([]([]float32), error) { inputTypeIdsFlat = append(inputTypeIdsFlat, inputTypeIds...) } - inputShape := ort.NewShape(int64(len(inputs)), int64(f.maxLength)) + inputShape := ort.NewShape(int64(len(inputs)), int64(encodings[0].Len())) inputTensorID, err := ort.NewTensor(inputShape, inputIdsFlat) if err != nil { @@ -175,7 +175,7 @@ func (f *FlagEmbedding) onnxEmbed(input []string) ([]([]float32), error) { return nil, err } - outputShape := ort.NewShape(int64(len(inputs)), int64(f.maxLength), int64(modelInfo.Dim)) + outputShape := ort.NewShape(int64(len(inputs)), int64(int64(encodings[0].Len())), int64(modelInfo.Dim)) outputTensor, err := ort.NewEmptyTensor[float32](outputShape) if err != nil { return nil, err