diff --git a/pydough/conversion/filter_pullup.py b/pydough/conversion/filter_pullup.py new file mode 100644 index 000000000..72b9ab1dd --- /dev/null +++ b/pydough/conversion/filter_pullup.py @@ -0,0 +1,151 @@ +""" +Logic used to transpose filters higher up relational trees, above joins. +""" + +__all__ = ["pull_filters"] + + +from pydough.configs import PyDoughSession +from pydough.relational import ( + Aggregate, + ColumnReference, + EmptySingleton, + Filter, + GeneratedTable, + Join, + JoinType, + Limit, + Project, + RelationalExpression, + RelationalNode, + RelationalShuttle, + Scan, +) +from pydough.relational.rel_util import ( + add_input_name, + bubble_expression, + build_filter, + get_conjunctions, +) + + +class FilterPullupShuttle(RelationalShuttle): + """ + Shuttle implementation that pulls up filters as far as possible in the + relational tree. + """ + + def __init__(self, session: PyDoughSession): + # The set used to contain filters from child nodes that can be pulled + # up further. + self.filters: set[RelationalExpression] = set() + + def reset(self): + self.filters.clear() + + def visit_join(self, join: Join) -> RelationalNode: + out_filters: set[RelationalExpression] = self.filters + + self.filters = set() + join.inputs[0] = join.inputs[0].accept_shuttle(self) + left_filters: set[RelationalExpression] = { + add_input_name(cond, join.default_input_aliases[0]) for cond in self.filters + } + + self.filters = set() + join.inputs[1] = join.inputs[1].accept_shuttle(self) + right_filters: set[RelationalExpression] = { + add_input_name(cond, join.default_input_aliases[1]) for cond in self.filters + } + + # TODO ADD COMMENTS + self.pull_conditions(left_filters, join, out_filters) + + # TODO ADD COMMENTS + if join.join_type in (JoinType.INNER, JoinType.SEMI): + self.pull_conditions(right_filters, join, out_filters) + + # for expr in cond_filters: + # join._condition = CallExpression(pydop.BAN, BooleanType(), [join._condition, expr]) + + build_filter(join, out_filters) + + self.filters = out_filters + return join + + def pull_conditions( + self, + conditions: set[RelationalExpression], + node: RelationalNode, + current_filters: set[RelationalExpression], + ) -> None: + """ + Attempts to pull the given conditions above the given node, modifying + `current_filters` as needed. + + Args: + `conditions`: The conditions to attempt to pull up. + `node`: The relational node at which the condition currently + resides. + `current_filters`: The set of filters currently in effect, modified + in place to add the condition if it is successfully pulled up. + """ + revmap: dict[RelationalExpression, RelationalExpression] = {} + for name, expr in node.columns.items(): + revmap[expr] = ColumnReference(name, expr.data_type) + + for condition in conditions: + new_cond: RelationalExpression | None = bubble_expression(condition, revmap) + if new_cond is not None: + current_filters.add(new_cond) + + def generic_pullup(self, node: RelationalNode) -> RelationalNode: + result: RelationalNode = self.generic_visit_inputs(node) + current_filters: set[RelationalExpression] = self.filters + self.filters = set() + self.pull_conditions(self.filters, result, current_filters) + return result + + def visit_filter(self, filter: Filter) -> RelationalNode: + result = self.generic_pullup(filter) + self.pull_conditions(get_conjunctions(filter.condition), result, self.filters) + return result + + def visit_project(self, project: Project) -> RelationalNode: + return self.generic_pullup(project) + + def visit_aggregate(self, aggregate: Aggregate) -> RelationalNode: + return self.generic_pullup(aggregate) + + def visit_limit(self, limit: Limit) -> RelationalNode: + # Limits cannot have filters pulled above them, but their inputs can + # still be transformed + return self.generic_visit_inputs(limit) + + def visit_scan(self, scan: Scan) -> RelationalNode: + # Scans cannot have filters pulled above them. + return scan + + def visit_empty_singleton(self, empty_singleton: EmptySingleton) -> RelationalNode: + # Empty singletons cannot have filters pulled above them. + return empty_singleton + + def visit_generated_table(self, generated_table: GeneratedTable) -> RelationalNode: + # Generated tables cannot have filters pulled above them. + return generated_table + + +def pull_filters(node: RelationalNode, session: PyDoughSession) -> RelationalNode: + """ + Transpose filter conditions up above joins. + + Args: + `node`: The current node of the relational tree. + `configs`: The PyDough configuration settings. + + Returns: + The transformed version of `node` and all of its descendants with + filters pulled up as far as possible. + """ + pusher: FilterPullupShuttle = FilterPullupShuttle(session) + return node.accept_shuttle(pusher) diff --git a/pydough/conversion/filter_pushdown.py b/pydough/conversion/filter_pushdown.py index e25c367c4..b53ad6115 100644 --- a/pydough/conversion/filter_pushdown.py +++ b/pydough/conversion/filter_pushdown.py @@ -5,12 +5,15 @@ __all__ = ["push_filters"] +import itertools + import pydough.pydough_operators as pydop from pydough.configs import PyDoughSession from pydough.relational import ( Aggregate, CallExpression, ColumnReference, + ColumnReferenceFinder, EmptySingleton, Filter, GeneratedTable, @@ -28,13 +31,17 @@ ) from pydough.relational.rel_util import ( ExpressionTranspositionShuttle, + add_input_name, + apply_substitution, build_filter, contains_window, + extract_equijoin_keys, false_when_null_columns, get_conjunctions, only_references_columns, partition_expressions, ) +from pydough.types import BooleanType from .relational_simplification import SimplificationShuttle @@ -173,6 +180,168 @@ def visit_aggregate(self, aggregate: Aggregate) -> RelationalNode: aggregate, remaining_filters, pushable_filters ) + def infer_extra_join_filters( + self, + join: Join, + input_idx: int, + original_filters: set[RelationalExpression], + ) -> set[RelationalExpression]: + """ + Infers any extra filters that can be deduced from the join condition + that can be pushed into one of the join inputs. + + Args: + `join`: The join node whose condition is to be analyzed. + `input_idx`: The index of the input for which to infer extra + filters. + `original_filters`: The original set of filters that are being + pushed from above the join into its inputs. + + Returns: + A set of relational expressions representing the inferred filters. + """ + inferred_filters: set[RelationalExpression] = set() + + # Cannot infer any extra filters for ANTI joins, or for LEFT joins + # pushing into the right-hand side. + if join.join_type == JoinType.ANTI or ( + join.join_type == JoinType.LEFT and input_idx == 1 + ): + return inferred_filters + + transposer: ExpressionTranspositionShuttle = ExpressionTranspositionShuttle( + join, True + ) + transposed_conds = { + cond.accept_shuttle(transposer) for cond in original_filters + } + + # Extract all equality conditions from the join condition, then build + # up equality sets via a union find structure. + lhs_keys, rhs_keys = extract_equijoin_keys(join, transposed_conds) + equality_sets: dict[RelationalExpression, RelationalExpression] = {} + for lhs_key in lhs_keys: + equality_sets[lhs_key] = lhs_key + for rhs_key in rhs_keys: + equality_sets[rhs_key] = rhs_key + + def find(expr: RelationalExpression) -> RelationalExpression: + # Finds the root representative of the equality set containing + # `expr`, applying path compression along the way. + parent = equality_sets.get(expr, expr) + if parent != expr: + parent = find(parent) + equality_sets[expr] = parent + return parent + + def union(expr1: RelationalExpression, expr2: RelationalExpression) -> None: + # Unites the equality sets containing `expr1` and `expr2`. + root1 = find(expr1) + root2 = find(expr2) + if root1 != root2: + equality_sets[root1] = root2 + + # The equality sets are built by uniting all of the equality operations + # in the join condition. + for expr in get_conjunctions(join.condition): + if isinstance(expr, CallExpression) and expr.op == pydop.EQU: + union(expr.inputs[0], expr.inputs[1]) + + rev_equality_sets: dict[RelationalExpression, set[RelationalExpression]] = {} + for key, value in equality_sets.items(): + rev_equality_sets.setdefault(value, set()).add(key) + + finder: ColumnReferenceFinder = ColumnReferenceFinder() + for eq_set in rev_equality_sets.values(): + if len(eq_set) > 1: + for a, b in itertools.combinations(eq_set, 2): + finder.reset() + new_cond: RelationalExpression = CallExpression( + pydop.EQU, + BooleanType(), + [a, b], + ) + new_cond.accept(finder) + col_refs: set[ColumnReference] = finder.get_column_references() + if {c.input_name for c in col_refs} == { + join.default_input_aliases[input_idx] + }: + inferred_filters.add(new_cond) + + # Iterate through all the keys from the specified input side. If any + # are in the same equality set as one another, add such a condition. + # Keep track of which keys from the other side map to keys from the + # desired side. + keys: list[ColumnReference] = lhs_keys + rhs_keys + key_remapping: dict[ColumnReference, set[ColumnReference]] = {} + for i in range(len(keys)): + for j in range(i + 1, len(keys)): + if keys[i] != keys[j] and find(keys[i]) == find(keys[j]): + new_cond = CallExpression( + pydop.EQU, + BooleanType(), + [add_input_name(keys[i], None), add_input_name(keys[j], None)], + ) + if ( + keys[i].input_name == join.default_input_aliases[input_idx] + and keys[j].input_name == join.default_input_aliases[input_idx] + ): + inferred_filters.add(new_cond) + elif ( + keys[i].input_name != join.default_input_aliases[input_idx] + and keys[j].input_name == join.default_input_aliases[input_idx] + ): + key_remapping.setdefault(keys[i], set()).add(keys[j]) + elif ( + keys[i].input_name == join.default_input_aliases[input_idx] + and keys[j].input_name != join.default_input_aliases[input_idx] + ): + key_remapping.setdefault(keys[j], set()).add(keys[i]) + + # Additionally, if there filters that apply to a different side of the + # join, try to transform them into filters on this side via the same + # substitution. + if len(key_remapping) > 0 and len(transposed_conds) > 0: + self.add_transitive_filters( + join, input_idx, transposed_conds, key_remapping, inferred_filters + ) + + return inferred_filters + + def add_transitive_filters( + self, + join: Join, + input_idx: int, + original_filters: set[RelationalExpression], + key_remapping: dict[ColumnReference, set[ColumnReference]], + filter_set: set[RelationalExpression], + ): + """ + TODO + """ + allowed_columns: set[ColumnReference] = set(key_remapping) + current_columns: set[ColumnReference] = set() + for name, expr in join.inputs[input_idx].columns.items(): + current_columns.add( + ColumnReference( + name, expr.data_type, join.default_input_aliases[input_idx] + ) + ) + allowed_columns.update(current_columns) + reference_finder: ColumnReferenceFinder = ColumnReferenceFinder() + key_substitution: dict[RelationalExpression, RelationalExpression] = {} + for key, other_keys in key_remapping.items(): + key_substitution[key] = min(other_keys, key=repr) + for cond in original_filters: + reference_finder.reset() + cond.accept(reference_finder) + col_refs: set[ColumnReference] = reference_finder.get_column_references() + if (col_refs <= allowed_columns) and not (col_refs <= current_columns): + new_cond: RelationalExpression = apply_substitution( + cond, key_substitution, {} + ) + filter_set.add(add_input_name(new_cond, None)) + def visit_join(self, join: Join) -> RelationalNode: # Identify the set of all column names that correspond to a reference # to a column from one side of the join. @@ -225,6 +394,7 @@ def visit_join(self, join: Join) -> RelationalNode: # reference columns from that input. pushable_filters: set[RelationalExpression] remaining_filters: set[RelationalExpression] = self.filters + original_filters: set[RelationalExpression] = self.filters transposer: ExpressionTranspositionShuttle = ExpressionTranspositionShuttle( join, False ) @@ -242,6 +412,32 @@ def visit_join(self, join: Join) -> RelationalNode: remaining_filters, lambda expr: only_references_columns(expr, input_cols[idx]), ) + + pushable_filters = { + expr.accept_shuttle(transposer) for expr in pushable_filters + } + + # Find any extra filters that can be deduced from the join + # condition, e.g. if `t0.a = t1.b` and `t0.a = t1.c` are in the join + # condition, then we can infer an extra filter `t1.b = t1.c`. Add + # these filters to the pushable filters. + if join.join_type == JoinType.INNER: + pushable_filters.update( + self.infer_extra_join_filters( + join, + idx, + original_filters, + ) + ) + + # Simplify all of the pushable filters before pushing them down, in + # case any always-true conditions were added, then remove them so + # we do not incorrectly think filters are being added. + pushable_filters = { + expr.accept_shuttle(self.simplifier) for expr in pushable_filters + } + pushable_filters.discard(LiteralExpression(True, BooleanType())) + # Ensure that if any filter is pushed into an input, the # corresponding join cardinality is updated to reflect that a filter # has been applied. @@ -250,9 +446,7 @@ def visit_join(self, join: Join) -> RelationalNode: cardinality = join.cardinality.add_filter() else: reverse_cardinality = reverse_cardinality.add_filter() - pushable_filters = { - expr.accept_shuttle(transposer) for expr in pushable_filters - } + # Transform the child input with the filters that can be # pushed down. self.filters = pushable_filters diff --git a/pydough/conversion/relational_converter.py b/pydough/conversion/relational_converter.py index 35c98e75e..b834c2237 100644 --- a/pydough/conversion/relational_converter.py +++ b/pydough/conversion/relational_converter.py @@ -57,6 +57,7 @@ from .agg_removal import remove_redundant_aggs from .agg_split import split_partial_aggregates from .column_bubbler import bubble_column_names +from .filter_pullup import pull_filters from .filter_pushdown import push_filters from .hybrid_connection import ConnectionType, HybridConnection from .hybrid_expressions import ( @@ -1603,7 +1604,12 @@ def optimize_relational_tree( # exist to compute a scalar projection and then link it with the data. root = confirm_root(pullup_projections(root)) - # Push filters down as far as possible + # Pull filters above joins before pushing them down as far as possible + print() + print(root.to_tree_string()) + root = confirm_root(pull_filters(root, session)) + print() + print(root.to_tree_string()) root = confirm_root(push_filters(root, session)) # Merge adjacent projections, unless it would result in excessive duplicate diff --git a/pydough/relational/rel_util.py b/pydough/relational/rel_util.py index 220dc6b71..3c49e6e6a 100644 --- a/pydough/relational/rel_util.py +++ b/pydough/relational/rel_util.py @@ -7,6 +7,7 @@ "add_expr_uses", "add_input_name", "apply_substitution", + "bubble_expression", "bubble_uniqueness", "build_filter", "contains_window", @@ -429,12 +430,15 @@ def add_expr_uses( def extract_equijoin_keys( join: Join, + extra_filters: set[RelationalExpression] | None = None, ) -> tuple[list[ColumnReference], list[ColumnReference]]: """ Extracts the equi-join keys from a join condition with two inputs. Args: `join`: the Join node whose condition is being parsed. + `extra_filters`: an optional set of extra filter expressions to + consider as part of the join condition, or applied after it. Returns: A tuple where the first element are the equi-join keys from the LHS, @@ -444,6 +448,8 @@ def extract_equijoin_keys( lhs_keys: list[ColumnReference] = [] rhs_keys: list[ColumnReference] = [] stack: list[RelationalExpression] = [join.condition] + if extra_filters is not None: + stack.extend(list(extra_filters)) lhs_name: str | None = join.default_input_aliases[0] rhs_name: str | None = join.default_input_aliases[1] while stack: @@ -625,6 +631,75 @@ def bubble_uniqueness( return output_uniqueness +def bubble_expression( + expr: RelationalExpression, + revmap: dict[RelationalExpression, RelationalExpression], +) -> RelationalExpression | None: + """ + Attempts to bubble an expression up through a relational node using a + reverse mapping of column references to their corresponding expressions in + the child node. + + Args: + `expr`: The expression to bubble up. + `revmap`: A mapping of column references from the parent node to their + corresponding expressions in the child node. + + Returns: + The bubbled up expression, or None if it could not be bubbled up. + """ + if expr in revmap: + return revmap[expr] + match expr: + case CallExpression(): + new_inputs: list[RelationalExpression] = [] + for arg in expr.inputs: + new_arg: RelationalExpression | None = bubble_expression(arg, revmap) + if new_arg is None: + return None + new_inputs.append(new_arg) + return CallExpression(expr.op, expr.data_type, new_inputs) + case WindowCallExpression(): + new_inputs = [] + new_partition_inputs: list[RelationalExpression] = [] + new_order_inputs: list[ExpressionSortInfo] = [] + for arg in expr.inputs: + new_arg = bubble_expression(arg, revmap) + if new_arg is None: + return None + new_inputs.append(new_arg) + for part_arg in expr.partition_inputs: + new_part_arg: RelationalExpression | None = bubble_expression( + part_arg, revmap + ) + if new_part_arg is None: + return None + new_partition_inputs.append(new_part_arg) + for order_arg in expr.order_inputs: + new_order_expr: RelationalExpression | None = bubble_expression( + order_arg.expr, revmap + ) + if new_order_expr is None: + return None + new_order_inputs.append( + ExpressionSortInfo( + new_order_expr, order_arg.ascending, order_arg.nulls_first + ) + ) + return WindowCallExpression( + expr.op, + expr.data_type, + new_inputs, + new_partition_inputs, + new_order_inputs, + expr.kwargs, + ) + case LiteralExpression(): + return expr + case _: + return None + + def apply_substitution( expr: RelationalExpression, substitutions: dict[RelationalExpression, RelationalExpression], diff --git a/tests/test_plan_refsols/keywords_cast_alias_and_missing_alias.txt b/tests/test_plan_refsols/keywords_cast_alias_and_missing_alias.txt index 2b0812748..6238be7d6 100644 --- a/tests/test_plan_refsols/keywords_cast_alias_and_missing_alias.txt +++ b/tests/test_plan_refsols/keywords_cast_alias_and_missing_alias.txt @@ -1,7 +1,8 @@ ROOT(columns=[('id1', ID2), ('id2', ID), ('fk1_select', `select`), ('fk1_as', `as`), ('fk2_two_words', `two words`)], orderings=[]) JOIN(condition=t0.ID == t1.id, type=INNER, cardinality=SINGULAR_FILTER, reverse_cardinality=PLURAL_FILTER, columns={'ID': t0.ID, 'ID2': t0.ID2, '`as`': t0.`as`, '`select`': t0.`select`, '`two words`': t1.`two words`}) JOIN(condition=t0.ID2 == t1.id, type=INNER, cardinality=SINGULAR_FILTER, reverse_cardinality=PLURAL_FILTER, columns={'ID': t0.ID, 'ID2': t0.ID2, '`as`': t1.`as`, '`select`': t1.`select`}) - SCAN(table=keywords."CAST", columns={'ID': ID, 'ID2': ID2}) + FILTER(condition=ID == 1:numeric, columns={'ID': ID, 'ID2': ID2}) + SCAN(table=keywords."CAST", columns={'ID': ID, 'ID2': ID2}) FILTER(condition=`0 = 0 and '` == '2 "0 = 0 and \'" field name':string, columns={'`as`': `as`, '`select`': `select`, 'id': id}) SCAN(table=keywords."lowercase_detail", columns={"`0 = 0 and '`": `0 = 0 and '`, '`as`': `as`, '`select`': `select`, 'id': id}) FILTER(condition=id == 1:numeric, columns={'`two words`': `two words`, 'id': id}) diff --git a/tests/test_plan_refsols/many_net_filter_10.txt b/tests/test_plan_refsols/many_net_filter_10.txt index 8eb35e3e3..2a0f57cdb 100644 --- a/tests/test_plan_refsols/many_net_filter_10.txt +++ b/tests/test_plan_refsols/many_net_filter_10.txt @@ -1,10 +1,11 @@ ROOT(columns=[('n', n_rows)], orderings=[]) AGGREGATE(keys={}, aggregations={'n_rows': COUNT()}) JOIN(condition=t0.n_nationkey == t1.c_nationkey & t1.c_custkey == t0.s_suppkey, type=INNER, cardinality=PLURAL_FILTER, reverse_cardinality=SINGULAR_FILTER, columns={}) - JOIN(condition=t0.n_regionkey == t1.n_regionkey, type=INNER, cardinality=PLURAL_FILTER, reverse_cardinality=SINGULAR_ACCESS, columns={'n_nationkey': t1.n_nationkey, 's_suppkey': t0.s_suppkey}) - JOIN(condition=t0.s_nationkey == t1.n_nationkey, type=INNER, cardinality=SINGULAR_ACCESS, reverse_cardinality=PLURAL_ACCESS, columns={'n_regionkey': t1.n_regionkey, 's_suppkey': t0.s_suppkey}) + JOIN(condition=t0.n_regionkey == t1.n_regionkey, type=INNER, cardinality=PLURAL_FILTER, reverse_cardinality=SINGULAR_FILTER, columns={'n_nationkey': t1.n_nationkey, 's_suppkey': t0.s_suppkey}) + JOIN(condition=t0.s_nationkey == t1.n_nationkey, type=INNER, cardinality=SINGULAR_FILTER, reverse_cardinality=PLURAL_ACCESS, columns={'n_regionkey': t1.n_regionkey, 's_suppkey': t0.s_suppkey}) SCAN(table=tpch.SUPPLIER, columns={'s_nationkey': s_nationkey, 's_suppkey': s_suppkey}) - SCAN(table=tpch.NATION, columns={'n_nationkey': n_nationkey, 'n_regionkey': n_regionkey}) + FILTER(condition=n_regionkey == 2:numeric, columns={'n_nationkey': n_nationkey, 'n_regionkey': n_regionkey}) + SCAN(table=tpch.NATION, columns={'n_nationkey': n_nationkey, 'n_regionkey': n_regionkey}) FILTER(condition=n_regionkey == 2:numeric, columns={'n_nationkey': n_nationkey, 'n_regionkey': n_regionkey}) SCAN(table=tpch.NATION, columns={'n_nationkey': n_nationkey, 'n_regionkey': n_regionkey}) SCAN(table=tpch.CUSTOMER, columns={'c_custkey': c_custkey, 'c_nationkey': c_nationkey}) diff --git a/tests/test_plan_refsols/many_net_filter_11.txt b/tests/test_plan_refsols/many_net_filter_11.txt index 6122ff62a..b003169ac 100644 --- a/tests/test_plan_refsols/many_net_filter_11.txt +++ b/tests/test_plan_refsols/many_net_filter_11.txt @@ -5,9 +5,9 @@ ROOT(columns=[('n', n_rows)], orderings=[]) JOIN(condition=t0.s_nationkey == t1.n_nationkey, type=INNER, cardinality=SINGULAR_FILTER, reverse_cardinality=PLURAL_FILTER, columns={'n_regionkey': t1.n_regionkey, 's_suppkey': t0.s_suppkey}) FILTER(condition=NOT(ISIN(s_nationkey, [0, 3, 6, 9, 12, 15, 18, 21, 24]:array[unknown])), columns={'s_nationkey': s_nationkey, 's_suppkey': s_suppkey}) SCAN(table=tpch.SUPPLIER, columns={'s_nationkey': s_nationkey, 's_suppkey': s_suppkey}) - FILTER(condition=n_regionkey < 3:numeric, columns={'n_nationkey': n_nationkey, 'n_regionkey': n_regionkey}) + FILTER(condition=n_regionkey < 3:numeric & n_regionkey > 0:numeric, columns={'n_nationkey': n_nationkey, 'n_regionkey': n_regionkey}) SCAN(table=tpch.NATION, columns={'n_nationkey': n_nationkey, 'n_regionkey': n_regionkey}) - FILTER(condition=n_regionkey > 0:numeric, columns={'n_nationkey': n_nationkey, 'n_regionkey': n_regionkey}) + FILTER(condition=n_regionkey > 0:numeric & NOT(ISIN(n_nationkey, [1, 4, 7, 10, 13, 16, 19, 22]:array[unknown])), columns={'n_nationkey': n_nationkey, 'n_regionkey': n_regionkey}) SCAN(table=tpch.NATION, columns={'n_nationkey': n_nationkey, 'n_regionkey': n_regionkey}) FILTER(condition=NOT(ISIN(c_nationkey, [1, 4, 7, 10, 13, 16, 19, 22]:array[unknown])), columns={'c_custkey': c_custkey, 'c_nationkey': c_nationkey}) SCAN(table=tpch.CUSTOMER, columns={'c_custkey': c_custkey, 'c_nationkey': c_nationkey}) diff --git a/tests/test_plan_refsols/many_net_filter_3.txt b/tests/test_plan_refsols/many_net_filter_3.txt index 0a679d928..5c93ef290 100644 --- a/tests/test_plan_refsols/many_net_filter_3.txt +++ b/tests/test_plan_refsols/many_net_filter_3.txt @@ -1,6 +1,7 @@ ROOT(columns=[('n', n_rows)], orderings=[]) AGGREGATE(keys={}, aggregations={'n_rows': COUNT()}) JOIN(condition=t0.s_nationkey == t1.c_nationkey & t1.c_custkey == t0.s_suppkey, type=INNER, cardinality=PLURAL_FILTER, reverse_cardinality=SINGULAR_FILTER, columns={}) - SCAN(table=tpch.SUPPLIER, columns={'s_nationkey': s_nationkey, 's_suppkey': s_suppkey}) + FILTER(condition=s_nationkey == 3:numeric, columns={'s_nationkey': s_nationkey, 's_suppkey': s_suppkey}) + SCAN(table=tpch.SUPPLIER, columns={'s_nationkey': s_nationkey, 's_suppkey': s_suppkey}) FILTER(condition=c_nationkey == 3:numeric, columns={'c_custkey': c_custkey, 'c_nationkey': c_nationkey}) SCAN(table=tpch.CUSTOMER, columns={'c_custkey': c_custkey, 'c_nationkey': c_nationkey}) diff --git a/tests/test_plan_refsols/many_net_filter_5.txt b/tests/test_plan_refsols/many_net_filter_5.txt index 5c3449d2a..0347d1c40 100644 --- a/tests/test_plan_refsols/many_net_filter_5.txt +++ b/tests/test_plan_refsols/many_net_filter_5.txt @@ -2,8 +2,9 @@ ROOT(columns=[('n', n_rows)], orderings=[]) AGGREGATE(keys={}, aggregations={'n_rows': COUNT()}) JOIN(condition=t0.n_nationkey == t1.c_nationkey & t1.c_custkey == t0.s_suppkey, type=INNER, cardinality=PLURAL_FILTER, reverse_cardinality=SINGULAR_FILTER, columns={}) JOIN(condition=t0.n_regionkey == t1.n_regionkey, type=INNER, cardinality=PLURAL_ACCESS, reverse_cardinality=SINGULAR_FILTER, columns={'n_nationkey': t1.n_nationkey, 's_suppkey': t0.s_suppkey}) - JOIN(condition=t0.s_nationkey == t1.n_nationkey, type=INNER, cardinality=SINGULAR_FILTER, reverse_cardinality=PLURAL_ACCESS, columns={'n_regionkey': t1.n_regionkey, 's_suppkey': t0.s_suppkey}) - SCAN(table=tpch.SUPPLIER, columns={'s_nationkey': s_nationkey, 's_suppkey': s_suppkey}) + JOIN(condition=t0.s_nationkey == t1.n_nationkey, type=INNER, cardinality=SINGULAR_FILTER, reverse_cardinality=PLURAL_FILTER, columns={'n_regionkey': t1.n_regionkey, 's_suppkey': t0.s_suppkey}) + FILTER(condition=s_nationkey == 5:numeric, columns={'s_nationkey': s_nationkey, 's_suppkey': s_suppkey}) + SCAN(table=tpch.SUPPLIER, columns={'s_nationkey': s_nationkey, 's_suppkey': s_suppkey}) FILTER(condition=n_nationkey == 5:numeric, columns={'n_nationkey': n_nationkey, 'n_regionkey': n_regionkey}) SCAN(table=tpch.NATION, columns={'n_nationkey': n_nationkey, 'n_regionkey': n_regionkey}) SCAN(table=tpch.NATION, columns={'n_nationkey': n_nationkey, 'n_regionkey': n_regionkey}) diff --git a/tests/test_plan_refsols/many_net_filter_7.txt b/tests/test_plan_refsols/many_net_filter_7.txt index fa1027a5b..48853c216 100644 --- a/tests/test_plan_refsols/many_net_filter_7.txt +++ b/tests/test_plan_refsols/many_net_filter_7.txt @@ -1,10 +1,11 @@ ROOT(columns=[('n', n_rows)], orderings=[]) AGGREGATE(keys={}, aggregations={'n_rows': COUNT()}) JOIN(condition=t0.n_nationkey == t1.c_nationkey & t1.c_custkey == t0.s_suppkey, type=INNER, cardinality=PLURAL_FILTER, reverse_cardinality=SINGULAR_FILTER, columns={}) - JOIN(condition=t0.n_regionkey == t1.n_regionkey, type=INNER, cardinality=PLURAL_ACCESS, reverse_cardinality=SINGULAR_ACCESS, columns={'n_nationkey': t1.n_nationkey, 's_suppkey': t0.s_suppkey}) + JOIN(condition=t0.n_regionkey == t1.n_regionkey, type=INNER, cardinality=PLURAL_FILTER, reverse_cardinality=SINGULAR_ACCESS, columns={'n_nationkey': t1.n_nationkey, 's_suppkey': t0.s_suppkey}) JOIN(condition=t0.s_nationkey == t1.n_nationkey, type=INNER, cardinality=SINGULAR_ACCESS, reverse_cardinality=PLURAL_ACCESS, columns={'n_regionkey': t1.n_regionkey, 's_suppkey': t0.s_suppkey}) SCAN(table=tpch.SUPPLIER, columns={'s_nationkey': s_nationkey, 's_suppkey': s_suppkey}) SCAN(table=tpch.NATION, columns={'n_nationkey': n_nationkey, 'n_regionkey': n_regionkey}) - SCAN(table=tpch.NATION, columns={'n_nationkey': n_nationkey, 'n_regionkey': n_regionkey}) + FILTER(condition=n_nationkey == 7:numeric, columns={'n_nationkey': n_nationkey, 'n_regionkey': n_regionkey}) + SCAN(table=tpch.NATION, columns={'n_nationkey': n_nationkey, 'n_regionkey': n_regionkey}) FILTER(condition=c_nationkey == 7:numeric, columns={'c_custkey': c_custkey, 'c_nationkey': c_nationkey}) SCAN(table=tpch.CUSTOMER, columns={'c_custkey': c_custkey, 'c_nationkey': c_nationkey}) diff --git a/tests/test_plan_refsols/top_lineitems_info_2.txt b/tests/test_plan_refsols/top_lineitems_info_2.txt index aff71c57c..d6bed81db 100644 --- a/tests/test_plan_refsols/top_lineitems_info_2.txt +++ b/tests/test_plan_refsols/top_lineitems_info_2.txt @@ -1,7 +1,7 @@ ROOT(columns=[('order_key', l_orderkey), ('line_number', l_linenumber), ('part_size', p_size), ('supplier_nation', n_nationkey)], orderings=[(l_orderkey):asc_first, (l_linenumber):asc_first], limit=7:numeric) JOIN(condition=t0.ps_partkey == t1.l_partkey & t0.supplier_key_11 == t1.l_suppkey & t1.l_partkey == t0.p_partkey & t1.l_suppkey == t0.ps_suppkey, type=INNER, cardinality=PLURAL_FILTER, reverse_cardinality=SINGULAR_FILTER, columns={'l_linenumber': t1.l_linenumber, 'l_orderkey': t1.l_orderkey, 'n_nationkey': t0.n_nationkey, 'p_size': t0.p_size}) - JOIN(condition=t0.s_suppkey == t1.ps_suppkey, type=INNER, cardinality=PLURAL_ACCESS, reverse_cardinality=SINGULAR_ACCESS, columns={'n_nationkey': t0.n_nationkey, 'p_partkey': t0.p_partkey, 'p_size': t0.p_size, 'ps_partkey': t1.ps_partkey, 'ps_suppkey': t0.ps_suppkey, 'supplier_key_11': t1.ps_suppkey}) - JOIN(condition=t0.n_nationkey == t1.s_nationkey, type=INNER, cardinality=PLURAL_ACCESS, reverse_cardinality=SINGULAR_ACCESS, columns={'n_nationkey': t0.n_nationkey, 'p_partkey': t0.p_partkey, 'p_size': t0.p_size, 'ps_suppkey': t0.ps_suppkey, 's_suppkey': t1.s_suppkey}) + JOIN(condition=t0.p_partkey == t1.ps_partkey & t0.ps_suppkey == t1.ps_suppkey & t0.s_suppkey == t1.ps_suppkey & t1.ps_partkey == t0.p_partkey & t1.ps_suppkey == t0.ps_suppkey, type=INNER, cardinality=PLURAL_FILTER, reverse_cardinality=SINGULAR_FILTER, columns={'n_nationkey': t0.n_nationkey, 'p_partkey': t0.p_partkey, 'p_size': t0.p_size, 'ps_partkey': t1.ps_partkey, 'ps_suppkey': t0.ps_suppkey, 'supplier_key_11': t1.ps_suppkey}) + JOIN(condition=t0.n_nationkey == t1.s_nationkey & t0.ps_suppkey == t1.s_suppkey & t1.s_suppkey == t0.ps_suppkey, type=INNER, cardinality=PLURAL_FILTER, reverse_cardinality=SINGULAR_FILTER, columns={'n_nationkey': t0.n_nationkey, 'p_partkey': t0.p_partkey, 'p_size': t0.p_size, 'ps_suppkey': t0.ps_suppkey, 's_suppkey': t1.s_suppkey}) JOIN(condition=True:bool, type=INNER, cardinality=PLURAL_ACCESS, reverse_cardinality=SINGULAR_ACCESS, columns={'n_nationkey': t1.n_nationkey, 'p_partkey': t0.p_partkey, 'p_size': t0.p_size, 'ps_suppkey': t0.ps_suppkey}) JOIN(condition=t0.p_partkey == t1.ps_partkey, type=INNER, cardinality=PLURAL_ACCESS, reverse_cardinality=SINGULAR_ACCESS, columns={'p_partkey': t0.p_partkey, 'p_size': t0.p_size, 'ps_suppkey': t1.ps_suppkey}) SCAN(table=tpch.PART, columns={'p_partkey': p_partkey, 'p_size': p_size}) diff --git a/tests/test_sql_refsols/keywords_cast_alias_and_missing_alias_ansi.sql b/tests/test_sql_refsols/keywords_cast_alias_and_missing_alias_ansi.sql index 67208d505..75f4f0c9c 100644 --- a/tests/test_sql_refsols/keywords_cast_alias_and_missing_alias_ansi.sql +++ b/tests/test_sql_refsols/keywords_cast_alias_and_missing_alias_ansi.sql @@ -10,3 +10,5 @@ JOIN keywords."lowercase_detail" AS "lowercase_detail" AND "lowercase_detail"."0 = 0 and '" = '2 "0 = 0 and ''" field name' JOIN keywords."lowercase_detail" AS lowercase_detail_2 ON "CAST".id = lowercase_detail_2.id AND lowercase_detail_2.id = 1 +WHERE + "CAST".id = 1 diff --git a/tests/test_sql_refsols/keywords_cast_alias_and_missing_alias_mysql.sql b/tests/test_sql_refsols/keywords_cast_alias_and_missing_alias_mysql.sql index 7dbd9574a..18640ce45 100644 --- a/tests/test_sql_refsols/keywords_cast_alias_and_missing_alias_mysql.sql +++ b/tests/test_sql_refsols/keywords_cast_alias_and_missing_alias_mysql.sql @@ -10,3 +10,5 @@ JOIN keywords.`lowercase_detail` AS `lowercase_detail` AND `lowercase_detail`.`0 = 0 and '` = '2 "0 = 0 and ''" field name' JOIN keywords.`lowercase_detail` AS lowercase_detail_2 ON `CAST`.id = lowercase_detail_2.id AND lowercase_detail_2.id = 1 +WHERE + `CAST`.id = 1 diff --git a/tests/test_sql_refsols/keywords_cast_alias_and_missing_alias_postgres.sql b/tests/test_sql_refsols/keywords_cast_alias_and_missing_alias_postgres.sql index 67208d505..75f4f0c9c 100644 --- a/tests/test_sql_refsols/keywords_cast_alias_and_missing_alias_postgres.sql +++ b/tests/test_sql_refsols/keywords_cast_alias_and_missing_alias_postgres.sql @@ -10,3 +10,5 @@ JOIN keywords."lowercase_detail" AS "lowercase_detail" AND "lowercase_detail"."0 = 0 and '" = '2 "0 = 0 and ''" field name' JOIN keywords."lowercase_detail" AS lowercase_detail_2 ON "CAST".id = lowercase_detail_2.id AND lowercase_detail_2.id = 1 +WHERE + "CAST".id = 1 diff --git a/tests/test_sql_refsols/keywords_cast_alias_and_missing_alias_snowflake.sql b/tests/test_sql_refsols/keywords_cast_alias_and_missing_alias_snowflake.sql index 70b5136ea..5eafd373c 100644 --- a/tests/test_sql_refsols/keywords_cast_alias_and_missing_alias_snowflake.sql +++ b/tests/test_sql_refsols/keywords_cast_alias_and_missing_alias_snowflake.sql @@ -10,3 +10,5 @@ JOIN keywords."lowercase_detail" AS "lowercase_detail" AND "lowercase_detail"."0 = 0 and '" = '2 "0 = 0 and \'" field name' JOIN keywords."lowercase_detail" AS lowercase_detail_2 ON "CAST".id = lowercase_detail_2.id AND lowercase_detail_2.id = 1 +WHERE + "CAST".id = 1 diff --git a/tests/test_sql_refsols/keywords_cast_alias_and_missing_alias_sqlite.sql b/tests/test_sql_refsols/keywords_cast_alias_and_missing_alias_sqlite.sql index 0837c18bf..9bd259ebe 100644 --- a/tests/test_sql_refsols/keywords_cast_alias_and_missing_alias_sqlite.sql +++ b/tests/test_sql_refsols/keywords_cast_alias_and_missing_alias_sqlite.sql @@ -10,3 +10,5 @@ JOIN keywords."lowercase_detail" AS "lowercase_detail" AND "lowercase_detail"."0 = 0 and '" = '2 "0 = 0 and ''" field name' JOIN keywords."lowercase_detail" AS lowercase_detail_2 ON "cast".id = lowercase_detail_2.id AND lowercase_detail_2.id = 1 +WHERE + "cast".id = 1