Skip to content

Commit

Permalink
Fix bug with missed columns in the column pruner (#1679)
Browse files Browse the repository at this point in the history
This PR fixes a bug where the column pruner does not correctly keep
track of columns that are used in `SqlStringExpression` and
`SqlColumnAliasReferenceExpression`. Those expressions are different
from other expressions as they do not have a table alias, and there is
existing case handling for those expressions. The bug is that the
handler did not propagate the knowledge that columns from those
expressions is required to the available CTEs.

This bug currently does not affect queries due to the way that we wrap
access to CTEs in a `SELECT`, but fixing this as it affects later work.
  • Loading branch information
plypaul authored Feb 23, 2025
1 parent 0d8411f commit 24204e7
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,13 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> None:
join_desc.right_source for join_desc in node.join_descs
):
self._current_required_column_alias_mapping.add_aliases(node_to_retain_columns, column_aliases_to_retain)
sql_table_node = node_to_retain_columns.as_sql_table_node
if sql_table_node is not None and sql_table_node.sql_table.schema_name is None:
self._map_required_column_aliases_in_potential_cte(
cte_alias_mapping=cte_alias_mapping,
table_name=sql_table_node.sql_table.table_name,
column_aliases=column_aliases_to_retain,
)

# Visit recursively.
self._visit_parents(node)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
test_name: test_column_reference_expression
test_filename: test_cte_column_pruner.py
docstring:
Test a column reference expression that does not specify a table alias.
expectation_description:
`cte_source_0__col_01` should be retained in the CTE.
---
optimizer:
SqlColumnPrunerOptimizer

sql_before_optimizing:
-- Top-level SELECT
WITH cte_source_0 AS (
-- CTE source 0
SELECT
test_table_alias.col_0 AS cte_source_0__col_0
, test_table_alias.col_0 AS cte_source_0__col_1
FROM test_schema.test_table test_table_alias
)

SELECT
cte_source_0__col_0 AS top_level__col_0
FROM cte_source_0 cte_source_0_alias

sql_after_optimizing:
-- Top-level SELECT
WITH cte_source_0 AS (
-- CTE source 0
SELECT
test_table_alias.col_0 AS cte_source_0__col_0
FROM test_schema.test_table test_table_alias
)

SELECT
cte_source_0__col_0 AS top_level__col_0
FROM cte_source_0 cte_source_0_alias
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
test_name: test_string_expression
test_filename: test_cte_column_pruner.py
docstring:
Test a string expression that references a column in the cte.
expectation_description:
`cte_source_0__col_01` should be retained in the CTE.
---
optimizer:
SqlColumnPrunerOptimizer

sql_before_optimizing:
-- Top-level SELECT
WITH cte_source_0 AS (
-- CTE source 0
SELECT
test_table_alias.col_0 AS cte_source_0__col_0
, test_table_alias.col_0 AS cte_source_0__col_1
FROM test_schema.test_table test_table_alias
)

SELECT
cte_source_0__col_0 AS top_level__col_0
FROM cte_source_0 cte_source_0_alias

sql_after_optimizing:
-- Top-level SELECT
WITH cte_source_0 AS (
-- CTE source 0
SELECT
test_table_alias.col_0 AS cte_source_0__col_0
FROM test_schema.test_table test_table_alias
)

SELECT
cte_source_0__col_0 AS top_level__col_0
FROM cte_source_0 cte_source_0_alias
109 changes: 109 additions & 0 deletions tests_metricflow/sql/optimizer/test_cte_column_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
SqlColumnReferenceExpression,
SqlComparison,
SqlComparisonExpression,
SqlStringExpression,
)
from metricflow_semantics.sql.sql_join_type import SqlJoinType
from metricflow_semantics.sql.sql_table import SqlTable
Expand Down Expand Up @@ -464,3 +465,111 @@ def test_common_cte_aliases_in_nested_query(
"""
),
)


def test_string_expression(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
column_pruner: SqlColumnPrunerOptimizer,
sql_plan_renderer: DefaultSqlPlanRenderer,
) -> None:
"""Test a string expression that references a column in the cte."""
select_statement = SqlSelectStatementNode.create(
description="Top-level SELECT",
select_columns=(
SqlSelectColumn(
expr=SqlStringExpression.create(sql_expr="cte_source_0__col_0", used_columns=("cte_source_0__col_0",)),
column_alias="top_level__col_0",
),
),
from_source=SqlTableNode.create(sql_table=SqlTable(schema_name=None, table_name="cte_source_0")),
from_source_alias="cte_source_0_alias",
cte_sources=(
SqlCteNode.create(
cte_alias="cte_source_0",
select_statement=SqlSelectStatementNode.create(
description="CTE source 0",
select_columns=(
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0")
),
column_alias="cte_source_0__col_0",
),
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0")
),
column_alias="cte_source_0__col_1",
),
),
from_source=SqlTableNode.create(
sql_table=SqlTable(schema_name="test_schema", table_name="test_table")
),
from_source_alias="test_table_alias",
),
),
),
)
assert_optimizer_result_snapshot_equal(
request=request,
mf_test_configuration=mf_test_configuration,
optimizer=column_pruner,
sql_plan_renderer=sql_plan_renderer,
select_statement=select_statement,
expectation_description="`cte_source_0__col_01` should be retained in the CTE.",
)


def test_column_reference_expression(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
column_pruner: SqlColumnPrunerOptimizer,
sql_plan_renderer: DefaultSqlPlanRenderer,
) -> None:
"""Test a column reference expression that does not specify a table alias."""
select_statement = SqlSelectStatementNode.create(
description="Top-level SELECT",
select_columns=(
SqlSelectColumn(
expr=SqlStringExpression.create(sql_expr="cte_source_0__col_0", used_columns=("cte_source_0__col_0",)),
column_alias="top_level__col_0",
),
),
from_source=SqlTableNode.create(sql_table=SqlTable(schema_name=None, table_name="cte_source_0")),
from_source_alias="cte_source_0_alias",
cte_sources=(
SqlCteNode.create(
cte_alias="cte_source_0",
select_statement=SqlSelectStatementNode.create(
description="CTE source 0",
select_columns=(
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0")
),
column_alias="cte_source_0__col_0",
),
SqlSelectColumn(
expr=SqlColumnReferenceExpression.create(
col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0")
),
column_alias="cte_source_0__col_1",
),
),
from_source=SqlTableNode.create(
sql_table=SqlTable(schema_name="test_schema", table_name="test_table")
),
from_source_alias="test_table_alias",
),
),
),
)
assert_optimizer_result_snapshot_equal(
request=request,
mf_test_configuration=mf_test_configuration,
optimizer=column_pruner,
sql_plan_renderer=sql_plan_renderer,
select_statement=select_statement,
expectation_description="`cte_source_0__col_01` should be retained in the CTE.",
)

0 comments on commit 24204e7

Please sign in to comment.