From 1dac29baba623f19de901da9c122d4fbb2aeb8d8 Mon Sep 17 00:00:00 2001 From: JaySon-Huang Date: Thu, 7 Nov 2024 21:20:39 +0800 Subject: [PATCH] Add example Signed-off-by: JaySon-Huang --- django_tidb/fields/vector.py | 9 +++--- django_tidb/patch.py | 9 +++--- django_tidb/schema.py | 5 ++- tests/tidb/test_tiflash_replica.py | 50 ++++++++++++++++++++++++++++++ tests/tidb_vector/models.py | 1 + 5 files changed, 64 insertions(+), 10 deletions(-) create mode 100644 tests/tidb/test_tiflash_replica.py diff --git a/django_tidb/fields/vector.py b/django_tidb/fields/vector.py index 260b2e3..2b19190 100644 --- a/django_tidb/fields/vector.py +++ b/django_tidb/fields/vector.py @@ -143,7 +143,9 @@ class VectorIndex(Index): class Document(models.Model): content = models.TextField() embedding = VectorField(dimensions=3) + class Meta: + tiflash_replica = 1 # When defining a vector index, the tiflash_replica must be non-zero indexes = [ VectorIndex(CosineDistance("embedding"), name='idx_cos'), ] @@ -189,10 +191,7 @@ def create_sql(self, model, schema_editor, using="", **kwargs): ) fields = None col_suffixes = None - # TODO: remove the tiflash replica setting statement from sql_template - # after we support `ADD_TIFLASH_ON_DEMAND` in the `CREATE VECTOR INDEX ...` - sql_template = """ALTER TABLE %(table)s SET TIFLASH REPLICA 1; - CREATE VECTOR INDEX %(name)s ON %(table)s%(using)s (%(columns)s)%(extra)s""" + sql_template = """CREATE VECTOR INDEX %(name)s ON %(table)s%(using)s (%(columns)s)%(extra)s""" return schema_editor._create_index_sql( model, fields=fields, @@ -214,7 +213,7 @@ class DistanceBase(Func): def __init__(self, expression, vector=None, **extra): """ - expression: the name of a field, or an expression returing a vector + expression: the name of a field, or an expression returning a vector vector: a vector to compare against """ expressions = [expression] diff --git a/django_tidb/patch.py b/django_tidb/patch.py index 3712447..e66b250 100644 --- a/django_tidb/patch.py +++ b/django_tidb/patch.py @@ -32,14 +32,15 @@ def patch_model_functions(): def patch_model_options(): - # Patch `tidb_auto_id_cache` to options.DEFAULT_NAMES, + extra_meta_options = ("tidb_auto_id_cache", "tiflash_replica",) + # Patch `extra_meta_options` to options.DEFAULT_NAMES, # so that user can define it in model's Meta class. - options.DEFAULT_NAMES += ("tidb_auto_id_cache",) + options.DEFAULT_NAMES += extra_meta_options # Because Django named import DEFAULT_NAMES in migrations, # so we need to patch it again here. - # Django will record `tidb_auto_id_cache` in migration files, + # Django will record `extra_meta_options` in migration files, # and then restore it when applying migrations. - state.DEFAULT_NAMES += ("tidb_auto_id_cache",) + state.DEFAULT_NAMES += extra_meta_options def monkey_patch(): diff --git a/django_tidb/schema.py b/django_tidb/schema.py index e018f7c..f0ca0cc 100644 --- a/django_tidb/schema.py +++ b/django_tidb/schema.py @@ -65,5 +65,8 @@ def table_sql(self, model): sql, params = super().table_sql(model) tidb_auto_id_cache = getattr(model._meta, "tidb_auto_id_cache", None) if tidb_auto_id_cache is not None: - sql += " AUTO_ID_CACHE %s" % tidb_auto_id_cache + sql += f" AUTO_ID_CACHE {tidb_auto_id_cache}" + tiflash_replica = getattr(model._meta, "tiflash_replica", None) + if tiflash_replica is not None: + sql += f" TIFLASH_REPLICA={tiflash_replica}" return sql, params diff --git a/tests/tidb/test_tiflash_replica.py b/tests/tidb/test_tiflash_replica.py new file mode 100644 index 0000000..518cc1f --- /dev/null +++ b/tests/tidb/test_tiflash_replica.py @@ -0,0 +1,50 @@ +import re + +from django.db import models, connection +from django.test import TransactionTestCase +from django.test.utils import isolate_apps + +TIFLASH_REPLICA_PATTERN = re.compile(r"\/\*T!\[tiflash_replica\] TIFLASH_REPLICA=(\d+) \*\/") + +class TiDBTiFlashReplicaTests(TransactionTestCase): + available_apps = ["tidb"] + + def get_tiflash_replica_info(self, table): + with connection.cursor() as cursor: + cursor.execute( + f"SHOW CREATE TABLE {table}", + ) + row = cursor.fetchone() + if row is None: + return None + match = TIFLASH_REPLICA_PATTERN.search(row[1]) + if match: + return match.groups()[0] + return None + + @isolate_apps("tidb") + def test_create_table_without_tiflash_replica(self): + class TiFlashReplicaNode0(models.Model): + title = models.CharField(max_length=255) + + class Meta: + app_label = "tidb" + + with connection.schema_editor() as editor: + editor.create_model(TiFlashReplicaNode0) + self.assertIsNone(self.get_tiflash_replica_info(TiFlashReplicaNode0._meta.db_table)) + + @isolate_apps("tidb") + def test_create_table_with_tiflash_replica_1(self): + class TiFlashReplicaNode1(models.Model): + title = models.CharField(max_length=255) + + class Meta: + app_label = "tidb" + tiflash_replica = 1 + + with connection.schema_editor() as editor: + editor.create_model(TiFlashReplicaNode1) + self.assertEqual( + self.get_tiflash_replica_info(TiFlashReplicaNode1._meta.db_table), "1" + ) diff --git a/tests/tidb_vector/models.py b/tests/tidb_vector/models.py index c34b87f..c59ebad 100644 --- a/tests/tidb_vector/models.py +++ b/tests/tidb_vector/models.py @@ -23,6 +23,7 @@ class DocumentWithAnnIndex(models.Model): embedding = VectorField(dimensions=3) class Meta: + tiflash_replica = 1 # When defining a vector index, the tiflash_replica must be non-zero indexes = [ VectorIndex(CosineDistance("embedding"), name="idx_cos"), VectorIndex(L2Distance("embedding"), name="idx_l2"),