Skip to content

Commit

Permalink
Add window_function attribute to TimeDimensionSpec (#1576)
Browse files Browse the repository at this point in the history
And related cleanup. This will allow us to track which specs have had a
window function applied between DataflowPlan nodes, which is needed for
the dataflow plan used for custom offset windows.
  • Loading branch information
courtneyholcomb authored Dec 21, 2024
1 parent 48e85f5 commit b016853
Show file tree
Hide file tree
Showing 17 changed files with 69 additions and 61 deletions.
6 changes: 1 addition & 5 deletions metricflow-semantics/metricflow_semantics/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,7 @@ def with_entity_prefix(
) -> TimeDimensionInstance:
"""Returns a new instance with the entity prefix added to the entity links."""
transformed_spec = self.spec.with_entity_prefix(entity_prefix)
return TimeDimensionInstance(
associated_columns=(column_association_resolver.resolve_spec(transformed_spec),),
defined_from=self.defined_from,
spec=transformed_spec,
)
return self.with_new_spec(transformed_spec, column_association_resolver)

def with_new_defined_from(self, defined_from: Sequence[SemanticModelElementReference]) -> TimeDimensionInstance:
"""Returns a new instance with the defined_from field replaced."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from metricflow_semantics.model.semantic_manifest_lookup import SemanticManifestLookup
from metricflow_semantics.naming.linkable_spec_name import DUNDER
from metricflow_semantics.specs.column_assoc import (
ColumnAssociation,
Expand Down Expand Up @@ -28,8 +27,8 @@ class DunderColumnAssociationResolver(ColumnAssociationResolver):
listing__country
"""

def __init__(self, semantic_manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107
self._visitor_helper = DunderColumnAssociationResolverVisitor(semantic_manifest_lookup)
def __init__(self) -> None: # noqa: D107
self._visitor_helper = DunderColumnAssociationResolverVisitor()

def resolve_spec(self, spec: InstanceSpec) -> ColumnAssociation: # noqa: D102
return spec.accept(self._visitor_helper)
Expand All @@ -38,9 +37,6 @@ def resolve_spec(self, spec: InstanceSpec) -> ColumnAssociation: # noqa: D102
class DunderColumnAssociationResolverVisitor(InstanceSpecVisitor[ColumnAssociation]):
"""Visitor helper class for DefaultColumnAssociationResolver2."""

def __init__(self, semantic_manifest_lookup: SemanticManifestLookup) -> None: # noqa: D107
self._semantic_manifest_lookup = semantic_manifest_lookup

def visit_metric_spec(self, metric_spec: MetricSpec) -> ColumnAssociation: # noqa: D102
return ColumnAssociation(metric_spec.element_name if metric_spec.alias is None else metric_spec.alias)

Expand All @@ -58,6 +54,11 @@ def visit_time_dimension_spec(self, time_dimension_spec: TimeDimensionSpec) -> C
if time_dimension_spec.aggregation_state
else ""
)
+ (
f"{DUNDER}{time_dimension_spec.window_function.value.lower()}"
if time_dimension_spec.window_function
else ""
)
)

def visit_entity_spec(self, entity_spec: EntitySpec) -> ColumnAssociation: # noqa: D102
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from metricflow_semantics.naming.linkable_spec_name import StructuredLinkableSpecName
from metricflow_semantics.specs.dimension_spec import DimensionSpec
from metricflow_semantics.specs.instance_spec import InstanceSpecVisitor
from metricflow_semantics.sql.sql_exprs import SqlWindowFunction
from metricflow_semantics.time.granularity import ExpandedTimeGranularity
from metricflow_semantics.visitor import VisitorOutputT

Expand Down Expand Up @@ -91,6 +92,8 @@ class TimeDimensionSpec(DimensionSpec): # noqa: D101
# Used for semi-additive joins. Some more thought is needed, but this may be useful in InstanceSpec.
aggregation_state: Optional[AggregationState] = None

window_function: Optional[SqlWindowFunction] = None

@property
def without_first_entity_link(self) -> TimeDimensionSpec: # noqa: D102
assert len(self.entity_links) > 0, f"Spec does not have any entity links: {self}"
Expand All @@ -99,6 +102,8 @@ def without_first_entity_link(self) -> TimeDimensionSpec: # noqa: D102
entity_links=self.entity_links[1:],
time_granularity=self.time_granularity,
date_part=self.date_part,
aggregation_state=self.aggregation_state,
window_function=self.window_function,
)

@property
Expand All @@ -108,6 +113,8 @@ def without_entity_links(self) -> TimeDimensionSpec: # noqa: D102
time_granularity=self.time_granularity,
date_part=self.date_part,
entity_links=(),
aggregation_state=self.aggregation_state,
window_function=self.window_function,
)

