Skip to content

Commit 31f1e7a

Browse files
Merge pull request #192 from mindsdb/feature/update_kb_apis
Updated the SDK Operations for Knowledge Bases
2 parents c8c51f2 + 0295b85 commit 31f1e7a

File tree

2 files changed

+77
-47
lines changed

2 files changed

+77
-47
lines changed

mindsdb_sdk/knowledge_bases.py

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import pandas as pd
66

7-
from mindsdb_sql_parser.ast.mindsdb import CreateKnowledgeBase, DropKnowledgeBase
87
from mindsdb_sql_parser.ast import Identifier, Star, Select, BinaryOperation, Constant, Insert
98

109
from mindsdb_sdk.utils.sql import dict_to_binary_op, query_to_native_query
@@ -57,9 +56,14 @@ def __init__(self, api, project, data: dict):
5756
table = Table(database, data['vector_database_table'])
5857
self.storage = table
5958

60-
self.model = None
61-
if data['embedding_model'] is not None:
62-
self.model = Model(self.project, {'name': data['embedding_model']})
59+
# models
60+
self.embedding_model = data.get('embedding_model', {})
61+
self.reranking_model = data.get('reranking_model', {})
62+
63+
# columns
64+
self.metadata_columns = data.get('metadata_columns', [])
65+
self.content_columns = data.get('content_columns', [])
66+
self.id_column = data.get('id_column', None)
6367

6468
params = data.get('params', {})
6569
if isinstance(params, str):
@@ -68,11 +72,6 @@ def __init__(self, api, project, data: dict):
6872
except json.JSONDecodeError:
6973
params = {}
7074

71-
# columns
72-
self.metadata_columns = params.pop('metadata_columns', [])
73-
self.content_columns = params.pop('content_columns', [])
74-
self.id_column = params.pop('id_column', None)
75-
7675
self.params = params
7776

