Skip to content
This repository has been archived by the owner on Sep 11, 2024. It is now read-only.

Commit

Permalink
Rename keyword_filters to filters
Browse files Browse the repository at this point in the history
The old term is no longer that descriptive as it can from older
functionality, this also allows us to distinguish between the backends
treatment of the argument to the way vespa needs it. Finally, it's also
just simpler!
  • Loading branch information
olaughter committed Mar 27, 2024
1 parent 38705c4 commit ee6718b
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 33 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ Matching documents can also be filtered by keyword field, and by publication dat
```python
request = SearchParameters(
query_string="forest fires",
keyword_filters={
filters={
"language": ["English", "French"],
"category": ["Executive"],
},
Expand Down
4 changes: 2 additions & 2 deletions src/cpr_data_access/models/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
ID_PATTERN = re.compile(rf"{_ID_ELEMENT}\.{_ID_ELEMENT}\.{_ID_ELEMENT}\.{_ID_ELEMENT}")


class KeywordFilters(BaseModel):
class Filters(BaseModel):
"""Filterable fields in a search request"""

family_geography: Sequence[str] = []
Expand Down Expand Up @@ -82,7 +82,7 @@ class SearchParameters(BaseModel):
family_ids: Optional[Sequence[str]] = None
document_ids: Optional[Sequence[str]] = None

keyword_filters: Optional[KeywordFilters] = None
filters: Optional[Filters] = None
year_range: Optional[tuple[Optional[int], Optional[int]]] = None

sort_by: Optional[str] = Field(
Expand Down
30 changes: 13 additions & 17 deletions src/cpr_data_access/yql_builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from string import Template
from typing import Optional

from cpr_data_access.models.search import KeywordFilters, SearchParameters
from cpr_data_access.models.search import Filters, SearchParameters


class YQLBuilder:
Expand Down Expand Up @@ -91,15 +91,13 @@ def build_document_filter(self) -> Optional[str]:
return f"(document_import_id in({documents}))"
return None

def _inclusive_keyword_filters(
self, keyword_filters: KeywordFilters, field_name: str
):
values = getattr(keyword_filters, field_name)
filters = []
def _inclusive_filters(self, filters: Filters, field_name: str):
values = getattr(filters, field_name)
query_filters = []
for value in values:
filters.append(f'({field_name} contains "{value}")')
if filters:
return f"({' or '.join(filters)})"
query_filters.append(f'({field_name} contains "{value}")')
if query_filters:
return f"({' or '.join(query_filters)})"

def build_year_start_filter(self) -> Optional[str]:
"""Create the part of the query that filters on a year range"""
Expand All @@ -123,11 +121,11 @@ def build_where_clause(self) -> str:
filters.append(self.build_search_term())
filters.append(self.build_family_filter())
filters.append(self.build_document_filter())
if kf := self.params.keyword_filters:
filters.append(self._inclusive_keyword_filters(kf, "family_geography"))
filters.append(self._inclusive_keyword_filters(kf, "family_category"))
filters.append(self._inclusive_keyword_filters(kf, "document_languages"))
filters.append(self._inclusive_keyword_filters(kf, "family_source"))
if f := self.params.filters:
filters.append(self._inclusive_filters(f, "family_geography"))
filters.append(self._inclusive_filters(f, "family_category"))
filters.append(self._inclusive_filters(f, "document_languages"))
filters.append(self._inclusive_filters(f, "family_source"))
filters.append(self.build_year_start_filter())
filters.append(self.build_year_end_filter())
return " and ".join([f for f in filters if f]) # Remove empty
Expand Down Expand Up @@ -176,9 +174,7 @@ def to_str(self) -> str:
exact_match=False,
limit=10,
max_hits_per_family=10,
keyword_filters=KeywordFilters(
**{"document_languages": "value", "family_source": "value"}
),
filters=Filters(**{"document_languages": "value", "family_source": "value"}),
year_range=(2000, 2020),
continuation_tokens=None,
)
Expand Down
14 changes: 7 additions & 7 deletions tests/test_search_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from pydantic import ValidationError

from cpr_data_access.models.search import (
KeywordFilters,
Filters,
SearchParameters,
sort_orders,
sort_fields,
sort_orders,
)
from cpr_data_access.vespa import build_vespa_request_body
from cpr_data_access.exceptions import QueryError
Expand Down Expand Up @@ -189,16 +189,16 @@ def test_computed_vespa_sort_fields(sort_by, sort_order):
["family_geography", "family_category", "document_languages", "family_source"],
)
def test_whether_valid_filter_fields_are_accepted(field):
keyword_filters = KeywordFilters(**{field: ["value"]})
params = SearchParameters(query_string="test", keyword_filters=keyword_filters)
filters = Filters(**{field: ["value"]})
params = SearchParameters(query_string="test", filters=filters)
assert isinstance(params, SearchParameters)


def test_whether_an_invalid_filter_fields_raises_a_valueerror():
with pytest.raises(ValidationError) as excinfo:
SearchParameters(
query_string="test",
keyword_filters=KeywordFilters(**{"invalid_field": ["value"]}),
filters=Filters(**{"invalid_field": ["value"]}),
)
assert "Extra inputs are not permitted" in str(excinfo.value)

Expand All @@ -222,9 +222,9 @@ def test_whether_an_invalid_filter_fields_value_fixes_it_silently(
):
params = SearchParameters(
query_string="test",
keyword_filters=KeywordFilters(**{"family_source": input_filters}),
filters=Filters(**{"family_source": input_filters}),
)
assert params.keyword_filters.family_source == expected
assert params.filters.family_source == expected


@pytest.mark.parametrize(
Expand Down
12 changes: 6 additions & 6 deletions tests/test_yql_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from vespa.exceptions import VespaError

from cpr_data_access.models.search import (
KeywordFilters,
Filters,
SearchParameters,
sort_fields,
sort_orders,
Expand All @@ -12,20 +12,20 @@


def test_whether_single_filter_values_and_lists_of_filter_values_appear_in_yql():
keyword_filters = {
filters = {
"family_geography": ["SWE"],
"family_category": ["Executive"],
"document_languages": ["English", "Swedish"],
"family_source": ["CCLW"],
}
params = SearchParameters(
query_string="test",
keyword_filters=KeywordFilters(**keyword_filters),
filters=Filters(**filters),
)
yql = YQLBuilder(params).to_str()
assert isinstance(params.keyword_filters, KeywordFilters)
assert isinstance(params.filters, Filters)

for key, values in keyword_filters.items():
for key, values in filters.items():
for value in values:
assert key in yql
assert value in yql
Expand Down Expand Up @@ -139,7 +139,7 @@ def test_yql_builder_build_where_clause():
assert query_string not in where_clause

params = SearchParameters(
query_string="climate", keyword_filters={"family_geography": ["SWE"]}
query_string="climate", filters={"family_geography": ["SWE"]}
)
where_clause = YQLBuilder(params).build_where_clause()
assert "SWE" in where_clause
Expand Down

0 comments on commit ee6718b

Please sign in to comment.