@property
Expand Down Expand Up @@ -153,6 +160,7 @@ def with_grain(self, time_granularity: ExpandedTimeGranularity) -> TimeDimension
time_granularity=time_granularity,
date_part=self.date_part,
aggregation_state=self.aggregation_state,
window_function=self.window_function,
)

def with_base_grain(self) -> TimeDimensionSpec: # noqa: D102
Expand All @@ -162,6 +170,7 @@ def with_base_grain(self) -> TimeDimensionSpec: # noqa: D102
time_granularity=ExpandedTimeGranularity.from_time_granularity(self.time_granularity.base_granularity),
date_part=self.date_part,
aggregation_state=self.aggregation_state,
window_function=self.window_function,
)

def with_grain_and_date_part( # noqa: D102
Expand All @@ -173,6 +182,7 @@ def with_grain_and_date_part( # noqa: D102
time_granularity=time_granularity,
date_part=date_part,
aggregation_state=self.aggregation_state,
window_function=self.window_function,
)

def with_aggregation_state(self, aggregation_state: AggregationState) -> TimeDimensionSpec: # noqa: D102
Expand All @@ -182,6 +192,17 @@ def with_aggregation_state(self, aggregation_state: AggregationState) -> TimeDim
time_granularity=self.time_granularity,
date_part=self.date_part,
aggregation_state=aggregation_state,
window_function=self.window_function,
)

def with_window_function(self, window_function: SqlWindowFunction) -> TimeDimensionSpec: # noqa: D102
return TimeDimensionSpec(
element_name=self.element_name,
entity_links=self.entity_links,
time_granularity=self.time_granularity,
date_part=self.date_part,
aggregation_state=self.aggregation_state,
window_function=window_function,
)

