Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Geography filter on admin service does not work #256

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion app/api/api_v1/routers/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ async def search_family(request: Request) -> list[FamilyReadDTO]:

query_params = set_default_query_params(query_params)

VALID_PARAMS = ["q", "title", "summary", "geography", "status", "max_results"]
VALID_PARAMS = ["q", "title", "summary", "geographies", "status", "max_results"]
validate_query_params(query_params, VALID_PARAMS)

try:
Expand Down
8 changes: 4 additions & 4 deletions app/model/family.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Optional, Union
from typing import Optional

from pydantic import BaseModel

Expand All @@ -12,7 +12,7 @@ class FamilyReadDTO(BaseModel):
import_id: str
title: str
summary: str
geography: str
geographies: list[str]
category: str
status: str
metadata: Json
Expand Down Expand Up @@ -41,7 +41,7 @@ class FamilyWriteDTO(BaseModel):

title: str
summary: str
geography: str
geographies: list[str]
category: str
metadata: Json
collections: list[str]
Expand All @@ -60,7 +60,7 @@ class FamilyCreateDTO(BaseModel):
import_id: Optional[str] = None
title: str
summary: str
geography: Union[str, list[str]]
geographies: list[str]
category: str
metadata: Json
collections: list[str]
Expand Down
106 changes: 65 additions & 41 deletions app/repository/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from db_client.models.organisation.users import Organisation
from sqlalchemy import Column, and_
from sqlalchemy import delete as db_delete
from sqlalchemy import desc, func, or_
from sqlalchemy import desc, or_
from sqlalchemy import update as db_update
from sqlalchemy.exc import NoResultFound, OperationalError
from sqlalchemy.orm import Query, Session
Expand All @@ -34,30 +34,24 @@

_LOGGER = logging.getLogger(__name__)

FamilyGeoMetaOrg = Tuple[Family, str, FamilyMetadata, Corpus, Organisation]
FamilyGeoMetaOrg = Tuple[Family, Geography, FamilyMetadata, Corpus, Organisation]


def _get_query(db: Session) -> Query:
# NOTE: SqlAlchemy will make a complete hash of query generation
# if columns are used in the query() call. Therefore, entire
# objects are returned.
geo_subquery = (
db.query(
func.min(Geography.value).label("value"),
FamilyGeography.family_import_id,
)
.join(FamilyGeography, FamilyGeography.geography_id == Geography.id)
.filter(FamilyGeography.family_import_id == Family.import_id)
.group_by(Geography.value, FamilyGeography.family_import_id)
).subquery("geo_subquery")

return (
db.query(Family, geo_subquery.c.value, FamilyMetadata, Corpus, Organisation) # type: ignore
db.query(Family, Geography, FamilyMetadata, Corpus, Organisation) # type: ignore
.join(FamilyGeography, FamilyGeography.family_import_id == Family.import_id)
.join(
Geography,
Geography.id == FamilyGeography.geography_id,
)
.join(FamilyMetadata, FamilyMetadata.family_import_id == Family.import_id)
.join(FamilyCorpus, FamilyCorpus.family_import_id == Family.import_id)
.join(Corpus, Corpus.import_id == FamilyCorpus.corpus_import_id)
.join(Organisation, Corpus.organisation_id == Organisation.id)
.filter(geo_subquery.c.family_import_id == Family.import_id) # type: ignore
)


