Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable different granularities for metrics with time offset #426

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 34 additions & 24 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def _build_metrics_output_node(
)
aggregated_measures_node = self.build_aggregated_measures(
metric_input_measure_specs=metric_input_measure_specs,
metric_spec=metric_spec,
queried_linkable_specs=queried_linkable_specs,
where_constraint=combined_where,
time_range_constraint=time_range_constraint,
Expand All @@ -231,16 +232,7 @@ def _build_metrics_output_node(
aggregated_measures_node=aggregated_measures_node,
)

if metric_spec.offset_window or metric_spec.offset_to_grain:
join_to_time_spine_node = JoinToTimeSpineNode(
parent_node=compute_metrics_node,
time_range_constraint=time_range_constraint,
offset_window=metric_spec.offset_window,
offset_to_grain=metric_spec.offset_to_grain,
)
output_nodes.append(join_to_time_spine_node)
else:
output_nodes.append(compute_metrics_node)
output_nodes.append(compute_metrics_node)

assert len(output_nodes) > 0, "ComputeMetricsNode was not properly constructed"

Expand Down Expand Up @@ -601,6 +593,7 @@ def build_computed_metrics_node(
def build_aggregated_measures(
self,
metric_input_measure_specs: Tuple[MetricInputMeasureSpec, ...],
metric_spec: MetricSpec,
queried_linkable_specs: LinkableSpecSet,
where_constraint: Optional[SpecWhereClauseConstraint] = None,
time_range_constraint: Optional[TimeRangeConstraint] = None,
Expand Down Expand Up @@ -659,6 +652,7 @@ def build_aggregated_measures(
output_nodes.append(
self._build_aggregated_measures_from_measure_source_node(
metric_input_measure_specs=input_specs,
metric_spec=metric_spec,
queried_linkable_specs=queried_linkable_specs,
where_constraint=node_where_constraint,
time_range_constraint=time_range_constraint,
Expand Down Expand Up @@ -686,6 +680,7 @@ def build_aggregated_measures(
def _build_aggregated_measures_from_measure_source_node(
self,
metric_input_measure_specs: Tuple[MetricInputMeasureSpec, ...],
metric_spec: MetricSpec,
queried_linkable_specs: LinkableSpecSet,
where_constraint: Optional[SpecWhereClauseConstraint] = None,
time_range_constraint: Optional[TimeRangeConstraint] = None,
Expand Down Expand Up @@ -754,9 +749,35 @@ def _build_aggregated_measures_from_measure_source_node(
f"Recipe not found for measure specs: {measure_specs} and linkable specs: {required_linkable_specs}"
)

metric_time_dimension_spec: Optional[TimeDimensionSpec] = None
for linkable_spec in queried_linkable_specs.time_dimension_specs:
if linkable_spec.element_name == self._metric_time_dimension_reference.element_name:
metric_time_dimension_spec = linkable_spec
break

time_range_node: Optional[JoinOverTimeRangeNode[SqlDataSetT]] = None
if cumulative and metric_time_dimension_spec:
time_range_node = JoinOverTimeRangeNode(
parent_node=measure_recipe.measure_node,
window=cumulative_window,
grain_to_date=cumulative_grain_to_date,
time_range_constraint=time_range_constraint,
)

join_to_time_spine_node: Optional[JoinToTimeSpineNode] = None
if metric_spec.offset_window or metric_spec.offset_to_grain:
assert metric_time_dimension_spec, "Joining to time spine requires querying with a time dimension."
join_to_time_spine_node = JoinToTimeSpineNode(
parent_node=time_range_node or measure_recipe.measure_node,
time_dimension_spec=metric_time_dimension_spec,
time_range_constraint=time_range_constraint,
offset_window=metric_spec.offset_window,
offset_to_grain=metric_spec.offset_to_grain,
)

# Only get the required measure and the local linkable instances so that aggregations work correctly.
filtered_measure_source_node = FilterElementsNode[SqlDataSetT](
parent_node=measure_recipe.measure_node,
parent_node=join_to_time_spine_node or time_range_node or measure_recipe.measure_node,
include_specs=InstanceSpecSet.merge(
(
InstanceSpecSet(measure_specs=measure_specs),
Expand All @@ -765,18 +786,7 @@ def _build_aggregated_measures_from_measure_source_node(
),
)

time_range_node: Optional[JoinOverTimeRangeNode[SqlDataSetT]] = None
if cumulative:
time_range_node = JoinOverTimeRangeNode(
parent_node=filtered_measure_source_node,
window=cumulative_window,
grain_to_date=cumulative_grain_to_date,
time_range_constraint=time_range_constraint,
)

filtered_measure_or_time_range_node = time_range_node or filtered_measure_source_node
join_targets = []

for join_recipe in measure_recipe.join_linkable_instances_recipes:
# Figure out what elements to filter from the joined node.

Expand Down Expand Up @@ -823,7 +833,7 @@ def _build_aggregated_measures_from_measure_source_node(
unaggregated_measure_node: BaseOutput[SqlDataSetT]
if len(join_targets) > 0:
filtered_measures_with_joined_elements = JoinToBaseOutputNode[SqlDataSetT](
left_node=filtered_measure_or_time_range_node,
left_node=filtered_measure_source_node,
join_targets=join_targets,
)

Expand All @@ -839,7 +849,7 @@ def _build_aggregated_measures_from_measure_source_node(
)
unaggregated_measure_node = after_join_filtered_node
else:
unaggregated_measure_node = filtered_measure_or_time_range_node
unaggregated_measure_node = filtered_measure_source_node

cumulative_metric_constrained_node: Optional[ConstrainTimeRangeNode] = None
if (
Expand Down
13 changes: 11 additions & 2 deletions metricflow/dataflow/dataflow_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,7 @@ class JoinToTimeSpineNode(Generic[SourceDataSetT], BaseOutput[SourceDataSetT], A
def __init__(
self,
parent_node: BaseOutput[SourceDataSetT],
time_dimension_spec: TimeDimensionSpec,
time_range_constraint: Optional[TimeRangeConstraint] = None,
offset_window: Optional[MetricTimeWindow] = None,
offset_to_grain: Optional[TimeGranularity] = None,
Expand All @@ -726,6 +727,7 @@ def __init__(

Args:
parent_node: Node that returns desired dataset to join to time spine.
time_dimension_spec: Time dimension requested in query. Used to determine time spine granularity.
time_range_constraint: Time range to constrain the time spine to.
offset_window: Time window to offset the parent dataset by when joining to time spine.
offset_to_grain: Granularity period to offset the parent dataset to when joining to time spine.
Expand All @@ -736,6 +738,7 @@ def __init__(
offset_window and offset_to_grain
), "Can't set both offset_window and offset_to_grain when joining to time spine. Choose one or the other."
self._parent_node = parent_node
self._time_dimension_spec = time_dimension_spec
self._offset_window = offset_window
self._offset_to_grain = offset_to_grain
self._time_range_constraint = time_range_constraint
Expand All @@ -746,6 +749,11 @@ def __init__(
def id_prefix(cls) -> str: # noqa: D
return DATAFLOW_NODE_JOIN_TO_TIME_SPINE_ID_PREFIX

@property
def time_dimension_spec(self) -> TimeDimensionSpec: # noqa: D
"""Time dimension spec to use when creating time spine table."""
return self._time_dimension_spec

@property
def time_range_constraint(self) -> Optional[TimeRangeConstraint]: # noqa: D
"""Time range constraint to apply when querying time spine table."""
Expand Down Expand Up @@ -795,9 +803,10 @@ def with_new_parents( # noqa: D
assert len(new_parent_nodes) == 1
return JoinToTimeSpineNode[SourceDataSetT](
parent_node=new_parent_nodes[0],
time_dimension_spec=self.time_dimension_spec,
time_range_constraint=self.time_range_constraint,
offset_window=self._offset_window,
offset_to_grain=self._offset_to_grain,
offset_window=self.offset_window,
offset_to_grain=self.offset_to_grain,
)


Expand Down
124 changes: 65 additions & 59 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from collections import OrderedDict
from typing import Generic, List, Optional, Sequence, TypeVar, Union
from typing import Generic, List, Optional, Sequence, TypeVar, Union, Tuple

from metricflow.aggregation_properties import AggregationState, AggregationType
from metricflow.column_assoc import ColumnAssociation, SingleColumnCorrelationKey
Expand Down Expand Up @@ -197,34 +197,15 @@ def _make_time_spine_data_set(

# If the requested granularity is the same as the granularity of the spine, do a direct select.
if metric_time_dimension_instance.spec.time_granularity == time_spine_source.time_column_granularity:
return SqlDataSet(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed some redundant code here

instance_set=time_spine_instance_set,
sql_select_node=SqlSelectStatementNode(
description=description,
# This creates select expressions for all columns referenced in the instance set.
select_columns=(
SqlSelectColumn(
expr=SqlColumnReferenceExpression(
SqlColumnReference(
table_alias=time_spine_table_alias,
column_name=time_spine_source.time_column_name,
),
),
column_alias=metric_time_dimension_column_name,
select_columns = (
SqlSelectColumn(
expr=SqlColumnReferenceExpression(
SqlColumnReference(
table_alias=time_spine_table_alias,
column_name=time_spine_source.time_column_name,
),
),
from_source=SqlTableFromClauseNode(sql_table=time_spine_source.spine_table),
from_source_alias=time_spine_table_alias,
joins_descs=(),
group_bys=(),
where=_make_time_range_comparison_expr(
table_alias=time_spine_table_alias,
column_alias=time_spine_source.time_column_name,
time_range_constraint=time_range_constraint,
)
if time_range_constraint
else None,
order_bys=(),
column_alias=metric_time_dimension_column_name,
),
)
# If the granularity is different, apply a DATE_TRUNC() and aggregate.
Expand All @@ -243,26 +224,26 @@ def _make_time_spine_data_set(
column_alias=metric_time_dimension_column_name,
),
)
return SqlDataSet(
instance_set=time_spine_instance_set,
sql_select_node=SqlSelectStatementNode(
description=description,
# This creates select expressions for all columns referenced in the instance set.
select_columns=select_columns,
from_source=SqlTableFromClauseNode(sql_table=time_spine_source.spine_table),
from_source_alias=time_spine_table_alias,
joins_descs=(),
group_bys=select_columns,
where=_make_time_range_comparison_expr(
table_alias=time_spine_table_alias,
column_alias=time_spine_source.time_column_name,
time_range_constraint=time_range_constraint,
)
if time_range_constraint
else None,
order_bys=(),
),
)
return SqlDataSet(
instance_set=time_spine_instance_set,
sql_select_node=SqlSelectStatementNode(
description=description,
# This creates select expressions for all columns referenced in the instance set.
select_columns=select_columns,
from_source=SqlTableFromClauseNode(sql_table=time_spine_source.spine_table),
from_source_alias=time_spine_table_alias,
joins_descs=(),
group_bys=select_columns,
where=_make_time_range_comparison_expr(
table_alias=time_spine_table_alias,
column_alias=time_spine_source.time_column_name,
time_range_constraint=time_range_constraint,
)
if time_range_constraint
else None,
order_bys=(),
),
)

def visit_source_node(self, node: ReadSqlSourceNode[SqlDataSetT]) -> SqlDataSet:
"""Generate the SQL to read from the source."""
Expand All @@ -286,13 +267,9 @@ def visit_join_over_time_range_node(self, node: JoinOverTimeRangeNode[SqlDataSet
metric_time_dimension_spec = instance.spec
break

# If the metric time dimension isn't present in the parent node it's because it wasn't requested
# and therefore we don't need the time range join because we can just let the metric sum over all time
if metric_time_dimension_spec is None:
return input_data_set

time_spine_data_set_alias = self._next_unique_table_alias()

assert metric_time_dimension_spec
metric_time_dimension_column_name = self.column_association_resolver.resolve_time_dimension_spec(
metric_time_dimension_spec
).column_name
Expand Down Expand Up @@ -1334,7 +1311,7 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode[SourceDataSetT
parent_alias=parent_alias,
)

# Use metric_time instance from time spine, all instances EXCEPT metric_time from parent data set.
# Use all instances EXCEPT metric_time from parent data set.
non_metric_time_parent_instance_set = InstanceSet(
measure_instances=parent_data_set.instance_set.measure_instances,
dimension_instances=parent_data_set.instance_set.dimension_instances,
Expand All @@ -1347,16 +1324,45 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode[SourceDataSetT
metric_instances=parent_data_set.instance_set.metric_instances,
metadata_instances=parent_data_set.instance_set.metadata_instances,
)
table_alias_to_instance_set = OrderedDict(
{time_spine_alias: time_spine_dataset.instance_set, parent_alias: non_metric_time_parent_instance_set}
)

# Add requested granularity to time spine specs & columns.
time_dimension_instances: Tuple[TimeDimensionInstance, ...] = tuple()
time_spine_select_columns: Tuple[SqlSelectColumn, ...] = tuple()
for original_time_dim_instance in time_spine_dataset.instance_set.time_dimension_instances:
if original_time_dim_instance.spec.element_name == DataSet.metric_time_dimension_reference().element_name:
new_time_dim_spec = TimeDimensionSpec(
element_name=original_time_dim_instance.spec.element_name,
identifier_links=original_time_dim_instance.spec.identifier_links,
time_granularity=node.time_dimension_spec.time_granularity,
)
new_time_dim_instance = TimeDimensionInstance(
defined_from=original_time_dim_instance.defined_from,
associated_columns=new_time_dim_spec.column_associations(self._column_association_resolver),
spec=new_time_dim_spec,
)
time_dimension_instances += (new_time_dim_instance,)
time_spine_select_columns += (
SqlSelectColumn(
expr=SqlDateTruncExpression(
time_granularity=node.time_dimension_spec.time_granularity,
arg=SqlColumnReferenceExpression(
SqlColumnReference(
table_alias=time_spine_alias, column_name=new_time_dim_spec.element_name
)
),
),
column_alias=new_time_dim_instance.associated_column.column_name,
),
)
time_spine_instance_set = InstanceSet(time_dimension_instances=time_dimension_instances)

return SqlDataSet(
instance_set=InstanceSet.merge(list(table_alias_to_instance_set.values())),
instance_set=InstanceSet.merge([time_spine_instance_set, non_metric_time_parent_instance_set]),
sql_select_node=SqlSelectStatementNode(
description=node.description,
select_columns=create_select_columns_for_instance_sets(
self._column_association_resolver, table_alias_to_instance_set
select_columns=time_spine_select_columns
+ create_select_columns_for_instance_sets(
self._column_association_resolver, OrderedDict({parent_alias: non_metric_time_parent_instance_set})
),
from_source=time_spine_dataset.sql_select_node,
from_source_alias=time_spine_alias,
Expand Down
Loading