diff --git a/core/query_rewriter.py b/core/query_rewriter.py index b24b1db..49b1cf8 100644 --- a/core/query_rewriter.py +++ b/core/query_rewriter.py @@ -138,7 +138,7 @@ def match_node(query_node: Any, rule_node: Any, rule: dict, memo: dict) -> bool: # handle case when query_node = {'all_columns': {}} and rule_node = {"value": "V001"} # we want "V001" to match "all_columns" # - if QueryRewriter.is_var(rule_node['value']) and not QueryRewriter.is_list(query_node) and 'all_columns' in query_node.keys(): + if QueryRewriter.is_var(rule_node['value']) and QueryRewriter.is_dict(query_node) and 'all_columns' in query_node.keys(): memo[rule_node['value']] = list(query_node.keys())[0] return True diff --git a/data/rules.py b/data/rules.py index 9fd2435..aa633e9 100644 --- a/data/rules.py +++ b/data/rules.py @@ -392,6 +392,17 @@ # 'mapping': "{\"x\": \"V1\", \"y\": \"V2\"}", 'database': 'mysql' }, + + { + 'id': 103, + 'key': 'stackoverflow_1', + 'name': 'Stackoverflow 1', + 'pattern': 'SELECT DISTINCT <> FROM <> WHERE <>', + 'constraints': '', + 'rewrite': 'SELECT <> FROM <> WHERE <> GROUP BY <>', + 'actions': '', + 'database': 'postgresql' + }, ] # fetch one rule by key (json attributes are in json) diff --git a/tests/test_query_rewriter.py b/tests/test_query_rewriter.py index cce37f3..896aba1 100644 --- a/tests/test_query_rewriter.py +++ b/tests/test_query_rewriter.py @@ -1052,6 +1052,33 @@ def test_rewrite_rule_query_rule_wetune_90(): assert format(parse(q1)) == format(parse(_q1)) +def test_rewrite_stackoverflow_1(): + q0 = ''' + SELECT DISTINCT my_table.foo, your_table.boo + FROM my_table, your_table + WHERE my_table.num = 1 OR your_table.num = 2 + ''' + q1 = ''' + SELECT + my_table.foo, + your_table.boo + FROM + my_table, + your_table + WHERE + my_table.num = 1 + OR your_table.num = 2 + GROUP BY + my_table.foo, + your_table.boo + ''' + rule_keys = ['stackoverflow_1', 'remove_self_join'] + + rules = [get_rule(k) for k in rule_keys] + _q1, _rewrite_path = QueryRewriter.rewrite(q0, rules) + assert format(parse(q1)) == format(parse(_q1)) + + # TODO - TBI # def test_rewrite_postgresql():