Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add eq=False to node classes #1500

Merged
merged 1 commit into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions metricflow-semantics/metricflow_semantics/dag/mf_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,14 @@ def visit_node(self, node: DagNode) -> VisitorOutputT: # noqa: D102
DagNodeT = TypeVar("DagNodeT", bound="DagNode")


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a comment here explaining essentially what you explained in the PR description? Will likely be helpful to future devs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

class DagNode(MetricFlowPrettyFormattable, Generic[DagNodeT], ABC):
"""A node in a DAG. These should be immutable."""
"""A node in a DAG. These should be immutable.

Since there should only be a single instance of a node with a given ID, `eq` can be set to false so that equality
operations can be done without comparing the fields. Comparing the fields can be a slow process since the
`parent_nodes` field is recursive.
"""

parent_nodes: Tuple[DagNodeT, ...]

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/dataflow_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
NodeSelfT = TypeVar("NodeSelfT", bound="DataflowPlanNode")


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class DataflowPlanNode(DagNode["DataflowPlanNode"], Visitable, ABC):
"""A node in the graph representation of the dataflow.

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/add_generated_uuid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class AddGeneratedUuidColumnNode(DataflowPlanNode):
"""Adds a UUID column."""

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/aggregate_measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class AggregateMeasuresNode(DataflowPlanNode):
"""A node that aggregates the measures by the associated group by elements.

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/combine_aggregated_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class CombineAggregatedOutputsNode(DataflowPlanNode):
"""Combines metrics from different nodes into a single output."""

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/compute_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class ComputeMetricsNode(DataflowPlanNode):
"""A node that computes metrics from input measures. Dimensions / entities are passed through.

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/constrain_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from metricflow.dataflow.nodes.aggregate_measures import DataflowPlanNode


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class ConstrainTimeRangeNode(DataflowPlanNode):
"""Constrains the time range of the input data set.

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/filter_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class FilterElementsNode(DataflowPlanNode):
"""Only passes the listed elements.

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/join_conversion_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class JoinConversionEventsNode(DataflowPlanNode):
"""Builds a data set containing successful conversion events.

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/join_over_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class JoinOverTimeRangeNode(DataflowPlanNode):
"""A node that allows for cumulative metric computation by doing a self join across a cumulative date range.

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/join_to_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __post_init__(self) -> None: # noqa: D105
raise RuntimeError("`join_on_entity` is required unless using CROSS JOIN.")


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class JoinOnEntitiesNode(DataflowPlanNode):
"""A node that joins data from other nodes via the entities in the inputs.

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/join_to_custom_granularity.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class JoinToCustomGranularityNode(DataflowPlanNode, ABC):
"""Join parent dataset to time spine dataset to convert time dimension to a custom granularity.

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/join_to_time_spine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class JoinToTimeSpineNode(DataflowPlanNode, ABC):
"""Join parent dataset to time spine dataset.

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/metric_time_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class MetricTimeDimensionTransformNode(DataflowPlanNode):
"""A node transforms the input data set so that it contains the metric time dimension and relevant measures.

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class MinMaxNode(DataflowPlanNode):
"""Calculate the min and max of a single instance data set."""

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/order_by_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class OrderByLimitNode(DataflowPlanNode):
"""A node that re-orders the input data with a limit.

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/read_sql_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from metricflow.dataset.sql_dataset import SqlDataSet


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class ReadSqlSourceNode(DataflowPlanNode):
"""A source node where data from an SQL table or SQL query is read and output.

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/semi_additive_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SemiAdditiveJoinNode(DataflowPlanNode):
"""A node that performs a row filter by aggregating a given non-additive dimension.

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/where_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from metricflow.dataflow.dataflow_plan import DataflowPlanNode, DataflowPlanNodeVisitor


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class WhereConstraintNode(DataflowPlanNode):
"""Remove rows using a WHERE clause.

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/window_reaggregation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class WindowReaggregationNode(DataflowPlanNode):
"""A node that re-aggregates metrics using window functions.

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/write_to_data_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class WriteToResultDataTableNode(DataflowPlanNode):
"""A node where incoming data gets written to a data_table."""

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataflow/nodes/write_to_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class WriteToResultTableNode(DataflowPlanNode):
"""A node where incoming data gets written to a table.

Expand Down
39 changes: 20 additions & 19 deletions metricflow/sql/sql_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing_extensions import override


@dataclass(frozen=True, order=True)
@dataclass(frozen=True, eq=False)
class SqlExpressionNode(DagNode["SqlExpressionNode"], Visitable, ABC):
"""An SQL expression like my_table.my_column, CONCAT(a, b) or 1 + 1 that evaluates to a value."""

Expand Down Expand Up @@ -230,7 +230,7 @@ def visit_generate_uuid_expr(self, node: SqlGenerateUuidExpression) -> VisitorOu
pass


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlStringExpression(SqlExpressionNode):
"""An SQL expression in a string format, so it lacks information about the structure.

