From 4bbc435e3693033f626438bf973dee4bb772bf82 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 29 Oct 2025 15:11:53 +0000 Subject: [PATCH 1/4] Initial plan From c34f785d367c0a1a56d0600a518755da9d563b67 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 29 Oct 2025 15:19:14 +0000 Subject: [PATCH 2/4] Remove import-time dedent calls from cache backend Co-authored-by: adamchainz <857609+adamchainz@users.noreply.github.com> --- src/django_mysql/cache.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/src/django_mysql/cache.py b/src/django_mysql/cache.py index 0f8f1da9..bbe22a98 100644 --- a/src/django_mysql/cache.py +++ b/src/django_mysql/cache.py @@ -6,7 +6,6 @@ import zlib from collections.abc import Iterable from random import random -from textwrap import dedent from time import time from typing import Any, Callable, Literal, cast @@ -102,17 +101,15 @@ class MySQLCache(BaseDatabaseCache): # 1970) FOREVER_TIMEOUT = BIGINT_UNSIGNED_MAX >> 1 - create_table_sql = dedent( - """\ - CREATE TABLE `{table_name}` ( - cache_key varchar(255) CHARACTER SET utf8 COLLATE utf8_bin - NOT NULL PRIMARY KEY, - value longblob NOT NULL, - value_type char(1) CHARACTER SET latin1 COLLATE latin1_bin - NOT NULL DEFAULT 'p', - expires BIGINT UNSIGNED NOT NULL - ); - """ + create_table_sql = ( + "CREATE TABLE `{table_name}` (\n" + " cache_key varchar(255) CHARACTER SET utf8 COLLATE utf8_bin\n" + " NOT NULL PRIMARY KEY,\n" + " value longblob NOT NULL,\n" + " value_type char(1) CHARACTER SET latin1 COLLATE latin1_bin\n" + " NOT NULL DEFAULT 'p',\n" + " expires BIGINT UNSIGNED NOT NULL\n" + ");\n" ) @classmethod From 2af3a00fb11c7622b5cc279d34abf5e8db8fe38a Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 29 Oct 2025 16:38:29 +0000 Subject: [PATCH 3/4] Replace collapse_spaces calls with plain strings Co-authored-by: adamchainz <857609+adamchainz@users.noreply.github.com> --- src/django_mysql/cache.py | 113 ++++++----------- .../management/commands/cull_mysql_caches.py | 9 +- .../commands/mysql_cache_migration.py | 7 +- src/django_mysql/models/expressions.py | 117 ++++-------------- src/django_mysql/models/transforms.py | 13 +- src/django_mysql/utils.py | 5 - 6 files changed, 69 insertions(+), 195 deletions(-) diff --git a/src/django_mysql/cache.py b/src/django_mysql/cache.py index bbe22a98..00ccf2af 100644 --- a/src/django_mysql/cache.py +++ b/src/django_mysql/cache.py @@ -14,7 +14,7 @@ from django.utils.encoding import force_bytes from django.utils.module_loading import import_string -from django_mysql.utils import collapse_spaces, get_list_sql +from django_mysql.utils import get_list_sql _EncodedKeyType = Literal["i", "p", "z"] @@ -159,13 +159,10 @@ def get( value, value_type = row return self.decode(value, value_type) - _get_query = collapse_spaces( - """ - SELECT value, value_type - FROM {table} - WHERE cache_key = %s AND - expires >= %s - """ + _get_query = ( + "SELECT value, value_type " + "FROM {table} " + "WHERE cache_key = %s AND expires >= %s" ) def get_many( @@ -196,13 +193,10 @@ def get_many( return data - _get_many_query = collapse_spaces( - """ - SELECT cache_key, value, value_type - FROM {table} - WHERE cache_key IN {list_sql} AND - expires >= %s - """ + _get_many_query = ( + "SELECT cache_key, value, value_type " + "FROM {table} " + "WHERE cache_key IN {list_sql} AND expires >= %s" ) def set( @@ -258,15 +252,13 @@ def _base_set( insert_id = cursor.lastrowid return insert_id != 444 - _set_many_query = collapse_spaces( - """ - INSERT INTO {table} (cache_key, value, value_type, expires) - VALUES {{VALUES_CLAUSE}} - ON DUPLICATE KEY UPDATE - value=VALUES(value), - value_type=VALUES(value_type), - expires=VALUES(expires) - """ + _set_many_query = ( + "INSERT INTO {table} (cache_key, value, value_type, expires) " + "VALUES {{VALUES_CLAUSE}} " + "ON DUPLICATE KEY UPDATE " + "value=VALUES(value), " + "value_type=VALUES(value_type), " + "expires=VALUES(expires)" ) _set_query = _set_many_query.replace("{{VALUES_CLAUSE}}", "(%s, %s, %s, %s)") @@ -274,22 +266,15 @@ def _base_set( # Uses the IFNULL / LEAST / LAST_INSERT_ID trick to communicate the special # value of 444 back to the client (LAST_INSERT_ID is otherwise 0, since # there is no AUTO_INCREMENT column) - _add_query = collapse_spaces( - """ - INSERT INTO {table} (cache_key, value, value_type, expires) - VALUES (%s, %s, %s, %s) - ON DUPLICATE KEY UPDATE - value=IF(expires > @tmp_now:=%s, value, VALUES(value)), - value_type=IF(expires > @tmp_now, value_type, VALUES(value_type)), - expires=IF( - expires > @tmp_now, - IFNULL( - LEAST(LAST_INSERT_ID(444), NULL), - expires - ), - VALUES(expires) - ) - """ + _add_query = ( + "INSERT INTO {table} (cache_key, value, value_type, expires) " + "VALUES (%s, %s, %s, %s) " + "ON DUPLICATE KEY UPDATE " + "value=IF(expires > @tmp_now:=%s, value, VALUES(value)), " + "value_type=IF(expires > @tmp_now, value_type, VALUES(value_type)), " + "expires=IF(expires > @tmp_now, " + "IFNULL(LEAST(LAST_INSERT_ID(444), NULL), expires), " + "VALUES(expires))" ) def set_many( @@ -329,12 +314,7 @@ def delete(self, key: str, version: int | None = None) -> None: with connections[db].cursor() as cursor: cursor.execute(self._delete_query.format(table=table), (key,)) - _delete_query = collapse_spaces( - """ - DELETE FROM {table} - WHERE cache_key = %s - """ - ) + _delete_query = "DELETE FROM {table} WHERE cache_key = %s" def delete_many(self, keys: Iterable[str], version: int | None = None) -> None: made_keys = [self.make_key(key, version=version) for key in keys] @@ -352,12 +332,7 @@ def delete_many(self, keys: Iterable[str], version: int | None = None) -> None: made_keys, ) - _delete_many_query = collapse_spaces( - """ - DELETE FROM {table} - WHERE cache_key IN {list_sql} - """ - ) + _delete_many_query = "DELETE FROM {table} WHERE cache_key IN {list_sql}" def has_key(self, key: str, version: int | None = None) -> bool: key = self.make_key(key, version=version) @@ -370,11 +345,9 @@ def has_key(self, key: str, version: int | None = None) -> bool: cursor.execute(self._has_key_query.format(table=table), (key, self._now())) return cursor.fetchone() is not None - _has_key_query = collapse_spaces( - """ - SELECT 1 FROM {table} - WHERE cache_key = %s and expires > %s - """ + _has_key_query = ( + "SELECT 1 FROM {table} " + "WHERE cache_key = %s and expires > %s" ) def incr(self, key: str, delta: int = 1, version: int | None = None) -> int: @@ -409,17 +382,10 @@ def _base_delta( # Looks a bit tangled to turn the blob back into an int for updating, but # it works. Stores the new value for insert_id() with LAST_INSERT_ID - _delta_query = collapse_spaces( - """ - UPDATE {table} - SET value = LAST_INSERT_ID( - CAST(value AS SIGNED INTEGER) - {operation} - %s - ) - WHERE cache_key = %s AND - value_type = 'i' - """ + _delta_query = ( + "UPDATE {table} " + "SET value = LAST_INSERT_ID(CAST(value AS SIGNED INTEGER) {operation} %s) " + "WHERE cache_key = %s AND value_type = 'i'" ) def clear(self) -> None: @@ -442,13 +408,10 @@ def touch( ) return affected_rows > 0 - _touch_query = collapse_spaces( - """ - UPDATE {table} - SET expires = %s - WHERE cache_key = %s AND - expires >= %s - """ + _touch_query = ( + "UPDATE {table} " + "SET expires = %s " + "WHERE cache_key = %s AND expires >= %s" ) def validate_key(self, key: str) -> None: diff --git a/src/django_mysql/management/commands/cull_mysql_caches.py b/src/django_mysql/management/commands/cull_mysql_caches.py index a6af604b..cfc27e41 100644 --- a/src/django_mysql/management/commands/cull_mysql_caches.py +++ b/src/django_mysql/management/commands/cull_mysql_caches.py @@ -8,17 +8,14 @@ from django.core.management import BaseCommand, CommandError from django_mysql.cache import MySQLCache -from django_mysql.utils import collapse_spaces class Command(BaseCommand): args = "" - help = collapse_spaces( - """ - Runs cache.cull() on all your MySQLCache caches, or only those - specified aliases. - """ + help = ( + "Runs cache.cull() on all your MySQLCache caches, or only those " + "specified aliases." ) def add_arguments(self, parser: argparse.ArgumentParser) -> None: diff --git a/src/django_mysql/management/commands/mysql_cache_migration.py b/src/django_mysql/management/commands/mysql_cache_migration.py index afe90fae..3133c24e 100644 --- a/src/django_mysql/management/commands/mysql_cache_migration.py +++ b/src/django_mysql/management/commands/mysql_cache_migration.py @@ -8,17 +8,12 @@ from django.core.management import BaseCommand, CommandError from django_mysql.cache import MySQLCache -from django_mysql.utils import collapse_spaces class Command(BaseCommand): args = "" - help = collapse_spaces( - """ - Outputs a migration that will create a table. - """ - ) + help = "Outputs a migration that will create a table." def add_arguments(self, parser: argparse.ArgumentParser) -> None: parser.add_argument( diff --git a/src/django_mysql/models/expressions.py b/src/django_mysql/models/expressions.py index fea27d8f..948c2a31 100644 --- a/src/django_mysql/models/expressions.py +++ b/src/django_mysql/models/expressions.py @@ -8,8 +8,6 @@ from django.db.models.expressions import BaseExpression from django.db.models.sql.compiler import SQLCompiler -from django_mysql.utils import collapse_spaces - class TwoSidedExpression(BaseExpression): def __init__(self, lhs: BaseExpression, rhs: BaseExpression) -> None: @@ -52,18 +50,8 @@ class AppendListF(TwoSidedExpression): # comma and 'value' # N.B. using MySQL side variables to avoid repeat calculation of # expression[s] - sql_expression = collapse_spaces( - """ - CONCAT_WS( - ',', - IF( - (@tmp_f:=%s) > '', - @tmp_f, - NULL - ), - %s - ) - """ + sql_expression = ( + "CONCAT_WS(',', IF((@tmp_f:=%s) > '', @tmp_f, NULL), %s)" ) def as_sql( @@ -86,18 +74,8 @@ class AppendLeftListF(TwoSidedExpression): # comma and 'value' # N.B. using MySQL side variables to avoid repeat calculation of # expression[s] - sql_expression = collapse_spaces( - """ - CONCAT_WS( - ',', - %s, - IF( - (@tmp_f:=%s) > '', - @tmp_f, - NULL - ) - ) - """ + sql_expression = ( + "CONCAT_WS(',', %s, IF((@tmp_f:=%s) > '', @tmp_f, NULL))" ) def as_sql( @@ -115,22 +93,11 @@ def as_sql( class PopListF(BaseExpression): - sql_expression = collapse_spaces( - """ - SUBSTRING( - @tmp_f:=%s, - 1, - IF( - LOCATE(',', @tmp_f), - ( - CHAR_LENGTH(@tmp_f) - - CHAR_LENGTH(SUBSTRING_INDEX(@tmp_f, ',', -1)) - - 1 - ), - 0 - ) - ) - """ + sql_expression = ( + "SUBSTRING(@tmp_f:=%s, 1, " + "IF(LOCATE(',', @tmp_f), " + "(CHAR_LENGTH(@tmp_f) - CHAR_LENGTH(SUBSTRING_INDEX(@tmp_f, ',', -1)) - 1), " + "0))" ) def __init__(self, lhs: BaseExpression) -> None: @@ -155,14 +122,8 @@ def as_sql( class PopLeftListF(BaseExpression): - sql_expression = collapse_spaces( - """ - IF( - (@tmp_c:=LOCATE(',', @tmp_f:=%s)) > 0, - SUBSTRING(@tmp_f, @tmp_c + 1), - '' - ) - """ + sql_expression = ( + "IF((@tmp_c:=LOCATE(',', @tmp_f:=%s)) > 0, SUBSTRING(@tmp_f, @tmp_c + 1), '')" ) def __init__(self, lhs: BaseExpression) -> None: @@ -207,18 +168,9 @@ class AddSetF(TwoSidedExpression): # comma and 'value' # N.B. using MySQL side variables to avoid repeat calculation of # expression[s] - sql_expression = collapse_spaces( - """ - IF( - FIND_IN_SET(@tmp_val:=%s, @tmp_f:=%s), - @tmp_f, - CONCAT_WS( - ',', - IF(CHAR_LENGTH(@tmp_f), @tmp_f, NULL), - @tmp_val - ) - ) - """ + sql_expression = ( + "IF(FIND_IN_SET(@tmp_val:=%s, @tmp_f:=%s), @tmp_f, " + "CONCAT_WS(',', IF(CHAR_LENGTH(@tmp_f), @tmp_f, NULL), @tmp_val))" ) def as_sql( @@ -241,37 +193,16 @@ class RemoveSetF(TwoSidedExpression): # that element. # There are some tricks going on - e.g. LEAST to evaluate a sub expression # but not use it in the output of CONCAT_WS - sql_expression = collapse_spaces( - """ - IF( - @tmp_pos:=FIND_IN_SET(%s, @tmp_f:=%s), - CONCAT_WS( - ',', - LEAST( - @tmp_len:=( - CHAR_LENGTH(@tmp_f) - - CHAR_LENGTH(REPLACE(@tmp_f, ',', '')) + - IF(CHAR_LENGTH(@tmp_f), 1, 0) - ), - NULL - ), - CASE WHEN - (@tmp_before:=SUBSTRING_INDEX(@tmp_f, ',', @tmp_pos - 1)) - = '' - THEN NULL - ELSE @tmp_before - END, - CASE WHEN - (@tmp_after:= - SUBSTRING_INDEX(@tmp_f, ',', - (@tmp_len - @tmp_pos))) - = '' - THEN NULL - ELSE @tmp_after - END - ), - @tmp_f - ) - """ + sql_expression = ( + "IF(@tmp_pos:=FIND_IN_SET(%s, @tmp_f:=%s), " + "CONCAT_WS(',', " + "LEAST(@tmp_len:=(CHAR_LENGTH(@tmp_f) - CHAR_LENGTH(REPLACE(@tmp_f, ',', '')) + " + "IF(CHAR_LENGTH(@tmp_f), 1, 0)), NULL), " + "CASE WHEN (@tmp_before:=SUBSTRING_INDEX(@tmp_f, ',', @tmp_pos - 1)) = '' " + "THEN NULL ELSE @tmp_before END, " + "CASE WHEN (@tmp_after:=SUBSTRING_INDEX(@tmp_f, ',', - (@tmp_len - @tmp_pos))) = '' " + "THEN NULL ELSE @tmp_after END), " + "@tmp_f)" ) def as_sql( diff --git a/src/django_mysql/models/transforms.py b/src/django_mysql/models/transforms.py index 5796e601..5f5ae2c8 100644 --- a/src/django_mysql/models/transforms.py +++ b/src/django_mysql/models/transforms.py @@ -7,22 +7,15 @@ from django.db.models import IntegerField, Transform from django.db.models.sql.compiler import SQLCompiler -from django_mysql.utils import collapse_spaces - class SetLength(Transform): lookup_name = "len" output_field = IntegerField() # No str.count equivalent in MySQL :( - expr = collapse_spaces( - """ - ( - CHAR_LENGTH(%s) - - CHAR_LENGTH(REPLACE(%s, ',', '')) + - IF(CHAR_LENGTH(%s), 1, 0) - ) - """ + expr = ( + "(CHAR_LENGTH(%s) - CHAR_LENGTH(REPLACE(%s, ',', '')) + " + "IF(CHAR_LENGTH(%s), 1, 0))" ) def as_sql( diff --git a/src/django_mysql/utils.py b/src/django_mysql/utils.py index 1ab3c4f2..84b6e505 100644 --- a/src/django_mysql/utils.py +++ b/src/django_mysql/utils.py @@ -126,11 +126,6 @@ def settings_to_cmd_args(settings_dict: dict[str, Any]) -> list[str]: return args -def collapse_spaces(string: str) -> str: - bits = string.replace("\n", " ").split(" ") - return " ".join(filter(None, bits)) - - def index_name( model: type[Model], *field_names: str, using: str = DEFAULT_DB_ALIAS ) -> str: From c9c88fdeccb5b0688438fa766445ff3b832379a6 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 30 Oct 2025 00:19:04 +0000 Subject: [PATCH 4/4] Preserve indentation in SQL strings with fmt off/on Co-authored-by: adamchainz <857609+adamchainz@users.noreply.github.com> --- src/django_mysql/cache.py | 69 +++++++++++++---- src/django_mysql/models/expressions.py | 103 ++++++++++++++++++++----- src/django_mysql/models/transforms.py | 9 ++- 3 files changed, 146 insertions(+), 35 deletions(-) diff --git a/src/django_mysql/cache.py b/src/django_mysql/cache.py index 00ccf2af..f260b9c5 100644 --- a/src/django_mysql/cache.py +++ b/src/django_mysql/cache.py @@ -101,6 +101,7 @@ class MySQLCache(BaseDatabaseCache): # 1970) FOREVER_TIMEOUT = BIGINT_UNSIGNED_MAX >> 1 + # fmt: off create_table_sql = ( "CREATE TABLE `{table_name}` (\n" " cache_key varchar(255) CHARACTER SET utf8 COLLATE utf8_bin\n" @@ -111,6 +112,7 @@ class MySQLCache(BaseDatabaseCache): " expires BIGINT UNSIGNED NOT NULL\n" ");\n" ) + # fmt: on @classmethod def _now(cls) -> int: @@ -159,11 +161,14 @@ def get( value, value_type = row return self.decode(value, value_type) + # fmt: off _get_query = ( "SELECT value, value_type " "FROM {table} " - "WHERE cache_key = %s AND expires >= %s" + "WHERE cache_key = %s AND " + "expires >= %s" ) + # fmt: on def get_many( self, keys: Iterable[str], version: int | None = None @@ -193,11 +198,14 @@ def get_many( return data + # fmt: off _get_many_query = ( "SELECT cache_key, value, value_type " "FROM {table} " - "WHERE cache_key IN {list_sql} AND expires >= %s" + "WHERE cache_key IN {list_sql} AND " + "expires >= %s" ) + # fmt: on def set( self, @@ -252,30 +260,39 @@ def _base_set( insert_id = cursor.lastrowid return insert_id != 444 + # fmt: off _set_many_query = ( "INSERT INTO {table} (cache_key, value, value_type, expires) " "VALUES {{VALUES_CLAUSE}} " "ON DUPLICATE KEY UPDATE " - "value=VALUES(value), " - "value_type=VALUES(value_type), " - "expires=VALUES(expires)" + "value=VALUES(value), " + "value_type=VALUES(value_type), " + "expires=VALUES(expires)" ) + # fmt: on _set_query = _set_many_query.replace("{{VALUES_CLAUSE}}", "(%s, %s, %s, %s)") # Uses the IFNULL / LEAST / LAST_INSERT_ID trick to communicate the special # value of 444 back to the client (LAST_INSERT_ID is otherwise 0, since # there is no AUTO_INCREMENT column) + # fmt: off _add_query = ( "INSERT INTO {table} (cache_key, value, value_type, expires) " "VALUES (%s, %s, %s, %s) " "ON DUPLICATE KEY UPDATE " - "value=IF(expires > @tmp_now:=%s, value, VALUES(value)), " - "value_type=IF(expires > @tmp_now, value_type, VALUES(value_type)), " - "expires=IF(expires > @tmp_now, " - "IFNULL(LEAST(LAST_INSERT_ID(444), NULL), expires), " - "VALUES(expires))" + "value=IF(expires > @tmp_now:=%s, value, VALUES(value)), " + "value_type=IF(expires > @tmp_now, value_type, VALUES(value_type)), " + "expires=IF(" + "expires > @tmp_now, " + "IFNULL(" + "LEAST(LAST_INSERT_ID(444), NULL), " + "expires" + "), " + "VALUES(expires)" + ")" ) + # fmt: on def set_many( self, @@ -314,7 +331,12 @@ def delete(self, key: str, version: int | None = None) -> None: with connections[db].cursor() as cursor: cursor.execute(self._delete_query.format(table=table), (key,)) - _delete_query = "DELETE FROM {table} WHERE cache_key = %s" + # fmt: off + _delete_query = ( + "DELETE FROM {table} " + "WHERE cache_key = %s" + ) + # fmt: on def delete_many(self, keys: Iterable[str], version: int | None = None) -> None: made_keys = [self.make_key(key, version=version) for key in keys] @@ -332,7 +354,12 @@ def delete_many(self, keys: Iterable[str], version: int | None = None) -> None: made_keys, ) - _delete_many_query = "DELETE FROM {table} WHERE cache_key IN {list_sql}" + # fmt: off + _delete_many_query = ( + "DELETE FROM {table} " + "WHERE cache_key IN {list_sql}" + ) + # fmt: on def has_key(self, key: str, version: int | None = None) -> bool: key = self.make_key(key, version=version) @@ -345,10 +372,12 @@ def has_key(self, key: str, version: int | None = None) -> bool: cursor.execute(self._has_key_query.format(table=table), (key, self._now())) return cursor.fetchone() is not None + # fmt: off _has_key_query = ( "SELECT 1 FROM {table} " "WHERE cache_key = %s and expires > %s" ) + # fmt: on def incr(self, key: str, delta: int = 1, version: int | None = None) -> int: return self._base_delta(key, delta, version, "+") @@ -382,11 +411,18 @@ def _base_delta( # Looks a bit tangled to turn the blob back into an int for updating, but # it works. Stores the new value for insert_id() with LAST_INSERT_ID + # fmt: off _delta_query = ( "UPDATE {table} " - "SET value = LAST_INSERT_ID(CAST(value AS SIGNED INTEGER) {operation} %s) " - "WHERE cache_key = %s AND value_type = 'i'" + "SET value = LAST_INSERT_ID(" + "CAST(value AS SIGNED INTEGER) " + "{operation} " + "%s" + ") " + "WHERE cache_key = %s AND " + "value_type = 'i'" ) + # fmt: on def clear(self) -> None: db = router.db_for_write(self.cache_model_class) @@ -408,11 +444,14 @@ def touch( ) return affected_rows > 0 + # fmt: off _touch_query = ( "UPDATE {table} " "SET expires = %s " - "WHERE cache_key = %s AND expires >= %s" + "WHERE cache_key = %s AND " + "expires >= %s" ) + # fmt: on def validate_key(self, key: str) -> None: """ diff --git a/src/django_mysql/models/expressions.py b/src/django_mysql/models/expressions.py index 948c2a31..f94a8adc 100644 --- a/src/django_mysql/models/expressions.py +++ b/src/django_mysql/models/expressions.py @@ -50,9 +50,19 @@ class AppendListF(TwoSidedExpression): # comma and 'value' # N.B. using MySQL side variables to avoid repeat calculation of # expression[s] + # fmt: off sql_expression = ( - "CONCAT_WS(',', IF((@tmp_f:=%s) > '', @tmp_f, NULL), %s)" + "CONCAT_WS(" + "',', " + "IF(" + "(@tmp_f:=%s) > '', " + "@tmp_f, " + "NULL" + "), " + "%s" + ")" ) + # fmt: on def as_sql( self, @@ -74,9 +84,19 @@ class AppendLeftListF(TwoSidedExpression): # comma and 'value' # N.B. using MySQL side variables to avoid repeat calculation of # expression[s] + # fmt: off sql_expression = ( - "CONCAT_WS(',', %s, IF((@tmp_f:=%s) > '', @tmp_f, NULL))" + "CONCAT_WS(" + "',', " + "%s, " + "IF(" + "(@tmp_f:=%s) > '', " + "@tmp_f, " + "NULL" + ")" + ")" ) + # fmt: on def as_sql( self, @@ -93,12 +113,23 @@ def as_sql( class PopListF(BaseExpression): + # fmt: off sql_expression = ( - "SUBSTRING(@tmp_f:=%s, 1, " - "IF(LOCATE(',', @tmp_f), " - "(CHAR_LENGTH(@tmp_f) - CHAR_LENGTH(SUBSTRING_INDEX(@tmp_f, ',', -1)) - 1), " - "0))" + "SUBSTRING(" + "@tmp_f:=%s, " + "1, " + "IF(" + "LOCATE(',', @tmp_f), " + "(" + "CHAR_LENGTH(@tmp_f) - " + "CHAR_LENGTH(SUBSTRING_INDEX(@tmp_f, ',', -1)) - " + "1" + "), " + "0" + ")" + ")" ) + # fmt: on def __init__(self, lhs: BaseExpression) -> None: super().__init__() @@ -122,9 +153,15 @@ def as_sql( class PopLeftListF(BaseExpression): + # fmt: off sql_expression = ( - "IF((@tmp_c:=LOCATE(',', @tmp_f:=%s)) > 0, SUBSTRING(@tmp_f, @tmp_c + 1), '')" + "IF(" + "(@tmp_c:=LOCATE(',', @tmp_f:=%s)) > 0, " + "SUBSTRING(@tmp_f, @tmp_c + 1), " + "''" + ")" ) + # fmt: on def __init__(self, lhs: BaseExpression) -> None: super().__init__() @@ -168,10 +205,19 @@ class AddSetF(TwoSidedExpression): # comma and 'value' # N.B. using MySQL side variables to avoid repeat calculation of # expression[s] + # fmt: off sql_expression = ( - "IF(FIND_IN_SET(@tmp_val:=%s, @tmp_f:=%s), @tmp_f, " - "CONCAT_WS(',', IF(CHAR_LENGTH(@tmp_f), @tmp_f, NULL), @tmp_val))" + "IF(" + "FIND_IN_SET(@tmp_val:=%s, @tmp_f:=%s), " + "@tmp_f, " + "CONCAT_WS(" + "',', " + "IF(CHAR_LENGTH(@tmp_f), @tmp_f, NULL), " + "@tmp_val" + ")" + ")" ) + # fmt: on def as_sql( self, @@ -193,17 +239,38 @@ class RemoveSetF(TwoSidedExpression): # that element. # There are some tricks going on - e.g. LEAST to evaluate a sub expression # but not use it in the output of CONCAT_WS + # fmt: off sql_expression = ( - "IF(@tmp_pos:=FIND_IN_SET(%s, @tmp_f:=%s), " - "CONCAT_WS(',', " - "LEAST(@tmp_len:=(CHAR_LENGTH(@tmp_f) - CHAR_LENGTH(REPLACE(@tmp_f, ',', '')) + " - "IF(CHAR_LENGTH(@tmp_f), 1, 0)), NULL), " - "CASE WHEN (@tmp_before:=SUBSTRING_INDEX(@tmp_f, ',', @tmp_pos - 1)) = '' " - "THEN NULL ELSE @tmp_before END, " - "CASE WHEN (@tmp_after:=SUBSTRING_INDEX(@tmp_f, ',', - (@tmp_len - @tmp_pos))) = '' " - "THEN NULL ELSE @tmp_after END), " - "@tmp_f)" + "IF(" + "@tmp_pos:=FIND_IN_SET(%s, @tmp_f:=%s), " + "CONCAT_WS(" + "',', " + "LEAST(" + "@tmp_len:=(" + "CHAR_LENGTH(@tmp_f) - " + "CHAR_LENGTH(REPLACE(@tmp_f, ',', '')) + " + "IF(CHAR_LENGTH(@tmp_f), 1, 0)" + "), " + "NULL" + "), " + "CASE WHEN " + "(@tmp_before:=SUBSTRING_INDEX(@tmp_f, ',', @tmp_pos - 1)) " + "= '' " + "THEN NULL " + "ELSE @tmp_before " + "END, " + "CASE WHEN " + "(@tmp_after:=" + "SUBSTRING_INDEX(@tmp_f, ',', - (@tmp_len - @tmp_pos))) " + "= '' " + "THEN NULL " + "ELSE @tmp_after " + "END" + "), " + "@tmp_f" + ")" ) + # fmt: on def as_sql( self, diff --git a/src/django_mysql/models/transforms.py b/src/django_mysql/models/transforms.py index 5f5ae2c8..8e38b283 100644 --- a/src/django_mysql/models/transforms.py +++ b/src/django_mysql/models/transforms.py @@ -13,10 +13,15 @@ class SetLength(Transform): output_field = IntegerField() # No str.count equivalent in MySQL :( + # fmt: off expr = ( - "(CHAR_LENGTH(%s) - CHAR_LENGTH(REPLACE(%s, ',', '')) + " - "IF(CHAR_LENGTH(%s), 1, 0))" + "(" + "CHAR_LENGTH(%s) - " + "CHAR_LENGTH(REPLACE(%s, ',', '')) + " + "IF(CHAR_LENGTH(%s), 1, 0)" + ")" ) + # fmt: on def as_sql( self, compiler: SQLCompiler, connection: BaseDatabaseWrapper