diff --git a/src/cpr_data_access/models/search.py b/src/cpr_data_access/models/search.py index e21b67d..968e5c0 100644 --- a/src/cpr_data_access/models/search.py +++ b/src/cpr_data_access/models/search.py @@ -72,6 +72,7 @@ class SearchParameters(BaseModel): query_string: Optional[str] = "" exact_match: bool = False all_results: bool = False + documents_only: bool = False limit: int = Field(ge=0, default=100) max_hits_per_family: int = Field( validation_alias=AliasChoices("max_passages_per_doc", "max_hits_per_family"), @@ -97,6 +98,10 @@ def validate(self): """Validate against mutually exclusive fields""" if self.exact_match and self.all_results: raise QueryError("`exact_match` and `all_results` are mutually exclusive") + if self.documents_only and not self.all_results: + raise QueryError( + "`documents_only` requires `all_results`, other queries are not supported" + ) return self @field_validator("continuation_tokens") diff --git a/src/cpr_data_access/yql_builder.py b/src/cpr_data_access/yql_builder.py index 17ff831..2e9c9d0 100644 --- a/src/cpr_data_access/yql_builder.py +++ b/src/cpr_data_access/yql_builder.py @@ -9,7 +9,7 @@ class YQLBuilder: yql_base = Template( """ - select * from sources family_document, document_passage + select * from sources $SOURCES where $WHERE_CLAUSE limit 0 | @@ -36,6 +36,13 @@ def __init__(self, params: SearchParameters, sensitive: bool = False) -> None: self.params = params self.sensitive = sensitive + def build_sources(self) -> str: + """Creates the part of the query that determines which sources to search""" + if self.params.documents_only: + return "family_document" + else: + return "family_document, document_passage" + def build_search_term(self) -> str: """Create the part of the query that matches a users search text""" if self.params.all_results: @@ -158,6 +165,7 @@ def build_max_hits_per_family(self) -> int: def to_str(self) -> str: """Assemble the yql from parts using the template""" yql = self.yql_base.substitute( + SOURCES=self.build_sources(), WHERE_CLAUSE=self.build_where_clause(), CONTINUATION=self.build_continuation(), LIMIT=self.build_limit(), diff --git a/tests/test_search_adaptors.py b/tests/test_search_adaptors.py index 70c4fcb..27cab22 100644 --- a/tests/test_search_adaptors.py +++ b/tests/test_search_adaptors.py @@ -7,6 +7,8 @@ SearchParameters, SearchResponse, sort_fields, + Document, + Passage, ) from conftest import VESPA_TEST_SEARCH_URL @@ -345,3 +347,25 @@ def test_vespa_search_adapter_sorting(fake_vespa_credentials, sort_by): ) assert ascend != descend + + +@pytest.mark.vespa +def test_vespa_search_no_passages_search(fake_vespa_credentials): + no_passages = vespa_search( + fake_vespa_credentials, + SearchParameters(all_results=True, documents_only=True), + ) + for family in no_passages.families: + for hit in family.hits: + assert isinstance(hit, Document) + + with_passages = vespa_search( + fake_vespa_credentials, + SearchParameters(all_results=True), + ) + found_a_passage = False + for family in with_passages.families: + for hit in family.hits: + if isinstance(hit, Passage): + found_a_passage = True + assert found_a_passage diff --git a/tests/test_search_requests.py b/tests/test_search_requests.py index 60a05a5..0e8ceaf 100644 --- a/tests/test_search_requests.py +++ b/tests/test_search_requests.py @@ -53,11 +53,26 @@ def test_whether_an_empty_query_string_does_all_result_search(): pytest.fail(f"{e.__class__.__name__}: {e}") +def test_wether_documents_only_without_all_results_raises_error(): + q = "Search" + with pytest.raises(QueryError) as excinfo: + SearchParameters(query_string=q, documents_only=True) + assert "Failed to build query" in str(excinfo.value) + assert "`documents_only` requires `all_results`" in str(excinfo.value) + + # They should be fine otherwise: + try: + SearchParameters(query_string=q, all_results=True, documents_only=True) + except Exception as e: + pytest.fail(f"{e.__class__.__name__}: {e}") + + def test_wether_combining_all_results_and_exact_match_raises_error(): q = "Search" with pytest.raises(QueryError) as excinfo: SearchParameters(query_string=q, exact_match=True, all_results=True) - assert "" in str(excinfo.value) + assert "Failed to build query" in str(excinfo.value) + assert "`exact_match` and `all_results`" in str(excinfo.value) # They should be fine independently: try: diff --git a/tests/test_yql_builder.py b/tests/test_yql_builder.py index f30c201..7d5097f 100644 --- a/tests/test_yql_builder.py +++ b/tests/test_yql_builder.py @@ -11,6 +11,16 @@ from cpr_data_access.yql_builder import YQLBuilder +def test_whether_document_only_search_ignores_passages_in_yql(): + params = SearchParameters( + all_results=True, + documents_only=True, + ) + yql = YQLBuilder(params).to_str() + assert "family_document" in yql + assert "document_passage" not in yql + + def test_whether_single_filter_values_and_lists_of_filter_values_appear_in_yql(): filters = { "family_geography": ["SWE"],