def comparison_key(self, exclude_fields: Sequence[TimeDimensionSpecField] = ()) -> TimeDimensionSpecComparisonKey:
Expand Down Expand Up @@ -243,6 +264,7 @@ def with_entity_prefix(self, entity_prefix: EntityReference) -> TimeDimensionSpe
time_granularity=self.time_granularity,
date_part=self.date_part,
aggregation_state=self.aggregation_state,
window_function=self.window_function,
)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_classes() -> None: # noqa: D103
time_granularity=ExpandedTimeGranularity(name='day', base_granularity=DAY),
date_part=None,
aggregation_state=None,
window_function=None,
)
"""
).rstrip()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def cyclic_join_semantic_manifest_lookup( # noqa: D103
def column_association_resolver( # noqa: D103
simple_semantic_manifest_lookup: SemanticManifestLookup,
) -> ColumnAssociationResolver:
return DunderColumnAssociationResolver(simple_semantic_manifest_lookup)
return DunderColumnAssociationResolver()


@pytest.fixture(scope="session")
Expand Down
4 changes: 3 additions & 1 deletion metricflow/dataflow/nodes/filter_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DisplayedProperty
from metricflow_semantics.mf_logging.pretty_print import mf_pformat
from metricflow_semantics.specs.dunder_column_association_resolver import DunderColumnAssociationResolver
from metricflow_semantics.specs.spec_set import InstanceSpecSet
from metricflow_semantics.visitor import VisitorOutputT

Expand Down Expand Up @@ -57,7 +58,8 @@ def description(self) -> str: # noqa: D102
if self.replace_description:
return self.replace_description

return f"Pass Only Elements: {mf_pformat([x.qualified_name for x in self.include_specs.all_specs])}"
column_resolver = DunderColumnAssociationResolver()
return f"Pass Only Elements: {mf_pformat([column_resolver.resolve_spec(spec).column_name for spec in self.include_specs.all_specs])}"

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
Expand Down
18 changes: 9 additions & 9 deletions metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,7 @@ def __init__(
SequentialIdGenerator.reset(MetricFlowEngine._ID_ENUMERATION_START_VALUE_FOR_INITIALIZER)
self._semantic_manifest_lookup = semantic_manifest_lookup
self._sql_client = sql_client
self._column_association_resolver = column_association_resolver or (
DunderColumnAssociationResolver(semantic_manifest_lookup)
)
self._column_association_resolver = column_association_resolver or (DunderColumnAssociationResolver())
self._time_source = time_source
self._time_spine_sources = TimeSpineSource.build_standard_time_spine_sources(
semantic_manifest_lookup.semantic_manifest
Expand Down Expand Up @@ -463,12 +461,14 @@ def _create_execution_plan(self, mf_query_request: MetricFlowQueryRequest) -> Me
raise InvalidQueryException("Group by items can't be specified with a saved query.")
query_spec = self._query_parser.parse_and_validate_saved_query(
saved_query_parameter=SavedQueryParameter(mf_query_request.saved_query_name),
where_filters=[
PydanticWhereFilter(where_sql_template=where_constraint)
for where_constraint in mf_query_request.where_constraints
]
if mf_query_request.where_constraints is not None
else None,
where_filters=(
[
PydanticWhereFilter(where_sql_template=where_constraint)
for where_constraint in mf_query_request.where_constraints
]
if mf_query_request.where_constraints is not None
else None
),
limit=mf_query_request.limit,
time_constraint_start=mf_query_request.time_constraint_start,
time_constraint_end=mf_query_request.time_constraint_end,
Expand Down
12 changes: 4 additions & 8 deletions metricflow/validation/data_warehouse_model_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,16 @@ class QueryRenderingTools:
def __init__(self, manifest: SemanticManifest) -> None: # noqa: D107
self.semantic_manifest_lookup = SemanticManifestLookup(semantic_manifest=manifest)
self.source_node_builder = SourceNodeBuilder(
column_association_resolver=DunderColumnAssociationResolver(self.semantic_manifest_lookup),
column_association_resolver=DunderColumnAssociationResolver(),
semantic_manifest_lookup=self.semantic_manifest_lookup,
)
self.converter = SemanticModelToDataSetConverter(
column_association_resolver=DunderColumnAssociationResolver(
semantic_manifest_lookup=self.semantic_manifest_lookup
)
)
self.converter = SemanticModelToDataSetConverter(column_association_resolver=DunderColumnAssociationResolver())
self.plan_converter = DataflowToSqlQueryPlanConverter(
column_association_resolver=DunderColumnAssociationResolver(self.semantic_manifest_lookup),
column_association_resolver=DunderColumnAssociationResolver(),
semantic_manifest_lookup=self.semantic_manifest_lookup,
)
self.node_resolver = DataflowPlanNodeOutputDataSetResolver(
column_association_resolver=DunderColumnAssociationResolver(self.semantic_manifest_lookup),
column_association_resolver=DunderColumnAssociationResolver(),
semantic_manifest_lookup=self.semantic_manifest_lookup,
)

Expand Down
6 changes: 2 additions & 4 deletions scripts/ci_tests/metricflow_package_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def _create_data_sets(
semantic_models: Sequence[SemanticModel] = semantic_manifest_lookup.semantic_manifest.semantic_models
semantic_models = sorted(semantic_models, key=lambda x: x.name)

converter = SemanticModelToDataSetConverter(
column_association_resolver=DunderColumnAssociationResolver(semantic_manifest_lookup)
)
converter = SemanticModelToDataSetConverter(column_association_resolver=DunderColumnAssociationResolver())

for semantic_model in semantic_models:
data_sets[semantic_model.name] = converter.create_sql_source_data_set(semantic_model)
Expand Down Expand Up @@ -138,7 +136,7 @@ def log_dataflow_plan() -> None: # noqa: D103
semantic_manifest = _semantic_manifest()
semantic_manifest_lookup = SemanticManifestLookup(semantic_manifest)
data_set_mapping = _create_data_sets(semantic_manifest_lookup)
column_association_resolver = DunderColumnAssociationResolver(semantic_manifest_lookup)
column_association_resolver = DunderColumnAssociationResolver()

source_node_builder = SourceNodeBuilder(column_association_resolver, semantic_manifest_lookup)
source_node_set = source_node_builder.create_from_data_sets(list(data_set_mapping.values()))
Expand Down
4 changes: 2 additions & 2 deletions tests_metricflow/dataflow/builder/test_node_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_no_parent_node_data_set(
) -> None:
"""Tests getting the data set from a single node."""
resolver: DataflowPlanNodeOutputDataSetResolver = DataflowPlanNodeOutputDataSetResolver(
column_association_resolver=DunderColumnAssociationResolver(simple_semantic_manifest_lookup),
column_association_resolver=DunderColumnAssociationResolver(),
semantic_manifest_lookup=simple_semantic_manifest_lookup,
)

Expand Down Expand Up @@ -96,7 +96,7 @@ def test_joined_node_data_set(
) -> None:
"""Tests getting the data set from a dataflow plan with a join."""
resolver: DataflowPlanNodeOutputDataSetResolver = DataflowPlanNodeOutputDataSetResolver(
column_association_resolver=DunderColumnAssociationResolver(simple_semantic_manifest_lookup),
column_association_resolver=DunderColumnAssociationResolver(),
semantic_manifest_lookup=simple_semantic_manifest_lookup,
)

Expand Down
4 changes: 2 additions & 2 deletions tests_metricflow/dataflow/builder/test_node_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def make_multihop_node_evaluator(
) -> NodeEvaluatorForLinkableInstances:
"""Return a node evaluator using the nodes in multihop_semantic_model_name_to_nodes."""
node_data_set_resolver: DataflowPlanNodeOutputDataSetResolver = DataflowPlanNodeOutputDataSetResolver(
column_association_resolver=DunderColumnAssociationResolver(semantic_manifest_lookup_with_multihop_links),
column_association_resolver=DunderColumnAssociationResolver(),
semantic_manifest_lookup=semantic_manifest_lookup_with_multihop_links,
)

Expand Down Expand Up @@ -510,7 +510,7 @@ def test_node_evaluator_with_scd_target(
) -> None:
"""Tests the case where the joined node is an SCD with a validity window filter."""
node_data_set_resolver: DataflowPlanNodeOutputDataSetResolver = DataflowPlanNodeOutputDataSetResolver(
column_association_resolver=DunderColumnAssociationResolver(scd_semantic_manifest_lookup),
column_association_resolver=DunderColumnAssociationResolver(),
semantic_manifest_lookup=scd_semantic_manifest_lookup,
)

Expand Down
6 changes: 2 additions & 4 deletions tests_metricflow/examples/test_node_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,11 @@ def test_view_sql_generated_at_a_node(
SemanticModelReference(semantic_model_name="bookings_source")
)
assert bookings_semantic_model
column_association_resolver = DunderColumnAssociationResolver(
semantic_manifest_lookup=simple_semantic_manifest_lookup,
)
column_association_resolver = DunderColumnAssociationResolver()
to_data_set_converter = SemanticModelToDataSetConverter(column_association_resolver)

to_sql_plan_converter = DataflowToSqlQueryPlanConverter(
column_association_resolver=DunderColumnAssociationResolver(simple_semantic_manifest_lookup),
column_association_resolver=DunderColumnAssociationResolver(),
semantic_manifest_lookup=simple_semantic_manifest_lookup,
)
sql_renderer: SqlQueryPlanRenderer = sql_client.sql_query_plan_renderer
Expand Down
6 changes: 2 additions & 4 deletions tests_metricflow/fixtures/manifest_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def from_parameters( # noqa: D102
semantic_manifest_lookup = SemanticManifestLookup(semantic_manifest)
data_set_mapping = MetricFlowEngineTestFixture._create_data_sets(semantic_manifest_lookup)
read_node_mapping = MetricFlowEngineTestFixture._data_set_to_read_nodes(data_set_mapping)
column_association_resolver = DunderColumnAssociationResolver(semantic_manifest_lookup)
column_association_resolver = DunderColumnAssociationResolver()
source_node_builder = SourceNodeBuilder(column_association_resolver, semantic_manifest_lookup)
source_node_set = source_node_builder.create_from_data_sets(list(data_set_mapping.values()))
node_output_resolver = DataflowPlanNodeOutputDataSetResolver(
Expand Down Expand Up @@ -247,9 +247,7 @@ def _create_data_sets(
semantic_models: Sequence[SemanticModel] = semantic_manifest_lookup.semantic_manifest.semantic_models
semantic_models = sorted(semantic_models, key=lambda x: x.name)

converter = SemanticModelToDataSetConverter(
column_association_resolver=DunderColumnAssociationResolver(semantic_manifest_lookup)
)
converter = SemanticModelToDataSetConverter(column_association_resolver=DunderColumnAssociationResolver())

for semantic_model in semantic_models:
data_sets[semantic_model.name] = converter.create_sql_source_data_set(semantic_model)
Expand Down
4 changes: 1 addition & 3 deletions tests_metricflow/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ def it_helpers( # noqa: D103
mf_engine=MetricFlowEngine(
semantic_manifest_lookup=simple_semantic_manifest_lookup,
sql_client=sql_client,
column_association_resolver=DunderColumnAssociationResolver(
semantic_manifest_lookup=simple_semantic_manifest_lookup
),
column_association_resolver=DunderColumnAssociationResolver(),
time_source=ConfigurableTimeSource(as_datetime("2020-01-01")),
),
mf_system_schema=mf_test_configuration.mf_system_schema,
Expand Down
4 changes: 1 addition & 3 deletions tests_metricflow/integration/test_rendered_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ def test_id_enumeration( # noqa: D103
mf_engine = MetricFlowEngine(
semantic_manifest_lookup=simple_semantic_manifest_lookup,
sql_client=sql_client,
column_association_resolver=DunderColumnAssociationResolver(
semantic_manifest_lookup=simple_semantic_manifest_lookup
),
column_association_resolver=DunderColumnAssociationResolver(),
time_source=ConfigurableTimeSource(as_datetime("2020-01-01")),
consistent_id_enumeration=True,
)
Expand Down
Loading

0 comments on commit b016853

Please sign in to comment.