Skip to content

Commit

Permalink
Move metric alias details from MetricSpecPattern to `ResolverInputF…
Browse files Browse the repository at this point in the history
…orMetric` (#1597)
  • Loading branch information
courtneyholcomb authored Jan 16, 2025
1 parent 2f1bb1c commit 78caba6
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from typing import Optional

from dbt_semantic_interfaces.references import MetricReference
from typing_extensions import override

from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow_semantics.naming.naming_scheme import QueryItemNamingScheme
from metricflow_semantics.specs.instance_spec import InstanceSpec
from metricflow_semantics.specs.patterns.entity_link_pattern import ParameterSetField, SpecPatternParameterSet
from metricflow_semantics.specs.patterns.metric_pattern import MetricSpecPattern
from metricflow_semantics.specs.spec_set import group_spec_by_type

Expand All @@ -30,11 +30,7 @@ def spec_pattern(self, input_str: str, semantic_manifest_lookup: SemanticManifes
input_str = input_str.lower()
if not self.input_str_follows_scheme(input_str, semantic_manifest_lookup=semantic_manifest_lookup):
raise RuntimeError(f"{repr(input_str)} does not follow this scheme.")
return MetricSpecPattern(
parameter_set=SpecPatternParameterSet.from_parameters(
fields_to_compare=(ParameterSetField.ELEMENT_NAME,), element_name=input_str
)
)
return MetricSpecPattern(metric_reference=MetricReference(element_name=input_str))

@override
def input_str_follows_scheme(self, input_str: str, semantic_manifest_lookup: SemanticManifestLookup) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _resolve_metric_inputs(
matching_specs = metric_input.spec_pattern.match(available_metric_specs)
if len(matching_specs) == 1:
matching_spec = matching_specs[0]
alias = metric_input.spec_pattern.parameter_set.alias
alias = metric_input.alias
if alias:
matching_spec = matching_spec.with_alias(alias)
metric_specs.append(matching_spec)
Expand Down Expand Up @@ -397,11 +397,7 @@ def _resolve_query(self, resolver_input_for_query: ResolverInputForQuery) -> Met
query_resolution_path = MetricFlowQueryResolutionPath.from_path_item(
QueryGroupByItemResolutionNode.create(
parent_nodes=(),
metrics_in_query=tuple(
MetricReference(metric_input.spec_pattern.parameter_set.element_name)
for metric_input in metric_inputs
if metric_input.spec_pattern.parameter_set.element_name # for type checker
),
metrics_in_query=tuple(metric_input.spec_pattern.metric_reference for metric_input in metric_inputs),
where_filter_intersection=query_level_filter_input.where_filter_intersection,
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
The naming of these classes is a little odd as they have the "For.." suffix. But using the "*ResolverInput" leads to
some confusing names like "ResolverInputForQuery" -> "QueryResolverInput". Improved naming for these classes is TBD.
"""

from __future__ import annotations

from dataclasses import dataclass
Expand Down Expand Up @@ -44,6 +45,7 @@ class ResolverInputForMetric(MetricFlowQueryResolverInput):
input_obj: Union[MetricQueryParameter, str]
naming_scheme: MetricNamingScheme
spec_pattern: MetricSpecPattern
alias: Optional[str] = None

@property
@override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class ParameterSetField(Enum):
TIME_GRANULARITY = "time_granularity_name"
DATE_PART = "date_part"
METRIC_SUBQUERY_ENTITY_LINKS = "metric_subquery_entity_links"
ALIAS = "alias"

def __lt__(self, other: Any) -> bool: # type: ignore[misc]
"""Allow for ordering so that a sequence of these can be consistently represented for test snapshots."""
Expand All @@ -55,7 +54,6 @@ class SpecPatternParameterSet:
time_granularity_name: Optional[str] = None
date_part: Optional[DatePart] = None
metric_subquery_entity_links: Optional[Tuple[EntityReference, ...]] = None
alias: Optional[str] = None

@staticmethod
def from_parameters( # noqa: D102
Expand All @@ -65,7 +63,6 @@ def from_parameters( # noqa: D102
time_granularity_name: Optional[str] = None,
date_part: Optional[DatePart] = None,
metric_subquery_entity_links: Optional[Tuple[EntityReference, ...]] = None,
alias: Optional[str] = None,
) -> SpecPatternParameterSet:
return SpecPatternParameterSet(
fields_to_compare=tuple(sorted(fields_to_compare)),
Expand All @@ -74,7 +71,6 @@ def from_parameters( # noqa: D102
time_granularity_name=time_granularity_name,
date_part=date_part,
metric_subquery_entity_links=metric_subquery_entity_links,
alias=alias,
)

def __post_init__(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import List, Sequence
from typing import Sequence

from dbt_semantic_interfaces.references import MetricReference
from typing_extensions import override

from metricflow_semantics.specs.instance_spec import InstanceSpec
from metricflow_semantics.specs.metric_spec import MetricSpec
from metricflow_semantics.specs.patterns.entity_link_pattern import SpecPatternParameterSet
from metricflow_semantics.specs.patterns.spec_pattern import SpecPattern
from metricflow_semantics.specs.spec_set import group_specs_by_type

Expand All @@ -16,18 +16,11 @@
class MetricSpecPattern(SpecPattern):
"""Matches MetricSpecs that have the given metric_reference."""

parameter_set: SpecPatternParameterSet
metric_reference: MetricReference

@override
def match(self, candidate_specs: Sequence[InstanceSpec]) -> Sequence[MetricSpec]:
filtered_candidate_specs = group_specs_by_type(candidate_specs).metric_specs
keys_to_check = set(field_to_compare.value for field_to_compare in self.parameter_set.fields_to_compare)

matching_specs: List[MetricSpec] = []
parameter_set_values = tuple(getattr(self.parameter_set, key_to_check) for key_to_check in keys_to_check)
for spec in filtered_candidate_specs:
spec_values = tuple(getattr(spec, key_to_check, None) for key_to_check in keys_to_check)
if spec_values == parameter_set_values:
matching_specs.append(spec)

return matching_specs
spec_set = group_specs_by_type(candidate_specs)
return tuple(
metric_name for metric_name in spec_set.metric_specs if metric_name.reference == self.metric_reference
)
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
ParameterSetField,
SpecPatternParameterSet,
)
from metricflow_semantics.specs.patterns.metric_pattern import MetricSpecPattern


@dataclass(frozen=True)
Expand Down Expand Up @@ -141,13 +140,8 @@ def query_resolver_input( # noqa: D102
return ResolverInputForMetric(
input_obj=self,
naming_scheme=naming_scheme,
spec_pattern=MetricSpecPattern(
SpecPatternParameterSet.from_parameters(
fields_to_compare=(ParameterSetField.ELEMENT_NAME,),
element_name=self.name.lower(),
alias=self.alias,
)
),
spec_pattern=naming_scheme.spec_pattern(self.name, semantic_manifest_lookup=semantic_manifest_lookup),
alias=self.alias,
)


Expand Down

0 comments on commit 78caba6

Please sign in to comment.