Skip to content

Commit 597138d

Browse files
fxied unit tests for KB
1 parent 4875774 commit 597138d

File tree

1 file changed

+48
-10
lines changed

1 file changed

+48
-10
lines changed

tests/test_sdk.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,7 @@ def check_project(self, project, database):
764764

765765
self.check_project_models_versions(project, database)
766766

767-
kb = self.check_project_kb(project, model, database)
767+
kb = self.check_project_kb(project, database)
768768

769769
self.check_project_jobs(project, model, database, kb)
770770

@@ -1132,12 +1132,21 @@ def check_project_jobs(self, project, model, database, kb, mock_post):
11321132
@patch('requests.Session.post')
11331133
@patch('requests.Session.delete')
11341134
@patch('requests.Session.get')
1135-
def check_project_kb(self, project, model, database, mock_get, mock_del, mock_post, mock_put):
1135+
def check_project_kb(self, project, database, mock_get, mock_del, mock_post, mock_put):
11361136

11371137
response_mock(mock_post, pd.DataFrame([{
11381138
'NAME': 'my_kb',
11391139
'PROJECT': 'mindsdb',
1140-
'MODEL': 'openai_emb',
1140+
'EMBEDDING_MODEL': {
1141+
'PROVIDER': 'openai',
1142+
'MODEL_NAME': 'openai_emb',
1143+
'API_KEY': 'sk-...'
1144+
},
1145+
'RERANKING_MODEL': {
1146+
'PROVIDER': 'openai',
1147+
'MODEL_NAME': 'openai_rerank',
1148+
'API_KEY': 'sk-...'
1149+
},
11411150
'STORAGE': 'pvec.tbl1',
11421151
'PARAMS': {"id_column": "num"},
11431152
}]))
@@ -1146,7 +1155,16 @@ def check_project_kb(self, project, model, database, mock_get, mock_del, mock_po
11461155
'id': 1,
11471156
'name': 'my_kb',
11481157
'project_id': 1,
1149-
'embedding_model': 'openai_emb',
1158+
'embedding_model': {
1159+
'provider': 'openai',
1160+
'model_name': 'openai_emb',
1161+
'api_key': 'sk-...'
1162+
},
1163+
'reranking_model': {
1164+
'provider': 'openai',
1165+
'model_name': 'openai_rerank',
1166+
'api_key': 'sk-...'
1167+
},
11501168
'vector_database': 'pvec',
11511169
'vector_database_table': 'tbl1',
11521170
'updated_at': '2024-10-04 10:55:25.350799',
@@ -1165,8 +1183,8 @@ def check_project_kb(self, project, model, database, mock_get, mock_del, mock_po
11651183

11661184
assert kb.name == 'my_kb'
11671185

1168-
assert isinstance(kb.model, Model)
1169-
assert kb.model.name == 'openai_emb'
1186+
assert kb.embedding_model['model_name'] == 'openai_emb'
1187+
assert kb.reranking_model['model_name'] == 'openai_rerank'
11701188

11711189
assert isinstance(kb.storage, Table)
11721190
assert kb.storage.name == 'tbl1'
@@ -1178,7 +1196,8 @@ def check_project_kb(self, project, model, database, mock_get, mock_del, mock_po
11781196
str(kb)
11791197
assert kb.name == 'my_kb'
11801198
assert kb.storage.db.name == 'pvec'
1181-
assert kb.model.name == 'openai_emb'
1199+
assert kb.embedding_model['model_name'] == 'openai_emb'
1200+
assert kb.reranking_model['model_name'] == 'openai_rerank'
11821201

11831202
# --- insert ---
11841203

@@ -1221,15 +1240,33 @@ def check_project_kb(self, project, model, database, mock_get, mock_del, mock_po
12211240
# create 1
12221241
project.knowledge_bases.create(
12231242
name='kb2',
1224-
model=model,
1243+
embedding_model={
1244+
'provider': 'openai',
1245+
'model_name': 'openai_emb',
1246+
'api_key': 'sk-...'
1247+
},
1248+
reranking_model={
1249+
'provider': 'openai',
1250+
'model_name': 'openai_rerank',
1251+
'api_key': 'sk-...'
1252+
},
12251253
metadata_columns=['date', 'author'],
12261254
params={'k': 'v'}
12271255
)
12281256
args, kwargs = mock_post.call_args
12291257
assert args[0] == f'{DEFAULT_CLOUD_API_URL}/api/projects/{project.name}/knowledge_bases'
12301258
assert kwargs == {'json': {'knowledge_base': {
12311259
'name': 'kb2',
1232-
'model': model.name,
1260+
'embedding_model': {
1261+
'provider': 'openai',
1262+
'model_name': 'openai_emb',
1263+
'api_key': 'sk-...'
1264+
},
1265+
'reranking_model': {
1266+
'provider': 'openai',
1267+
'model_name': 'openai_rerank',
1268+
'api_key': 'sk-...'
1269+
},
12331270
'params': {
12341271
'k': 'v',
12351272
'metadata_columns': ['date', 'author']
@@ -1248,7 +1285,8 @@ def check_project_kb(self, project, model, database, mock_get, mock_del, mock_po
12481285
assert args[0] == f'{DEFAULT_CLOUD_API_URL}/api/projects/{project.name}/knowledge_bases'
12491286
assert kwargs == {'json': {'knowledge_base': {
12501287
'name': 'kb2',
1251-
'model': None,
1288+
'embedding_model': None,
1289+
'reranking_model': None,
12521290
'params': {
12531291
'content_columns': ['review'],
12541292
'id_column': 'num'

0 commit comments

Comments
 (0)