Skip to content

Commit fde9a3b

Browse files
committed
fix: support sparse vector index
1 parent 68fe586 commit fde9a3b

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

pyobvector/client/index_param.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,11 @@ def _parse_kwargs(self):
134134
if 'efSearch' in params:
135135
ob_params['ef_search'] = params['efSearch']
136136

137-
if self.is_index_type_sparse_vector() and ob_params['distance'] != 'inner_product':
138-
raise ValueError("Metric type should be 'inner_product' for sparse vector index.")
137+
if self.is_index_type_sparse_vector():
138+
if ob_params['distance'] != 'inner_product':
139+
raise ValueError("Metric type should be 'inner_product' for sparse vector index.")
140+
if 'sparse_index_type' in self.kwargs:
141+
ob_params['type'] = self.kwargs['sparse_index_type']
139142
return ob_params
140143

141144
def param_str(self):

pyobvector/client/ob_vec_client.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,26 @@ def __init__(
4747
message=ExceptionsMessage.ClusterVersionIsLow % ("Vector Store", "4.3.3.0"),
4848
)
4949

50+
def _is_seekdb(self) -> bool:
51+
"""Check if the database is SeekDB by querying version.
52+
53+
Returns:
54+
bool: True if database is SeekDB, False otherwise
55+
"""
56+
is_seekdb = False
57+
try:
58+
if hasattr(self, '_is_seekdb_cached'):
59+
return self._is_seekdb_cached
60+
with self.engine.connect() as conn:
61+
result = conn.execute(text("SELECT VERSION()"))
62+
version_str = [r[0] for r in result][0]
63+
is_seekdb = "SeekDB" in version_str
64+
self._is_seekdb_cached = is_seekdb
65+
logger.debug(f"Version query result: {version_str}, is_seekdb: {is_seekdb}")
66+
except Exception as e:
67+
logger.warning(f"Failed to query version: {e}")
68+
return is_seekdb
69+
5070
def _get_sparse_vector_index_params(
5171
self, vidxs: Optional[IndexParams]
5272
):
@@ -99,7 +119,11 @@ def create_table_with_index_params(
99119
create_table_sql = str(CreateTable(table).compile(self.engine))
100120
new_sql = create_table_sql[:create_table_sql.rfind(')')]
101121
for sparse_vidx in sparse_vidxs:
102-
new_sql += f",\n\tVECTOR INDEX {sparse_vidx.index_name}({sparse_vidx.field_name}) with (distance=inner_product)"
122+
sparse_params = sparse_vidx._parse_kwargs()
123+
if 'type' in sparse_params:
124+
new_sql += f",\n\tVECTOR INDEX {sparse_vidx.index_name}({sparse_vidx.field_name}) with (type={sparse_params['type']}, distance=inner_product)"
125+
else:
126+
new_sql += f",\n\tVECTOR INDEX {sparse_vidx.index_name}({sparse_vidx.field_name}) with (distance=inner_product)"
103127
new_sql += "\n)"
104128
conn.execute(text(new_sql))
105129
else:

0 commit comments

Comments
 (0)