Skip to content

Commit

Permalink
Looks like it's working but must read the SQL
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Oct 25, 2024
1 parent 6f6d1d6 commit 1c0069a
Show file tree
Hide file tree
Showing 3 changed files with 310 additions and 249 deletions.
127 changes: 81 additions & 46 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,39 @@ def _optimize_plan(self, plan: DataflowPlan, optimizations: FrozenSet[DataflowPl

return plan

def _build_pre_aggregation_measure_node(
self,
source_node_recipe: SourceNodeRecipe,
filter_spec_set: WhereFilterSpecSet,
required_linkable_specs: LinkableSpecSet,
filter_to_linkable_specs: InstanceSpecSet, # should this be queried_linkable_specs?
) -> DataflowPlanNode:
output_node = source_node_recipe.source_node
if source_node_recipe.join_targets:
output_node = JoinOnEntitiesNode.create(left_node=output_node, join_targets=source_node_recipe.join_targets)

for time_dimension_spec in required_linkable_specs.time_dimension_specs:
if (
time_dimension_spec.time_granularity.is_custom_granularity
# We might already have the custom join if this is the second aggregation layer for a conversion metric.
and not time_dimension_spec in source_node_recipe.all_linkable_specs_required_for_source_nodes.as_tuple
):
output_node = JoinToCustomGranularityNode.create(
parent_node=output_node, time_dimension_spec=time_dimension_spec
)

# TODO: handle non-additive dimensions & offset logic for conversion metrics? copy from normal measure logic

# Join custom grains here
if len(filter_spec_set.all_filter_specs) > 0:
output_node = WhereConstraintNode.create(
parent_node=output_node,
where_specs=filter_spec_set.all_filter_specs,
)

# Testing: filter elements only at the end before aggregating, not in between each step
return FilterElementsNode.create(parent_node=output_node, include_specs=filter_to_linkable_specs)

def _build_aggregated_conversion_node(
self,
base_measure_spec: MetricInputMeasureSpec,
Expand Down Expand Up @@ -262,29 +295,31 @@ def _build_aggregated_conversion_node(
queried_linkable_specs=queried_linkable_specs,
filter_specs=base_measure_spec.filter_spec_set.all_filter_specs,
)
base_measure_recipe = self._find_source_node_recipe(
base_source_node_recipe = self._find_source_node_recipe(
FindSourceNodeRecipeParameterSet(
measure_spec_properties=self._build_measure_spec_properties([base_measure_spec.measure_spec]),
predicate_pushdown_state=time_range_only_pushdown_state,
linkable_spec_set=base_required_linkable_specs,
)
)
logger.debug(LazyFormat(lambda: f"Recipe for base measure aggregation:\n{mf_pformat(base_measure_recipe)}"))
conversion_measure_recipe = self._find_source_node_recipe(
logger.debug(LazyFormat(lambda: f"Recipe for base measure aggregation:\n{mf_pformat(base_source_node_recipe)}"))
conversion_source_node_recipe = self._find_source_node_recipe(
FindSourceNodeRecipeParameterSet(
measure_spec_properties=self._build_measure_spec_properties([conversion_measure_spec.measure_spec]),
predicate_pushdown_state=disabled_pushdown_state,
linkable_spec_set=LinkableSpecSet(),
)
)
logger.debug(
LazyFormat(lambda: f"Recipe for conversion measure aggregation:\n{mf_pformat(conversion_measure_recipe)}")
LazyFormat(
lambda: f"Recipe for conversion measure aggregation:\n{mf_pformat(conversion_source_node_recipe)}"
)
)
if base_measure_recipe is None:
if base_source_node_recipe is None:
raise UnableToSatisfyQueryError(
f"Unable to join all items in request. Measure: {base_measure_spec.measure_spec}; Specs to join: {base_required_linkable_specs}"
)
if conversion_measure_recipe is None:
if conversion_source_node_recipe is None:
raise UnableToSatisfyQueryError(
f"Unable to build dataflow plan for conversion measure: {conversion_measure_spec.measure_spec}"
)
Expand All @@ -299,7 +334,7 @@ def _build_aggregated_conversion_node(
# Build unaggregated conversions source node
# Generate UUID column for conversion source to uniquely identify each row
unaggregated_conversion_measure_node = AddGeneratedUuidColumnNode.create(
parent_node=conversion_measure_recipe.source_node
parent_node=conversion_source_node_recipe.source_node
)

# Get the agg time dimension for each measure used for matching conversion time windows
Expand All @@ -315,7 +350,7 @@ def _build_aggregated_conversion_node(
# Filter the source nodes with only the required specs needed for the calculation
constant_property_specs = []
required_local_specs = [base_measure_spec.measure_spec, entity_spec, base_time_dimension_spec] + list(
base_measure_recipe.required_local_linkable_specs.as_tuple
base_source_node_recipe.required_local_linkable_specs.as_tuple
)
for constant_property in constant_properties or []:
base_property_spec = self._semantic_model_lookup.get_element_spec_for_name(constant_property.base_property)
Expand All @@ -328,20 +363,12 @@ def _build_aggregated_conversion_node(
)

# Build the unaggregated base measure node for computing conversions
unaggregated_base_measure_node = base_measure_recipe.source_node
if base_measure_recipe.join_targets:
unaggregated_base_measure_node = JoinOnEntitiesNode.create(
left_node=unaggregated_base_measure_node, join_targets=base_measure_recipe.join_targets
)
if len(base_measure_spec.filter_spec_set.all_filter_specs) > 0:
unaggregated_base_measure_node = WhereConstraintNode.create(
parent_node=unaggregated_base_measure_node,
where_specs=base_measure_spec.filter_spec_set.all_filter_specs,
)
filtered_unaggregated_base_node = FilterElementsNode.create(
parent_node=unaggregated_base_measure_node,
include_specs=group_specs_by_type(required_local_specs)
.merge(base_measure_recipe.all_linkable_specs_required_for_source_nodes.as_instance_spec_set)
base_node = self._build_pre_aggregation_measure_node(
source_node_recipe=base_source_node_recipe,
filter_spec_set=base_measure_spec.filter_spec_set,
required_linkable_specs=base_required_linkable_specs,
filter_to_linkable_specs=group_specs_by_type(required_local_specs)
.merge(base_required_linkable_specs.as_instance_spec_set)
.dedupe(),
)

Expand All @@ -350,7 +377,7 @@ def _build_aggregated_conversion_node(
# be still be constrained, where we adjust the time range to the window size similar to cumulative, but
# adjusted in the opposite direction.
join_conversion_node = JoinConversionEventsNode.create(
base_node=filtered_unaggregated_base_node,
base_node=base_node,
base_time_dimension_spec=base_time_dimension_spec,
conversion_node=unaggregated_conversion_measure_node,
conversion_measure_spec=conversion_measure_spec.measure_spec,
Expand All @@ -360,13 +387,13 @@ def _build_aggregated_conversion_node(
window=window,
constant_properties=constant_property_specs,
)

print("queried_linkable_specs: ", queried_linkable_specs)
# Aggregate the conversion events with the JoinConversionEventsNode as the source node.
recipe_with_join_conversion_source_node = SourceNodeRecipe(
source_node=join_conversion_node,
required_local_linkable_specs=queried_linkable_specs,
join_linkable_instances_recipes=(),
all_linkable_specs_required_for_source_nodes=queried_linkable_specs.replace_custom_granularity_with_base_granularity(),
all_linkable_specs_required_for_source_nodes=queried_linkable_specs,
)
# TODO: Refine conversion metric configuration to fit into the standard dataflow plan building model
# In this case we override the measure recipe, which currently results in us bypassing predicate pushdown
Expand All @@ -376,7 +403,7 @@ def _build_aggregated_conversion_node(
metric_input_measure_spec=conversion_measure_spec,
queried_linkable_specs=queried_linkable_specs,
predicate_pushdown_state=disabled_pushdown_state,
measure_recipe=recipe_with_join_conversion_source_node,
source_node_recipe=recipe_with_join_conversion_source_node,
)

# Combine the aggregated opportunities and conversion data sets
Expand Down Expand Up @@ -1322,6 +1349,7 @@ def _build_input_measure_spec(

before_aggregation_time_spine_join_description = None
# If querying an offset metric, join to time spine.
# TODO: are we handling offset conversion metrics properly?
if child_metric_offset_window is not None or child_metric_offset_to_grain is not None:
before_aggregation_time_spine_join_description = JoinToTimeSpineDescription(
join_type=SqlJoinType.INNER,
Expand Down Expand Up @@ -1426,7 +1454,7 @@ def build_aggregated_measure(
metric_input_measure_spec: MetricInputMeasureSpec,
queried_linkable_specs: LinkableSpecSet,
predicate_pushdown_state: PredicatePushdownState,
measure_recipe: Optional[SourceNodeRecipe] = None,
source_node_recipe: Optional[SourceNodeRecipe] = None,
) -> DataflowPlanNode:
"""Returns a node where the measures are aggregated by the linkable specs and constrained appropriately.
Expand All @@ -1448,7 +1476,7 @@ def build_aggregated_measure(
metric_input_measure_spec=metric_input_measure_spec,
queried_linkable_specs=queried_linkable_specs,
predicate_pushdown_state=predicate_pushdown_state,
measure_recipe=measure_recipe,
source_node_recipe=source_node_recipe,
)

def __get_required_and_extraneous_linkable_specs(
Expand Down Expand Up @@ -1491,7 +1519,7 @@ def _build_aggregated_measure_from_measure_source_node(
metric_input_measure_spec: MetricInputMeasureSpec,
queried_linkable_specs: LinkableSpecSet,
predicate_pushdown_state: PredicatePushdownState,
measure_recipe: Optional[SourceNodeRecipe] = None,
source_node_recipe: Optional[SourceNodeRecipe] = None,
) -> DataflowPlanNode:
measure_spec = metric_input_measure_spec.measure_spec
cumulative = metric_input_measure_spec.cumulative_description is not None
Expand Down Expand Up @@ -1543,7 +1571,7 @@ def _build_aggregated_measure_from_measure_source_node(
metric_input_measure_spec.before_aggregation_time_spine_join_description
)

if measure_recipe is None:
if source_node_recipe is None:
logger.debug(
LazyFormat(
lambda: "Looking for a recipe to get:"
Expand All @@ -1565,7 +1593,7 @@ def _build_aggregated_measure_from_measure_source_node(
)

find_recipe_start_time = time.time()
measure_recipe = self._find_source_node_recipe(
source_node_recipe = self._find_source_node_recipe(
FindSourceNodeRecipeParameterSet(
measure_spec_properties=measure_properties,
predicate_pushdown_state=measure_pushdown_state,
Expand All @@ -1579,9 +1607,9 @@ def _build_aggregated_measure_from_measure_source_node(
)
)

logger.debug(LazyFormat(lambda: f"Using recipe:\n{indent(mf_pformat(measure_recipe))}"))
logger.debug(LazyFormat(lambda: f"Using recipe:\n{indent(mf_pformat(source_node_recipe))}"))

if measure_recipe is None:
if source_node_recipe is None:
raise UnableToSatisfyQueryError(
f"Unable to join all items in request. Measure: {measure_spec.element_name}; Specs to join: {required_linkable_specs}"
)
Expand All @@ -1595,7 +1623,7 @@ def _build_aggregated_measure_from_measure_source_node(
time_range_node: Optional[JoinOverTimeRangeNode] = None
if cumulative and queried_agg_time_dimension_specs:
time_range_node = JoinOverTimeRangeNode.create(
parent_node=measure_recipe.source_node,
parent_node=source_node_recipe.source_node,
queried_agg_time_dimension_specs=tuple(queried_agg_time_dimension_specs),
window=cumulative_window,
grain_to_date=cumulative_grain_to_date,
Expand All @@ -1622,7 +1650,7 @@ def _build_aggregated_measure_from_measure_source_node(
# This also uses the original time range constraint due to the application of the time window intervals
# in join rendering
join_to_time_spine_node = JoinToTimeSpineNode.create(
parent_node=time_range_node or measure_recipe.source_node,
parent_node=time_range_node or source_node_recipe.source_node,
requested_agg_time_dimension_specs=queried_agg_time_dimension_specs,
use_custom_agg_time_dimension=not queried_linkable_specs.contains_metric_time,
time_range_constraint=predicate_pushdown_state.time_range_constraint,
Expand All @@ -1631,15 +1659,16 @@ def _build_aggregated_measure_from_measure_source_node(
join_type=before_aggregation_time_spine_join_description.join_type,
)

# TODO: do we need to filter before joins AND after joins?
# Only get the required measure and the local linkable instances so that aggregations work correctly.
filtered_measure_source_node = FilterElementsNode.create(
parent_node=join_to_time_spine_node or time_range_node or measure_recipe.source_node,
parent_node=join_to_time_spine_node or time_range_node or source_node_recipe.source_node,
include_specs=InstanceSpecSet(measure_specs=(measure_spec,)).merge(
measure_recipe.required_local_linkable_specs.as_instance_spec_set,
source_node_recipe.required_local_linkable_specs.as_instance_spec_set,
),
)

join_targets = measure_recipe.join_targets
join_targets = source_node_recipe.join_targets
unaggregated_measure_node: DataflowPlanNode
if len(join_targets) > 0:
filtered_measures_with_joined_elements = JoinOnEntitiesNode.create(
Expand All @@ -1648,7 +1677,9 @@ def _build_aggregated_measure_from_measure_source_node(
)

specs_to_keep_after_join = InstanceSpecSet(measure_specs=(measure_spec,)).merge(
InstanceSpecSet.create_from_specs(measure_recipe.all_linkable_specs_required_for_source_nodes.as_tuple),
InstanceSpecSet.create_from_specs(
source_node_recipe.all_linkable_specs_required_for_source_nodes.as_tuple
),
)

after_join_filtered_node = FilterElementsNode.create(
Expand All @@ -1659,7 +1690,11 @@ def _build_aggregated_measure_from_measure_source_node(
unaggregated_measure_node = filtered_measure_source_node

for time_dimension_spec in required_linkable_specs.time_dimension_specs:
if time_dimension_spec.time_granularity.is_custom_granularity:
if (
time_dimension_spec.time_granularity.is_custom_granularity
# We might already have the custom join if this is the second aggregation layer for a conversion metric.
and not time_dimension_spec in source_node_recipe.all_linkable_specs_required_for_source_nodes.as_tuple
):
unaggregated_measure_node = JoinToCustomGranularityNode.create(
parent_node=unaggregated_measure_node, time_dimension_spec=time_dimension_spec
)
Expand Down Expand Up @@ -1693,12 +1728,12 @@ def _build_aggregated_measure_from_measure_source_node(
non_additive_dimension_grain = self._semantic_model_lookup.get_defined_time_granularity(
TimeDimensionReference(non_additive_dimension_spec.name)
)
queried_time_dimension_spec: Optional[
TimeDimensionSpec
] = self._find_non_additive_dimension_in_linkable_specs(
agg_time_dimension=agg_time_dimension,
linkable_specs=queried_linkable_specs.as_tuple,
non_additive_dimension_spec=non_additive_dimension_spec,
queried_time_dimension_spec: Optional[TimeDimensionSpec] = (
self._find_non_additive_dimension_in_linkable_specs(
agg_time_dimension=agg_time_dimension,
linkable_specs=queried_linkable_specs.as_tuple,
non_additive_dimension_spec=non_additive_dimension_spec,
)
)
time_dimension_spec = TimeDimensionSpec(
# The NonAdditiveDimensionSpec name property is a plain element name
Expand Down
Loading

0 comments on commit 1c0069a

Please sign in to comment.