Skip to content

Commit 36f3d50

Browse files
committed
fix: support sparse vector index
1 parent 68fe586 commit 36f3d50

File tree

5 files changed

+35
-12
lines changed

5 files changed

+35
-12
lines changed

pyobvector/client/hybrid_search.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,9 @@ def __init__(
3030

3131
if self.ob_version < min_required_version:
3232
# For versions < 4.4.1.0, check if it's SeekDB
33-
with self.engine.connect() as conn:
34-
with conn.begin():
35-
res = conn.execute(text("SELECT version()"))
36-
version_str = [r[0] for r in res][0]
37-
if "SeekDB" in version_str:
38-
logger.info(f"SeekDB detected in version string: {version_str}, allowing hybrid search")
39-
return
33+
if self._is_seekdb():
34+
logger.info("SeekDB detected, allowing hybrid search")
35+
return
4036
raise ClusterVersionException(
4137
code=ErrorCode.NOT_SUPPORTED,
4238
message=ExceptionsMessage.ClusterVersionIsLow % ("Hybrid Search", "4.4.1.0"),

pyobvector/client/index_param.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,11 @@ def _parse_kwargs(self):
134134
if 'efSearch' in params:
135135
ob_params['ef_search'] = params['efSearch']
136136

137-
if self.is_index_type_sparse_vector() and ob_params['distance'] != 'inner_product':
138-
raise ValueError("Metric type should be 'inner_product' for sparse vector index.")
137+
if self.is_index_type_sparse_vector():
138+
if ob_params['distance'] != 'inner_product':
139+
raise ValueError("Metric type should be 'inner_product' for sparse vector index.")
140+
if 'sparse_index_type' in self.kwargs:
141+
ob_params['type'] = self.kwargs['sparse_index_type']
139142
return ob_params
140143

141144
def param_str(self):

pyobvector/client/ob_client.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,26 @@ def refresh_metadata(self, tables: Optional[list[str]] = None):
9393
self.metadata_obj.clear()
9494
self.metadata_obj.reflect(bind=self.engine, extend_existing=True)
9595

96+
def _is_seekdb(self) -> bool:
97+
"""Check if the database is SeekDB by querying version.
98+
99+
Returns:
100+
bool: True if database is SeekDB, False otherwise
101+
"""
102+
is_seekdb = False
103+
try:
104+
if hasattr(self, '_is_seekdb_cached'):
105+
return self._is_seekdb_cached
106+
with self.engine.connect() as conn:
107+
result = conn.execute(text("SELECT VERSION()"))
108+
version_str = [r[0] for r in result][0]
109+
is_seekdb = "SeekDB" in version_str
110+
self._is_seekdb_cached = is_seekdb
111+
logger.debug(f"Version query result: {version_str}, is_seekdb: {is_seekdb}")
112+
except Exception as e:
113+
logger.warning(f"Failed to query version: {e}")
114+
return is_seekdb
115+
96116
def _insert_partition_hint_for_query_sql(self, sql: str, partition_hint: str):
97117
from_index = sql.find("FROM")
98118
assert from_index != -1

pyobvector/client/ob_vec_client.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,11 @@ def create_table_with_index_params(
9999
create_table_sql = str(CreateTable(table).compile(self.engine))
100100
new_sql = create_table_sql[:create_table_sql.rfind(')')]
101101
for sparse_vidx in sparse_vidxs:
102-
new_sql += f",\n\tVECTOR INDEX {sparse_vidx.index_name}({sparse_vidx.field_name}) with (distance=inner_product)"
102+
sparse_params = sparse_vidx._parse_kwargs()
103+
if 'type' in sparse_params:
104+
new_sql += f",\n\tVECTOR INDEX {sparse_vidx.index_name}({sparse_vidx.field_name}) with (type={sparse_params['type']}, distance=inner_product)"
105+
else:
106+
new_sql += f",\n\tVECTOR INDEX {sparse_vidx.index_name}({sparse_vidx.field_name}) with (distance=inner_product)"
103107
new_sql += "\n)"
104108
conn.execute(text(new_sql))
105109
else:

pyobvector/client/ob_vec_json_table_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -817,11 +817,11 @@ def _handle_jtable_dml_select(
817817
):
818818
real_user_id = opt_user_id or self.user_id
819819

820-
table_name = ast.args['from'].this.this.this
820+
table_name = ast.args['from_'].this.this.this
821821
if not self._check_table_exists(table_name):
822822
raise ValueError(f"Table {table_name} does not exists")
823823

824-
ast.args['from'].args['this'].args['this'] = to_identifier(name=JSON_TABLE_DATA_TABLE_NAME, quoted=False)
824+
ast.args['from_'].args['this'].args['this'] = to_identifier(name=JSON_TABLE_DATA_TABLE_NAME, quoted=False)
825825

826826
col_meta = self.jmetadata.meta_cache[table_name]
827827
json_table_meta_str = []

0 commit comments

Comments
 (0)