From 0a75f1cea65eaab96f09fe5bce342597e907e680 Mon Sep 17 00:00:00 2001 From: Tony Locke Date: Sat, 17 Aug 2024 14:48:46 +0100 Subject: [PATCH] Simplify identifier() by always quoting. Previously we only quoted when it had to be. --- README.md | 2 +- src/pg8000/converters.py | 23 ++++------------------- test/test_converters.py | 8 ++++---- 3 files changed, 9 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index 05c7fb8..c367e6e 100644 --- a/README.md +++ b/README.md @@ -632,7 +632,7 @@ are both separate identifiers. So to escape them you'd do: ... f"WHERE lanname = 'sql'" ... ) >>> print(query) -SELECT lanname FROM pg_catalog.pg_language WHERE lanname = 'sql' +SELECT lanname FROM "pg_catalog"."pg_language" WHERE lanname = 'sql' >>> >>> con.run(query) [['sql']] diff --git a/src/pg8000/converters.py b/src/pg8000/converters.py index 84593ff..b234236 100644 --- a/src/pg8000/converters.py +++ b/src/pg8000/converters.py @@ -772,10 +772,6 @@ def make_params(py_types, values): return tuple([make_param(py_types, v) for v in values]) -def _quote_letter(c): - return c.isupper() if c.isalpha() else True - - def identifier(sql): if not isinstance(sql, str): raise InterfaceError("identifier must be a str") @@ -783,22 +779,11 @@ def identifier(sql): if len(sql) == 0: raise InterfaceError("identifier must be > 0 characters in length") - quote = _quote_letter(sql[0]) + if "\u0000" in sql: + raise InterfaceError("identifier cannot contain the code zero character") - for c in sql[1:]: - if _quote_letter(c) and c not in "0123456789_$": - if c == "\u0000": - raise InterfaceError( - "identifier cannot contain the code zero character" - ) - quote = True - break - - if quote: - sql = sql.replace('"', '""') - return f'"{sql}"' - else: - return sql + sql = sql.replace('"', '""') + return f'"{sql}"' def literal(value): diff --git a/test/test_converters.py b/test/test_converters.py index 4781aec..3504765 100644 --- a/test/test_converters.py +++ b/test/test_converters.py @@ -341,14 +341,14 @@ def test_identifier_quoted_null(): @pytest.mark.parametrize( "value,expected", [ - ("top_secret", "top_secret"), + ("top_secret", '"top_secret"'), (" Table", '" Table"'), ("A Table", '"A Table"'), ('A " Table', '"A "" Table"'), - ("table$", "table$"), + ("table$", '"table$"'), ("Table$", '"Table$"'), - ("tableఐ", "tableఐ"), # Unicode character 0C10 which is uncased - ("table", "table"), + ("tableఐ", '"tableఐ"'), # Unicode character 0C10 which is uncased + ("table", '"table"'), ("tAble", '"tAble"'), ], )