From 4d9ec3bb26eb85a837cb4fcae289231e5372de93 Mon Sep 17 00:00:00 2001 From: Tony Kuo <123580782+tonykploomber@users.noreply.github.com> Date: Mon, 6 Mar 2023 23:06:13 -0500 Subject: [PATCH] Fix - Issue better detect the user uses sqlalchemy variable expansion (#215) * Add: fail test case * Add: better check for variable usage --- src/sql/command.py | 4 ++-- src/tests/test_command.py | 14 ++++++++++++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/sql/command.py b/src/sql/command.py index cf78a4193..a4acb6745 100644 --- a/src/sql/command.py +++ b/src/sql/command.py @@ -92,8 +92,8 @@ def _var_expand(self, sql, user_ns, magic): sql = Template(sql).render(user_ns) parsed_sql = magic.shell.var_expand(sql, depth=2) - has_SQLAlchemy_var_expand = ":" in sql - # parsed_sql != sql: detect if using IPython fashion - {a} or $a + has_SQLAlchemy_var_expand = any((':' + ns_var_key in sql + for ns_var_key in user_ns.keys())) # has_SQLAlchemy_var_expand: detect if using Sqlalchemy fashion - :a msg = ( diff --git a/src/tests/test_command.py b/src/tests/test_command.py index b91d3d84b..bf869a4ed 100644 --- a/src/tests/test_command.py +++ b/src/tests/test_command.py @@ -1,4 +1,5 @@ from pathlib import Path +import warnings import pytest from sqlalchemy import create_engine @@ -209,6 +210,17 @@ def test_variable_substitution_legacy_warning_message_colon(ip, sql_magic, capsy """, ) + with warnings.catch_warnings(): + warnings.simplefilter("error") + ip.user_global_ns["limit_number"] = 1 + ip.run_cell_magic( + "sql", + "", + """ + SELECT * FROM author WHERE last_name = 'Something with : inside' + """, + ) + def test_variable_substitution_legacy_dollar_prefix_cell_magic(ip, sql_magic): ip.user_global_ns["username"] = "some-user" @@ -246,7 +258,6 @@ def test_variable_substitution_double_curly_cell_magic(ip, sql_magic): cell="GRANT CONNECT ON DATABASE postgres TO {{username}};", ) - print("cmd.parsed['sql']", cmd.parsed["sql"]) assert cmd.parsed["sql"] == "\nGRANT CONNECT ON DATABASE postgres TO some-user;" @@ -260,5 +271,4 @@ def test_variable_substitution_double_curly_line_magic(ip, sql_magic): cell="", ) - # print ("cmd.parsed['sql']", cmd.parsed["sql"]) assert cmd.parsed["sql"] == "SELECT first_name FROM author LIMIT 5;"