diff --git a/.changes/unreleased/Breaking Changes-20241105-180727.yaml b/.changes/unreleased/Breaking Changes-20241105-180727.yaml new file mode 100644 index 00000000..b6f58a59 --- /dev/null +++ b/.changes/unreleased/Breaking Changes-20241105-180727.yaml @@ -0,0 +1,6 @@ +kind: Breaking Changes +body: Update PydanticWhereFilter.call_parameter_sets and PydanticWhereFilterIntersection.filter_expression_parameter_sets from property to a method +time: 2024-11-05T18:07:27.325103-05:00 +custom: + Author: WilliamDee + Issue: None diff --git a/.changes/unreleased/Under the Hood-20241023-180425.yaml b/.changes/unreleased/Under the Hood-20241023-180425.yaml new file mode 100644 index 00000000..25b933b8 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20241023-180425.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Added validation warnings for invalid granularity names in where filters of saved queries. +time: 2024-10-23T18:04:25.235887-07:00 +custom: + Author: theyostalservice + Issue: "360" diff --git a/dbt_semantic_interfaces/implementations/filters/where_filter.py b/dbt_semantic_interfaces/implementations/filters/where_filter.py index 96a74581..f49e237f 100644 --- a/dbt_semantic_interfaces/implementations/filters/where_filter.py +++ b/dbt_semantic_interfaces/implementations/filters/where_filter.py @@ -2,7 +2,7 @@ import textwrap import traceback -from typing import Callable, Generator, List, Tuple +from typing import Callable, Generator, List, Sequence, Tuple from typing_extensions import Self @@ -49,9 +49,10 @@ def _from_yaml_value( else: raise ValueError(f"Expected input to be of type string, but got type {type(input)} with value: {input}") - @property - def call_parameter_sets(self) -> FilterCallParameterSets: # noqa: D - return WhereFilterParser.parse_call_parameter_sets(self.where_sql_template) + def call_parameter_sets(self, custom_granularity_names: Sequence[str]) -> FilterCallParameterSets: # noqa: D + return WhereFilterParser.parse_call_parameter_sets( + where_sql_template=self.where_sql_template, custom_granularity_names=custom_granularity_names + ) class PydanticWhereFilterIntersection(HashableBaseModel): @@ -115,14 +116,20 @@ def _convert_legacy_and_yaml_input(cls, input: PydanticParseableValueType) -> Se f"or dict but got {type(input)} with value {input}" ) - @property - def filter_expression_parameter_sets(self) -> List[Tuple[str, FilterCallParameterSets]]: + def filter_expression_parameter_sets( + self, custom_granularity_names: Sequence[str] + ) -> List[Tuple[str, FilterCallParameterSets]]: """Gets the call parameter sets for each filter expression.""" filter_parameter_sets: List[Tuple[str, FilterCallParameterSets]] = [] invalid_filter_expressions: List[Tuple[str, Exception]] = [] for where_filter in self.where_filters: try: - filter_parameter_sets.append((where_filter.where_sql_template, where_filter.call_parameter_sets)) + filter_parameter_sets.append( + ( + where_filter.where_sql_template, + where_filter.call_parameter_sets(custom_granularity_names=custom_granularity_names), + ) + ) except Exception as e: invalid_filter_expressions.append((where_filter.where_sql_template, e)) diff --git a/dbt_semantic_interfaces/naming/dundered.py b/dbt_semantic_interfaces/naming/dundered.py index 1095d3a3..d8c03e47 100644 --- a/dbt_semantic_interfaces/naming/dundered.py +++ b/dbt_semantic_interfaces/naming/dundered.py @@ -2,7 +2,7 @@ import logging from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional, Sequence, Tuple from dbt_semantic_interfaces.naming.keywords import DUNDER from dbt_semantic_interfaces.references import EntityReference @@ -19,20 +19,14 @@ class StructuredDunderedName: entity_links: ["listing"] element_name: "ds" granularity: TimeGranularity.WEEK - - The time granularity is part of legacy query syntax and there are plans to migrate away from this format. As such, - this will not be updated to allow for custom granularity values. This implies that any query paths that push named - parameters through this class will not support a custom grain reference of the form `metric_time__martian_year`, - and users wishing to use their martian year grain will have to explicitly reference it via a separate parameter - instead of gluing it onto the end of the name. """ entity_links: Tuple[EntityReference, ...] element_name: str - time_granularity: Optional[TimeGranularity] = None + time_granularity: Optional[str] = None @staticmethod - def parse_name(name: str) -> StructuredDunderedName: + def parse_name(name: str, custom_granularity_names: Sequence[str] = ()) -> StructuredDunderedName: """Construct from a string like 'listing__ds__month'.""" name_parts = name.split(DUNDER) @@ -40,11 +34,17 @@ def parse_name(name: str) -> StructuredDunderedName: if len(name_parts) == 1: return StructuredDunderedName((), name_parts[0]) - associated_granularity = None - granularity: TimeGranularity + associated_granularity: Optional[str] = None for granularity in TimeGranularity: if name_parts[-1] == granularity.value: - associated_granularity = granularity + associated_granularity = granularity.value + break + + if associated_granularity is None: + for custom_grain in custom_granularity_names: + if name_parts[-1] == custom_grain: + associated_granularity = custom_grain + break # Has a time granularity if associated_granularity: @@ -69,7 +69,7 @@ def dundered_name(self) -> str: """Return the full name form. e.g. ds or listing__ds__month.""" items = [entity_reference.element_name for entity_reference in self.entity_links] + [self.element_name] if self.time_granularity: - items.append(self.time_granularity.value) + items.append(self.time_granularity) return DUNDER.join(items) @property @@ -82,7 +82,7 @@ def dundered_name_without_granularity(self) -> str: @property def dundered_name_without_entity(self) -> str: """Return the name without the entity. e.g. listing__ds__month -> ds__month.""" - return DUNDER.join((self.element_name,) + ((self.time_granularity.value,) if self.time_granularity else ())) + return DUNDER.join((self.element_name,) + ((self.time_granularity,) if self.time_granularity else ())) @property def entity_prefix(self) -> Optional[str]: @@ -91,52 +91,3 @@ def entity_prefix(self) -> Optional[str]: return DUNDER.join(tuple(entity_reference.element_name for entity_reference in self.entity_links)) return None - - -class DunderedNameFormatter: - """Helps to parse names into StructuredDunderedName and vice versa.""" - - @staticmethod - def parse_name(name: str) -> StructuredDunderedName: - """Construct from a string like 'listing__ds__month'.""" - name_parts = name.split(DUNDER) - - # No dunder, e.g. "ds" - if len(name_parts) == 1: - return StructuredDunderedName((), name_parts[0]) - - associated_granularity = None - granularity: TimeGranularity - for granularity in TimeGranularity: - if name_parts[-1] == granularity.value: - associated_granularity = granularity - - # Has a time granularity - if associated_granularity: - # e.g. "ds__month" - if len(name_parts) == 2: - return StructuredDunderedName((), name_parts[0], associated_granularity) - # e.g. "messages__ds__month" - return StructuredDunderedName( - entity_links=tuple(EntityReference(element_name=entity_name) for entity_name in name_parts[:-2]), - element_name=name_parts[-2], - time_granularity=associated_granularity, - ) - # e.g. "messages__ds" - else: - return StructuredDunderedName( - entity_links=tuple(EntityReference(element_name=entity_name) for entity_name in name_parts[:-1]), - element_name=name_parts[-1], - ) - - @staticmethod - def create_structured_name( # noqa: D - element_name: str, - entity_links: Tuple[EntityReference, ...] = (), - time_granularity: Optional[TimeGranularity] = None, - ) -> StructuredDunderedName: - return StructuredDunderedName( - entity_links=entity_links, - element_name=element_name, - time_granularity=time_granularity, - ) diff --git a/dbt_semantic_interfaces/parsing/text_input/ti_description.py b/dbt_semantic_interfaces/parsing/text_input/ti_description.py index 62fff662..cdb619aa 100644 --- a/dbt_semantic_interfaces/parsing/text_input/ti_description.py +++ b/dbt_semantic_interfaces/parsing/text_input/ti_description.py @@ -56,13 +56,14 @@ def __post_init__(self) -> None: # noqa: D105 else: assert_values_exhausted(item_type) - structured_item_name = StructuredDunderedName.parse_name(self.item_name) - # Check that metrics do not have an entity prefix or entity path. if item_type is QueryItemType.METRIC: if len(self.entity_path) > 0: raise InvalidQuerySyntax("The entity path should not be specified for a metric.") - if len(structured_item_name.entity_links) > 0: + if ( + len(StructuredDunderedName.parse_name(name=self.item_name, custom_granularity_names=()).entity_links) + > 0 + ): raise InvalidQuerySyntax("The name of the metric should not have entity links.") # Check that dimensions / time dimensions have a valid date part. elif item_type is QueryItemType.DIMENSION or item_type is QueryItemType.TIME_DIMENSION: diff --git a/dbt_semantic_interfaces/parsing/text_input/ti_processor.py b/dbt_semantic_interfaces/parsing/text_input/ti_processor.py index 6823a6a4..cac7c122 100644 --- a/dbt_semantic_interfaces/parsing/text_input/ti_processor.py +++ b/dbt_semantic_interfaces/parsing/text_input/ti_processor.py @@ -10,9 +10,6 @@ from typing_extensions import override from dbt_semantic_interfaces.errors import InvalidQuerySyntax -from dbt_semantic_interfaces.parsing.text_input.description_renderer import ( - QueryItemDescriptionRenderer, -) from dbt_semantic_interfaces.parsing.text_input.rendering_helper import ( ObjectBuilderJinjaRenderHelper, ) @@ -77,34 +74,6 @@ def collect_descriptions_from_template( ) return description_collector.collected_descriptions() - def render_template( - self, - jinja_template: str, - renderer: QueryItemDescriptionRenderer, - valid_method_mapping: ValidMethodMapping, - ) -> str: - """Renders the Jinja template using the specified renderer. - - Args: - jinja_template: A Jinja template string like `{{ Dimension('listing__country') }} = 'US'`. - renderer: The renderer to use for rendering each item. - valid_method_mapping: Mapping from the builder object to the valid methods. See - `ConfiguredValidMethodMapping`. - - Returns: - The rendered Jinja template. - - Raises: - QueryItemJinjaException: See definition. - InvalidBuilderMethodException: See definition. - """ - render_processor = _RendererProcessor(renderer) - return self._process_template( - jinja_template=jinja_template, - valid_method_mapping=valid_method_mapping, - description_processor=render_processor, - ) - def _process_template( self, jinja_template: str, @@ -161,18 +130,3 @@ def process_description(self, item_description: ObjectBuilderItemDescription) -> self._items.append(item_description) return "" - - -class _RendererProcessor(ObjectBuilderItemDescriptionProcessor): - """Processor that renders the descriptions in a Jinja template using the given renderer. - - This is just a pass-through, but it allows `QueryItemDescriptionRenderer` to be a facade that has more appropriate - method names. - """ - - def __init__(self, renderer: QueryItemDescriptionRenderer) -> None: # noqa: D107 - self._renderer = renderer - - @override - def process_description(self, item_description: ObjectBuilderItemDescription) -> str: - return self._renderer.render_description(item_description) diff --git a/dbt_semantic_interfaces/parsing/where_filter/parameter_set_factory.py b/dbt_semantic_interfaces/parsing/where_filter/parameter_set_factory.py index c59dc016..565ddaad 100644 --- a/dbt_semantic_interfaces/parsing/where_filter/parameter_set_factory.py +++ b/dbt_semantic_interfaces/parsing/where_filter/parameter_set_factory.py @@ -7,7 +7,7 @@ ParseWhereFilterException, TimeDimensionCallParameterSet, ) -from dbt_semantic_interfaces.naming.dundered import DunderedNameFormatter +from dbt_semantic_interfaces.naming.dundered import StructuredDunderedName from dbt_semantic_interfaces.naming.keywords import is_metric_time_name from dbt_semantic_interfaces.references import ( DimensionReference, @@ -46,6 +46,7 @@ def _exception_message_for_incorrect_format(element_name: str) -> str: @staticmethod def create_time_dimension( time_dimension_name: str, + custom_granularity_names: Sequence[str], time_granularity_name: Optional[str] = None, entity_path: Sequence[str] = (), date_part_name: Optional[str] = None, @@ -65,14 +66,14 @@ def create_time_dimension( for parsing where filters. When we solve the problems with our current where filter spec this will persist as a backwards compatibility model, but nothing more. """ - group_by_item_name = DunderedNameFormatter.parse_name(time_dimension_name) + group_by_item_name = StructuredDunderedName.parse_name( + name=time_dimension_name, custom_granularity_names=custom_granularity_names + ) if len(group_by_item_name.entity_links) != 1 and not is_metric_time_name(group_by_item_name.element_name): raise ParseWhereFilterException( ParameterSetFactory._exception_message_for_incorrect_format(time_dimension_name) ) - grain_parsed_from_name = ( - group_by_item_name.time_granularity.value if group_by_item_name.time_granularity else None - ) + grain_parsed_from_name = group_by_item_name.time_granularity inputs_are_mismatched = ( grain_parsed_from_name is not None and time_granularity_name is not None @@ -101,7 +102,7 @@ def create_time_dimension( @staticmethod def create_dimension(dimension_name: str, entity_path: Sequence[str] = ()) -> DimensionCallParameterSet: """Gets called by Jinja when rendering {{ Dimension(...) }}.""" - group_by_item_name = DunderedNameFormatter.parse_name(dimension_name) + group_by_item_name = StructuredDunderedName.parse_name(name=dimension_name, custom_granularity_names=()) if len(group_by_item_name.entity_links) != 1 and not is_metric_time_name(group_by_item_name.element_name): raise ParseWhereFilterException(ParameterSetFactory._exception_message_for_incorrect_format(dimension_name)) @@ -116,7 +117,7 @@ def create_dimension(dimension_name: str, entity_path: Sequence[str] = ()) -> Di @staticmethod def create_entity(entity_name: str, entity_path: Sequence[str] = ()) -> EntityCallParameterSet: """Gets called by Jinja when rendering {{ Entity(...) }}.""" - structured_dundered_name = DunderedNameFormatter.parse_name(entity_name) + structured_dundered_name = StructuredDunderedName.parse_name(name=entity_name, custom_granularity_names=()) if structured_dundered_name.time_granularity is not None: raise ParseWhereFilterException( f"Name is in an incorrect format: {repr(entity_name)}. " f"It should not contain a time grain suffix." diff --git a/dbt_semantic_interfaces/parsing/where_filter/where_filter_parser.py b/dbt_semantic_interfaces/parsing/where_filter/where_filter_parser.py index 7d070b84..f8ba9a51 100644 --- a/dbt_semantic_interfaces/parsing/where_filter/where_filter_parser.py +++ b/dbt_semantic_interfaces/parsing/where_filter/where_filter_parser.py @@ -39,7 +39,9 @@ def parse_item_descriptions(where_sql_template: str) -> Sequence[ObjectBuilderIt raise ParseWhereFilterException(f"Error while parsing Jinja template:\n{where_sql_template}") from e @staticmethod - def parse_call_parameter_sets(where_sql_template: str) -> FilterCallParameterSets: + def parse_call_parameter_sets( + where_sql_template: str, custom_granularity_names: Sequence[str] + ) -> FilterCallParameterSets: """Return the result of extracting the semantic objects referenced in the where SQL template string.""" descriptions = WhereFilterParser.parse_item_descriptions(where_sql_template) @@ -63,6 +65,7 @@ def parse_call_parameter_sets(where_sql_template: str) -> FilterCallParameterSet time_granularity_name=description.time_granularity_name, entity_path=description.entity_path, date_part_name=description.date_part_name, + custom_granularity_names=custom_granularity_names, ) ) else: @@ -79,6 +82,7 @@ def parse_call_parameter_sets(where_sql_template: str) -> FilterCallParameterSet time_granularity_name=description.time_granularity_name, entity_path=description.entity_path, date_part_name=description.date_part_name, + custom_granularity_names=custom_granularity_names, ) ) elif item_type is QueryItemType.ENTITY: diff --git a/dbt_semantic_interfaces/parsing/where_filter/where_filter_time_dimension.py b/dbt_semantic_interfaces/parsing/where_filter/where_filter_time_dimension.py deleted file mode 100644 index 693c8344..00000000 --- a/dbt_semantic_interfaces/parsing/where_filter/where_filter_time_dimension.py +++ /dev/null @@ -1,57 +0,0 @@ -from __future__ import annotations - -from typing import List, Optional, Sequence - -from typing_extensions import override - -from dbt_semantic_interfaces.call_parameter_sets import TimeDimensionCallParameterSet -from dbt_semantic_interfaces.errors import InvalidQuerySyntax -from dbt_semantic_interfaces.parsing.where_filter.parameter_set_factory import ( - ParameterSetFactory, -) -from dbt_semantic_interfaces.protocols.protocol_hint import ProtocolHint -from dbt_semantic_interfaces.protocols.query_interface import ( - QueryInterfaceTimeDimension, - QueryInterfaceTimeDimensionFactory, -) - - -class TimeDimensionStub(ProtocolHint[QueryInterfaceTimeDimension]): - """A TimeDimension implementation that just satisfies the protocol. - - QueryInterfaceTimeDimension currently has no methods and the parameter set is created in the factory. - So, there is nothing to do here. - """ - - @override - def _implements_protocol(self) -> QueryInterfaceTimeDimension: - return self - - -class WhereFilterTimeDimensionFactory(ProtocolHint[QueryInterfaceTimeDimensionFactory]): - """Executes in the Jinja sandbox to produce parameter sets and append them to a list.""" - - @override - def _implements_protocol(self) -> QueryInterfaceTimeDimensionFactory: - return self - - def __init__(self) -> None: # noqa - self.time_dimension_call_parameter_sets: List[TimeDimensionCallParameterSet] = [] - - def create( - self, - time_dimension_name: str, - time_granularity_name: Optional[str] = None, - entity_path: Sequence[str] = (), - descending: Optional[bool] = None, - date_part_name: Optional[str] = None, - ) -> TimeDimensionStub: - """Gets called by Jinja when rendering {{ TimeDimension(...) }}.""" - if descending is not None: - raise InvalidQuerySyntax("descending is invalid in the where parameter and filter spec") - self.time_dimension_call_parameter_sets.append( - ParameterSetFactory.create_time_dimension( - time_dimension_name, time_granularity_name, entity_path, date_part_name - ) - ) - return TimeDimensionStub() diff --git a/dbt_semantic_interfaces/protocols/where_filter.py b/dbt_semantic_interfaces/protocols/where_filter.py index 7792e006..470e912e 100644 --- a/dbt_semantic_interfaces/protocols/where_filter.py +++ b/dbt_semantic_interfaces/protocols/where_filter.py @@ -13,9 +13,8 @@ def where_sql_template(self) -> str: """A template that describes how to render the SQL for a WHERE clause.""" pass - @property @abstractmethod - def call_parameter_sets(self) -> FilterCallParameterSets: + def call_parameter_sets(self, custom_granularity_names: Sequence[str]) -> FilterCallParameterSets: """Describe calls like 'dimension(...)' in the SQL template.""" pass @@ -41,9 +40,10 @@ def where_filters(self) -> Sequence[WhereFilter]: """The collection of WhereFilters to be applied to the input data set.""" pass - @property @abstractmethod - def filter_expression_parameter_sets(self) -> Sequence[Tuple[str, FilterCallParameterSets]]: + def filter_expression_parameter_sets( + self, custom_granularity_names: Sequence[str] + ) -> Sequence[Tuple[str, FilterCallParameterSets]]: """Mapping from distinct filter expressions to the call parameter sets associated with them. We use a tuple, rather than a Mapping, in case the call parameter sets may vary between diff --git a/dbt_semantic_interfaces/test_utils.py b/dbt_semantic_interfaces/test_utils.py index ced0cd21..fc9a01b2 100644 --- a/dbt_semantic_interfaces/test_utils.py +++ b/dbt_semantic_interfaces/test_utils.py @@ -25,6 +25,10 @@ ) from dbt_semantic_interfaces.parsing.objects import YamlConfigFile from dbt_semantic_interfaces.type_enums import MetricType, TimeGranularity +from dbt_semantic_interfaces.validations.validator_helpers import ( + SemanticManifestValidationResults, + ValidationIssue, +) logger = logging.getLogger(__name__) @@ -169,3 +173,64 @@ def semantic_model_with_guaranteed_meta( dimensions=dimensions, metadata=metadata, ) + + +def _assert_expected_validation_message( # noqa: D + issues: Sequence[ValidationIssue], + message_fragment: str, +) -> None: + found_match = any([issue.message.find(message_fragment) != -1 for issue in issues]) + # Adding this dict to the assert so that when it does not match, pytest prints the expected and actual values. + assert { + "expected": message_fragment, + "actual_messages": [issue.message for issue in issues], + } and found_match + + +def check_expected_issues( # noqa: D + results: SemanticManifestValidationResults, + num_expected_errors: int = 0, + num_expected_warnings: int = 0, + expected_error_msgs: Sequence[str] = [], + expected_warning_msgs: Sequence[str] = [], +) -> None: + """Validates the number, type, and content of ValidationIssues. + + Currently assumes zero future_errors as there are no future_errors + implemented, but this function can be expanded to cover those if needed. + """ + assert len(results.warnings) == num_expected_warnings + assert len(results.errors) == num_expected_errors + assert len(results.future_errors) == 0, "validation function expects zero future_errors to be implemented." + + for expected_error_msg in expected_error_msgs: + _assert_expected_validation_message(issues=results.errors, message_fragment=expected_error_msg) + for expected_warning_msg in expected_warning_msgs: + _assert_expected_validation_message(issues=results.warnings, message_fragment=expected_warning_msg) + + +def check_only_one_error_with_message( # noqa: D + results: SemanticManifestValidationResults, target_message: str +) -> None: + check_expected_issues( + results=results, + num_expected_errors=1, + expected_error_msgs=[target_message], + ) + + +def check_only_one_warning_with_message( # noqa: D + results: SemanticManifestValidationResults, target_message: str +) -> None: + check_expected_issues( + results=results, + num_expected_warnings=1, + expected_warning_msgs=[target_message], + ) + + +def check_no_errors_or_warnings(results: SemanticManifestValidationResults) -> None: # noqa: D + # no num arguments required since all defaults are zero + check_expected_issues( + results=results, + ) diff --git a/dbt_semantic_interfaces/validations/metrics.py b/dbt_semantic_interfaces/validations/metrics.py index f0eb0dff..80278a51 100644 --- a/dbt_semantic_interfaces/validations/metrics.py +++ b/dbt_semantic_interfaces/validations/metrics.py @@ -1,7 +1,6 @@ import traceback -from typing import Dict, Generic, List, Optional, Sequence, Tuple +from typing import Dict, Generic, List, Optional, Sequence -from dbt_semantic_interfaces.call_parameter_sets import FilterCallParameterSets from dbt_semantic_interfaces.errors import ParsingException from dbt_semantic_interfaces.implementations.metric import ( PydanticMetric, @@ -35,10 +34,14 @@ ValidationError, ValidationIssue, ValidationWarning, - generate_exception_issue, validate_safely, ) +# Avoids breaking change from moving this class out of this file. +from dbt_semantic_interfaces.validations.where_filters import ( + WhereFiltersAreParseable, # noQa +) + class CumulativeMetricRule(SemanticManifestValidationRule[SemanticManifestT], Generic[SemanticManifestT]): """Checks that cumulative metrics are configured properly.""" @@ -244,177 +247,6 @@ def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[Validati return issues -class WhereFiltersAreParseable(SemanticManifestValidationRule[SemanticManifestT], Generic[SemanticManifestT]): - """Validates that all Metric WhereFilters are parseable.""" - - @staticmethod - def _validate_time_granularity_names( - context: MetricContext, - filter_expression_parameter_sets: Sequence[Tuple[str, FilterCallParameterSets]], - custom_granularity_names: List[str], - ) -> Sequence[ValidationIssue]: - issues: List[ValidationIssue] = [] - - valid_granularity_names = [ - standard_granularity.value for standard_granularity in TimeGranularity - ] + custom_granularity_names - for _, parameter_set in filter_expression_parameter_sets: - for time_dim_call_parameter_set in parameter_set.time_dimension_call_parameter_sets: - if not time_dim_call_parameter_set.time_granularity_name: - continue - if time_dim_call_parameter_set.time_granularity_name.lower() not in valid_granularity_names: - issues.append( - ValidationWarning( - context=context, - message=f"Filter for metric `{context.metric.metric_name}` is not valid. " - f"`{time_dim_call_parameter_set.time_granularity_name}` is not a valid granularity name. " - f"Valid granularity options: {valid_granularity_names}", - ) - ) - return issues - - @staticmethod - @validate_safely( - whats_being_done="running model validation ensuring a metric's filter properties are configured properly" - ) - def _validate_metric(metric: Metric, custom_granularity_names: List[str]) -> Sequence[ValidationIssue]: # noqa: D - issues: List[ValidationIssue] = [] - context = MetricContext( - file_context=FileContext.from_metadata(metadata=metric.metadata), - metric=MetricModelReference(metric_name=metric.name), - ) - - if metric.filter is not None: - try: - metric.filter.filter_expression_parameter_sets - except Exception as e: - issues.append( - generate_exception_issue( - what_was_being_done=f"trying to parse filter of metric `{metric.name}`", - e=e, - context=context, - extras={ - "traceback": "".join(traceback.format_tb(e.__traceback__)), - }, - ) - ) - else: - issues += WhereFiltersAreParseable._validate_time_granularity_names( - context=context, - filter_expression_parameter_sets=metric.filter.filter_expression_parameter_sets, - custom_granularity_names=custom_granularity_names, - ) - - if metric.type_params: - measure = metric.type_params.measure - if measure is not None and measure.filter is not None: - try: - measure.filter.filter_expression_parameter_sets - except Exception as e: - issues.append( - generate_exception_issue( - what_was_being_done=f"trying to parse filter of measure input `{measure.name}` " - f"on metric `{metric.name}`", - e=e, - context=context, - extras={ - "traceback": "".join(traceback.format_tb(e.__traceback__)), - }, - ) - ) - else: - issues += WhereFiltersAreParseable._validate_time_granularity_names( - context=context, - filter_expression_parameter_sets=measure.filter.filter_expression_parameter_sets, - custom_granularity_names=custom_granularity_names, - ) - - numerator = metric.type_params.numerator - if numerator is not None and numerator.filter is not None: - try: - numerator.filter.filter_expression_parameter_sets - except Exception as e: - issues.append( - generate_exception_issue( - what_was_being_done=f"trying to parse the numerator filter on metric `{metric.name}`", - e=e, - context=context, - extras={ - "traceback": "".join(traceback.format_tb(e.__traceback__)), - }, - ) - ) - else: - issues += WhereFiltersAreParseable._validate_time_granularity_names( - context=context, - filter_expression_parameter_sets=numerator.filter.filter_expression_parameter_sets, - custom_granularity_names=custom_granularity_names, - ) - - denominator = metric.type_params.denominator - if denominator is not None and denominator.filter is not None: - try: - denominator.filter.filter_expression_parameter_sets - except Exception as e: - issues.append( - generate_exception_issue( - what_was_being_done=f"trying to parse the denominator filter on metric `{metric.name}`", - e=e, - context=context, - extras={ - "traceback": "".join(traceback.format_tb(e.__traceback__)), - }, - ) - ) - else: - issues += WhereFiltersAreParseable._validate_time_granularity_names( - context=context, - filter_expression_parameter_sets=denominator.filter.filter_expression_parameter_sets, - custom_granularity_names=custom_granularity_names, - ) - - for input_metric in metric.type_params.metrics or []: - if input_metric.filter is not None: - try: - input_metric.filter.filter_expression_parameter_sets - except Exception as e: - issues.append( - generate_exception_issue( - what_was_being_done=f"trying to parse filter for input metric `{input_metric.name}` " - f"on metric `{metric.name}`", - e=e, - context=context, - extras={ - "traceback": "".join(traceback.format_tb(e.__traceback__)), - }, - ) - ) - else: - issues += WhereFiltersAreParseable._validate_time_granularity_names( - context=context, - filter_expression_parameter_sets=input_metric.filter.filter_expression_parameter_sets, - custom_granularity_names=custom_granularity_names, - ) - - # TODO: Are saved query filters being validated? Task: SL-2932 - return issues - - @staticmethod - @validate_safely(whats_being_done="running manifest validation ensuring all metric where filters are parseable") - def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[ValidationIssue]: # noqa: D - issues: List[ValidationIssue] = [] - custom_granularity_names = [ - granularity.name - for time_spine in semantic_manifest.project_configuration.time_spines - for granularity in time_spine.custom_granularities - ] - for metric in semantic_manifest.metrics or []: - issues += WhereFiltersAreParseable._validate_metric( - metric=metric, custom_granularity_names=custom_granularity_names - ) - return issues - - class ConversionMetricRule(SemanticManifestValidationRule[SemanticManifestT], Generic[SemanticManifestT]): """Checks that conversion metrics are configured properly.""" diff --git a/dbt_semantic_interfaces/validations/saved_query.py b/dbt_semantic_interfaces/validations/saved_query.py index 7d3716bb..cad7562c 100644 --- a/dbt_semantic_interfaces/validations/saved_query.py +++ b/dbt_semantic_interfaces/validations/saved_query.py @@ -49,14 +49,18 @@ class SavedQueryRule(SemanticManifestValidationRule[SemanticManifestT], Generic[ @staticmethod @validate_safely("Validate the group-by field in a saved query.") - def _check_group_bys(valid_group_by_element_names: Set[str], saved_query: SavedQuery) -> Sequence[ValidationIssue]: + def _check_group_bys( + valid_group_by_element_names: Set[str], saved_query: SavedQuery, custom_granularity_names: Sequence[str] + ) -> Sequence[ValidationIssue]: issues: List[ValidationIssue] = [] for group_by_item in saved_query.query_params.group_by: # TODO: Replace with more appropriate abstractions once available. parameter_sets: FilterCallParameterSets try: - parameter_sets = WhereFilterParser.parse_call_parameter_sets("{{" + group_by_item + "}}") + parameter_sets = WhereFilterParser.parse_call_parameter_sets( + where_sql_template="{{" + group_by_item + "}}", custom_granularity_names=custom_granularity_names + ) except Exception as e: issues.append( generate_exception_issue( @@ -112,33 +116,6 @@ def _check_metrics(valid_metric_names: Set[str], saved_query: SavedQuery) -> Seq ) return issues - @staticmethod - @validate_safely("Validate the where field in a saved query.") - def _check_where(saved_query: SavedQuery) -> Sequence[ValidationIssue]: - issues: List[ValidationIssue] = [] - if saved_query.query_params.where is None: - return issues - for where_filter in saved_query.query_params.where.where_filters: - try: - where_filter.call_parameter_sets - except Exception as e: - issues.append( - generate_exception_issue( - what_was_being_done=f"trying to parse a filter in saved query `{saved_query.name}`", - e=e, - context=SavedQueryContext( - file_context=FileContext.from_metadata(metadata=saved_query.metadata), - element_type=SavedQueryElementType.WHERE, - element_value=where_filter.where_sql_template, - ), - extras={ - "traceback": "".join(traceback.format_tb(e.__traceback__)), - }, - ) - ) - - return issues - @staticmethod def _parse_query_item( saved_query: SavedQuery, @@ -272,6 +249,11 @@ def _check_limit(saved_query: SavedQuery) -> Sequence[ValidationIssue]: @validate_safely("Validate all saved queries in a semantic manifest.") def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[ValidationIssue]: # noqa: D issues: List[ValidationIssue] = [] + custom_granularity_names = [ + granularity.name + for time_spine in semantic_manifest.project_configuration.time_spines + for granularity in time_spine.custom_granularities + ] valid_metric_names = {metric.name for metric in semantic_manifest.metrics} valid_group_by_element_names = valid_metric_names.union({METRIC_TIME_ELEMENT_NAME}) for semantic_model in semantic_manifest.semantic_models: @@ -288,8 +270,8 @@ def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[Validati issues += SavedQueryRule._check_group_bys( valid_group_by_element_names=valid_group_by_element_names, saved_query=saved_query, + custom_granularity_names=custom_granularity_names, ) - issues += SavedQueryRule._check_where(saved_query) issues += SavedQueryRule._check_order_by(saved_query) issues += SavedQueryRule._check_limit(saved_query) return issues diff --git a/dbt_semantic_interfaces/validations/semantic_manifest_validator.py b/dbt_semantic_interfaces/validations/semantic_manifest_validator.py index 26f81452..44de62cf 100644 --- a/dbt_semantic_interfaces/validations/semantic_manifest_validator.py +++ b/dbt_semantic_interfaces/validations/semantic_manifest_validator.py @@ -27,7 +27,6 @@ ConversionMetricRule, CumulativeMetricRule, DerivedMetricRule, - WhereFiltersAreParseable, ) from dbt_semantic_interfaces.validations.non_empty import NonEmptyRule from dbt_semantic_interfaces.validations.primary_entity import PrimaryEntityRule @@ -47,6 +46,7 @@ SemanticManifestValidationResults, SemanticManifestValidationRule, ) +from dbt_semantic_interfaces.validations.where_filters import WhereFiltersAreParseable logger = logging.getLogger(__name__) diff --git a/dbt_semantic_interfaces/validations/where_filters.py b/dbt_semantic_interfaces/validations/where_filters.py new file mode 100644 index 00000000..d01dde39 --- /dev/null +++ b/dbt_semantic_interfaces/validations/where_filters.py @@ -0,0 +1,290 @@ +import traceback +from enum import Enum +from typing import Generic, List, Sequence, Tuple + +from dbt_semantic_interfaces.call_parameter_sets import FilterCallParameterSets +from dbt_semantic_interfaces.protocols import Metric, SemanticManifestT +from dbt_semantic_interfaces.protocols.saved_query import SavedQuery +from dbt_semantic_interfaces.references import MetricModelReference +from dbt_semantic_interfaces.type_enums import TimeGranularity +from dbt_semantic_interfaces.validations.validator_helpers import ( + FileContext, + MetricContext, + SavedQueryContext, + SavedQueryElementType, + SemanticManifestValidationRule, + ValidationContext, + ValidationIssue, + ValidationWarning, + generate_exception_issue, + validate_safely, +) + + +class SemanticManifestNodeType(Enum): + """Types of objects to validate (used for validation messages).""" + + SAVED_QUERY = "saved query" + METRIC = "metric" + + +class WhereFiltersAreParseable(SemanticManifestValidationRule[SemanticManifestT], Generic[SemanticManifestT]): + """Validates that all WhereFilters are parseable.""" + + @staticmethod + def _validate_time_granularity_names( + element_name: str, + object_type: SemanticManifestNodeType, + context: ValidationContext, + filter_call_param_sets: FilterCallParameterSets, + valid_granularity_names: List[str], + ) -> Sequence[ValidationIssue]: + issues: List[ValidationIssue] = [] + + for time_dim_call_parameter_set in filter_call_param_sets.time_dimension_call_parameter_sets: + if not time_dim_call_parameter_set.time_granularity_name: + continue + if time_dim_call_parameter_set.time_granularity_name.lower() not in valid_granularity_names: + issues.append( + ValidationWarning( + context=context, + message=f"Filter for {object_type} `{element_name}` is not valid. " + f"`{time_dim_call_parameter_set.time_granularity_name}` is not a valid granularity name. " + f"Valid granularity options: {valid_granularity_names}", + ) + ) + return issues + + @staticmethod + def _validate_time_granularity_names_for_saved_query( + saved_query: SavedQuery, valid_granularity_names: List[str] + ) -> Sequence[ValidationIssue]: + where_param = saved_query.query_params.where + if where_param is None: + return [] + + issues: List[ValidationIssue] = [] + for where_filter in where_param.where_filters: + issues += WhereFiltersAreParseable._validate_time_granularity_names( + element_name=saved_query.name, + object_type=SemanticManifestNodeType.SAVED_QUERY, + context=SavedQueryContext( + file_context=FileContext.from_metadata(metadata=saved_query.metadata), + element_type=SavedQueryElementType.WHERE, + element_value=where_filter.where_sql_template, + ), + filter_call_param_sets=where_filter.call_parameter_sets( + custom_granularity_names=valid_granularity_names + ), + valid_granularity_names=valid_granularity_names, + ) + + return issues + + @staticmethod + def _validate_time_granularity_names_for_metric( + context: MetricContext, + filter_expression_parameter_sets: Sequence[Tuple[str, FilterCallParameterSets]], + valid_granularity_names: List[str], + ) -> Sequence[ValidationIssue]: + issues: List[ValidationIssue] = [] + for _, param_set in filter_expression_parameter_sets: + issues += WhereFiltersAreParseable._validate_time_granularity_names( + element_name=context.metric.metric_name, + object_type=SemanticManifestNodeType.METRIC, + context=context, + filter_call_param_sets=param_set, + valid_granularity_names=valid_granularity_names, + ) + return issues + + @staticmethod + @validate_safely("validating the where field in a saved query.") + def _validate_saved_query(saved_query: SavedQuery, valid_granularity_names: List[str]) -> Sequence[ValidationIssue]: + issues: List[ValidationIssue] = [] + if saved_query.query_params.where is None: + return issues + for where_filter in saved_query.query_params.where.where_filters: + try: + where_filter.call_parameter_sets(custom_granularity_names=valid_granularity_names) + except Exception as e: + issues.append( + generate_exception_issue( + what_was_being_done=f"trying to parse a filter in saved query `{saved_query.name}`", + e=e, + context=SavedQueryContext( + file_context=FileContext.from_metadata(metadata=saved_query.metadata), + element_type=SavedQueryElementType.WHERE, + element_value=where_filter.where_sql_template, + ), + extras={ + "traceback": "".join(traceback.format_tb(e.__traceback__)), + }, + ) + ) + else: + issues += WhereFiltersAreParseable._validate_time_granularity_names_for_saved_query( + saved_query, valid_granularity_names + ) + + return issues + + @staticmethod + @validate_safely( + whats_being_done="running model validation ensuring a metric's filter properties are configured properly" + ) + def _validate_metric(metric: Metric, valid_granularity_names: List[str]) -> Sequence[ValidationIssue]: # noqa: D + issues: List[ValidationIssue] = [] + context = MetricContext( + file_context=FileContext.from_metadata(metadata=metric.metadata), + metric=MetricModelReference(metric_name=metric.name), + ) + + if metric.filter is not None: + try: + metric.filter.filter_expression_parameter_sets(custom_granularity_names=valid_granularity_names) + except Exception as e: + issues.append( + generate_exception_issue( + what_was_being_done=f"trying to parse filter of metric `{metric.name}`", + e=e, + context=context, + extras={ + "traceback": "".join(traceback.format_tb(e.__traceback__)), + }, + ) + ) + else: + issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric( + context=context, + filter_expression_parameter_sets=metric.filter.filter_expression_parameter_sets( + custom_granularity_names=valid_granularity_names + ), + valid_granularity_names=valid_granularity_names, + ) + + if metric.type_params: + measure = metric.type_params.measure + if measure is not None and measure.filter is not None: + try: + measure.filter.filter_expression_parameter_sets(custom_granularity_names=valid_granularity_names) + except Exception as e: + issues.append( + generate_exception_issue( + what_was_being_done=f"trying to parse filter of measure input `{measure.name}` " + f"on metric `{metric.name}`", + e=e, + context=context, + extras={ + "traceback": "".join(traceback.format_tb(e.__traceback__)), + }, + ) + ) + else: + issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric( + context=context, + filter_expression_parameter_sets=measure.filter.filter_expression_parameter_sets( + custom_granularity_names=valid_granularity_names + ), + valid_granularity_names=valid_granularity_names, + ) + + numerator = metric.type_params.numerator + if numerator is not None and numerator.filter is not None: + try: + numerator.filter.filter_expression_parameter_sets(custom_granularity_names=valid_granularity_names) + except Exception as e: + issues.append( + generate_exception_issue( + what_was_being_done=f"trying to parse the numerator filter on metric `{metric.name}`", + e=e, + context=context, + extras={ + "traceback": "".join(traceback.format_tb(e.__traceback__)), + }, + ) + ) + else: + issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric( + context=context, + filter_expression_parameter_sets=numerator.filter.filter_expression_parameter_sets( + custom_granularity_names=valid_granularity_names + ), + valid_granularity_names=valid_granularity_names, + ) + + denominator = metric.type_params.denominator + if denominator is not None and denominator.filter is not None: + try: + denominator.filter.filter_expression_parameter_sets( + custom_granularity_names=valid_granularity_names + ) + except Exception as e: + issues.append( + generate_exception_issue( + what_was_being_done=f"trying to parse the denominator filter on metric `{metric.name}`", + e=e, + context=context, + extras={ + "traceback": "".join(traceback.format_tb(e.__traceback__)), + }, + ) + ) + else: + issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric( + context=context, + filter_expression_parameter_sets=denominator.filter.filter_expression_parameter_sets( + custom_granularity_names=valid_granularity_names + ), + valid_granularity_names=valid_granularity_names, + ) + + for input_metric in metric.type_params.metrics or []: + if input_metric.filter is not None: + try: + input_metric.filter.filter_expression_parameter_sets( + custom_granularity_names=valid_granularity_names + ) + except Exception as e: + issues.append( + generate_exception_issue( + what_was_being_done=f"trying to parse filter for input metric `{input_metric.name}` " + f"on metric `{metric.name}`", + e=e, + context=context, + extras={ + "traceback": "".join(traceback.format_tb(e.__traceback__)), + }, + ) + ) + else: + issues += WhereFiltersAreParseable._validate_time_granularity_names_for_metric( + context=context, + filter_expression_parameter_sets=input_metric.filter.filter_expression_parameter_sets( + custom_granularity_names=valid_granularity_names + ), + valid_granularity_names=valid_granularity_names, + ) + return issues + + @staticmethod + @validate_safely(whats_being_done="running manifest validation ensuring all metric where filters are parseable") + def validate_manifest(semantic_manifest: SemanticManifestT) -> Sequence[ValidationIssue]: # noqa: D + issues: List[ValidationIssue] = [] + custom_granularity_names = [ + granularity.name + for time_spine in semantic_manifest.project_configuration.time_spines + for granularity in time_spine.custom_granularities + ] + valid_granularity_names = [ + standard_granularity.value for standard_granularity in TimeGranularity + ] + custom_granularity_names + + for metric in semantic_manifest.metrics or []: + issues += WhereFiltersAreParseable._validate_metric( + metric=metric, valid_granularity_names=valid_granularity_names + ) + for saved_query in semantic_manifest.saved_queries: + issues += WhereFiltersAreParseable._validate_saved_query(saved_query, valid_granularity_names) + + return issues diff --git a/tests/example_project_configuration.py b/tests/example_project_configuration.py index 71f407bf..c107165e 100644 --- a/tests/example_project_configuration.py +++ b/tests/example_project_configuration.py @@ -51,6 +51,14 @@ primary_column: name: ds_day time_granularity: day + - node_relation: + schema_name: stuffs + alias: week_time_spine + primary_column: + name: ds + time_granularity: week + custom_granularities: + - name: martian_week """ ), ) diff --git a/tests/fixtures/semantic_manifest_yamls/simple_semantic_manifest/project_configuration.yaml b/tests/fixtures/semantic_manifest_yamls/simple_semantic_manifest/project_configuration.yaml index 80c6f34a..efce1c6a 100644 --- a/tests/fixtures/semantic_manifest_yamls/simple_semantic_manifest/project_configuration.yaml +++ b/tests/fixtures/semantic_manifest_yamls/simple_semantic_manifest/project_configuration.yaml @@ -11,3 +11,11 @@ project_configuration: primary_column: name: ds_day time_granularity: day + - node_relation: + alias: mf_time_spine + schema_name: stufffs + primary_column: + name: ds + time_granularity: day + custom_granularities: + - name: martian_day diff --git a/tests/implementations/where_filter/test_parse_calls.py b/tests/implementations/where_filter/test_parse_calls.py index 2a7f9e89..02f284f5 100644 --- a/tests/implementations/where_filter/test_parse_calls.py +++ b/tests/implementations/where_filter/test_parse_calls.py @@ -34,7 +34,7 @@ def test_extract_dimension_call_parameter_sets() -> None: # noqa: D AND {{ Dimension('user__country', entity_path=['listing']) }} == 'US'\ """ ) - ).call_parameter_sets + ).call_parameter_sets(custom_granularity_names=()) assert parse_result == FilterCallParameterSets( dimension_call_parameter_sets=( @@ -61,7 +61,7 @@ def test_extract_dimension_with_grain_call_parameter_sets() -> None: # noqa: D {{ Dimension('metric_time').grain('WEEK') }} > 2023-09-18 """ ) - ).call_parameter_sets + ).call_parameter_sets(custom_granularity_names=()) assert parse_result == FilterCallParameterSets( dimension_call_parameter_sets=(), @@ -81,7 +81,7 @@ def test_extract_time_dimension_call_parameter_sets() -> None: # noqa: D where_sql_template=( """{{ TimeDimension('user__created_at', 'month', entity_path=['listing']) }} = '2020-01-01'""" ) - ).call_parameter_sets + ).call_parameter_sets(custom_granularity_names=()) assert parse_result == FilterCallParameterSets( time_dimension_call_parameter_sets=( @@ -100,7 +100,7 @@ def test_extract_time_dimension_call_parameter_sets() -> None: # noqa: D where_sql_template=( """{{ TimeDimension('user__created_at__month', entity_path=['listing']) }} = '2020-01-01'""" ) - ).call_parameter_sets + ).call_parameter_sets(custom_granularity_names=()) assert parse_result == FilterCallParameterSets( time_dimension_call_parameter_sets=( @@ -119,7 +119,7 @@ def test_extract_time_dimension_call_parameter_sets() -> None: # noqa: D def test_extract_metric_time_dimension_call_parameter_sets() -> None: # noqa: D parse_result = PydanticWhereFilter( where_sql_template="""{{ TimeDimension('metric_time', 'month') }} = '2020-01-01'""" - ).call_parameter_sets + ).call_parameter_sets(custom_granularity_names=()) assert parse_result == FilterCallParameterSets( time_dimension_call_parameter_sets=( @@ -137,7 +137,7 @@ def test_extract_entity_call_parameter_sets() -> None: # noqa: D where_sql_template=( """{{ Entity('listing') }} AND {{ Entity('user', entity_path=['listing']) }} == 'TEST_USER_ID'""" ) - ).call_parameter_sets + ).call_parameter_sets(custom_granularity_names=()) assert parse_result == FilterCallParameterSets( dimension_call_parameter_sets=(), @@ -157,7 +157,7 @@ def test_extract_entity_call_parameter_sets() -> None: # noqa: D def test_extract_metric_call_parameter_sets() -> None: # noqa: D parse_result = PydanticWhereFilter( where_sql_template=("{{ Metric('bookings', group_by=['listing']) }} > 2") - ).call_parameter_sets + ).call_parameter_sets(custom_granularity_names=()) assert parse_result == FilterCallParameterSets( dimension_call_parameter_sets=(), @@ -172,7 +172,7 @@ def test_extract_metric_call_parameter_sets() -> None: # noqa: D parse_result = PydanticWhereFilter( where_sql_template=("{{ Metric('bookings', group_by=['listing', 'metric_time']) }} > 2") - ).call_parameter_sets + ).call_parameter_sets(custom_granularity_names=()) assert parse_result == FilterCallParameterSets( dimension_call_parameter_sets=(), @@ -186,7 +186,9 @@ def test_extract_metric_call_parameter_sets() -> None: # noqa: D ) with pytest.raises(ParseWhereFilterException): - PydanticWhereFilter(where_sql_template=("{{ Metric('bookings') }} > 2")).call_parameter_sets + PydanticWhereFilter(where_sql_template=("{{ Metric('bookings') }} > 2")).call_parameter_sets( + custom_granularity_names=() + ) def test_invalid_entity_name_error() -> None: @@ -194,7 +196,7 @@ def test_invalid_entity_name_error() -> None: bad_entity_filter = PydanticWhereFilter(where_sql_template="{{ Entity('is_food_order__day' )}}") with pytest.raises(ParseWhereFilterException, match="Name is in an incorrect format"): - bad_entity_filter.call_parameter_sets + bad_entity_filter.call_parameter_sets(custom_granularity_names=()) def test_where_filter_interesection_extract_call_parameter_sets() -> None: @@ -209,7 +211,7 @@ def test_where_filter_interesection_extract_call_parameter_sets() -> None: ) filter_intersection = PydanticWhereFilterIntersection(where_filters=[time_filter, entity_filter]) - parse_result = dict(filter_intersection.filter_expression_parameter_sets) + parse_result = dict(filter_intersection.filter_expression_parameter_sets(custom_granularity_names=())) assert parse_result.get(time_filter.where_sql_template) == FilterCallParameterSets( time_dimension_call_parameter_sets=( @@ -250,7 +252,7 @@ def test_where_filter_intersection_error_collection() -> None: ) with pytest.raises(ParseWhereFilterException) as exc_info: - filter_intersection.filter_expression_parameter_sets + filter_intersection.filter_expression_parameter_sets(custom_granularity_names=()) error_string = str(exc_info.value) # These are a little too implementation-specific, but it demonstrates that we are collecting the errors we find. @@ -261,7 +263,7 @@ def test_where_filter_intersection_error_collection() -> None: def test_time_dimension_without_granularity() -> None: # noqa: D parse_result = PydanticWhereFilter( where_sql_template="{{ TimeDimension('booking__created_at') }} > 2023-09-18" - ).call_parameter_sets + ).call_parameter_sets(custom_granularity_names=()) assert parse_result == FilterCallParameterSets( dimension_call_parameter_sets=(), @@ -274,3 +276,21 @@ def test_time_dimension_without_granularity() -> None: # noqa: D ), entity_call_parameter_sets=(), ) + + +def test_time_dimension_with_custom_granularity() -> None: # noqa: D + parse_result = PydanticWhereFilter( + where_sql_template="{{ TimeDimension('booking__created_at', 'martian_week') }} > 2023-09-18" + ).call_parameter_sets(custom_granularity_names=("martian_week",)) + + assert parse_result == FilterCallParameterSets( + dimension_call_parameter_sets=(), + time_dimension_call_parameter_sets=( + TimeDimensionCallParameterSet( + entity_path=(EntityReference("booking"),), + time_dimension_reference=TimeDimensionReference(element_name="created_at"), + time_granularity_name="martian_week", + ), + ), + entity_call_parameter_sets=(), + ) diff --git a/tests/parsing/test_saved_query_parsing.py b/tests/parsing/test_saved_query_parsing.py index 20b8f4b3..2bc04d11 100644 --- a/tests/parsing/test_saved_query_parsing.py +++ b/tests/parsing/test_saved_query_parsing.py @@ -120,6 +120,35 @@ def test_saved_query_group_by() -> None: ) +def test_saved_query_group_by_with_custom_grain() -> None: + """Test for parsing group_bys in a saved query.""" + yaml_contents = textwrap.dedent( + """\ + saved_query: + name: test_saved_query_group_bys + query_params: + metrics: + - test_metric_a + group_by: + - TimeDimension('test_entity__metric_time', 'martian_week') + - Dimension('test_entity__metric_time__martian_week') + + """ + ) + file = YamlConfigFile(filepath="test_dir/inline_for_test", contents=yaml_contents) + + build_result = parse_yaml_files_to_semantic_manifest(files=[file, EXAMPLE_PROJECT_CONFIGURATION_YAML_CONFIG_FILE]) + + assert len(build_result.semantic_manifest.saved_queries) == 1 + saved_query = build_result.semantic_manifest.saved_queries[0] + assert len(saved_query.query_params.group_by) == 2 + print(saved_query.query_params.group_by) + assert { + "TimeDimension('test_entity__metric_time', 'martian_week')", + "Dimension('test_entity__metric_time__martian_week')", + } == set(saved_query.query_params.group_by) + + def test_saved_query_where() -> None: """Test for parsing where clause in a saved query.""" where = "Dimension(test_entity__test_dimension) == true" diff --git a/tests/parsing/test_where_filter_parsing.py b/tests/parsing/test_where_filter_parsing.py index 60a37e0a..8764a97a 100644 --- a/tests/parsing/test_where_filter_parsing.py +++ b/tests/parsing/test_where_filter_parsing.py @@ -165,14 +165,14 @@ def test_where_filter_intersection_from_partially_deserialized_list_of_strings() ], ) def test_time_dimension_date_part(where: str) -> None: # noqa - param_sets = WhereFilterParser.parse_call_parameter_sets(where) + param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=()) assert len(param_sets.time_dimension_call_parameter_sets) == 1 assert param_sets.time_dimension_call_parameter_sets[0].date_part == DatePart.YEAR def test_dimension_date_part() -> None: # noqa where = "{{ Dimension('metric_time').grain('DAY').date_part('YEAR') }} > '2023-01-01'" - param_sets = WhereFilterParser.parse_call_parameter_sets(where) + param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=()) assert len(param_sets.time_dimension_call_parameter_sets) == 1 assert param_sets.time_dimension_call_parameter_sets[0].date_part == DatePart.YEAR @@ -196,20 +196,36 @@ def test_dimension_date_part() -> None: # noqa time_granularity_name=TimeGranularity.WEEK.value, ), ), + ( + "{{ TimeDimension('metric_time__martian_week') }} > '2023-01-01'", + TimeDimensionCallParameterSet( + time_dimension_reference=TimeDimensionReference("metric_time"), + entity_path=(), + time_granularity_name="martian_week", + ), + ), + ( + "{{ TimeDimension('metric_time', time_granularity_name='martian_week') }} > '2023-01-01'", + TimeDimensionCallParameterSet( + time_dimension_reference=TimeDimensionReference("metric_time"), + entity_path=(), + time_granularity_name="martian_week", + ), + ), ], ) def test_time_dimension_grain( # noqa where_and_expected_call_params: Tuple[str, Union[TimeDimensionCallParameterSet, DimensionCallParameterSet]] ) -> None: where, expected_call_params = where_and_expected_call_params - param_sets = WhereFilterParser.parse_call_parameter_sets(where) + param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=("martian_week",)) assert len(param_sets.time_dimension_call_parameter_sets) == 1 assert param_sets.time_dimension_call_parameter_sets[0] == expected_call_params def test_entity_without_primary_entity_prefix() -> None: # noqa where = "{{ Entity('non_primary_entity') }} = '1'" - param_sets = WhereFilterParser.parse_call_parameter_sets(where) + param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=()) assert len(param_sets.entity_call_parameter_sets) == 1 assert param_sets.entity_call_parameter_sets[0] == EntityCallParameterSet( entity_path=(), @@ -219,7 +235,7 @@ def test_entity_without_primary_entity_prefix() -> None: # noqa def test_entity() -> None: # noqa where = "{{ Entity('entity_1__entity_2', entity_path=['entity_0']) }} = '1'" - param_sets = WhereFilterParser.parse_call_parameter_sets(where) + param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=()) assert len(param_sets.entity_call_parameter_sets) == 1 assert param_sets.entity_call_parameter_sets[0] == EntityCallParameterSet( entity_path=( @@ -232,7 +248,7 @@ def test_entity() -> None: # noqa def test_metric() -> None: # noqa where = "{{ Metric('metric', group_by=['dimension']) }} = 10" - param_sets = WhereFilterParser.parse_call_parameter_sets(where) + param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=()) assert len(param_sets.metric_call_parameter_sets) == 1 assert param_sets.metric_call_parameter_sets[0] == MetricCallParameterSet( group_by=(LinkableElementReference(element_name="dimension"),), @@ -241,7 +257,7 @@ def test_metric() -> None: # noqa # Without kwarg syntax where = "{{ Metric('metric', ['dimension']) }} = 10" - param_sets = WhereFilterParser.parse_call_parameter_sets(where) + param_sets = WhereFilterParser.parse_call_parameter_sets(where, custom_granularity_names=()) assert len(param_sets.metric_call_parameter_sets) == 1 assert param_sets.metric_call_parameter_sets[0] == MetricCallParameterSet( group_by=(LinkableElementReference(element_name="dimension"),), diff --git a/tests/validations/test_metrics.py b/tests/validations/test_metrics.py index efafeef0..c96d23b3 100644 --- a/tests/validations/test_metrics.py +++ b/tests/validations/test_metrics.py @@ -1,4 +1,3 @@ -from copy import deepcopy from typing import List, Tuple import pytest @@ -31,7 +30,6 @@ TimeDimensionReference, ) from dbt_semantic_interfaces.test_utils import ( - find_metric_with, metric_with_guaranteed_meta, semantic_model_with_guaranteed_meta, ) @@ -48,7 +46,6 @@ CumulativeMetricRule, DerivedMetricRule, MetricTimeGranularityRule, - WhereFiltersAreParseable, ) from dbt_semantic_interfaces.validations.semantic_manifest_validator import ( SemanticManifestValidator, @@ -345,145 +342,6 @@ def test_derived_metric() -> None: # noqa: D check_error_in_issues(error_substrings=expected_substrings, issues=build_issues) -def test_where_filter_validations_happy( # noqa: D - simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, -) -> None: - validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) - results = validator.validate_semantic_manifest(simple_semantic_manifest__with_primary_transforms) - assert not results.has_blocking_issues - - -def test_where_filter_validations_bad_base_filter( # noqa: D - simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, -) -> None: - manifest = deepcopy(simple_semantic_manifest__with_primary_transforms) - - metric, _ = find_metric_with(manifest, lambda metric: metric.filter is not None) - assert metric.filter is not None - assert len(metric.filter.where_filters) > 0 - metric.filter.where_filters[0].where_sql_template = "{{ dimension('too', 'many', 'variables', 'to', 'handle') }}" - validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) - with pytest.raises(SemanticManifestValidationException, match=f"trying to parse filter of metric `{metric.name}`"): - validator.checked_validations(manifest) - - -def test_where_filter_validations_bad_measure_filter( # noqa: D - simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, -) -> None: - manifest = deepcopy(simple_semantic_manifest__with_primary_transforms) - - metric, _ = find_metric_with( - manifest, lambda metric: metric.type_params is not None and metric.type_params.measure is not None - ) - assert metric.type_params.measure is not None - metric.type_params.measure.filter = PydanticWhereFilterIntersection( - where_filters=[ - PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}") - ] - ) - validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) - with pytest.raises( - SemanticManifestValidationException, - match=f"trying to parse filter of measure input `{metric.type_params.measure.name}` on metric `{metric.name}`", - ): - validator.checked_validations(manifest) - - -def test_where_filter_validations_bad_numerator_filter( # noqa: D - simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, -) -> None: - manifest = deepcopy(simple_semantic_manifest__with_primary_transforms) - - metric, _ = find_metric_with( - manifest, lambda metric: metric.type_params is not None and metric.type_params.numerator is not None - ) - assert metric.type_params.numerator is not None - metric.type_params.numerator.filter = PydanticWhereFilterIntersection( - where_filters=[ - PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}") - ] - ) - validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) - with pytest.raises( - SemanticManifestValidationException, match=f"trying to parse the numerator filter on metric `{metric.name}`" - ): - validator.checked_validations(manifest) - - -def test_where_filter_validations_bad_denominator_filter( # noqa: D - simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, -) -> None: - manifest = deepcopy(simple_semantic_manifest__with_primary_transforms) - - metric, _ = find_metric_with( - manifest, lambda metric: metric.type_params is not None and metric.type_params.denominator is not None - ) - assert metric.type_params.denominator is not None - metric.type_params.denominator.filter = PydanticWhereFilterIntersection( - where_filters=[ - PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}") - ] - ) - validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) - with pytest.raises( - SemanticManifestValidationException, match=f"trying to parse the denominator filter on metric `{metric.name}`" - ): - validator.checked_validations(manifest) - - -def test_where_filter_validations_bad_input_metric_filter( # noqa: D - simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, -) -> None: - manifest = deepcopy(simple_semantic_manifest__with_primary_transforms) - - metric, _ = find_metric_with( - manifest, - lambda metric: metric.type_params is not None - and metric.type_params.metrics is not None - and len(metric.type_params.metrics) > 0, - ) - assert metric.type_params.metrics is not None - input_metric = metric.type_params.metrics[0] - input_metric.filter = PydanticWhereFilterIntersection( - where_filters=[ - PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}") - ] - ) - validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) - with pytest.raises( - SemanticManifestValidationException, - match=f"trying to parse filter for input metric `{input_metric.name}` on metric `{metric.name}`", - ): - validator.checked_validations(manifest) - - -def test_where_filter_validations_invalid_granularity( # noqa: D - simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, -) -> None: - manifest = deepcopy(simple_semantic_manifest__with_primary_transforms) - - metric, _ = find_metric_with( - manifest, - lambda metric: metric.type_params is not None - and metric.type_params.metrics is not None - and len(metric.type_params.metrics) > 0, - ) - assert metric.type_params.metrics is not None - input_metric = metric.type_params.metrics[0] - input_metric.filter = PydanticWhereFilterIntersection( - where_filters=[ - PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'cool') }}"), - PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'month') }}"), - PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'MONTH') }}"), - ] - ) - validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) - issues = validator.validate_semantic_manifest(manifest) - assert not issues.has_blocking_issues - assert len(issues.warnings) == 1 - assert "`cool` is not a valid granularity name" in issues.warnings[0].message - - def test_conversion_metrics() -> None: # noqa: D base_measure_name = "base_measure" conversion_measure_name = "conversion_measure" diff --git a/tests/validations/test_saved_query.py b/tests/validations/test_saved_query.py index e07288e5..566407f9 100644 --- a/tests/validations/test_saved_query.py +++ b/tests/validations/test_saved_query.py @@ -12,32 +12,15 @@ from dbt_semantic_interfaces.implementations.semantic_manifest import ( PydanticSemanticManifest, ) +from dbt_semantic_interfaces.test_utils import check_only_one_error_with_message from dbt_semantic_interfaces.validations.saved_query import SavedQueryRule from dbt_semantic_interfaces.validations.semantic_manifest_validator import ( SemanticManifestValidator, ) -from dbt_semantic_interfaces.validations.validator_helpers import ( - SemanticManifestValidationResults, -) logger = logging.getLogger(__name__) -def check_only_one_error_with_message( # noqa: D - results: SemanticManifestValidationResults, target_message: str -) -> None: - assert len(results.warnings) == 0 - assert len(results.errors) == 1 - assert len(results.future_errors) == 0 - - found_match = results.errors[0].message.find(target_message) != -1 - # Adding this dict to the assert so that when it does not match, pytest prints the expected and actual values. - assert { - "expected": target_message, - "actual": results.errors[0].message, - } and found_match - - def test_invalid_metric_in_saved_query( # noqa: D simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, ) -> None: @@ -62,31 +45,6 @@ def test_invalid_metric_in_saved_query( # noqa: D ) -def test_invalid_where_in_saved_query( # noqa: D - simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, -) -> None: - manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) - manifest.saved_queries = [ - PydanticSavedQuery( - name="Example Saved Query", - description="Example description.", - query_params=PydanticSavedQueryQueryParams( - metrics=["bookings"], - group_by=["Dimension('booking__is_instant')"], - where=PydanticWhereFilterIntersection( - where_filters=[PydanticWhereFilter(where_sql_template="{{ invalid_jinja }}")], - ), - ), - ), - ] - - manifest_validator = SemanticManifestValidator[PydanticSemanticManifest]([SavedQueryRule()]) - check_only_one_error_with_message( - manifest_validator.validate_semantic_manifest(manifest), - "trying to parse a filter in saved query", - ) - - def test_invalid_group_by_element_in_saved_query( # noqa: D simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, ) -> None: @@ -137,30 +95,6 @@ def test_invalid_group_by_format_in_saved_query( # noqa: D ) -def test_metric_filter_error( # noqa: D - simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, -) -> None: - manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) - manifest.saved_queries = [ - PydanticSavedQuery( - name="Example Saved Query", - description="Example description.", - query_params=PydanticSavedQueryQueryParams( - metrics=["listings"], - where=PydanticWhereFilterIntersection( - where_filters=[PydanticWhereFilter(where_sql_template="{{ Metric('bookings') }} > 2")], - ), - ), - ), - ] - - manifest_validator = SemanticManifestValidator[PydanticSemanticManifest]([SavedQueryRule()]) - check_only_one_error_with_message( - manifest_validator.validate_semantic_manifest(manifest), - "An error occurred while trying to parse a filter in saved query", - ) - - def test_metric_filter_success( # noqa: D simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, ) -> None: diff --git a/tests/validations/test_where_filters_are_parseable.py b/tests/validations/test_where_filters_are_parseable.py new file mode 100644 index 00000000..3bedd8ee --- /dev/null +++ b/tests/validations/test_where_filters_are_parseable.py @@ -0,0 +1,309 @@ +import copy +import logging + +import pytest + +from dbt_semantic_interfaces.implementations.filters.where_filter import ( + PydanticWhereFilter, + PydanticWhereFilterIntersection, +) +from dbt_semantic_interfaces.implementations.saved_query import ( + PydanticSavedQuery, + PydanticSavedQueryQueryParams, +) +from dbt_semantic_interfaces.implementations.semantic_manifest import ( + PydanticSemanticManifest, +) +from dbt_semantic_interfaces.test_utils import ( + check_no_errors_or_warnings, + check_only_one_error_with_message, + check_only_one_warning_with_message, + find_metric_with, +) +from dbt_semantic_interfaces.validations.semantic_manifest_validator import ( + SemanticManifestValidator, +) +from dbt_semantic_interfaces.validations.validator_helpers import ( + SemanticManifestValidationException, +) +from dbt_semantic_interfaces.validations.where_filters import WhereFiltersAreParseable + +logger = logging.getLogger(__name__) + + +# ------------------------------------------------------------------------------ +# Metric validations +# ------------------------------------------------------------------------------ + + +def test_metric_where_filter_validations_happy( # noqa: D + simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, +) -> None: + validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) + results = validator.validate_semantic_manifest(simple_semantic_manifest__with_primary_transforms) + assert not results.has_blocking_issues + + +def test_where_filter_validations_bad_base_filter( # noqa: D + simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, +) -> None: + manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) + + metric, _ = find_metric_with(manifest, lambda metric: metric.filter is not None) + assert metric.filter is not None + assert len(metric.filter.where_filters) > 0 + metric.filter.where_filters[0].where_sql_template = "{{ dimension('too', 'many', 'variables', 'to', 'handle') }}" + validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) + with pytest.raises(SemanticManifestValidationException, match=f"trying to parse filter of metric `{metric.name}`"): + validator.checked_validations(manifest) + + +def test_where_filter_validations_bad_measure_filter( # noqa: D + simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, +) -> None: + manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) + + metric, _ = find_metric_with( + manifest, lambda metric: metric.type_params is not None and metric.type_params.measure is not None + ) + assert metric.type_params.measure is not None + metric.type_params.measure.filter = PydanticWhereFilterIntersection( + where_filters=[ + PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}") + ] + ) + validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) + with pytest.raises( + SemanticManifestValidationException, + match=f"trying to parse filter of measure input `{metric.type_params.measure.name}` on metric `{metric.name}`", + ): + validator.checked_validations(manifest) + + +def test_where_filter_validations_bad_numerator_filter( # noqa: D + simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, +) -> None: + manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) + + metric, _ = find_metric_with( + manifest, lambda metric: metric.type_params is not None and metric.type_params.numerator is not None + ) + assert metric.type_params.numerator is not None + metric.type_params.numerator.filter = PydanticWhereFilterIntersection( + where_filters=[ + PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}") + ] + ) + validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) + with pytest.raises( + SemanticManifestValidationException, match=f"trying to parse the numerator filter on metric `{metric.name}`" + ): + validator.checked_validations(manifest) + + +def test_where_filter_validations_bad_denominator_filter( # noqa: D + simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, +) -> None: + manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) + + metric, _ = find_metric_with( + manifest, lambda metric: metric.type_params is not None and metric.type_params.denominator is not None + ) + assert metric.type_params.denominator is not None + metric.type_params.denominator.filter = PydanticWhereFilterIntersection( + where_filters=[ + PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}") + ] + ) + validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) + with pytest.raises( + SemanticManifestValidationException, match=f"trying to parse the denominator filter on metric `{metric.name}`" + ): + validator.checked_validations(manifest) + + +def test_where_filter_validations_bad_input_metric_filter( # noqa: D + simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, +) -> None: + manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) + + metric, _ = find_metric_with( + manifest, + lambda metric: metric.type_params is not None + and metric.type_params.metrics is not None + and len(metric.type_params.metrics) > 0, + ) + assert metric.type_params.metrics is not None + input_metric = metric.type_params.metrics[0] + input_metric.filter = PydanticWhereFilterIntersection( + where_filters=[ + PydanticWhereFilter(where_sql_template="{{ dimension('too', 'many', 'variables', 'to', 'handle') }}") + ] + ) + validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) + with pytest.raises( + SemanticManifestValidationException, + match=f"trying to parse filter for input metric `{input_metric.name}` on metric `{metric.name}`", + ): + validator.checked_validations(manifest) + + +def test_metric_where_filter_validations_invalid_granularity( # noqa: D + simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, +) -> None: + manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) + + metric, _ = find_metric_with( + manifest, + lambda metric: metric.type_params is not None + and metric.type_params.metrics is not None + and len(metric.type_params.metrics) > 0, + ) + assert metric.type_params.metrics is not None + input_metric = metric.type_params.metrics[0] + input_metric.filter = PydanticWhereFilterIntersection( + where_filters=[ + PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'cool') }}"), + PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'month') }}"), + PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'MONTH') }}"), + PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'martian_day') }}"), + ] + ) + validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) + issues = validator.validate_semantic_manifest(manifest) + assert not issues.has_blocking_issues + assert len(issues.warnings) == 1 + assert "`cool` is not a valid granularity name" in issues.warnings[0].message + + +# ------------------------------------------------------------------------------ +# Saved Query validations +# ------------------------------------------------------------------------------ + + +def test_saved_query_with_happy_filter( # noqa: D + simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, +) -> None: + manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) + + manifest.saved_queries = [ + PydanticSavedQuery( + name="Example Saved Query", + description="Example description.", + query_params=PydanticSavedQueryQueryParams( + metrics=["bookings"], + group_by=["Dimension('booking__is_instant')"], + where=PydanticWhereFilterIntersection( + where_filters=[ + PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'hour') }}"), + PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'martian_day') }}"), + ] + ), + ), + ), + ] + + manifest_validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) + check_no_errors_or_warnings(manifest_validator.validate_semantic_manifest(manifest)) + + +def test_saved_query_validates_granularity_name_despite_case( # noqa: D + simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, +) -> None: + manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) + + manifest.saved_queries = [ + PydanticSavedQuery( + name="Example Saved Query", + description="Example description.", + query_params=PydanticSavedQueryQueryParams( + metrics=["bookings"], + group_by=["Dimension('booking__is_instant')"], + where=PydanticWhereFilterIntersection( + where_filters=[ + PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'DAY') }}"), + ] + ), + ), + ), + ] + + manifest_validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) + check_no_errors_or_warnings(manifest_validator.validate_semantic_manifest(manifest)) + + +def test_invalid_where_in_saved_query( # noqa: D + simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, +) -> None: + manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) + manifest.saved_queries = [ + PydanticSavedQuery( + name="Example Saved Query", + description="Example description.", + query_params=PydanticSavedQueryQueryParams( + metrics=["bookings"], + group_by=["Dimension('booking__is_instant')"], + where=PydanticWhereFilterIntersection( + where_filters=[PydanticWhereFilter(where_sql_template="{{ invalid_jinja }}")], + ), + ), + ), + ] + + manifest_validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) + check_only_one_error_with_message( + manifest_validator.validate_semantic_manifest(manifest), + "trying to parse a filter in saved query", + ) + + +def test_saved_query_where_filter_validations_invalid_granularity( # noqa: D + simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, +) -> None: + manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) + + manifest.saved_queries = [ + PydanticSavedQuery( + name="Example Saved Query", + description="Example description.", + query_params=PydanticSavedQueryQueryParams( + metrics=["bookings"], + group_by=["Dimension('booking__is_instant')"], + where=PydanticWhereFilterIntersection( + where_filters=[ + PydanticWhereFilter(where_sql_template="{{ TimeDimension('metric_time', 'cool') }}"), + ] + ), + ), + ), + ] + + manifest_validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) + check_only_one_warning_with_message( + manifest_validator.validate_semantic_manifest(manifest), + "is not a valid granularity name", + ) + + +def test_metric_filter_error( # noqa: D + simple_semantic_manifest__with_primary_transforms: PydanticSemanticManifest, +) -> None: + manifest = copy.deepcopy(simple_semantic_manifest__with_primary_transforms) + manifest.saved_queries = [ + PydanticSavedQuery( + name="Example Saved Query", + description="Example description.", + query_params=PydanticSavedQueryQueryParams( + metrics=["listings"], + where=PydanticWhereFilterIntersection( + where_filters=[PydanticWhereFilter(where_sql_template="{{ Metric('bookings') }} > 2")], + ), + ), + ), + ] + + manifest_validator = SemanticManifestValidator[PydanticSemanticManifest]([WhereFiltersAreParseable()]) + check_only_one_error_with_message( + manifest_validator.validate_semantic_manifest(manifest), + "An error occurred while trying to parse a filter in saved query", + )