Expand All @@ -72,7 +66,7 @@ def _family_to_dto(
import_id=str(fam.import_id),
title=str(fam.title),
summary=str(fam.description),
geography=geo_value,
geographies=[str(geo_value.display_value)],
category=str(fam.family_category),
status=str(fam.family_status),
metadata=metadata,
Expand Down Expand Up @@ -100,7 +94,6 @@ def _update_intention(
db: Session,
import_id: str,
family: FamilyWriteDTO,
geo_id: int,
original_family: Family,
):
original_collections = [
Expand All @@ -111,17 +104,17 @@ def _update_intention(
]
update_collections = set(original_collections) != set(family.collections)
update_title = cast(str, original_family.title) != family.title
# TODO: PDCT-1406: Properly implement multi-geography support
update_geo = (
db.query(FamilyGeography)
.filter(FamilyGeography.family_import_id == import_id)
.one()
.geography_id
!= geo_id
)
original_geographies = [
geography.collection_import_id
for geography in db.query(Geography).filter(
original_family.import_id == Geography.family_import_id
)
]
update_geographies = set(original_geographies) != set(family.geographies)

update_basics = (
update_title
or update_geo
or update_geographies
or original_family.description != family.summary
or original_family.family_category != family.category
)
Expand All @@ -131,7 +124,13 @@ def _update_intention(
.one()
)
update_metadata = existing_metadata.value != family.metadata
return update_title, update_basics, update_metadata, update_collections
return (
update_title,
update_basics,
update_metadata,
update_collections,
update_geographies,
)


def all(db: Session, org_id: Optional[int]) -> list[FamilyReadDTO]:
Expand Down Expand Up @@ -201,13 +200,9 @@ def search(
term = f"%{escape_like(search_params['summary'])}%"
search.append(Family.description.ilike(term))

if "geography" in search_params.keys():
term = cast(str, search_params["geography"])
search.append(
or_(
Geography.display_value == term.title(), Geography.value == term.upper()
)
)
if "geographies" in search_params.keys():
term = cast(str, search_params["geographies"])
search.append(Geography.display_value == term.title())

if "status" in search_params.keys():
term = cast(str, search_params["status"])
Expand All @@ -231,7 +226,7 @@ def search(
return [_family_to_dto(db, f) for f in found]


def update(db: Session, import_id: str, family: FamilyWriteDTO, geo_id: int) -> bool:
def update(db: Session, import_id: str, family: FamilyWriteDTO) -> bool:
"""
Updates a single entry with the new values passed.

Expand All @@ -257,7 +252,8 @@ def update(db: Session, import_id: str, family: FamilyWriteDTO, geo_id: int) ->
update_basics,
update_metadata,
update_collections,
) = _update_intention(db, import_id, family, geo_id, original_family)
update_geographies,
) = _update_intention(db, import_id, family, original_family)

# Return if nothing to do
if not (update_title or update_basics or update_metadata or update_collections):
Expand All @@ -276,12 +272,6 @@ def update(db: Session, import_id: str, family: FamilyWriteDTO, geo_id: int) ->
)
)
updates = result.rowcount # type: ignore
# TODO: PDCT-1406: Properly implement multi-geography support
result = db.execute(
db_update(FamilyGeography)
.where(FamilyGeography.family_import_id == import_id)
.values(geography_id=geo_id)
)

updates += result.rowcount # type: ignore
if updates == 0: # type: ignore
Expand Down Expand Up @@ -316,6 +306,40 @@ def update(db: Session, import_id: str, family: FamilyWriteDTO, geo_id: int) ->
db.add(new_slug)
_LOGGER.info(f"Added a new slug for {import_id} of {new_slug.name}")

# update geographies if geographies changed
if update_geographies:
original_geographies = set(
[
geography.geography_id
for geography in db.query(FamilyGeography).filter(
original_family.import_id == FamilyGeography.family_import_id
)
]
)

# Remove any collections that were originally associated with the family but
# now aren't.
geographies_to_remove = set(original_geographies) - set(family.geographies)
for geography in geographies_to_remove:
result = db.execute(
db_delete(Geography).where(FamilyGeography.geography_id == geography)
)

if result.rowcount == 0: # type: ignore
msg = f"Could not remove family {import_id} from collection {geography}"
_LOGGER.error(msg)
raise RepositoryError(msg)

# Add any collections that weren't originally associated with the family.
geographies_to_add = set(family.geographies) - set(original_geographies)
for geography in geographies_to_add:
db.flush()
new_geography = FamilyGeography(
family_import_id=import_id,
geography_id=geography,
)
db.add(new_geography)

# Update collections if collections changed.
if update_collections:
original_collections = set(
Expand Down
4 changes: 4 additions & 0 deletions app/repository/geography.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@

def get_id_from_value(db: Session, geo_string: str) -> Optional[int]:
return db.query(Geography.id).filter_by(value=geo_string).scalar()


def get_ids_from_values(db: Session, geo_strings: list[str]) -> Optional[list[int]]:
return db.query(Geography.id).filter(Geography.value.in_(geo_strings)).scalar()
4 changes: 1 addition & 3 deletions app/repository/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ def search(
...

@staticmethod
def update(
db: Session, import_id: str, family: FamilyWriteDTO, geo_id: int
) -> bool:
def update(db: Session, import_id: str, family: FamilyWriteDTO) -> bool:
"""Updates a family"""
...

Expand Down
12 changes: 6 additions & 6 deletions app/service/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def update(
db = db_session.get_db()

# Validate geography
geo_id = geography.get_id(db, family_dto.geography)
# geo_ids = geography_repo.get_ids_from_values(db, family_dto.geographies)

# Validate family belongs to same org as current user.
entity_org_id: int = corpus.get_corpus_org_id(family.corpus_import_id, db)
Expand All @@ -153,7 +153,7 @@ def update(
raise ValidationError(msg)

try:
if family_repo.update(db, import_id, family_dto, geo_id):
if family_repo.update(db, import_id, family_dto):
db.commit()
else:
db.rollback()
Expand Down Expand Up @@ -185,10 +185,10 @@ def create(

# Validate geographies
geo_ids = []
if isinstance(family.geography, str):
geo_ids.append(geography.get_id(db, family.geography))
elif isinstance(family.geography, list):
for geo_id in family.geography:
if isinstance(family.geographies, str):
geo_ids.append(geography.get_id(db, family.geographies))
elif isinstance(family.geographies, list):
for geo_id in family.geographies:
geo_ids.append(geography.get_id(db, geo_id))

# Validate category
Expand Down
2 changes: 1 addition & 1 deletion app/service/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def save_families(
**fam, corpus_import_id=corpus_import_id
).to_family_create_dto(corpus_import_id)
geo_ids = []
for geo in dto.geography:
for geo in dto.geographies:
geo_ids.append(geography.get_id(db, geo))
import_id = family_repository.create(db, dto, geo_ids, org_id)
family_import_ids.append(import_id)
Expand Down
19 changes: 19 additions & 0 deletions tests/integration_tests/family/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,25 @@
from tests.integration_tests.setup_db import setup_db


def test_search_geographies(
client: TestClient, data_db: Session, superuser_header_token
):
setup_db(data_db)
response = client.get(
"/api/v1/families/?geographies=zimbabwe",
headers=superuser_header_token,
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert isinstance(data, list)

ids_found = set([f["import_id"] for f in data])
assert len(ids_found) == 1

# expected_ids = set(["A.0.0.2", "A.0.0.3"])
# assert ids_found.symmetric_difference(expected_ids) == set([])


def test_search_family_super(
client: TestClient, data_db: Session, superuser_header_token
):
Expand Down
8 changes: 4 additions & 4 deletions tests/integration_tests/setup_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"import_id": "A.0.0.1",
"title": "apple",
"summary": "",
"geography": "Other",
"geographies": ["AFG"],
"category": "UNFCCC",
"status": "Created",
"metadata": {
Expand All @@ -57,7 +57,7 @@
"import_id": "A.0.0.2",
"title": "apple orange banana",
"summary": "apple",
"geography": "Other",
"geographies": ["ZWE"],
"category": "UNFCCC",
"status": "Created",
"metadata": {
Expand All @@ -83,7 +83,7 @@
"import_id": "A.0.0.3",
"title": "title",
"summary": "orange peas",
"geography": "Other",
"geographies": ["AFG"],
"category": "UNFCCC",
"status": "Created",
"metadata": {"author": "CPR", "author_type": "Party"},
Expand Down Expand Up @@ -490,7 +490,7 @@ def _setup_family_data(

geo_id = (
test_db.query(Geography.id)
.filter(Geography.value == data["geography"])
.filter(Geography.value == data["geographies"][0])
.scalar()
)

Expand Down
Loading