Skip to content

Commit a24f05e

Browse files
committed
WIP
1 parent 5061ac3 commit a24f05e

File tree

5 files changed

+40
-31
lines changed

5 files changed

+40
-31
lines changed

app/api/api_v1/routers/family.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ async def search_family(request: Request) -> list[FamilyReadDTO]:
9292

9393
query_params = set_default_query_params(query_params)
9494

95-
VALID_PARAMS = ["q", "title", "summary", "geography", "status", "max_results"]
95+
VALID_PARAMS = ["q", "title", "summary", "geographies", "status", "max_results"]
9696
validate_query_params(query_params, VALID_PARAMS)
9797

9898
try:

app/model/family.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import datetime
2-
from typing import Optional, Union
2+
from typing import Optional
33

44
from pydantic import BaseModel
55

@@ -12,7 +12,7 @@ class FamilyReadDTO(BaseModel):
1212
import_id: str
1313
title: str
1414
summary: str
15-
geography: str
15+
geographies: list[str]
1616
category: str
1717
status: str
1818
metadata: Json
@@ -41,7 +41,7 @@ class FamilyWriteDTO(BaseModel):
4141

4242
title: str
4343
summary: str
44-
geography: str
44+
geographies: list[str]
4545
category: str
4646
metadata: Json
4747
collections: list[str]
@@ -60,7 +60,7 @@ class FamilyCreateDTO(BaseModel):
6060
import_id: Optional[str] = None
6161
title: str
6262
summary: str
63-
geography: Union[str, list[str]]
63+
geographies: list[str]
6464
category: str
6565
metadata: Json
6666
collections: list[str]

app/repository/family.py

+12-22
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from db_client.models.organisation.users import Organisation
2323
from sqlalchemy import Column, and_
2424
from sqlalchemy import delete as db_delete
25-
from sqlalchemy import desc, func, or_
25+
from sqlalchemy import desc, or_
2626
from sqlalchemy import update as db_update
2727
from sqlalchemy.exc import NoResultFound, OperationalError
2828
from sqlalchemy.orm import Query, Session
@@ -34,30 +34,24 @@
3434

3535
_LOGGER = logging.getLogger(__name__)
3636

37-
FamilyGeoMetaOrg = Tuple[Family, str, FamilyMetadata, Corpus, Organisation]
37+
FamilyGeoMetaOrg = Tuple[Family, Geography, FamilyMetadata, Corpus, Organisation]
3838

3939

4040
def _get_query(db: Session) -> Query:
4141
# NOTE: SqlAlchemy will make a complete hash of query generation
4242
# if columns are used in the query() call. Therefore, entire
4343
# objects are returned.
44-
geo_subquery = (
45-
db.query(
46-
func.min(Geography.value).label("value"),
47-
FamilyGeography.family_import_id,
48-
)
49-
.join(FamilyGeography, FamilyGeography.geography_id == Geography.id)
50-
.filter(FamilyGeography.family_import_id == Family.import_id)
51-
.group_by(Geography.value, FamilyGeography.family_import_id)
52-
).subquery("geo_subquery")
53-
5444
return (
55-
db.query(Family, geo_subquery.c.value, FamilyMetadata, Corpus, Organisation) # type: ignore
45+
db.query(Family, Geography, FamilyMetadata, Corpus, Organisation) # type: ignore
46+
.join(FamilyGeography, FamilyGeography.family_import_id == Family.import_id)
47+
.join(
48+
Geography,
49+
Geography.id == FamilyGeography.geography_id,
50+
)
5651
.join(FamilyMetadata, FamilyMetadata.family_import_id == Family.import_id)
5752
.join(FamilyCorpus, FamilyCorpus.family_import_id == Family.import_id)
5853
.join(Corpus, Corpus.import_id == FamilyCorpus.corpus_import_id)
5954
.join(Organisation, Corpus.organisation_id == Organisation.id)
60-
.filter(geo_subquery.c.family_import_id == Family.import_id) # type: ignore
6155
)
6256

6357

@@ -72,7 +66,7 @@ def _family_to_dto(
7266
import_id=str(fam.import_id),
7367
title=str(fam.title),
7468
summary=str(fam.description),
75-
geography=geo_value,
69+
geographies=[str(geo_value.display_value)],
7670
category=str(fam.family_category),
7771
status=str(fam.family_status),
7872
metadata=metadata,
@@ -201,13 +195,9 @@ def search(
201195
term = f"%{escape_like(search_params['summary'])}%"
202196
search.append(Family.description.ilike(term))
203197

204-
if "geography" in search_params.keys():
205-
term = cast(str, search_params["geography"])
206-
search.append(
207-
or_(
208-
Geography.display_value == term.title(), Geography.value == term.upper()
209-
)
210-
)
198+
if "geographies" in search_params.keys():
199+
term = cast(str, search_params["geographies"])
200+
search.append(Geography.display_value == term.title())
211201

212202
if "status" in search_params.keys():
213203
term = cast(str, search_params["status"])

tests/integration_tests/family/test_search.py

+19
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,25 @@
77
from tests.integration_tests.setup_db import setup_db
88

99

10+
def test_search_geographies(
11+
client: TestClient, data_db: Session, superuser_header_token
12+
):
13+
setup_db(data_db)
14+
response = client.get(
15+
"/api/v1/families/?geographies=zimbabwe",
16+
headers=superuser_header_token,
17+
)
18+
assert response.status_code == status.HTTP_200_OK
19+
data = response.json()
20+
assert isinstance(data, list)
21+
22+
ids_found = set([f["import_id"] for f in data])
23+
assert len(ids_found) == 1
24+
25+
# expected_ids = set(["A.0.0.2", "A.0.0.3"])
26+
# assert ids_found.symmetric_difference(expected_ids) == set([])
27+
28+
1029
def test_search_family_super(
1130
client: TestClient, data_db: Session, superuser_header_token
1231
):

tests/integration_tests/setup_db.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
"import_id": "A.0.0.1",
3232
"title": "apple",
3333
"summary": "",
34-
"geography": "Other",
34+
"geographies": ["AFG"],
3535
"category": "UNFCCC",
3636
"status": "Created",
3737
"metadata": {
@@ -57,7 +57,7 @@
5757
"import_id": "A.0.0.2",
5858
"title": "apple orange banana",
5959
"summary": "apple",
60-
"geography": "Other",
60+
"geographies": ["ZWE"],
6161
"category": "UNFCCC",
6262
"status": "Created",
6363
"metadata": {
@@ -83,7 +83,7 @@
8383
"import_id": "A.0.0.3",
8484
"title": "title",
8585
"summary": "orange peas",
86-
"geography": "Other",
86+
"geographies": ["AFG"],
8787
"category": "UNFCCC",
8888
"status": "Created",
8989
"metadata": {"author": "CPR", "author_type": "Party"},
@@ -490,7 +490,7 @@ def _setup_family_data(
490490

491491
geo_id = (
492492
test_db.query(Geography.id)
493-
.filter(Geography.value == data["geography"])
493+
.filter(Geography.value == data["geographies"][0])
494494
.scalar()
495495
)
496496

0 commit comments

Comments
 (0)