diff --git a/app/api/api_v1/routers/geographies.py b/app/api/api_v1/routers/geographies.py deleted file mode 100644 index 6eff0cda..00000000 --- a/app/api/api_v1/routers/geographies.py +++ /dev/null @@ -1,34 +0,0 @@ -import logging - -from fastapi import APIRouter, Depends, HTTPException, status - -from app.clients.db.session import get_db -from app.errors import RepositoryError -from app.models.geography import GeographyStatsDTO -from app.repository.geography import get_world_map_stats - -_LOGGER = logging.getLogger(__file__) - -geographies_router = APIRouter() - - -@geographies_router.get("/geographies", response_model=list[GeographyStatsDTO]) -async def geographies(db=Depends(get_db)): - """Get a summary of fam stats for all geographies for world map.""" - _LOGGER.info("Getting detailed information on all geographies") - - try: - world_map_stats = get_world_map_stats(db) - - if world_map_stats == []: - _LOGGER.error("No stats for world map found") - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="No stats for world map found", - ) - - return world_map_stats - except RepositoryError as e: - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=e.message - ) diff --git a/app/api/api_v1/routers/world_map.py b/app/api/api_v1/routers/world_map.py new file mode 100644 index 00000000..78e7fa4c --- /dev/null +++ b/app/api/api_v1/routers/world_map.py @@ -0,0 +1,51 @@ +import logging +from typing import Annotated + +from fastapi import APIRouter, Depends, Header, HTTPException, Request, status + +from app.clients.db.session import get_db +from app.errors import RepositoryError, ValidationError +from app.models.geography import GeographyStatsDTO +from app.service.custom_app import AppTokenFactory +from app.service.world_map import get_world_map_stats + +_LOGGER = logging.getLogger(__file__) + +world_map_router = APIRouter() + + +@world_map_router.get("/geographies", response_model=list[GeographyStatsDTO]) +async def world_map_stats( + request: Request, app_token: Annotated[str, Header()], db=Depends(get_db) +): + """Get a summary of family counts for all geographies for world map.""" + _LOGGER.info( + "Getting world map counts for all geographies", + extra={ + "props": {"app_token": str(app_token)}, + }, + ) + + # Decode the app token and validate it. + token = AppTokenFactory() + token.decode_and_validate(db, request, app_token) + + try: + world_map_stats = get_world_map_stats(db, token.allowed_corpora_ids) + + if world_map_stats == []: + _LOGGER.error("No stats for world map found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="No stats for world map found", + ) + + return world_map_stats + except RepositoryError as e: + _LOGGER.error(e) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=e.message + ) + except ValidationError as e: + _LOGGER.error(e) + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.message) diff --git a/app/main.py b/app/main.py index e3e1296f..09438b37 100644 --- a/app/main.py +++ b/app/main.py @@ -16,11 +16,11 @@ from app.api.api_v1.routers.admin import admin_document_router from app.api.api_v1.routers.auth import auth_router from app.api.api_v1.routers.documents import documents_router -from app.api.api_v1.routers.geographies import geographies_router from app.api.api_v1.routers.lookups import lookups_router from app.api.api_v1.routers.pipeline_trigger import pipeline_trigger_router from app.api.api_v1.routers.search import search_router from app.api.api_v1.routers.summaries import summary_router +from app.api.api_v1.routers.world_map import world_map_router from app.clients.db.session import SessionLocal, engine from app.service.auth import get_superuser_details from app.service.health import is_database_online @@ -158,7 +158,7 @@ async def root(): summary_router, prefix="/api/v1", tags=["Summaries"], include_in_schema=False ) app.include_router( - geographies_router, prefix="/api/v1", tags=["Geographies"], include_in_schema=False + world_map_router, prefix="/api/v1", tags=["Geographies"], include_in_schema=False ) # add pagination support to all routes that ask for it diff --git a/app/repository/download.py b/app/repository/download.py index 33592cc0..85dd1c47 100644 --- a/app/repository/download.py +++ b/app/repository/download.py @@ -20,7 +20,7 @@ def get_whole_database_dump( """Get whole database dump and bind variables. :param str ingest_cycle_start: The current ingest cycle date. - :param list[str] corpora_ids: The corpora from which we + :param list[str] allowed_corpora_ids: The corpora from which we should allow the data to be dumped. :return pd.DataFrame: A DataFrame containing the results of the SQL query that gets the whole database dump in our desired format. diff --git a/app/repository/geography.py b/app/repository/geography.py index e4b23a37..5b8c10b6 100644 --- a/app/repository/geography.py +++ b/app/repository/geography.py @@ -1,21 +1,18 @@ """Functions to support the geographies endpoint.""" import logging +import os from typing import Optional, Sequence -from db_client.models.dfce.family import ( - Family, - FamilyDocument, - FamilyGeography, - FamilyStatus, -) +from db_client.models.dfce.family import Family, FamilyDocument, FamilyGeography from db_client.models.dfce.geography import Geography -from sqlalchemy import func -from sqlalchemy.exc import OperationalError +from sqlalchemy import bindparam, text from sqlalchemy.orm import Query, Session +from sqlalchemy.types import ARRAY, String -from app.errors import RepositoryError +from app.errors import ValidationError from app.models.geography import GeographyStatsDTO +from app.repository.helpers import get_query_template _LOGGER = logging.getLogger(__file__) @@ -63,74 +60,32 @@ def get_geo_subquery( return geo_subquery.subquery("geo_subquery") -def _db_count_fams_in_category_and_geo(db: Session) -> Query: - """ - Query the database for the fam count per category per geo. - - NOTE: SqlAlchemy will make a complete hash of query generation if - columns are used in the query() call. Therefore, entire objects are - returned. +def count_families_per_category_in_each_geo( + db: Session, allowed_corpora: list[str] +) -> list[GeographyStatsDTO]: + """Query the database for the family count per category per geo. :param Session db: DB Session to perform query on. - :return Query: A Query object containing the queries to perform. + :param list[str] allowed_corpora: The list of allowed corpora IDs to + filter on. + :return list[GeographyStatsDTO]: A list of counts of families by + category per geography. """ - # Get the required Geography information and cross join each with all of the unique - # family_category values (so if some geographies have no documents for a particular - # family_category, we can set the count for that category to 0). - family_categories = db.query(Family.family_category).distinct().subquery() - geo_family_combinations = db.query( - Geography.id.label("geography_id"), - Geography.display_value, - Geography.slug, - Geography.value, - family_categories.c.family_category, - ).subquery("geo_family_combinations") - - # Get a count of documents in each present family_category for each geography. - counts = ( - db.query( - Family.family_category, - FamilyGeography.geography_id, - func.count().label("records_count"), - ) - .join(FamilyGeography, Family.import_id == FamilyGeography.family_import_id) - .filter(Family.family_status == FamilyStatus.PUBLISHED) - .group_by(Family.family_category, FamilyGeography.geography_id) - .subquery("counts") - ) + if allowed_corpora in [None, []]: + raise ValidationError("No allowed corpora provided") - # Aggregate family_category counts per geography into a JSONB object, and if a - # family_category count is missing, set the count for that category to 0 so each - # geography will always have a count for all family_category values. - query = ( - db.query( - geo_family_combinations.c.display_value.label("display_value"), - geo_family_combinations.c.slug.label("slug"), - geo_family_combinations.c.value.label("value"), - func.jsonb_object_agg( - geo_family_combinations.c.family_category, - func.coalesce(counts.c.records_count, 0), - ).label("counts"), - ) - .select_from( - geo_family_combinations.join( - counts, - (geo_family_combinations.c.geography_id == counts.c.geography_id) - & ( - geo_family_combinations.c.family_category - == counts.c.family_category - ), - isouter=True, - ) - ) - .group_by( - geo_family_combinations.c.display_value, - geo_family_combinations.c.slug, - geo_family_combinations.c.value, - ) - .order_by(geo_family_combinations.c.display_value) + query_template = text( + get_query_template(os.path.join("app", "repository", "sql", "world_map.sql")) ) - return query + query_template = query_template.bindparams( + bindparam("allowed_corpora_ids", value=allowed_corpora, type_=ARRAY(String)), + ) + + family_geo_stats = db.execute( + query_template, {"allowed_corpora_ids": allowed_corpora} + ).all() + results = [_to_dto(fgs) for fgs in family_geo_stats] + return results def _to_dto(family_doc_geo_stats) -> GeographyStatsDTO: @@ -148,23 +103,3 @@ def _to_dto(family_doc_geo_stats) -> GeographyStatsDTO: slug=family_doc_geo_stats.slug, family_counts=family_doc_geo_stats.counts, ) - - -def get_world_map_stats(db: Session) -> list[GeographyStatsDTO]: - """ - Get a count of fam per category per geography for all geographies. - - :param db Session: The database session. - :return list[GeographyStatsDTO]: A list of Geography stats objects - """ - try: - family_geo_stats = _db_count_fams_in_category_and_geo(db).all() - except OperationalError as e: - _LOGGER.error(e) - raise RepositoryError("Error querying the database for geography stats") - - if not family_geo_stats: - return [] - - result = [_to_dto(fgs) for fgs in family_geo_stats] - return result diff --git a/app/repository/sql/download.sql b/app/repository/sql/download.sql index 15bbf8ac..c5ab3ab4 100644 --- a/app/repository/sql/download.sql +++ b/app/repository/sql/download.sql @@ -227,7 +227,7 @@ SELECT n3.event_type_names AS "Full timeline of events (types)", n3.event_dates AS "Full timeline of events (dates)", d.created::DATE AS "Date Added to System", - f.last_modified::DATE AS "Last ModIFied on System", + f.last_modified::DATE AS "Last Modified on System", d.import_id AS "Internal Document ID", f.import_id AS "Internal Family ID", n1.collection_import_ids AS "Internal Collection ID(s)", @@ -237,6 +237,7 @@ SELECT type,0}') AS "Document Type", CASE WHEN f.family_category = 'UNFCCC' THEN 'UNFCCC' + WHEN f.family_category = 'MCF' THEN 'MCF' ELSE INITCAP(f.family_category::TEXT) END AS "Category", ARRAY_TO_STRING( diff --git a/app/repository/sql/world_map.sql b/app/repository/sql/world_map.sql new file mode 100644 index 00000000..42db7a4d --- /dev/null +++ b/app/repository/sql/world_map.sql @@ -0,0 +1,93 @@ +WITH counts AS ( + SELECT + family.family_category, + family_geography.geography_id, + COUNT(*) AS records_count + FROM + family + INNER JOIN + family_corpus + ON family.import_id = family_corpus.family_import_id + INNER JOIN corpus ON family_corpus.corpus_import_id = corpus.import_id + INNER JOIN + family_geography + ON family.import_id = family_geography.family_import_id + WHERE + family_corpus.corpus_import_id = ANY(:allowed_corpora_ids) + AND CASE + WHEN ( + NOT ( + EXISTS ( + SELECT + 1 + FROM + family_document + WHERE + family.import_id = family_document.family_import_id + ) + ) + ) THEN 'Created' + WHEN ( + ( + SELECT + COUNT(family_document.document_status) AS count_1 + FROM + family_document + WHERE + family_document.family_import_id = family.import_id + AND family_document.document_status = 'PUBLISHED' + ) > 0 + ) THEN 'Published' + WHEN ( + ( + SELECT + COUNT(family_document.document_status) AS count_2 + FROM + family_document + WHERE + family_document.family_import_id = family.import_id + AND family_document.document_status = 'CREATED' + ) > 0 + ) THEN 'Created' + ELSE 'Deleted' + END = 'Published' + GROUP BY + family.family_category, + family_geography.geography_id + ) + +SELECT + geo_family_combinations.display_value, + geo_family_combinations.slug, + geo_family_combinations.value, + JSONB_OBJECT_AGG( + geo_family_combinations.family_category, + COALESCE(counts.records_count, 0) + ) AS counts +FROM + ( + SELECT + geography.id AS geography_id, + geography.display_value, + geography.slug, + geography.value, + anon_1.family_category + FROM + geography, + ( + SELECT DISTINCT + family.family_category + FROM + family + ) AS anon_1 + ) AS geo_family_combinations + LEFT OUTER JOIN + counts + ON geo_family_combinations.geography_id = counts.geography_id + AND geo_family_combinations.family_category = counts.family_category +GROUP BY + geo_family_combinations.display_value, + geo_family_combinations.slug, + geo_family_combinations.value +ORDER BY + geo_family_combinations.display_value; diff --git a/app/service/world_map.py b/app/service/world_map.py new file mode 100644 index 00000000..dc8735ce --- /dev/null +++ b/app/service/world_map.py @@ -0,0 +1,38 @@ +"""Functions to support the geographies endpoint.""" + +import logging +from typing import Optional + +from sqlalchemy.exc import OperationalError +from sqlalchemy.orm import Session + +from app.errors import RepositoryError, ValidationError +from app.models.geography import GeographyStatsDTO +from app.repository.geography import count_families_per_category_in_each_geo + +_LOGGER = logging.getLogger(__file__) + + +def get_world_map_stats( + db: Session, allowed_corpora: Optional[list[str]] +) -> list[GeographyStatsDTO]: + """ + Get a count of fam per category per geography for all geographies. + + :param db Session: The database session. + :param Optional[list[str]] allowed_corpora: The list of allowed + corpora IDs to filter on. + :return list[GeographyStatsDTO]: A list of Geography stats objects + """ + if allowed_corpora is None or allowed_corpora == []: + raise ValidationError("No allowed corpora provided") + + try: + family_geo_stats = count_families_per_category_in_each_geo(db, allowed_corpora) + except OperationalError as e: + _LOGGER.error(e) + raise RepositoryError("Error querying the database for geography stats") + + if not family_geo_stats: + return [] + return family_geo_stats diff --git a/pyproject.toml b/pyproject.toml index 849d3b4b..b5593f7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "navigator_backend" -version = "1.19.15" +version = "1.19.16" description = "" authors = ["CPR-dev-team "] packages = [{ include = "app" }, { include = "tests" }] diff --git a/tests/conftest.py b/tests/conftest.py index 6d6fb4cb..a60cfa3f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -117,7 +117,7 @@ def valid_token(monkeypatch): def mock_return(_, __, ___): return True - corpora_ids = "CCLW.corpus.1.0,CCLW.corpus.2.0,CCLW.corpus.i00000001.n0000" + corpora_ids = "CCLW.corpus.1.0,CCLW.corpus.2.0,CCLW.corpus.i00000001.n0000,UNFCCC.corpus.i00000001.n0000" subject = "CCLW" audience = "localhost" input_str = f"{corpora_ids};{subject};{audience}" @@ -142,7 +142,7 @@ def mock_return(_, __, ___): return True corpora_ids = "UNFCCC.corpus.i00000001.n0000" - subject = "CPR" + subject = "CCLW" audience = "localhost" input_str = f"{corpora_ids};{subject};{audience}" diff --git a/tests/non_search/conftest.py b/tests/non_search/conftest.py deleted file mode 100644 index 114d68d9..00000000 --- a/tests/non_search/conftest.py +++ /dev/null @@ -1,45 +0,0 @@ -import pytest -from db_client.models.dfce import Geography -from db_client.models.organisation import Organisation - - -@pytest.fixture -def summary_geography_family_data(test_db): - geos = [ - Geography( - display_value="A place on the land", slug="a-place-on-the-land", value="XXX" - ), - Geography( - display_value="A place in the sea", slug="a-place-in-the-sea", value="YYY" - ), - ] - organisations = [Organisation(name="test org")] - - test_db.add_all(geos) - test_db.add_all(organisations) - test_db.flush() - - # Now setup the Documents/Families - - ## WORKING HERE - documents = [] - families = [] - - test_db.add_all(documents) - test_db.add_all(families) - test_db.flush() - - # Now some events - events = [] - - test_db.add_all(events) - - test_db.commit() - yield { - "db": test_db, - "docs": documents, - "families": documents, - "geos": geos, - "organisations": organisations, - "events": events, - } diff --git a/tests/non_search/routers/geographies/setup_world_map_helpers.py b/tests/non_search/routers/geographies/setup_world_map_helpers.py index 1752d0e2..b46f9f91 100644 --- a/tests/non_search/routers/geographies/setup_world_map_helpers.py +++ b/tests/non_search/routers/geographies/setup_world_map_helpers.py @@ -1,4 +1,7 @@ +from typing import Optional + from db_client.functions.dfce_helpers import add_collections, add_document, add_families +from fastapi import status from sqlalchemy.orm import Session from tests.non_search.setup_helpers import ( @@ -6,6 +9,26 @@ get_default_documents, ) +WORLD_MAP_STATS_ENDPOINT = "/api/v1/geographies" +TEST_HOST = "http://localhost:3000/" + + +def _make_world_map_lookup_request( + client, + token, + expected_status_code: int = status.HTTP_200_OK, + origin: Optional[str] = TEST_HOST, +): + headers = ( + {"app-token": token} + if origin is None + else {"app-token": token, "origin": origin} + ) + + response = client.get(f"{WORLD_MAP_STATS_ENDPOINT}", headers=headers) + assert response.status_code == expected_status_code, response.text + return response.json() + def _add_published_fams_and_docs(db: Session): # Collection @@ -13,6 +36,27 @@ def _add_published_fams_and_docs(db: Session): add_collections(db, collections=[collection1]) # Family + Document + events + document0 = { + "title": "UnfcccDocument0", + "slug": "UnfcccDocSlug0", + "md5_sum": None, + "url": "http://another_somewhere", + "content_type": None, + "import_id": "UNFCCC.executive.3.3", + "language_variant": None, + "status": "PUBLISHED", + "metadata": {"role": ["MAIN"], "type": ["Order"]}, + "languages": [], + "events": [ + { + "import_id": "UNFCCC.Event.3.0", + "title": "Published", + "date": "2019-12-25", + "type": "Passed/Approved", + "status": "OK", + } + ], + } document1, document2 = get_default_documents() document3 = { "title": "Document3", @@ -36,6 +80,20 @@ def _add_published_fams_and_docs(db: Session): ], } + unfccc_fam = { + "import_id": "UNFCCC.family.0000.0", + "corpus_import_id": "UNFCCC.corpus.i00000001.n0000", + "title": "UnfcccFam0", + "slug": "UnfcccFamSlug0", + "description": "Summary0", + "geography_id": [2, 5], + "category": "UNFCCC", + "documents": [], + "metadata": { + "size": "small", + "color": "blue", + }, + } family0 = { "import_id": "CCLW.family.0000.0", "corpus_import_id": "CCLW.corpus.i00000001.n0000", @@ -93,11 +151,12 @@ def _add_published_fams_and_docs(db: Session): }, } + unfccc_fam["documents"] = [document0] family0["documents"] = [] family1["documents"] = [document1] family2["documents"] = [document2] family3["documents"] = [document3] - add_families(db, families=[family0, family1, family2, family3]) + add_families(db, families=[family0, family1, family2, family3, unfccc_fam]) def setup_all_docs_published_world_map(db: Session): diff --git a/tests/non_search/routers/geographies/test_world_map_summary.py b/tests/non_search/routers/geographies/test_world_map_summary.py index 92bea770..3155a94b 100644 --- a/tests/non_search/routers/geographies/test_world_map_summary.py +++ b/tests/non_search/routers/geographies/test_world_map_summary.py @@ -1,9 +1,16 @@ import pytest -from db_client.models.dfce.family import Family, FamilyGeography, FamilyStatus +from db_client.models.dfce.family import ( + Corpus, + Family, + FamilyCorpus, + FamilyGeography, + FamilyStatus, +) from db_client.models.dfce.geography import Geography from fastapi import status from tests.non_search.routers.geographies.setup_world_map_helpers import ( + _make_world_map_lookup_request, setup_all_docs_published_world_map, setup_mixed_doc_statuses_world_map, ) @@ -11,12 +18,8 @@ EXPECTED_NUM_FAM_CATEGORIES = 3 -def _get_expected_keys(): - return ["display_name", "iso_code", "slug", "family_counts"] - - -def _url_under_test() -> str: - return "/api/v1/geographies" +def _test_has_expected_keys(keys: list[str]) -> bool: + return set(["display_name", "iso_code", "slug", "family_counts"]) == set(keys) def _find_geography_index(lst, key, value): @@ -34,8 +37,8 @@ def test_geo_table_populated(data_db): @pytest.mark.parametrize( ("geo_display_value", "expected_exec", "expected_leg", "expected_unfccc"), [ - ("India", 1, 1, 1), - ("Afghanistan", 0, 0, 1), + ("India", 1, 1, 2), + ("Afghanistan", 0, 0, 2), ], ) def test_endpoint_returns_ok_all_docs_per_family_published( @@ -45,20 +48,17 @@ def test_endpoint_returns_ok_all_docs_per_family_published( expected_exec, expected_leg, expected_unfccc, + valid_token, ): - """Check endpoint returns 200 on success""" setup_all_docs_published_world_map(data_db) - response = data_client.get(_url_under_test()) - assert response.status_code == status.HTTP_200_OK - resp_json = response.json() + + resp_json = _make_world_map_lookup_request(data_client, valid_token) assert len(resp_json) > 1 idx = _find_geography_index(resp_json, "display_name", geo_display_value) resp = resp_json[idx] - assert set(["display_name", "iso_code", "slug", "family_counts"]) == set( - resp.keys() - ) + assert _test_has_expected_keys(resp.keys()) family_geos = ( data_db.query(Family) @@ -72,20 +72,20 @@ def test_endpoint_returns_ok_all_docs_per_family_published( assert len(resp["family_counts"]) == EXPECTED_NUM_FAM_CATEGORIES assert sum(resp["family_counts"].values()) == len(family_geos) + assert resp["family_counts"]["EXECUTIVE"] == expected_exec + assert resp["family_counts"]["LEGISLATIVE"] == expected_leg + assert resp["family_counts"]["UNFCCC"] == expected_unfccc assert ( sum(resp["family_counts"].values()) == expected_exec + expected_leg + expected_unfccc ) - assert resp["family_counts"]["EXECUTIVE"] == expected_exec - assert resp["family_counts"]["LEGISLATIVE"] == expected_leg - assert resp["family_counts"]["UNFCCC"] == expected_unfccc @pytest.mark.parametrize( ("geo_display_value", "expected_exec", "expected_leg", "expected_unfccc"), [ - ("India", 1, 1, 2), - ("Afghanistan", 0, 0, 1), + ("India", 1, 1, 3), + ("Afghanistan", 0, 0, 2), ], ) def test_endpoint_returns_ok_some_docs_per_published_family_unpublished( @@ -95,20 +95,18 @@ def test_endpoint_returns_ok_some_docs_per_published_family_unpublished( expected_exec, expected_leg, expected_unfccc, + valid_token, ): """Check endpoint returns 200 & discounts CREATED & DELETED docs""" setup_mixed_doc_statuses_world_map(data_db) - response = data_client.get(_url_under_test()) - assert response.status_code == status.HTTP_200_OK - resp_json = response.json() + + resp_json = _make_world_map_lookup_request(data_client, valid_token) assert len(resp_json) > 1 idx = _find_geography_index(resp_json, "display_name", geo_display_value) resp = resp_json[idx] - assert set(["display_name", "iso_code", "slug", "family_counts"]) == set( - resp.keys() - ) + assert _test_has_expected_keys(resp.keys()) fams = ( data_db.query(Family) @@ -122,27 +120,88 @@ def test_endpoint_returns_ok_some_docs_per_published_family_unpublished( assert len(resp["family_counts"]) == EXPECTED_NUM_FAM_CATEGORIES assert sum(resp["family_counts"].values()) == len(fams) + assert resp["family_counts"]["EXECUTIVE"] == expected_exec + assert resp["family_counts"]["LEGISLATIVE"] == expected_leg + assert resp["family_counts"]["UNFCCC"] == expected_unfccc assert ( sum(resp["family_counts"].values()) == expected_exec + expected_leg + expected_unfccc ) - assert resp["family_counts"]["EXECUTIVE"] == expected_exec - assert resp["family_counts"]["LEGISLATIVE"] == expected_leg - assert resp["family_counts"]["UNFCCC"] == expected_unfccc -def test_endpoint_returns_404_when_not_found(data_client): +def test_endpoint_returns_404_when_not_found(data_client, valid_token): """Test the endpoint returns a 404 when no world map stats found""" - response = data_client.get(_url_under_test()) - assert response.status_code == status.HTTP_404_NOT_FOUND - data = response.json() + data = _make_world_map_lookup_request( + data_client, valid_token, expected_status_code=status.HTTP_404_NOT_FOUND + ) assert data["detail"] == "No stats for world map found" @pytest.mark.skip(reason="Bad repo and rollback mocks need rewriting") -def test_endpoint_returns_503_when_error(data_client): +def test_endpoint_returns_503_when_error(data_client, valid_token): """Test the endpoint returns a 503 on db error""" - response = data_client.get(_url_under_test()) - assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE - data = response.json() + data = _make_world_map_lookup_request( + data_client, + valid_token, + expected_status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + ) assert data["detail"] == "Database error" + + +@pytest.mark.parametrize( + ("geo_display_value", "expected_exec", "expected_leg", "expected_unfccc"), + [ + ("India", 0, 0, 1), + ("Afghanistan", 0, 0, 1), + ], +) +def test_endpoint_returns_different_results_with_alt_token( + data_db, + data_client, + geo_display_value, + expected_exec, + expected_leg, + expected_unfccc, + alternative_token, +): + """Check endpoint returns 200 & only counts UNFCCC docs""" + setup_all_docs_published_world_map(data_db) + + fam = ( + data_db.query(Family, FamilyCorpus.corpus_import_id) + .filter(Family.import_id == "UNFCCC.family.0000.0") + .join(FamilyCorpus, Family.import_id == FamilyCorpus.family_import_id) + .one() + ) + assert fam + + resp_json = _make_world_map_lookup_request(data_client, alternative_token) + assert len(resp_json) > 1 + + idx = _find_geography_index(resp_json, "display_name", geo_display_value) + resp = resp_json[idx] + + assert _test_has_expected_keys(resp.keys()) + + fams = ( + data_db.query(Family.import_id) + .filter(Family.family_status == FamilyStatus.PUBLISHED) + .filter(Geography.display_value == geo_display_value) + .join(FamilyGeography, Family.import_id == FamilyGeography.family_import_id) + .join(FamilyCorpus, Family.import_id == FamilyCorpus.family_import_id) + .join(Corpus, Corpus.import_id == FamilyCorpus.corpus_import_id) + .join(Geography, Geography.id == FamilyGeography.geography_id) + .filter(Corpus.import_id == "UNFCCC.corpus.i00000001.n0000") + .all() + ) + + assert len(resp["family_counts"]) == EXPECTED_NUM_FAM_CATEGORIES + assert sum(resp["family_counts"].values()) == len(fams) + + assert resp["family_counts"]["EXECUTIVE"] == expected_exec + assert resp["family_counts"]["LEGISLATIVE"] == expected_leg + assert resp["family_counts"]["UNFCCC"] == expected_unfccc + assert ( + sum(resp["family_counts"].values()) + == expected_exec + expected_leg + expected_unfccc + )