7877
# query behavior
@@ -311,7 +310,8 @@ def get(self, name: str) -> KnowledgeBase:
311310
def create(
312311
self,
313312
name: str,
314-
model: Model = None,
313+
embedding_model: dict = None,
314+
reranking_model: dict = None,
315315
storage: Table = None,
316316
metadata_columns: list = None,
317317
content_columns: list = None,
@@ -324,7 +324,8 @@ def create(
324324
325325
>>> kb = server.knowledge_bases.create(
326326
... 'my_kb',
327-
... model=server.models.emb_model,
327+
... embedding_model={'provider': 'openai', 'model': 'text-embedding-ada-002', 'api_key': 'sk-...'},
328+
... reranking_model={'provider': 'openai', 'model': 'gpt-4', 'api_key': 'sk-...'},
328329
... storage=server.databases.pvec.tables.tbl1,
329330
... metadata_columns=['date', 'author'],
330331
... content_columns=['review', 'description'],
@@ -333,7 +334,8 @@ def create(
333334
...)
334335
335336
:param name: name of the knowledge base
336-
:param model: embedding model, optional. Default: 'sentence_transformers' will be used (defined in mindsdb server)
337+
:param embedding_model: embedding model, optional. Default: OpenAI will be the default provider
338+
:param reranking_model: reranking model, optional. Default: OpenAI will be the default provider
337339
:param storage: vector storage, optional. Default: chromadb database will be created
338340
:param metadata_columns: columns to use as metadata, optional. Default: all columns which are not content and id
339341
:param content_columns: columns to use as content, optional. Default: all columns except id column
@@ -342,30 +344,24 @@ def create(
342344
:return: created KnowledgeBase object
343345
"""
344346

345-
params_out = {}
346-
347-
if metadata_columns is not None:
348-
params_out['metadata_columns'] = metadata_columns
349-
350-
if content_columns is not None:
351-
params_out['content_columns'] = content_columns
352-
353-
if id_column is not None:
354-
params_out['id_column'] = id_column
355-
356-
if params is not None:
357-
params_out.update(params)
358-
359-
if model is not None:
360-
model = model.name
361-
362347
payload = {
363348
'name': name,
364-
'model': model,
365-
'params': params_out
366349
}
367350

368-
if storage is not None:
351+
if embedding_model:
352+
payload['embedding_model'] = embedding_model
353+
if reranking_model:
354+
payload['reranking_model'] = reranking_model
355+
if metadata_columns:
356+
payload['metadata_columns'] = metadata_columns
357+
if content_columns:
358+
payload['content_columns'] = content_columns
359+
if id_column:
360+
payload['id_column'] = id_column
361+
if params:
362+
payload['params'] = params
363+
364+
if storage:
369365
payload['storage'] = {
370366
'database': storage.db.name,
371367
'table': storage.name

tests/test_sdk.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -775,7 +775,7 @@ def check_project(self, project, database):
775775

776776
self.check_project_models_versions(project, database)
777777

778-
kb = self.check_project_kb(project, model, database)
778+
kb = self.check_project_kb(project, database)
779779

780780
self.check_project_jobs(project, model, database, kb)
781781

@@ -1143,12 +1143,21 @@ def check_project_jobs(self, project, model, database, kb, mock_post):
11431143
@patch('requests.Session.post')
11441144
@patch('requests.Session.delete')
11451145
@patch('requests.Session.get')
1146-
def check_project_kb(self, project, model, database, mock_get, mock_del, mock_post, mock_put):
1146+
def check_project_kb(self, project, database, mock_get, mock_del, mock_post, mock_put):
11471147

11481148
response_mock(mock_post, pd.DataFrame([{
11491149
'NAME': 'my_kb',
11501150
'PROJECT': 'mindsdb',
1151-
'MODEL': 'openai_emb',
1151+
'EMBEDDING_MODEL': {
1152+
'PROVIDER': 'openai',
1153+
'MODEL_NAME': 'openai_emb',
1154+
'API_KEY': 'sk-...'
1155+
},
1156+
'RERANKING_MODEL': {
1157+
'PROVIDER': 'openai',
1158+
'MODEL_NAME': 'openai_rerank',
1159+
'API_KEY': 'sk-...'
1160+
},
11521161
'STORAGE': 'pvec.tbl1',
11531162
'PARAMS': {"id_column": "num"},
11541163
}]))
@@ -1157,7 +1166,16 @@ def check_project_kb(self, project, model, database, mock_get, mock_del, mock_po
11571166
'id': 1,
11581167
'name': 'my_kb',
11591168
'project_id': 1,
1160-
'embedding_model': 'openai_emb',
1169+
'embedding_model': {
1170+
'provider': 'openai',
1171+
'model_name': 'openai_emb',
1172+
'api_key': 'sk-...'
1173+
},
1174+
'reranking_model': {
1175+
'provider': 'openai',
1176+
'model_name': 'openai_rerank',
1177+
'api_key': 'sk-...'
1178+
},
11611179
'vector_database': 'pvec',
11621180
'vector_database_table': 'tbl1',
11631181
'updated_at': '2024-10-04 10:55:25.350799',
@@ -1176,8 +1194,8 @@ def check_project_kb(self, project, model, database, mock_get, mock_del, mock_po
11761194

11771195
assert kb.name == 'my_kb'
11781196

1179-
assert isinstance(kb.model, Model)
1180-
assert kb.model.name == 'openai_emb'
1197+
assert kb.embedding_model['model_name'] == 'openai_emb'
1198+
assert kb.reranking_model['model_name'] == 'openai_rerank'
11811199

11821200
assert isinstance(kb.storage, Table)
11831201
assert kb.storage.name == 'tbl1'
@@ -1189,7 +1207,8 @@ def check_project_kb(self, project, model, database, mock_get, mock_del, mock_po
11891207
str(kb)
11901208
assert kb.name == 'my_kb'
11911209
assert kb.storage.db.name == 'pvec'
1192-
assert kb.model.name == 'openai_emb'
1210+
assert kb.embedding_model['model_name'] == 'openai_emb'
1211+
assert kb.reranking_model['model_name'] == 'openai_rerank'
11931212

11941213
# --- insert ---
11951214

@@ -1232,18 +1251,36 @@ def check_project_kb(self, project, model, database, mock_get, mock_del, mock_po
12321251
# create 1
12331252
project.knowledge_bases.create(
12341253
name='kb2',
1235-
model=model,
1254+
embedding_model={
1255+
'provider': 'openai',
1256+
'model_name': 'openai_emb',
1257+
'api_key': 'sk-...'
1258+
},
1259+
reranking_model={
1260+
'provider': 'openai',
1261+
'model_name': 'openai_rerank',
1262+
'api_key': 'sk-...'
1263+
},
12361264
metadata_columns=['date', 'author'],
12371265
params={'k': 'v'}
12381266
)
12391267
args, kwargs = mock_post.call_args
12401268
assert args[0] == f'{DEFAULT_CLOUD_API_URL}/api/projects/{project.name}/knowledge_bases'
12411269
assert kwargs == {'json': {'knowledge_base': {
12421270
'name': 'kb2',
1243-
'model': model.name,
1271+
'embedding_model': {
1272+
'provider': 'openai',
1273+
'model_name': 'openai_emb',
1274+
'api_key': 'sk-...'
1275+
},
1276+
'reranking_model': {
1277+
'provider': 'openai',
1278+
'model_name': 'openai_rerank',
1279+
'api_key': 'sk-...'
1280+
},
1281+
'metadata_columns': ['date', 'author'],
12441282
'params': {
12451283
'k': 'v',
1246-
'metadata_columns': ['date', 'author']
12471284
}
12481285
}}}
12491286

@@ -1259,11 +1296,8 @@ def check_project_kb(self, project, model, database, mock_get, mock_del, mock_po
12591296
assert args[0] == f'{DEFAULT_CLOUD_API_URL}/api/projects/{project.name}/knowledge_bases'
12601297
assert kwargs == {'json': {'knowledge_base': {
12611298
'name': 'kb2',
1262-
'model': None,
1263-
'params': {
1264-
'content_columns': ['review'],
1265-
'id_column': 'num'
1266-
},
1299+
'content_columns': ['review'],
1300+
'id_column': 'num',
12671301
'storage': {
12681302
'database': database.name,
12691303
'table': 'tbl1'

0 commit comments

Comments
 (0)