Skip to content

Commit

Permalink
appending to time_dimension_call_parameter_sets
Browse files Browse the repository at this point in the history
  • Loading branch information
DevonFulcher committed Sep 20, 2023
1 parent bcc6d8e commit e3ef46c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@

from typing_extensions import override

from dbt_semantic_interfaces.call_parameter_sets import TimeDimensionCallParameterSet
from dbt_semantic_interfaces.parsing.where_filter.parameter_set_factory import (
ParameterSetFactory,
)
from dbt_semantic_interfaces.protocols.query_interface import (
QueryInterfaceDimension,
QueryInterfaceDimensionFactory,
Expand All @@ -27,38 +23,32 @@ def __init__( # noqa
self,
name: str,
entity_path: Sequence[str],
time_dimension_call_parameter_sets: List[TimeDimensionCallParameterSet],
):
self.name = name
self.entity_path = entity_path
self._time_dimension_call_parameter_sets = time_dimension_call_parameter_sets
self.time_granularity: Optional[TimeGranularity] = None

def grain(self, time_granularity: str) -> QueryInterfaceDimension:
"""The time granularity."""
self.time_granularity = TimeGranularity(time_granularity)
self._time_dimension_call_parameter_sets.append(
ParameterSetFactory.create_time_dimension(self.name, time_granularity, self.entity_path)
)
return self


class WhereFilterDimensionFactory(ProtocolHint[QueryInterfaceDimensionFactory]):
"""Creates a WhereFilterDimension.
Each call to `create` adds a WhereFilterDimension to created.
Each call to `create` adds a WhereFilterDimension to `created`.
"""

@override
def _implements_protocol(self) -> QueryInterfaceDimensionFactory:
return self

def __init__(self, time_dimension_call_parameter_sets: List[TimeDimensionCallParameterSet]): # noqa
def __init__(self): # noqa
self.created: List[WhereFilterDimension] = []
self._time_dimension_call_parameter_sets = time_dimension_call_parameter_sets

def create(self, dimension_name: str, entity_path: Sequence[str] = ()) -> WhereFilterDimension:
"""Gets called by Jinja when rendering {{ Dimension(...) }}."""
dimension = WhereFilterDimension(dimension_name, entity_path, self._time_dimension_call_parameter_sets)
dimension = WhereFilterDimension(dimension_name, entity_path)
self.created.append(dimension)
return dimension
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class WhereFilterParser:
def parse_call_parameter_sets(where_sql_template: str) -> FilterCallParameterSets:
"""Return the result of extracting the semantic objects referenced in the where SQL template string."""
time_dimension_factory = WhereFilterTimeDimensionFactory()
dimension_factory = WhereFilterDimensionFactory(time_dimension_factory.time_dimension_call_parameter_sets)
dimension_factory = WhereFilterDimensionFactory()
entity_factory = WhereFilterEntityFactory()

try:
Expand All @@ -44,9 +44,18 @@ def parse_call_parameter_sets(where_sql_template: str) -> FilterCallParameterSet

dimension_parameter_sets = []
for dimension in dimension_factory.created:
if not dimension.time_granularity:
param_set = ParameterSetFactory.create_dimension(dimension.name, dimension.entity_path)
dimension_parameter_sets.append(param_set)
if dimension.time_granularity:
time_dimension_factory.time_dimension_call_parameter_sets.append(
ParameterSetFactory.create_time_dimension(
dimension.name,
dimension.time_granularity,
dimension.entity_path,
)
)
else:
dimension_parameter_sets.append(
ParameterSetFactory.create_dimension(dimension.name, dimension.entity_path)
)

return FilterCallParameterSets(
dimension_call_parameter_sets=tuple(dimension_parameter_sets),
Expand Down

0 comments on commit e3ef46c

Please sign in to comment.