Skip to content

Commit

Permalink
Add support for querying with metric alias (#1573)
Browse files Browse the repository at this point in the history
## Summary

This PR introduces support for querying with aliased metrics. This is
done by adding an `AliasSpecsNode` after the order by step in
`DataflowPlanBuilder`. To make the alias information reach that point, I
also had to modify some of the query interface classes to allow aliases
in metrics.

You can review commit by commit.

---------

Co-authored-by: Courtney Holcomb <[email protected]>
Co-authored-by: Will Deng <[email protected]>
  • Loading branch information
3 people authored Jan 10, 2025
1 parent f6f63d7 commit a4cc831
Show file tree
Hide file tree
Showing 204 changed files with 2,535 additions and 681 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20241213-110407.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Allow setting aliases for queried metrics
time: 2024-12-13T11:04:07.020346+01:00
custom:
Author: serramatutu
Issue: "1573"
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from metricflow_semantics.specs.instance_spec import InstanceSpec
from metricflow_semantics.specs.patterns.entity_link_pattern import (
EntityLinkPattern,
EntityLinkPatternParameterSet,
ParameterSetField,
SpecPatternParameterSet,
)
from metricflow_semantics.specs.spec_set import InstanceSpecSet, InstanceSpecSetTransform, group_spec_by_type

Expand Down Expand Up @@ -67,7 +67,7 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes
# No dunder, e.g. "ds"
if len(input_str_parts) == 1:
return EntityLinkPattern(
parameter_set=EntityLinkPatternParameterSet.from_parameters(
parameter_set=SpecPatternParameterSet.from_parameters(
element_name=input_str_parts[0],
entity_links=(),
time_granularity_name=None,
Expand All @@ -88,7 +88,7 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes
# e.g. "ds__month"
if len(input_str_parts) == 2:
return EntityLinkPattern(
parameter_set=EntityLinkPatternParameterSet.from_parameters(
parameter_set=SpecPatternParameterSet.from_parameters(
element_name=input_str_parts[0],
entity_links=(),
time_granularity_name=time_granularity_name,
Expand All @@ -98,7 +98,7 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes
)
# e.g. "messages__ds__month"
return EntityLinkPattern(
parameter_set=EntityLinkPatternParameterSet.from_parameters(
parameter_set=SpecPatternParameterSet.from_parameters(
element_name=input_str_parts[-2],
entity_links=tuple(EntityReference(entity_name) for entity_name in input_str_parts[:-2]),
time_granularity_name=time_granularity_name,
Expand All @@ -109,7 +109,7 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes

# e.g. "messages__ds"
return EntityLinkPattern(
parameter_set=EntityLinkPatternParameterSet.from_parameters(
parameter_set=SpecPatternParameterSet.from_parameters(
element_name=suffix,
entity_links=tuple(EntityReference(entity_name) for entity_name in input_str_parts[:-1]),
time_granularity_name=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from typing import Optional

from dbt_semantic_interfaces.references import MetricReference
from typing_extensions import override

from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow_semantics.naming.naming_scheme import QueryItemNamingScheme
from metricflow_semantics.specs.instance_spec import InstanceSpec
from metricflow_semantics.specs.patterns.entity_link_pattern import ParameterSetField, SpecPatternParameterSet
from metricflow_semantics.specs.patterns.metric_pattern import MetricSpecPattern
from metricflow_semantics.specs.spec_set import group_spec_by_type

Expand All @@ -30,7 +30,11 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes
input_str = input_str.lower()
if not self.input_str_follows_scheme(input_str, semantic_manifest_lookup=semantic_manifest_lookup):
raise RuntimeError(f"{repr(input_str)} does not follow this scheme.")
return MetricSpecPattern(metric_reference=MetricReference(element_name=input_str))
return MetricSpecPattern(
parameter_set=SpecPatternParameterSet.from_parameters(
fields_to_compare=(ParameterSetField.ELEMENT_NAME,), element_name=input_str
)
)

@override
def input_str_follows_scheme(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from metricflow_semantics.specs.instance_spec import InstanceSpec
from metricflow_semantics.specs.patterns.entity_link_pattern import (
EntityLinkPattern,
EntityLinkPatternParameterSet,
ParameterSetField,
SpecPatternParameterSet,
)
from metricflow_semantics.specs.patterns.spec_pattern import SpecPattern
from metricflow_semantics.specs.patterns.typed_patterns import DimensionPattern, TimeDimensionPattern
Expand Down Expand Up @@ -62,7 +62,7 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes

for dimension_call_parameter_set in call_parameter_sets.dimension_call_parameter_sets:
return DimensionPattern(
EntityLinkPatternParameterSet.from_parameters(
SpecPatternParameterSet.from_parameters(
element_name=dimension_call_parameter_set.dimension_reference.element_name,
entity_links=dimension_call_parameter_set.entity_path,
fields_to_compare=(
Expand All @@ -84,7 +84,7 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes
fields_to_compare.append(ParameterSetField.TIME_GRANULARITY)

return TimeDimensionPattern(
EntityLinkPatternParameterSet.from_parameters(
SpecPatternParameterSet.from_parameters(
element_name=time_dimension_call_parameter_set.time_dimension_reference.element_name,
entity_links=time_dimension_call_parameter_set.entity_path,
time_granularity_name=time_dimension_call_parameter_set.time_granularity_name,
Expand All @@ -95,7 +95,7 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes

for entity_call_parameter_set in call_parameter_sets.entity_call_parameter_sets:
return EntityLinkPattern(
EntityLinkPatternParameterSet.from_parameters(
SpecPatternParameterSet.from_parameters(
element_name=entity_call_parameter_set.entity_reference.element_name,
entity_links=entity_call_parameter_set.entity_path,
fields_to_compare=(
Expand All @@ -107,7 +107,7 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes

for metric_call_parameter_set in call_parameter_sets.metric_call_parameter_sets:
return EntityLinkPattern(
EntityLinkPatternParameterSet.from_parameters(
SpecPatternParameterSet.from_parameters(
element_name=metric_call_parameter_set.metric_reference.element_name,
entity_links=tuple(
EntityReference(element_name=group_by_ref.element_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def name(self) -> str:
"""The name of the metric."""
raise NotImplementedError

@property
def alias(self) -> Optional[str]:
"""The alias of the metric."""
raise NotImplementedError

def query_resolver_input( # noqa: D102
self, semantic_manifest_lookup: SemanticManifestLookup
) -> ResolverInputForMetric:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Sequence, Tuple

from dbt_semantic_interfaces.references import MetricReference
from typing_extensions import override

from metricflow_semantics.query.group_by_item.resolution_path import MetricFlowQueryResolutionPath
from metricflow_semantics.query.issues.issues_base import (
MetricFlowQueryIssueType,
MetricFlowQueryResolutionIssue,
)
from metricflow_semantics.query.resolver_inputs.base_resolver_inputs import MetricFlowQueryResolverInput


@dataclass(frozen=True)
class DuplicateMetricAliasIssue(MetricFlowQueryResolutionIssue):
"""Describes when there are duplicate metric aliases in a query."""

duplicate_metric_references: Tuple[MetricReference, ...]

@staticmethod
def from_parameters( # noqa: D102
duplicate_metric_references: Sequence[MetricReference],
query_resolution_path: MetricFlowQueryResolutionPath,
) -> DuplicateMetricAliasIssue:
return DuplicateMetricAliasIssue(
issue_type=MetricFlowQueryIssueType.ERROR,
parent_issues=(),
duplicate_metric_references=tuple(duplicate_metric_references),
query_resolution_path=query_resolution_path,
)

@override
def ui_description(self, associated_input: MetricFlowQueryResolverInput) -> str:
return (
f"Query contains duplicate metric aliases: "
f"{[metric_reference.element_name for metric_reference in self.duplicate_metric_references]}"
)

@override
def with_path_prefix(self, path_prefix: MetricFlowQueryResolutionPath) -> DuplicateMetricAliasIssue:
return DuplicateMetricAliasIssue(
issue_type=self.issue_type,
parent_issues=tuple(issue.with_path_prefix(path_prefix) for issue in self.parent_issues),
query_resolution_path=self.query_resolution_path.with_path_prefix(path_prefix),
duplicate_metric_references=self.duplicate_metric_references,
)
40 changes: 35 additions & 5 deletions metricflow-semantics/metricflow_semantics/query/query_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import itertools
import logging
from collections import defaultdict
from dataclasses import dataclass
from typing import List, Optional, Sequence, Set, Tuple
from typing import Dict, List, Optional, Sequence, Set, Tuple

from dbt_semantic_interfaces.references import MeasureReference, MetricReference, SemanticModelReference

Expand Down Expand Up @@ -33,6 +34,7 @@
from metricflow_semantics.query.issues.issues_base import (
MetricFlowQueryResolutionIssueSet,
)
from metricflow_semantics.query.issues.parsing.duplicate_metric_alias import DuplicateMetricAliasIssue
from metricflow_semantics.query.issues.parsing.invalid_limit import InvalidLimitIssue
from metricflow_semantics.query.issues.parsing.invalid_metric import InvalidMetricIssue
from metricflow_semantics.query.issues.parsing.invalid_min_max_only import InvalidMinMaxOnlyIssue
Expand Down Expand Up @@ -173,11 +175,20 @@ def _resolve_metric_inputs(
)
metric_specs: List[MetricSpec] = []
input_to_issue_set_mapping_items: List[InputToIssueSetMappingItem] = []
alias_to_metrics: Dict[str, List[Tuple[ResolverInputForMetric, MetricReference]]] = defaultdict(list)

# Find the metric that matches the metric pattern from the input.
for metric_input in metric_inputs:
matching_specs = metric_input.spec_pattern.match(available_metric_specs)
if len(matching_specs) != 1:
if len(matching_specs) == 1:
matching_spec = matching_specs[0]
alias = metric_input.spec_pattern.parameter_set.alias
if alias:
matching_spec = matching_spec.with_alias(alias)
metric_specs.append(matching_spec)
resolved_name = matching_spec.alias or matching_spec.qualified_name
alias_to_metrics[resolved_name].append((metric_input, matching_spec.reference))
else:
suggestion_generator = QueryItemSuggestionGenerator(
input_naming_scheme=MetricNamingScheme(),
input_str=str(metric_input.input_obj),
Expand All @@ -195,8 +206,23 @@ def _resolve_metric_inputs(
),
)
)
else:
metric_specs.extend(matching_specs)

# Find any duplicate aliases
for alias, metrics in alias_to_metrics.items():
if len(metrics) > 1:
metric_inputs = [m[0] for m in metrics]
metric_references = [m[1] for m in metrics]
input_to_issue_set_mapping_items.append(
InputToIssueSetMappingItem(
resolver_input=metric_inputs[0],
issue_set=MetricFlowQueryResolutionIssueSet.from_issue(
DuplicateMetricAliasIssue.from_parameters(
duplicate_metric_references=metric_references,
query_resolution_path=query_resolution_path,
)
),
)
)

return ResolveMetricsResult(
metric_specs=tuple(metric_specs),
Expand Down Expand Up @@ -371,7 +397,11 @@ def _resolve_query(self, resolver_input_for_query: ResolverInputForQuery) -> Met
query_resolution_path = MetricFlowQueryResolutionPath.from_path_item(
QueryGroupByItemResolutionNode.create(
parent_nodes=(),
metrics_in_query=tuple(metric_input.spec_pattern.metric_reference for metric_input in metric_inputs),
metrics_in_query=tuple(
MetricReference(metric_input.spec_pattern.parameter_set.element_name)
for metric_input in metric_inputs
if metric_input.spec_pattern.parameter_set.element_name # for type checker
),
where_filter_intersection=query_level_filter_input.where_filter_intersection,
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def accept(self, visitor: InstanceSpecVisitor[VisitorOutputT]) -> VisitorOutputT
"""See Visitable."""
raise NotImplementedError()

def without_filter_specs(self) -> InstanceSpec:
"""Return the instance spec without any filtering (for comparison purposes)."""
return self


class InstanceSpecVisitor(Generic[VisitorOutputT], ABC):
"""Visitor for the InstanceSpec classes."""
Expand Down
18 changes: 18 additions & 0 deletions metricflow-semantics/metricflow_semantics/specs/metric_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,21 @@ def has_time_offset(self) -> bool: # noqa: D102
def without_offset(self) -> MetricSpec:
"""Represents the metric spec with any time offsets removed."""
return MetricSpec(element_name=self.element_name, filter_spec_set=self.filter_spec_set, alias=self.alias)

def with_alias(self, alias: Optional[str]) -> MetricSpec:
"""Add the alias to the metric spec."""
return MetricSpec(
element_name=self.element_name,
filter_spec_set=self.filter_spec_set,
alias=alias,
offset_window=self.offset_window,
offset_to_grain=self.offset_to_grain,
)

def without_filter_specs(self) -> MetricSpec: # noqa: D102
return MetricSpec(
element_name=self.element_name,
alias=self.alias,
offset_window=self.offset_window,
offset_to_grain=self.offset_to_grain,
)
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


class ParameterSetField(Enum):
"""The fields of the EntityLinkPatternParameterSet class used for matching in the EntityLinkPattern.
"""The fields of the SpecPatternParameterSet class used for matching in the EntityLinkPattern.
Considering moving this to be a part of the specs module / classes.
"""
Expand All @@ -30,6 +30,7 @@ class ParameterSetField(Enum):
TIME_GRANULARITY = "time_granularity_name"
DATE_PART = "date_part"
METRIC_SUBQUERY_ENTITY_LINKS = "metric_subquery_entity_links"
ALIAS = "alias"

def __lt__(self, other: Any) -> bool: # type: ignore[misc]
"""Allow for ordering so that a sequence of these can be consistently represented for test snapshots."""
Expand All @@ -39,7 +40,7 @@ def __lt__(self, other: Any) -> bool: # type: ignore[misc]


@dataclass(frozen=True)
class EntityLinkPatternParameterSet:
class SpecPatternParameterSet:
"""See EntityPathPattern for more details."""

# Specify the field values to compare. None can't be used to signal "don't compare" because sometimes a pattern
Expand All @@ -54,6 +55,7 @@ class EntityLinkPatternParameterSet:
time_granularity_name: Optional[str] = None
date_part: Optional[DatePart] = None
metric_subquery_entity_links: Optional[Tuple[EntityReference, ...]] = None
alias: Optional[str] = None

@staticmethod
def from_parameters( # noqa: D102
Expand All @@ -63,14 +65,16 @@ def from_parameters( # noqa: D102
time_granularity_name: Optional[str] = None,
date_part: Optional[DatePart] = None,
metric_subquery_entity_links: Optional[Tuple[EntityReference, ...]] = None,
) -> EntityLinkPatternParameterSet:
return EntityLinkPatternParameterSet(
alias: Optional[str] = None,
) -> SpecPatternParameterSet:
return SpecPatternParameterSet(
fields_to_compare=tuple(sorted(fields_to_compare)),
element_name=element_name,
entity_links=tuple(entity_links) if entity_links is not None else None,
time_granularity_name=time_granularity_name,
date_part=date_part,
metric_subquery_entity_links=metric_subquery_entity_links,
alias=alias,
)

def __post_init__(self) -> None:
Expand All @@ -91,7 +95,7 @@ class EntityLinkPattern(SpecPattern):
The entity links that are specified is used as a suffix match.
"""

parameter_set: EntityLinkPatternParameterSet
parameter_set: SpecPatternParameterSet

def _match_entity_links(self, candidate_specs: Sequence[LinkableInstanceSpec]) -> Sequence[LinkableInstanceSpec]:
assert self.parameter_set.entity_links is not None
Expand Down Expand Up @@ -129,7 +133,7 @@ def _match_time_granularities(
@override
def match(self, candidate_specs: Sequence[InstanceSpec]) -> Sequence[LinkableInstanceSpec]:
filtered_candidate_specs = group_specs_by_type(candidate_specs).linkable_specs
# Checks that EntityLinkPatternParameterSetField is valid wrt to the parameter set.
# Checks that SpecPatternParameterSetField is valid wrt to the parameter set.

# Entity links could be a partial match, so it's handled separately.
if ParameterSetField.ENTITY_LINKS in self.parameter_set.fields_to_compare:
Expand Down
Loading

0 comments on commit a4cc831

Please sign in to comment.