Skip to content

Commit 0274034

Browse files
committed
Merge branch 'feat/averiewang/support-new-model' into 'main' (merge request !61)
feat: support new embedding model
2 parents c31e01e + 4d21683 commit 0274034

File tree

8 files changed

+26
-7
lines changed

8 files changed

+26
-7
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Changelog
22

3+
## v1.4.6
4+
* 支持BAAI/bge-m3新模型,也支持创建embedding collection时使用string直接设置模型
35

46
## v1.4.5
57
* 更换依赖cgo的分词包,为纯go实现的分词包,以更好的支持跨平台编译

example/embedding_demo/main.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ func (d *EmbeddingDemo) CreateDBAndCollection(ctx context.Context, database, col
109109
index.FilterIndex = append(index.FilterIndex, tcvectordb.FilterIndex{FieldName: "bookName", FieldType: tcvectordb.String, IndexType: tcvectordb.FILTER})
110110
index.FilterIndex = append(index.FilterIndex, tcvectordb.FilterIndex{FieldName: "page", FieldType: tcvectordb.Uint64, IndexType: tcvectordb.FILTER})
111111

112-
ebd := &tcvectordb.Embedding{VectorField: "vector", Field: "text", Model: tcvectordb.BGE_BASE_ZH}
112+
ebd := &tcvectordb.Embedding{VectorField: "vector", Field: "text", ModelName: "bge-base-zh"}
113113
// 第二步:创建 Collection
114114
// 创建支持 Embedding 的 Collection
115115
db.WithTimeout(time.Second * 30)

tcvectordb/base_collection.go

+10-4
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ func (i *implementerCollection) CreateCollection(ctx context.Context, name strin
145145
req.Embedding.Field = param.Embedding.Field
146146
req.Embedding.VectorField = param.Embedding.VectorField
147147
req.Embedding.Model = string(param.Embedding.Model)
148+
if param.Embedding.ModelName != "" {
149+
req.Embedding.Model = param.Embedding.ModelName
150+
}
148151
}
149152
if param.TtlConfig != nil {
150153
req.TtlConfig = new(collection.TtlConfig)
@@ -315,6 +318,7 @@ func (i *implementerCollection) toCollection(collectionItem *collection.Describe
315318
coll.Embedding.Field = collectionItem.Embedding.Field
316319
coll.Embedding.VectorField = collectionItem.Embedding.VectorField
317320
coll.Embedding.Model = EmbeddingModel(collectionItem.Embedding.Model)
321+
coll.Embedding.ModelName = collectionItem.Embedding.Model
318322
coll.Embedding.Enabled = collectionItem.Embedding.Status == "enabled"
319323
}
320324
if collectionItem.TtlConfig != nil {
@@ -459,10 +463,12 @@ func (c *Collection) WithTimeout(t time.Duration) {
459463
}
460464

461465
type Embedding struct {
462-
Field string `json:"field,omitempty"`
463-
VectorField string `json:"vectorField,omitempty"`
464-
Model EmbeddingModel `json:"model,omitempty"`
465-
Enabled bool `json:"enabled,omitempty"` // 返回数据
466+
Field string `json:"field,omitempty"`
467+
VectorField string `json:"vectorField,omitempty"`
468+
// Deprecated: Use ModelName instead.
469+
Model EmbeddingModel `json:"model,omitempty"`
470+
ModelName string `json:"modelName,omitempty"`
471+
Enabled bool `json:"enabled,omitempty"` // 返回数据
466472
}
467473

468474
type IndexStatus struct {

tcvectordb/consts.go

+2
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ const (
7171
E5_LARGE_V2 EmbeddingModel = "e5-large-v2"
7272
// TEXT2VEC_LARGE_CHINESE 1024
7373
TEXT2VEC_LARGE_CHINESE EmbeddingModel = "text2vec-large-chinese"
74+
// BAAI_BGE_M3 1024
75+
BAAI_BGE_M3 EmbeddingModel = "BAAI/bge-m3"
7476
)
7577

7678
type ReadConsistency string

tcvectordb/rpc_base_collection.go

+4
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ func (r *rpcImplementerCollection) CreateCollection(ctx context.Context, name st
101101
VectorField: param.Embedding.VectorField,
102102
ModelName: string(param.Embedding.Model),
103103
}
104+
if param.Embedding.ModelName != "" {
105+
req.EmbeddingParams.ModelName = param.Embedding.ModelName
106+
}
104107
}
105108
if param.TtlConfig != nil {
106109
req.TtlConfig = &olama.TTLConfig{
@@ -246,6 +249,7 @@ func (r *rpcImplementerCollection) toCollection(collectionItem *olama.CreateColl
246249
coll.Embedding.Field = collectionItem.EmbeddingParams.Field
247250
coll.Embedding.VectorField = collectionItem.EmbeddingParams.VectorField
248251
coll.Embedding.Model = EmbeddingModel(collectionItem.EmbeddingParams.ModelName)
252+
coll.Embedding.ModelName = collectionItem.EmbeddingParams.ModelName
249253
coll.Embedding.Enabled = false
250254
}
251255
if collectionItem.TtlConfig != nil {

tcvectordb/version.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@
1818

1919
package tcvectordb
2020

21-
const SDKVersion = "v1.4.5"
21+
const SDKVersion = "v1.4.6"

test/embedding_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func TestCreateCollectionWithEmbedding(t *testing.T) {
3737
Embedding: &tcvectordb.Embedding{
3838
Field: "segment",
3939
VectorField: "vector",
40-
Model: tcvectordb.BGE_BASE_ZH,
40+
ModelName: "BAAI/bge-m3",
4141
},
4242
}
4343

test/normal_test.go

+5
Original file line numberDiff line numberDiff line change
@@ -426,3 +426,8 @@ func TestJson(t *testing.T) {
426426
temp1 := temp["shardNum"]
427427
println(fmt.Sprintf("%T, %v", temp1, temp1))
428428
}
429+
430+
func TestEmeb(t *testing.T) {
431+
model := "model_bge"
432+
println(tcvectordb.EmbeddingModel(model))
433+
}

0 commit comments

Comments
 (0)