Skip to content

Commit dcf17cf

Browse files
Use helper function to find matching instances in dataset (#1538)
Clean up. Just reducing some repeated code in dataflow to SQL logic.
1 parent aa6ec15 commit dcf17cf

File tree

3 files changed

+33
-37
lines changed

3 files changed

+33
-37
lines changed

metricflow-semantics/metricflow_semantics/instances.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ class MdoInstance(ABC, Generic[SpecT]):
4949
@property
5050
def associated_column(self) -> ColumnAssociation:
5151
"""Helper for getting the associated column until support for multiple associated columns is added."""
52-
assert len(self.associated_columns) == 1
52+
assert (
53+
len(self.associated_columns) == 1
54+
), f"Expected exactly one column for {self.__class__.__name__}, but got {self.associated_columns}"
5355
return self.associated_columns[0]
5456

5557
def accept(self, visitor: InstanceVisitor[VisitorOutputT]) -> VisitorOutputT:

metricflow/dataset/sql_dataset.py

+25-16
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from __future__ import annotations
22

3-
from typing import List, Optional, Sequence
3+
from typing import List, Optional, Sequence, Tuple
44

55
from dbt_semantic_interfaces.references import SemanticModelReference
66
from metricflow_semantics.assert_one_arg import assert_exactly_one_arg_set
7-
from metricflow_semantics.instances import EntityInstance, InstanceSet
7+
from metricflow_semantics.instances import EntityInstance, InstanceSet, TimeDimensionInstance
88
from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
99
from metricflow_semantics.specs.column_assoc import ColumnAssociation
1010
from metricflow_semantics.specs.dimension_spec import DimensionSpec
@@ -122,30 +122,39 @@ def column_association_for_dimension(
122122

123123
return column_associations_to_return[0]
124124

125-
def column_association_for_time_dimension(
126-
self,
127-
time_dimension_spec: TimeDimensionSpec,
128-
) -> ColumnAssociation:
129-
"""Given the name of the time dimension, return the set of columns associated with it in the data set."""
125+
def instances_for_time_dimensions(
126+
self, time_dimension_specs: Sequence[TimeDimensionSpec]
127+
) -> Tuple[TimeDimensionInstance, ...]:
128+
"""Return the instances associated with these specs in the data set."""
129+
time_dimension_specs_set = set(time_dimension_specs)
130130
matching_instances = 0
131-
column_associations_to_return = None
131+
instances_to_return: Tuple[TimeDimensionInstance, ...] = ()
132132
for time_dimension_instance in self.instance_set.time_dimension_instances:
133-
if time_dimension_instance.spec == time_dimension_spec:
134-
column_associations_to_return = time_dimension_instance.associated_columns
133+
if time_dimension_instance.spec in time_dimension_specs_set:
134+
instances_to_return += (time_dimension_instance,)
135135
matching_instances += 1
136136

137-
if matching_instances > 1:
137+
if matching_instances != len(time_dimension_specs_set):
138138
raise RuntimeError(
139-
f"More than one time dimension instance with spec {time_dimension_spec} in "
140-
f"instance set: {self.instance_set}"
139+
f"Unexpected number of time dimension instances found matching specs.\nSpecs: {time_dimension_specs_set}\n"
140+
f"Instances: {instances_to_return}"
141141
)
142142

143-
if not column_associations_to_return:
143+
return instances_to_return
144+
145+
def instance_for_time_dimension(self, time_dimension_spec: TimeDimensionSpec) -> TimeDimensionInstance:
146+
"""Given the name of the time dimension, return the instance associated with it in the data set."""
147+
instances = self.instances_for_time_dimensions((time_dimension_spec,))
148+
if not len(instances) == 1:
144149
raise RuntimeError(
145-
f"No time dimension instances with spec {time_dimension_spec} in instance set: {self.instance_set}"
150+
f"Unexpected number of time dimension instances found matching specs.\nSpecs: {time_dimension_spec}\n"
151+
f"Instances: {instances}"
146152
)
153+
return instances[0]
147154

148-
return column_associations_to_return[0]
155+
def column_association_for_time_dimension(self, time_dimension_spec: TimeDimensionSpec) -> ColumnAssociation:
156+
"""Given the name of the time dimension, return the set of columns associated with it in the data set."""
157+
return self.instance_for_time_dimension(time_dimension_spec).associated_column
149158

150159
@property
151160
@override

metricflow/plan_conversion/dataflow_to_sql.py

+5-20
Original file line numberDiff line numberDiff line change
@@ -1472,16 +1472,8 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
14721472
self._column_association_resolver, OrderedDict({parent_alias: parent_instance_set})
14731473
)
14741474

1475-
# Select matching instance from time spine data set (using base grain - custom grain will be joined in a later node).
1476-
original_time_spine_dim_instance: Optional[TimeDimensionInstance] = None
1477-
for time_dimension_instance in time_spine_dataset.instance_set.time_dimension_instances:
1478-
if time_dimension_instance.spec == agg_time_dimension_instance_for_join.spec:
1479-
original_time_spine_dim_instance = time_dimension_instance
1480-
break
1481-
assert original_time_spine_dim_instance, (
1482-
"Couldn't find requested agg_time_dimension_instance_for_join in time spine data set, which "
1483-
f"indicates it may have been configured incorrectly. Expected: {agg_time_dimension_instance_for_join.spec};"
1484-
f" Got: {[instance.spec for instance in time_spine_dataset.instance_set.time_dimension_instances]}"
1475+
original_time_spine_dim_instance = time_spine_dataset.instance_for_time_dimension(
1476+
agg_time_dimension_instance_for_join.spec
14851477
)
14861478
time_spine_column_select_expr: Union[
14871479
SqlColumnReferenceExpression, SqlDateTruncExpression
@@ -1592,17 +1584,10 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
15921584

15931585
# New dataset will be joined to parent dataset without a subquery, so use the same FROM alias as the parent node.
15941586
parent_alias = parent_data_set.checked_sql_select_node.from_source_alias
1595-
parent_time_dimension_instance: Optional[TimeDimensionInstance] = None
1596-
for instance in parent_data_set.instance_set.time_dimension_instances:
1597-
if instance.spec == node.time_dimension_spec.with_base_grain():
1598-
parent_time_dimension_instance = instance
1599-
break
1600-
parent_column: Optional[SqlSelectColumn] = None
1601-
assert parent_time_dimension_instance, (
1602-
"JoinToCustomGranularityNode's expected time_dimension_spec not found in parent dataset instances. "
1603-
f"This indicates internal misconfiguration. Expected: {node.time_dimension_spec.with_base_grain()}; "
1604-
f"Got: {[instance.spec for instance in parent_data_set.instance_set.time_dimension_instances]}"
1587+
parent_time_dimension_instance = parent_data_set.instance_for_time_dimension(
1588+
node.time_dimension_spec.with_base_grain()
16051589
)
1590+
parent_column: Optional[SqlSelectColumn] = None
16061591
for select_column in parent_data_set.checked_sql_select_node.select_columns:
16071592
if select_column.column_alias == parent_time_dimension_instance.associated_column.column_name:
16081593
parent_column = select_column

0 commit comments

Comments
 (0)