diff --git a/tests/sqlalchemy/test_sqlalchemy.py b/tests/sqlalchemy/test_sqlalchemy.py index b59c87b..3b5e9c8 100644 --- a/tests/sqlalchemy/test_sqlalchemy.py +++ b/tests/sqlalchemy/test_sqlalchemy.py @@ -1,9 +1,11 @@ +from locale import normalize import pytest import numpy as np +import sqlalchemy from sqlalchemy import URL, create_engine, Column, Integer, select from sqlalchemy.orm import declarative_base, sessionmaker from sqlalchemy.exc import OperationalError -from tidb_vector.sqlalchemy import VectorType, VectorAdaptor +from tidb_vector.sqlalchemy import VectorType, VectorAdaptor, VectorIndex import tidb_vector from ..config import TestConfig @@ -385,3 +387,148 @@ def test_index_and_search(self): ) assert len(items) == 2 assert items[0].distance == 0.0 + + +class Item3Model(Base): + __tablename__ = "sqlalchemy_item3" + id = Column(Integer, primary_key=True) + embedding = Column(VectorType(dim=3)) + + __table_args__ = { + "mysql_tiflash_replica": "1", + } + + +class TestSQLAlchemyVectorIndex: + + def setup_class(self): + Item2Model.__table__.drop(bind=engine, checkfirst=True) + Item2Model.__table__.create(bind=engine) + + def teardown_class(self): + Item2Model.__table__.drop(bind=engine, checkfirst=True) + + def test_create_table_statement(self): + # Define a table using `sqlalchemy.schema.Table` + tbl = sqlalchemy.schema.Table( + 'mytable', + Base.metadata, + Column('id', Integer), + mysql_tiflash_replica='1', + ) + compiled = CreateTable(tbl).compile(dialect=engine.dialect) + normalized = compiled.string.replace("\n", "").replace("\t", "").strip() + assert normalized == "CREATE TABLE mytable (id INTEGER)" + + # Define a table with tiflash replica using `sqlalchemy.schema.Table` + tbl = sqlalchemy.schema.Table( + 'mytable', + Base.metadata, + Column('id', Integer), + mysql_tiflash_replica='1', + ) + from sqlalchemy.sql.ddl import CreateTable + compiled = CreateTable(tbl).compile(dialect=engine.dialect) + normalized = compiled.string.replace("\n", "").replace("\t", "").strip() + assert normalized == "CREATE TABLE mytable (id INTEGER)TIFLASH_REPLICA=1" + + # Define a table inheriting from `Base` + class TableModel(Base): + __tablename__ = "test_tbl" + id = Column(Integer, primary_key=True) + embedding = Column(VectorType(dim=3)) + compiled = CreateTable(TableModel.__table__).compile(dialect=engine.dialect) + normalized = compiled.string.replace("\n", "").replace("\t", "").strip() + assert normalized == "CREATE TABLE test_tbl (id INTEGER NOT NULL AUTO_INCREMENT, embedding VECTOR(3), PRIMARY KEY (id))" + + # Define a table inheriting from `Base` with tiflash replica using `__table_args__` + class TableModel(Base): + __tablename__ = "test_tbl" + id = Column(Integer, primary_key=True) + embedding = Column(VectorType(dim=3)) + __table_args__ = { + "mysql_tiflash_replica": "1", + } + compiled = CreateTable(TableModel.__table__).compile(dialect=engine.dialect) + normalized = compiled.string.replace("\n", "").replace("\t", "").strip() + assert normalized == "CREATE TABLE test_tbl (id INTEGER NOT NULL AUTO_INCREMENT, embedding VECTOR(3), PRIMARY KEY (id))TIFLASH_REPLICA=1" + + def test_create_vector_index_statement(self): + from sqlalchemy.sql.ddl import CreateIndex + l2_index = VectorIndex( + "idx_embedding_l2", + sqlalchemy.func.vec_l2_distance(Item3Model.__table__.c.embedding), + ) + compiled = CreateIndex(l2_index).compile(dialect=engine.dialect) + assert compiled.string == "CREATE VECTOR INDEX idx_embedding_l2 ON sqlalchemy_item2 ((vec_l2_distance(embedding)))" + + cos_index = VectorIndex( + "idx_embedding_cos", + sqlalchemy.func.vec_cosine_distance(Item3Model.__table__.c.embedding), + ) + compiled = CreateIndex(cos_index).compile(dialect=engine.dialect) + assert compiled.string == "CREATE VECTOR INDEX idx_embedding_cos ON sqlalchemy_item2 ((vec_cosine_distance(embedding)))" + + # non-vector index + normal_index = sqlalchemy.schema.Index("idx_unique", Item3Model.__table__.c.id, unique=True) + compiled = CreateIndex(normal_index).compile(dialect=engine.dialect) + assert compiled.string == "CREATE UNIQUE INDEX idx_unique ON sqlalchemy_item2 (id)" + + def test_query_with_index(self): + # indexes + l2_index = VectorIndex( + "idx_embedding_l2", + sqlalchemy.func.vec_l2_distance(Item3Model.__table__.c.embedding), + ) + l2_index.create(engine) + cos_index = VectorIndex( + "idx_embedding_cos", + sqlalchemy.func.vec_cosine_distance(Item3Model.__table__.c.embedding), + ) + cos_index.create(engine) + + with Session() as session: + session.add_all( + [Item3Model(embedding=[1, 2, 3]), Item3Model(embedding=[1, 2, 3.2])] + ) + session.commit() + + # l2 distance + result_l2 = session.scalars( + select(Item3Model).filter( + Item3Model.embedding.l2_distance([1, 2, 3.1]) < 0.2 + ) + ).all() + assert len(result_l2) == 2 + + distance_l2 = Item3Model.embedding.l2_distance([1, 2, 3]) + items_l2 = ( + session.query(Item3Model.id, distance_l2.label("distance")) + .order_by(distance_l2) + .limit(5) + .all() + ) + assert len(items_l2) == 2 + assert items_l2[0].distance == 0.0 + + # cosine distance + result_cos = session.scalars( + select(Item3Model).filter( + Item3Model.embedding.cosine_distance([1, 2, 3.1]) < 0.2 + ) + ).all() + assert len(result_cos) == 2 + + distance_cos = Item3Model.embedding.cosine_distance([1, 2, 3]) + items_cos = ( + session.query(Item3Model.id, distance_cos.label("distance")) + .order_by(distance_cos) + .limit(5) + .all() + ) + assert len(items_cos) == 2 + assert items_cos[0].distance == 0.0 + + # drop indexes + l2_index.drop(engine) + cos_index.drop(engine) diff --git a/tidb_vector/sqlalchemy/__init__.py b/tidb_vector/sqlalchemy/__init__.py index 17579f2..eda69ce 100644 --- a/tidb_vector/sqlalchemy/__init__.py +++ b/tidb_vector/sqlalchemy/__init__.py @@ -1,4 +1,5 @@ from .vector_type import VectorType from .adaptor import VectorAdaptor +from .index import VectorIndex -__all__ = ["VectorType", "VectorAdaptor"] +__all__ = ["VectorType", "VectorAdaptor", "VectorIndex"] diff --git a/tidb_vector/sqlalchemy/index.py b/tidb_vector/sqlalchemy/index.py new file mode 100644 index 0000000..24ef082 --- /dev/null +++ b/tidb_vector/sqlalchemy/index.py @@ -0,0 +1,17 @@ +from typing import Optional, Any + +import sqlalchemy + +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.schema import Index + +class VectorIndex(Index): + def __init__( + self, + name: Optional[str], + *expressions, # _DDLColumnArgument + _table: Optional[Any] = None, + **dialect_kw: Any, + ): + super().__init__(name, *expressions, unique=False, _table=_table, **dialect_kw) + self.dialect_options["mysql"]["prefix"] = "VECTOR"