Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions pyobvector/client/hybrid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,9 @@ def __init__(

if self.ob_version < min_required_version:
# For versions < 4.4.1.0, check if it's SeekDB
with self.engine.connect() as conn:
with conn.begin():
res = conn.execute(text("SELECT version()"))
version_str = [r[0] for r in res][0]
if "SeekDB" in version_str:
logger.info(f"SeekDB detected in version string: {version_str}, allowing hybrid search")
return
if self._is_seekdb():
logger.info("SeekDB detected, allowing hybrid search")
return
raise ClusterVersionException(
code=ErrorCode.NOT_SUPPORTED,
message=ExceptionsMessage.ClusterVersionIsLow % ("Hybrid Search", "4.4.1.0"),
Expand Down
7 changes: 5 additions & 2 deletions pyobvector/client/index_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,11 @@ def _parse_kwargs(self):
if 'efSearch' in params:
ob_params['ef_search'] = params['efSearch']

if self.is_index_type_sparse_vector() and ob_params['distance'] != 'inner_product':
raise ValueError("Metric type should be 'inner_product' for sparse vector index.")
if self.is_index_type_sparse_vector():
if ob_params['distance'] != 'inner_product':
raise ValueError("Metric type should be 'inner_product' for sparse vector index.")
if 'sparse_index_type' in self.kwargs:
ob_params['type'] = self.kwargs['sparse_index_type']
return ob_params

def param_str(self):
Expand Down
20 changes: 20 additions & 0 deletions pyobvector/client/ob_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,26 @@ def refresh_metadata(self, tables: Optional[list[str]] = None):
self.metadata_obj.clear()
self.metadata_obj.reflect(bind=self.engine, extend_existing=True)

def _is_seekdb(self) -> bool:
"""Check if the database is SeekDB by querying version.

Returns:
bool: True if database is SeekDB, False otherwise
"""
is_seekdb = False
try:
if hasattr(self, '_is_seekdb_cached'):
return self._is_seekdb_cached
with self.engine.connect() as conn:
result = conn.execute(text("SELECT VERSION()"))
version_str = [r[0] for r in result][0]
is_seekdb = "SeekDB" in version_str
self._is_seekdb_cached = is_seekdb
logger.debug(f"Version query result: {version_str}, is_seekdb: {is_seekdb}")
except Exception as e:
logger.warning(f"Failed to query version: {e}")
return is_seekdb

def _insert_partition_hint_for_query_sql(self, sql: str, partition_hint: str):
from_index = sql.find("FROM")
assert from_index != -1
Expand Down
6 changes: 5 additions & 1 deletion pyobvector/client/ob_vec_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ def create_table_with_index_params(
create_table_sql = str(CreateTable(table).compile(self.engine))
new_sql = create_table_sql[:create_table_sql.rfind(')')]
for sparse_vidx in sparse_vidxs:
new_sql += f",\n\tVECTOR INDEX {sparse_vidx.index_name}({sparse_vidx.field_name}) with (distance=inner_product)"
sparse_params = sparse_vidx._parse_kwargs()
if 'type' in sparse_params:
new_sql += f",\n\tVECTOR INDEX {sparse_vidx.index_name}({sparse_vidx.field_name}) with (type={sparse_params['type']}, distance=inner_product)"
else:
new_sql += f",\n\tVECTOR INDEX {sparse_vidx.index_name}({sparse_vidx.field_name}) with (distance=inner_product)"
new_sql += "\n)"
conn.execute(text(new_sql))
else:
Expand Down
5 changes: 3 additions & 2 deletions pyobvector/client/ob_vec_json_table_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,11 +817,12 @@ def _handle_jtable_dml_select(
):
real_user_id = opt_user_id or self.user_id

table_name = ast.args['from'].this.this.this
from_key = 'from_' if 'from_' in ast.args else 'from'
table_name = ast.args[from_key].this.this.this
if not self._check_table_exists(table_name):
raise ValueError(f"Table {table_name} does not exists")

ast.args['from'].args['this'].args['this'] = to_identifier(name=JSON_TABLE_DATA_TABLE_NAME, quoted=False)
ast.args[from_key].args['this'].args['this'] = to_identifier(name=JSON_TABLE_DATA_TABLE_NAME, quoted=False)

col_meta = self.jmetadata.meta_cache[table_name]
json_table_meta_str = []
Expand Down