From 67bf47058c6248dcbb3fee76662c4d64a88fdfac Mon Sep 17 00:00:00 2001 From: JaySon-Huang Date: Thu, 7 Nov 2024 18:50:24 +0800 Subject: [PATCH 1/5] Support ADD_TIFLASH_ON_DEMAND option Signed-off-by: JaySon-Huang --- tests/sqlalchemy/test_sqlalchemy.py | 92 ++++++++++++++++++++++++++++- tidb_vector/sqlalchemy/__init__.py | 3 +- tidb_vector/sqlalchemy/index.py | 29 +++++++++ 3 files changed, 122 insertions(+), 2 deletions(-) create mode 100644 tidb_vector/sqlalchemy/index.py diff --git a/tests/sqlalchemy/test_sqlalchemy.py b/tests/sqlalchemy/test_sqlalchemy.py index b59c87b..4cffaf4 100644 --- a/tests/sqlalchemy/test_sqlalchemy.py +++ b/tests/sqlalchemy/test_sqlalchemy.py @@ -1,9 +1,10 @@ import pytest import numpy as np +import sqlalchemy from sqlalchemy import URL, create_engine, Column, Integer, select from sqlalchemy.orm import declarative_base, sessionmaker from sqlalchemy.exc import OperationalError -from tidb_vector.sqlalchemy import VectorType, VectorAdaptor +from tidb_vector.sqlalchemy import VectorType, VectorAdaptor, VectorIndex import tidb_vector from ..config import TestConfig @@ -385,3 +386,92 @@ def test_index_and_search(self): ) assert len(items) == 2 assert items[0].distance == 0.0 + + +class TestSQLAlchemyVectorIndex: + + def setup_class(self): + Item2Model.__table__.drop(bind=engine, checkfirst=True) + Item2Model.__table__.create(bind=engine) + + def teardown_class(self): + Item2Model.__table__.drop(bind=engine, checkfirst=True) + + def test_create_vector_index_statement(self): + from sqlalchemy.sql.ddl import CreateIndex + l2_index = VectorIndex( + "idx_embedding_l2", + sqlalchemy.func.vec_l2_distance(Item2Model.__table__.c.embedding), + ) + compiled = CreateIndex(l2_index).compile(dialect=engine.dialect) + assert compiled.string == "CREATE VECTOR INDEX idx_embedding_l2 ON sqlalchemy_item2 ((vec_l2_distance(embedding))) ADD_TIFLASH_ON_DEMAND" + + cos_index = VectorIndex( + "idx_embedding_cos", + sqlalchemy.func.vec_cosine_distance(Item2Model.__table__.c.embedding), + ) + compiled = CreateIndex(cos_index).compile(dialect=engine.dialect) + assert compiled.string == "CREATE VECTOR INDEX idx_embedding_cos ON sqlalchemy_item2 ((vec_cosine_distance(embedding))) ADD_TIFLASH_ON_DEMAND" + + def test_query_with_index(self): + # indexes + l2_index = VectorIndex( + "idx_embedding_l2", + sqlalchemy.func.vec_l2_distance(Item2Model.__table__.c.embedding), + ) + l2_index.create(engine) + cos_index = VectorIndex( + "idx_embedding_cos", + sqlalchemy.func.vec_cosine_distance(Item2Model.__table__.c.embedding), + ) + cos_index.create(engine) + + self.check_indexes( + Item2Model.__table__, ["idx_embedding_l2", "idx_embedding_cos"] + ) + + with Session() as session: + session.add_all( + [Item2Model(embedding=[1, 2, 3]), Item2Model(embedding=[1, 2, 3.2])] + ) + session.commit() + + # l2 distance + result_l2 = session.scalars( + select(Item2Model).filter( + Item2Model.embedding.l2_distance([1, 2, 3.1]) < 0.2 + ) + ).all() + assert len(result_l2) == 2 + + distance_l2 = Item2Model.embedding.l2_distance([1, 2, 3]) + items_l2 = ( + session.query(Item2Model.id, distance_l2.label("distance")) + .order_by(distance_l2) + .limit(5) + .all() + ) + assert len(items_l2) == 2 + assert items_l2[0].distance == 0.0 + + # cosine distance + result_cos = session.scalars( + select(Item2Model).filter( + Item2Model.embedding.cosine_distance([1, 2, 3.1]) < 0.2 + ) + ).all() + assert len(result_cos) == 2 + + distance_cos = Item2Model.embedding.cosine_distance([1, 2, 3]) + items_cos = ( + session.query(Item2Model.id, distance_cos.label("distance")) + .order_by(distance_cos) + .limit(5) + .all() + ) + assert len(items_cos) == 2 + assert items_cos[0].distance == 0.0 + + # drop indexes + l2_index.drop(engine) + cos_index.drop(engine) diff --git a/tidb_vector/sqlalchemy/__init__.py b/tidb_vector/sqlalchemy/__init__.py index 17579f2..eda69ce 100644 --- a/tidb_vector/sqlalchemy/__init__.py +++ b/tidb_vector/sqlalchemy/__init__.py @@ -1,4 +1,5 @@ from .vector_type import VectorType from .adaptor import VectorAdaptor +from .index import VectorIndex -__all__ = ["VectorType", "VectorAdaptor"] +__all__ = ["VectorType", "VectorAdaptor", "VectorIndex"] diff --git a/tidb_vector/sqlalchemy/index.py b/tidb_vector/sqlalchemy/index.py new file mode 100644 index 0000000..67ec14d --- /dev/null +++ b/tidb_vector/sqlalchemy/index.py @@ -0,0 +1,29 @@ +from typing import Optional, Any + +import sqlalchemy + +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.schema import Index + +class VectorIndex(Index): + def __init__( + self, + name: Optional[str], + *expressions, # _DDLColumnArgument + _table: Optional[Any] = None, + **dialect_kw: Any, + ): + super().__init__(name, *expressions, unique=False, _table=_table, **dialect_kw) + self.dialect_options["mysql"]["prefix"] = "VECTOR" + # add tiflash automatically when creating vector index + self.dialect_options["mysql"]["add_tiflash_on_demand"] = True + +# VectorIndex.argument_for("mysql", "add_tiflash_on_demand", None) + +@compiles(sqlalchemy.schema.CreateIndex) +def compile_create_vector_index(create_index_elem: sqlalchemy.sql.ddl.CreateIndex, compiler: sqlalchemy.sql.compiler.DDLCompiler, **kw): + text = compiler.visit_create_index(create_index_elem, **kw) + index_elem = create_index_elem.element + if index_elem.dialect_options.get("mysql", {}).get("add_tiflash_on_demand"): + text += " ADD_TIFLASH_ON_DEMAND" + return text From 521837f86b924c8e1209e8b4f4a905265be1bcdf Mon Sep 17 00:00:00 2001 From: JaySon-Huang Date: Thu, 7 Nov 2024 18:59:04 +0800 Subject: [PATCH 2/5] normal index Signed-off-by: JaySon-Huang --- tests/sqlalchemy/test_sqlalchemy.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/sqlalchemy/test_sqlalchemy.py b/tests/sqlalchemy/test_sqlalchemy.py index 4cffaf4..393179f 100644 --- a/tests/sqlalchemy/test_sqlalchemy.py +++ b/tests/sqlalchemy/test_sqlalchemy.py @@ -413,6 +413,11 @@ def test_create_vector_index_statement(self): compiled = CreateIndex(cos_index).compile(dialect=engine.dialect) assert compiled.string == "CREATE VECTOR INDEX idx_embedding_cos ON sqlalchemy_item2 ((vec_cosine_distance(embedding))) ADD_TIFLASH_ON_DEMAND" + # non-vector index + normal_index = sqlalchemy.schema.Index("idx_unique", Item2Model.__table__.c.id, unique=True) + compiled = CreateIndex(normal_index).compile(dialect=engine.dialect) + assert compiled.string == "CREATE UNIQUE INDEX idx_unique ON sqlalchemy_item2 (id)" + def test_query_with_index(self): # indexes l2_index = VectorIndex( From a6157ad67b25def4a3877ab530b9cdc286512367 Mon Sep 17 00:00:00 2001 From: JaySon-Huang Date: Thu, 7 Nov 2024 19:53:48 +0800 Subject: [PATCH 3/5] create table with tiflash replica attribute Signed-off-by: JaySon-Huang --- tests/sqlalchemy/test_sqlalchemy.py | 55 +++++++++++++++++++++++++++-- tidb_vector/sqlalchemy/index.py | 16 ++++----- 2 files changed, 60 insertions(+), 11 deletions(-) diff --git a/tests/sqlalchemy/test_sqlalchemy.py b/tests/sqlalchemy/test_sqlalchemy.py index 393179f..047ed33 100644 --- a/tests/sqlalchemy/test_sqlalchemy.py +++ b/tests/sqlalchemy/test_sqlalchemy.py @@ -1,3 +1,4 @@ +from locale import normalize import pytest import numpy as np import sqlalchemy @@ -39,6 +40,9 @@ class Item2Model(Base): id = Column(Integer, primary_key=True) embedding = Column(VectorType(dim=3)) + # __table_args__ = { + # "mysql_tiflash_replica": "1", + # } class TestSQLAlchemy: def setup_class(self): @@ -397,6 +401,51 @@ def setup_class(self): def teardown_class(self): Item2Model.__table__.drop(bind=engine, checkfirst=True) + def test_create_table_statement(self): + # Define a table using `sqlalchemy.schema.Table` + tbl = sqlalchemy.schema.Table( + 'mytable', + Base.metadata, + Column('id', Integer), + mysql_tiflash_replica='1', + ) + compiled = CreateTable(tbl).compile(dialect=engine.dialect) + normalized = compiled.string.replace("\n", "").replace("\t", "").strip() + assert normalized == "CREATE TABLE mytable (id INTEGER)" + + # Define a table with tiflash replica using `sqlalchemy.schema.Table` + tbl = sqlalchemy.schema.Table( + 'mytable', + Base.metadata, + Column('id', Integer), + mysql_tiflash_replica='1', + ) + from sqlalchemy.sql.ddl import CreateTable + compiled = CreateTable(tbl).compile(dialect=engine.dialect) + normalized = compiled.string.replace("\n", "").replace("\t", "").strip() + assert normalized == "CREATE TABLE mytable (id INTEGER)TIFLASH_REPLICA=1" + + # Define a table inheriting from `Base` + class Item3Model(Base): + __tablename__ = "sqlalchemy_item3" + id = Column(Integer, primary_key=True) + embedding = Column(VectorType(dim=3)) + compiled = CreateTable(Item3Model.__table__).compile(dialect=engine.dialect) + normalized = compiled.string.replace("\n", "").replace("\t", "").strip() + assert normalized == "CREATE TABLE sqlalchemy_item3 (id INTEGER NOT NULL AUTO_INCREMENT, embedding VECTOR(3), PRIMARY KEY (id))" + + # Define a table inheriting from `Base` with tiflash replica using `__table_args__` + class Item3Model(Base): + __tablename__ = "sqlalchemy_item3" + id = Column(Integer, primary_key=True) + embedding = Column(VectorType(dim=3)) + __table_args__ = { + "mysql_tiflash_replica": "1", + } + compiled = CreateTable(Item3Model.__table__).compile(dialect=engine.dialect) + normalized = compiled.string.replace("\n", "").replace("\t", "").strip() + assert normalized == "CREATE TABLE sqlalchemy_item3 (id INTEGER NOT NULL AUTO_INCREMENT, embedding VECTOR(3), PRIMARY KEY (id))TIFLASH_REPLICA=1" + def test_create_vector_index_statement(self): from sqlalchemy.sql.ddl import CreateIndex l2_index = VectorIndex( @@ -404,20 +453,20 @@ def test_create_vector_index_statement(self): sqlalchemy.func.vec_l2_distance(Item2Model.__table__.c.embedding), ) compiled = CreateIndex(l2_index).compile(dialect=engine.dialect) - assert compiled.string == "CREATE VECTOR INDEX idx_embedding_l2 ON sqlalchemy_item2 ((vec_l2_distance(embedding))) ADD_TIFLASH_ON_DEMAND" + assert compiled.string == "CREATE VECTOR INDEX idx_embedding_l2 ON sqlalchemy_item2 ((vec_l2_distance(embedding)))" cos_index = VectorIndex( "idx_embedding_cos", sqlalchemy.func.vec_cosine_distance(Item2Model.__table__.c.embedding), ) compiled = CreateIndex(cos_index).compile(dialect=engine.dialect) - assert compiled.string == "CREATE VECTOR INDEX idx_embedding_cos ON sqlalchemy_item2 ((vec_cosine_distance(embedding))) ADD_TIFLASH_ON_DEMAND" + assert compiled.string == "CREATE VECTOR INDEX idx_embedding_cos ON sqlalchemy_item2 ((vec_cosine_distance(embedding)))" # non-vector index normal_index = sqlalchemy.schema.Index("idx_unique", Item2Model.__table__.c.id, unique=True) compiled = CreateIndex(normal_index).compile(dialect=engine.dialect) assert compiled.string == "CREATE UNIQUE INDEX idx_unique ON sqlalchemy_item2 (id)" - + def test_query_with_index(self): # indexes l2_index = VectorIndex( diff --git a/tidb_vector/sqlalchemy/index.py b/tidb_vector/sqlalchemy/index.py index 67ec14d..2a9cbd3 100644 --- a/tidb_vector/sqlalchemy/index.py +++ b/tidb_vector/sqlalchemy/index.py @@ -15,15 +15,15 @@ def __init__( ): super().__init__(name, *expressions, unique=False, _table=_table, **dialect_kw) self.dialect_options["mysql"]["prefix"] = "VECTOR" - # add tiflash automatically when creating vector index - self.dialect_options["mysql"]["add_tiflash_on_demand"] = True # VectorIndex.argument_for("mysql", "add_tiflash_on_demand", None) -@compiles(sqlalchemy.schema.CreateIndex) -def compile_create_vector_index(create_index_elem: sqlalchemy.sql.ddl.CreateIndex, compiler: sqlalchemy.sql.compiler.DDLCompiler, **kw): - text = compiler.visit_create_index(create_index_elem, **kw) - index_elem = create_index_elem.element - if index_elem.dialect_options.get("mysql", {}).get("add_tiflash_on_demand"): - text += " ADD_TIFLASH_ON_DEMAND" +# Table.argument_for("mysql", "tiflash", None) + +@compiles(sqlalchemy.schema.CreateTable) +def compile_create_table(create_table_elem: sqlalchemy.sql.ddl.CreateTable, compiler: sqlalchemy.sql.compiler.DDLCompiler, **kw): + text = compiler.visit_create_table(create_table_elem, **kw) + # table_elem = create_table_elem.element + # if table_elem.dialect_options.get("mysql", {}).get("tiflash_replica"): + # text += " TIFLASH_REPLICA = 1" return text From d114f1218366f698225eef02b96b764328855d8c Mon Sep 17 00:00:00 2001 From: JaySon-Huang Date: Thu, 7 Nov 2024 19:54:11 +0800 Subject: [PATCH 4/5] Remove useless code Signed-off-by: JaySon-Huang --- tests/sqlalchemy/test_sqlalchemy.py | 3 --- tidb_vector/sqlalchemy/index.py | 12 ------------ 2 files changed, 15 deletions(-) diff --git a/tests/sqlalchemy/test_sqlalchemy.py b/tests/sqlalchemy/test_sqlalchemy.py index 047ed33..c97abad 100644 --- a/tests/sqlalchemy/test_sqlalchemy.py +++ b/tests/sqlalchemy/test_sqlalchemy.py @@ -40,9 +40,6 @@ class Item2Model(Base): id = Column(Integer, primary_key=True) embedding = Column(VectorType(dim=3)) - # __table_args__ = { - # "mysql_tiflash_replica": "1", - # } class TestSQLAlchemy: def setup_class(self): diff --git a/tidb_vector/sqlalchemy/index.py b/tidb_vector/sqlalchemy/index.py index 2a9cbd3..24ef082 100644 --- a/tidb_vector/sqlalchemy/index.py +++ b/tidb_vector/sqlalchemy/index.py @@ -15,15 +15,3 @@ def __init__( ): super().__init__(name, *expressions, unique=False, _table=_table, **dialect_kw) self.dialect_options["mysql"]["prefix"] = "VECTOR" - -# VectorIndex.argument_for("mysql", "add_tiflash_on_demand", None) - -# Table.argument_for("mysql", "tiflash", None) - -@compiles(sqlalchemy.schema.CreateTable) -def compile_create_table(create_table_elem: sqlalchemy.sql.ddl.CreateTable, compiler: sqlalchemy.sql.compiler.DDLCompiler, **kw): - text = compiler.visit_create_table(create_table_elem, **kw) - # table_elem = create_table_elem.element - # if table_elem.dialect_options.get("mysql", {}).get("tiflash_replica"): - # text += " TIFLASH_REPLICA = 1" - return text From 7f4a7d29c2cb4f8c0108115c70b26b9ea211f5d9 Mon Sep 17 00:00:00 2001 From: JaySon-Huang Date: Thu, 7 Nov 2024 20:00:01 +0800 Subject: [PATCH 5/5] update test Signed-off-by: JaySon-Huang --- tests/sqlalchemy/test_sqlalchemy.py | 58 ++++++++++++++++------------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/tests/sqlalchemy/test_sqlalchemy.py b/tests/sqlalchemy/test_sqlalchemy.py index c97abad..3b5e9c8 100644 --- a/tests/sqlalchemy/test_sqlalchemy.py +++ b/tests/sqlalchemy/test_sqlalchemy.py @@ -389,6 +389,16 @@ def test_index_and_search(self): assert items[0].distance == 0.0 +class Item3Model(Base): + __tablename__ = "sqlalchemy_item3" + id = Column(Integer, primary_key=True) + embedding = Column(VectorType(dim=3)) + + __table_args__ = { + "mysql_tiflash_replica": "1", + } + + class TestSQLAlchemyVectorIndex: def setup_class(self): @@ -423,44 +433,44 @@ def test_create_table_statement(self): assert normalized == "CREATE TABLE mytable (id INTEGER)TIFLASH_REPLICA=1" # Define a table inheriting from `Base` - class Item3Model(Base): - __tablename__ = "sqlalchemy_item3" + class TableModel(Base): + __tablename__ = "test_tbl" id = Column(Integer, primary_key=True) embedding = Column(VectorType(dim=3)) - compiled = CreateTable(Item3Model.__table__).compile(dialect=engine.dialect) + compiled = CreateTable(TableModel.__table__).compile(dialect=engine.dialect) normalized = compiled.string.replace("\n", "").replace("\t", "").strip() - assert normalized == "CREATE TABLE sqlalchemy_item3 (id INTEGER NOT NULL AUTO_INCREMENT, embedding VECTOR(3), PRIMARY KEY (id))" + assert normalized == "CREATE TABLE test_tbl (id INTEGER NOT NULL AUTO_INCREMENT, embedding VECTOR(3), PRIMARY KEY (id))" # Define a table inheriting from `Base` with tiflash replica using `__table_args__` - class Item3Model(Base): - __tablename__ = "sqlalchemy_item3" + class TableModel(Base): + __tablename__ = "test_tbl" id = Column(Integer, primary_key=True) embedding = Column(VectorType(dim=3)) __table_args__ = { "mysql_tiflash_replica": "1", } - compiled = CreateTable(Item3Model.__table__).compile(dialect=engine.dialect) + compiled = CreateTable(TableModel.__table__).compile(dialect=engine.dialect) normalized = compiled.string.replace("\n", "").replace("\t", "").strip() - assert normalized == "CREATE TABLE sqlalchemy_item3 (id INTEGER NOT NULL AUTO_INCREMENT, embedding VECTOR(3), PRIMARY KEY (id))TIFLASH_REPLICA=1" + assert normalized == "CREATE TABLE test_tbl (id INTEGER NOT NULL AUTO_INCREMENT, embedding VECTOR(3), PRIMARY KEY (id))TIFLASH_REPLICA=1" def test_create_vector_index_statement(self): from sqlalchemy.sql.ddl import CreateIndex l2_index = VectorIndex( "idx_embedding_l2", - sqlalchemy.func.vec_l2_distance(Item2Model.__table__.c.embedding), + sqlalchemy.func.vec_l2_distance(Item3Model.__table__.c.embedding), ) compiled = CreateIndex(l2_index).compile(dialect=engine.dialect) assert compiled.string == "CREATE VECTOR INDEX idx_embedding_l2 ON sqlalchemy_item2 ((vec_l2_distance(embedding)))" cos_index = VectorIndex( "idx_embedding_cos", - sqlalchemy.func.vec_cosine_distance(Item2Model.__table__.c.embedding), + sqlalchemy.func.vec_cosine_distance(Item3Model.__table__.c.embedding), ) compiled = CreateIndex(cos_index).compile(dialect=engine.dialect) assert compiled.string == "CREATE VECTOR INDEX idx_embedding_cos ON sqlalchemy_item2 ((vec_cosine_distance(embedding)))" # non-vector index - normal_index = sqlalchemy.schema.Index("idx_unique", Item2Model.__table__.c.id, unique=True) + normal_index = sqlalchemy.schema.Index("idx_unique", Item3Model.__table__.c.id, unique=True) compiled = CreateIndex(normal_index).compile(dialect=engine.dialect) assert compiled.string == "CREATE UNIQUE INDEX idx_unique ON sqlalchemy_item2 (id)" @@ -468,36 +478,32 @@ def test_query_with_index(self): # indexes l2_index = VectorIndex( "idx_embedding_l2", - sqlalchemy.func.vec_l2_distance(Item2Model.__table__.c.embedding), + sqlalchemy.func.vec_l2_distance(Item3Model.__table__.c.embedding), ) l2_index.create(engine) cos_index = VectorIndex( "idx_embedding_cos", - sqlalchemy.func.vec_cosine_distance(Item2Model.__table__.c.embedding), + sqlalchemy.func.vec_cosine_distance(Item3Model.__table__.c.embedding), ) cos_index.create(engine) - self.check_indexes( - Item2Model.__table__, ["idx_embedding_l2", "idx_embedding_cos"] - ) - with Session() as session: session.add_all( - [Item2Model(embedding=[1, 2, 3]), Item2Model(embedding=[1, 2, 3.2])] + [Item3Model(embedding=[1, 2, 3]), Item3Model(embedding=[1, 2, 3.2])] ) session.commit() # l2 distance result_l2 = session.scalars( - select(Item2Model).filter( - Item2Model.embedding.l2_distance([1, 2, 3.1]) < 0.2 + select(Item3Model).filter( + Item3Model.embedding.l2_distance([1, 2, 3.1]) < 0.2 ) ).all() assert len(result_l2) == 2 - distance_l2 = Item2Model.embedding.l2_distance([1, 2, 3]) + distance_l2 = Item3Model.embedding.l2_distance([1, 2, 3]) items_l2 = ( - session.query(Item2Model.id, distance_l2.label("distance")) + session.query(Item3Model.id, distance_l2.label("distance")) .order_by(distance_l2) .limit(5) .all() @@ -507,15 +513,15 @@ def test_query_with_index(self): # cosine distance result_cos = session.scalars( - select(Item2Model).filter( - Item2Model.embedding.cosine_distance([1, 2, 3.1]) < 0.2 + select(Item3Model).filter( + Item3Model.embedding.cosine_distance([1, 2, 3.1]) < 0.2 ) ).all() assert len(result_cos) == 2 - distance_cos = Item2Model.embedding.cosine_distance([1, 2, 3]) + distance_cos = Item3Model.embedding.cosine_distance([1, 2, 3]) items_cos = ( - session.query(Item2Model.id, distance_cos.label("distance")) + session.query(Item3Model.id, distance_cos.label("distance")) .order_by(distance_cos) .limit(5) .all()