Expand Down Expand Up @@ -314,7 +314,7 @@ def as_string_expression(self) -> Optional[SqlStringExpression]:
return self


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlStringLiteralExpression(SqlExpressionNode):
"""A string literal like 'foo'. It shouldn't include delimiters as it should be added during rendering."""

Expand Down Expand Up @@ -375,7 +375,7 @@ class SqlColumnReference:
column_name: str


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlColumnReferenceExpression(SqlExpressionNode):
"""An expression that evaluates to the value of a column in one of the sources in the select query.

Expand Down Expand Up @@ -475,7 +475,7 @@ def from_table_and_column_names(table_alias: str, column_name: str) -> SqlColumn
return SqlColumnReferenceExpression.create(SqlColumnReference(table_alias=table_alias, column_name=column_name))


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlColumnAliasReferenceExpression(SqlExpressionNode):
"""An expression that evaluates to the alias of a column, but is not qualified with a table alias.

Expand Down Expand Up @@ -544,7 +544,7 @@ class SqlComparison(Enum): # noqa: D101
EQUALS = "="


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlComparisonExpression(SqlExpressionNode):
"""A comparison using >, <, <=, >=, =.

Expand Down Expand Up @@ -698,6 +698,7 @@ def from_aggregation_type(aggregation_type: AggregationType) -> SqlFunction:
assert_values_exhausted(aggregation_type)


@dataclass(frozen=True, eq=False)
class SqlFunctionExpression(SqlExpressionNode):
"""Denotes a function expression in SQL."""

Expand All @@ -723,7 +724,7 @@ def build_expression_from_aggregation_type(
return SqlAggregateFunctionExpression.from_aggregation_type(aggregation_type, sql_column_expression)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlAggregateFunctionExpression(SqlFunctionExpression):
"""An aggregate function expression like SUM(1).

Expand Down Expand Up @@ -857,7 +858,7 @@ def from_aggregation_parameters(agg_params: MeasureAggregationParameters) -> Sql
)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlPercentileExpression(SqlFunctionExpression):
"""A percentile aggregation expression.

Expand Down Expand Up @@ -984,7 +985,7 @@ def suffix(self) -> str:
return " ".join(result)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlWindowFunctionExpression(SqlFunctionExpression):
"""A window function expression like SUM(foo) OVER bar.

Expand Down Expand Up @@ -1101,7 +1102,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlNullExpression(SqlExpressionNode):
"""Represents NULL."""

Expand Down Expand Up @@ -1151,7 +1152,7 @@ class SqlLogicalOperator(Enum):
OR = "OR"


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlLogicalExpression(SqlExpressionNode):
"""A logical expression like "a AND b AND c"."""

Expand Down Expand Up @@ -1203,7 +1204,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return self.operator == other.operator and self._parents_match(other)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlIsNullExpression(SqlExpressionNode):
"""An IS NULL expression like "foo IS NULL"."""

Expand Down Expand Up @@ -1248,7 +1249,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return self._parents_match(other)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlSubtractTimeIntervalExpression(SqlExpressionNode):
"""Represents an interval subtraction from a given timestamp.

Expand Down Expand Up @@ -1313,7 +1314,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return self.count == other.count and self.granularity == other.granularity and self._parents_match(other)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlCastToTimestampExpression(SqlExpressionNode):
"""Cast to the timestamp type like CAST('2020-01-01' AS TIMESTAMP)."""

Expand Down Expand Up @@ -1360,7 +1361,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return self._parents_match(other)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlDateTruncExpression(SqlExpressionNode):
"""Apply a date trunc to a column like CAST('2020-01-01' AS TIMESTAMP)."""

Expand Down Expand Up @@ -1411,7 +1412,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return self.time_granularity == other.time_granularity and self._parents_match(other)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlExtractExpression(SqlExpressionNode):
"""Extract a date part from a time expression.

Expand Down Expand Up @@ -1470,7 +1471,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return self.date_part == other.date_part and self._parents_match(other)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlRatioComputationExpression(SqlExpressionNode):
"""Node for expressing Ratio metrics to allow for appropriate casting to float/double in each engine.

Expand Down Expand Up @@ -1535,7 +1536,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return self._parents_match(other)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlBetweenExpression(SqlExpressionNode):
"""A BETWEEN clause like `column BETWEEN val1 AND val2`.

Expand Down Expand Up @@ -1600,7 +1601,7 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
return self._parents_match(other)


@dataclass(frozen=True)
@dataclass(frozen=True, eq=False)
class SqlGenerateUuidExpression(SqlExpressionNode):
"""Renders a SQL to generate a random UUID, which is non-deterministic."""

Expand Down
Loading
Loading