diff --git a/pyobvector/client/hybrid_search.py b/pyobvector/client/hybrid_search.py index 572000f..83693d8 100644 --- a/pyobvector/client/hybrid_search.py +++ b/pyobvector/client/hybrid_search.py @@ -30,13 +30,9 @@ def __init__( if self.ob_version < min_required_version: # For versions < 4.4.1.0, check if it's SeekDB - with self.engine.connect() as conn: - with conn.begin(): - res = conn.execute(text("SELECT version()")) - version_str = [r[0] for r in res][0] - if "SeekDB" in version_str: - logger.info(f"SeekDB detected in version string: {version_str}, allowing hybrid search") - return + if self._is_seekdb(): + logger.info("SeekDB detected, allowing hybrid search") + return raise ClusterVersionException( code=ErrorCode.NOT_SUPPORTED, message=ExceptionsMessage.ClusterVersionIsLow % ("Hybrid Search", "4.4.1.0"), diff --git a/pyobvector/client/index_param.py b/pyobvector/client/index_param.py index e06badc..111a6e0 100644 --- a/pyobvector/client/index_param.py +++ b/pyobvector/client/index_param.py @@ -134,8 +134,11 @@ def _parse_kwargs(self): if 'efSearch' in params: ob_params['ef_search'] = params['efSearch'] - if self.is_index_type_sparse_vector() and ob_params['distance'] != 'inner_product': - raise ValueError("Metric type should be 'inner_product' for sparse vector index.") + if self.is_index_type_sparse_vector(): + if ob_params['distance'] != 'inner_product': + raise ValueError("Metric type should be 'inner_product' for sparse vector index.") + if 'sparse_index_type' in self.kwargs: + ob_params['type'] = self.kwargs['sparse_index_type'] return ob_params def param_str(self): diff --git a/pyobvector/client/ob_client.py b/pyobvector/client/ob_client.py index eb322cd..b40ade6 100644 --- a/pyobvector/client/ob_client.py +++ b/pyobvector/client/ob_client.py @@ -93,6 +93,26 @@ def refresh_metadata(self, tables: Optional[list[str]] = None): self.metadata_obj.clear() self.metadata_obj.reflect(bind=self.engine, extend_existing=True) + def _is_seekdb(self) -> bool: + """Check if the database is SeekDB by querying version. + + Returns: + bool: True if database is SeekDB, False otherwise + """ + is_seekdb = False + try: + if hasattr(self, '_is_seekdb_cached'): + return self._is_seekdb_cached + with self.engine.connect() as conn: + result = conn.execute(text("SELECT VERSION()")) + version_str = [r[0] for r in result][0] + is_seekdb = "SeekDB" in version_str + self._is_seekdb_cached = is_seekdb + logger.debug(f"Version query result: {version_str}, is_seekdb: {is_seekdb}") + except Exception as e: + logger.warning(f"Failed to query version: {e}") + return is_seekdb + def _insert_partition_hint_for_query_sql(self, sql: str, partition_hint: str): from_index = sql.find("FROM") assert from_index != -1 diff --git a/pyobvector/client/ob_vec_client.py b/pyobvector/client/ob_vec_client.py index 2796bd8..a95a93e 100644 --- a/pyobvector/client/ob_vec_client.py +++ b/pyobvector/client/ob_vec_client.py @@ -99,7 +99,11 @@ def create_table_with_index_params( create_table_sql = str(CreateTable(table).compile(self.engine)) new_sql = create_table_sql[:create_table_sql.rfind(')')] for sparse_vidx in sparse_vidxs: - new_sql += f",\n\tVECTOR INDEX {sparse_vidx.index_name}({sparse_vidx.field_name}) with (distance=inner_product)" + sparse_params = sparse_vidx._parse_kwargs() + if 'type' in sparse_params: + new_sql += f",\n\tVECTOR INDEX {sparse_vidx.index_name}({sparse_vidx.field_name}) with (type={sparse_params['type']}, distance=inner_product)" + else: + new_sql += f",\n\tVECTOR INDEX {sparse_vidx.index_name}({sparse_vidx.field_name}) with (distance=inner_product)" new_sql += "\n)" conn.execute(text(new_sql)) else: diff --git a/pyobvector/client/ob_vec_json_table_client.py b/pyobvector/client/ob_vec_json_table_client.py index 0f30c62..d5d9675 100644 --- a/pyobvector/client/ob_vec_json_table_client.py +++ b/pyobvector/client/ob_vec_json_table_client.py @@ -817,11 +817,11 @@ def _handle_jtable_dml_select( ): real_user_id = opt_user_id or self.user_id - table_name = ast.args['from'].this.this.this + table_name = ast.args['from_'].this.this.this if not self._check_table_exists(table_name): raise ValueError(f"Table {table_name} does not exists") - ast.args['from'].args['this'].args['this'] = to_identifier(name=JSON_TABLE_DATA_TABLE_NAME, quoted=False) + ast.args['from_'].args['this'].args['this'] = to_identifier(name=JSON_TABLE_DATA_TABLE_NAME, quoted=False) col_meta = self.jmetadata.meta_cache[table_name] json_table_meta_str = []