Skip to content

Commit

Permalink
Replace is_table with as_sql_table_node (#1502)
Browse files Browse the repository at this point in the history
This removes `SqlQueryPlanNode.is_table` and replaces it with a typed
accessor for parallelism with `as_select_node`. This is helpful in
handling some use cases for CTEs.
  • Loading branch information
plypaul authored Nov 10, 2024
1 parent dd090a2 commit 2eaf47a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 27 deletions.
4 changes: 2 additions & 2 deletions metricflow/sql/render/sql_plan_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _render_from_section(self, from_source: SqlQueryPlanNode, from_source_alias:
from_render_result = self._render_node(from_source)

from_section_lines = []
if from_source.is_table:
if from_source.as_sql_table_node is not None:
from_section_lines.append(f"FROM {from_render_result.sql} {from_source_alias}")
else:
from_section_lines.append("FROM (")
Expand Down Expand Up @@ -228,7 +228,7 @@ def _render_joins_section(self, join_descriptions: Sequence[SqlJoinDescription])
on_condition_rendered = self.EXPR_RENDERER.render_sql_expr(join_description.on_condition)
params = params.merge(on_condition_rendered.bind_parameter_set)

if join_description.right_source.is_table:
if join_description.right_source.as_sql_table_node is not None:
join_section_lines.append(join_description.join_type.value)
join_section_lines.append(
textwrap.indent(
Expand Down
52 changes: 27 additions & 25 deletions metricflow/sql/sql_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOut

@property
@abstractmethod
def is_table(self) -> bool:
"""Returns whether this node resolves to a table (vs. a query)."""
def as_select_node(self) -> Optional[SqlSelectStatementNode]:
"""If possible, return this as a select statement node."""
raise NotImplementedError

@property
@abstractmethod
def as_select_node(self) -> Optional[SqlSelectStatementNode]:
"""If possible, return this as a select statement node."""
def as_sql_table_node(self) -> Optional[SqlTableNode]:
"""If possible, return this as SQL table node."""
raise NotImplementedError

@abstractmethod
Expand Down Expand Up @@ -208,14 +208,15 @@ def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_select_statement_node(self)

@property
def is_table(self) -> bool: # noqa: D102
return False

@property
def as_select_node(self) -> Optional[SqlSelectStatementNode]: # noqa: D102
return self

@property
@override
def as_sql_table_node(self) -> Optional[SqlTableNode]:
return None

@property
@override
def description(self) -> str:
Expand Down Expand Up @@ -271,10 +272,6 @@ def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_table_node(self)

@property
def is_table(self) -> bool: # noqa: D102
return True

@property
def as_select_node(self) -> Optional[SqlSelectStatementNode]: # noqa: D102
return None
Expand All @@ -289,6 +286,11 @@ def nearest_select_columns(
return cte_node.nearest_select_columns(cte_source_mapping)
return None

@property
@override
def as_sql_table_node(self) -> Optional[SqlTableNode]:
return self


@dataclass(frozen=True, eq=False)
class SqlSelectQueryFromClauseNode(SqlQueryPlanNode):
Expand Down Expand Up @@ -318,10 +320,6 @@ def description(self) -> str: # noqa: D102
def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_query_from_clause_node(self)

@property
def is_table(self) -> bool: # noqa: D102
return False

@property
def as_select_node(self) -> Optional[SqlSelectStatementNode]: # noqa: D102
return None
Expand All @@ -332,14 +330,18 @@ def nearest_select_columns(
) -> Optional[Sequence[SqlSelectColumn]]:
return None

@property
@override
def as_sql_table_node(self) -> Optional[SqlTableNode]:
return None


@dataclass(frozen=True, eq=False)
class SqlCreateTableAsNode(SqlQueryPlanNode):
"""An SQL node representing a CREATE TABLE AS statement.
Attributes:
sql_table: The SQL table to create.
parent_node: The parent query plan node.
"""

sql_table: SqlTable
Expand All @@ -361,12 +363,12 @@ def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOut

@property
@override
def is_table(self) -> bool:
return False
def as_select_node(self) -> Optional[SqlSelectStatementNode]:
return None

@property
@override
def as_select_node(self) -> Optional[SqlSelectStatementNode]:
def as_sql_table_node(self) -> Optional[SqlTableNode]:
return None

@property
Expand Down Expand Up @@ -415,15 +417,15 @@ def render_node(self) -> SqlQueryPlanNode: # noqa: D102
class SqlCteNode(SqlQueryPlanNode):
"""Represents a single common table expression."""

select_statement: SqlSelectStatementNode
select_statement: SqlQueryPlanNode
cte_alias: str

def __post_init__(self) -> None: # noqa: D105
super().__post_init__()
assert len(self.parent_nodes) == 1

@staticmethod
def create(select_statement: SqlSelectStatementNode, cte_alias: str) -> SqlCteNode: # noqa: D102
def create(select_statement: SqlQueryPlanNode, cte_alias: str) -> SqlCteNode: # noqa: D102
return SqlCteNode(
parent_nodes=(select_statement,),
select_statement=select_statement,
Expand All @@ -436,12 +438,12 @@ def accept(self, visitor: SqlQueryPlanNodeVisitor[VisitorOutputT]) -> VisitorOut

@property
@override
def is_table(self) -> bool:
return False
def as_select_node(self) -> Optional[SqlSelectStatementNode]:
return None

@property
@override
def as_select_node(self) -> Optional[SqlSelectStatementNode]:
def as_sql_table_node(self) -> Optional[SqlTableNode]:
return None

@property
Expand Down

0 comments on commit 2eaf47a

Please sign in to comment.