From 6c9902650d906bb6e7033265c930f06e8e5e15a0 Mon Sep 17 00:00:00 2001 From: Katy Baulch <46493669+katybaulch@users.noreply.github.com> Date: Mon, 3 Jun 2024 10:41:38 +0100 Subject: [PATCH] Pdct 1110 only show entity count per org (#138) * Update service count unit tests to only count entities in user org * Only show count of entities in user org unless superuser * Move is_superuser out of repo layer * Fix ordering of function calls to fix tests * Only call restrict_entities_to_user_org once & call repo directly * Remove incorrect param from docstring --- app/api/api_v1/routers/analytics.py | 15 ++-- app/api/api_v1/routers/family.py | 23 ++---- app/repository/collection.py | 8 ++- app/repository/document.py | 9 ++- app/repository/event.py | 9 ++- app/repository/family.py | 12 ++-- app/repository/protocols.py | 5 +- app/service/analytics.py | 37 +++++----- app/service/app_user.py | 10 +++ app/service/collection.py | 15 ---- app/service/document.py | 15 ---- app/service/event.py | 15 ---- app/service/family.py | 20 +----- tests/integration_tests/analytics/test_get.py | 50 ++++++++++++- tests/mocks/repos/app_user_repo.py | 20 ++++-- tests/mocks/repos/bad_collection_repo.py | 5 +- tests/mocks/repos/bad_document_repo.py | 5 +- tests/mocks/repos/bad_event_repo.py | 5 +- tests/mocks/repos/bad_family_repo.py | 5 +- tests/mocks/repos/collection_repo.py | 12 ++-- tests/mocks/repos/config_repo.py | 1 - tests/mocks/repos/document_repo.py | 12 ++-- tests/mocks/repos/event_repo.py | 10 ++- tests/mocks/repos/family_repo.py | 13 ++-- tests/mocks/services/analytics_service.py | 2 +- tests/mocks/services/app_user_service.py | 15 ++++ tests/mocks/services/collection_service.py | 9 --- tests/mocks/services/config_service.py | 1 - tests/mocks/services/document_service.py | 9 --- tests/mocks/services/event_service.py | 9 --- tests/mocks/services/family_service.py | 9 --- .../analytics/test_analytics_service.py | 72 ++++++++++++++++--- .../test_count_collection_service.py | 29 -------- .../document/test_count_document_service.py | 29 -------- .../service/event/test_count_event_service.py | 29 -------- .../family/test_count_family_service.py | 38 ---------- .../service/family/test_get_family_service.py | 4 -- 37 files changed, 262 insertions(+), 324 deletions(-) delete mode 100644 tests/unit_tests/service/collection/test_count_collection_service.py delete mode 100644 tests/unit_tests/service/document/test_count_document_service.py delete mode 100644 tests/unit_tests/service/event/test_count_event_service.py delete mode 100644 tests/unit_tests/service/family/test_count_family_service.py diff --git a/app/api/api_v1/routers/analytics.py b/app/api/api_v1/routers/analytics.py index fa24a7a8..f734fb0b 100644 --- a/app/api/api_v1/routers/analytics.py +++ b/app/api/api_v1/routers/analytics.py @@ -2,7 +2,7 @@ import logging -from fastapi import APIRouter, HTTPException, status +from fastapi import APIRouter, HTTPException, Request, status import app.service.analytics as analytics_service from app.errors import RepositoryError @@ -13,11 +13,8 @@ _LOGGER = logging.getLogger(__name__) -@r.get( - "/analytics/summary", - response_model=SummaryDTO, -) -async def get_analytics_summary() -> SummaryDTO: +@r.get("/analytics/summary", response_model=SummaryDTO) +async def get_analytics_summary(request: Request) -> SummaryDTO: """ Returns an analytics summary. @@ -25,16 +22,18 @@ async def get_analytics_summary() -> SummaryDTO: data in key (str): value (int) form. """ try: - summary_dto = analytics_service.summary() + summary_dto = analytics_service.summary(request.state.user.email) except RepositoryError as e: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=e.message ) if any(summary_value is None for _, summary_value in summary_dto): + msg = "Analytics summary not found" + _LOGGER.error(msg) raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail="Analytics summary not found", + detail=msg, ) return summary_dto diff --git a/app/api/api_v1/routers/family.py b/app/api/api_v1/routers/family.py index 2644c8bc..8606086a 100644 --- a/app/api/api_v1/routers/family.py +++ b/app/api/api_v1/routers/family.py @@ -29,9 +29,7 @@ "/families/{import_id}", response_model=FamilyReadDTO, ) -async def get_family( - import_id: str, -) -> FamilyReadDTO: +async def get_family(import_id: str) -> FamilyReadDTO: """ Returns a specific family given the import id. @@ -57,10 +55,7 @@ async def get_family( return family -@r.get( - "/families", - response_model=list[FamilyReadDTO], -) +@r.get("/families", response_model=list[FamilyReadDTO]) async def get_all_families(request: Request) -> list[FamilyReadDTO]: """ Returns all families @@ -75,10 +70,7 @@ async def get_all_families(request: Request) -> list[FamilyReadDTO]: ) -@r.get( - "/families/", - response_model=list[FamilyReadDTO], -) +@r.get("/families/", response_model=list[FamilyReadDTO]) async def search_family(request: Request) -> list[FamilyReadDTO]: """ Searches for families matching URL parameters ("q" by default). @@ -122,10 +114,7 @@ async def search_family(request: Request) -> list[FamilyReadDTO]: return families -@r.put( - "/families/{import_id}", - response_model=FamilyReadDTO, -) +@r.put("/families/{import_id}", response_model=FamilyReadDTO) async def update_family( request: Request, import_id: str, @@ -179,9 +168,7 @@ async def create_family( return family -@r.delete( - "/families/{import_id}", -) +@r.delete("/families/{import_id}") async def delete_family( import_id: str, ) -> None: diff --git a/app/repository/collection.py b/app/repository/collection.py index 65817bf7..57acfa15 100644 --- a/app/repository/collection.py +++ b/app/repository/collection.py @@ -259,15 +259,19 @@ def delete(db: Session, import_id: str) -> bool: return result.rowcount > 0 # type: ignore -def count(db: Session) -> Optional[int]: +def count(db: Session, org_id: Optional[int]) -> Optional[int]: """ Counts the number of collections in the repository. :param db Session: the database connection + :param org_id Optional[int]: the ID of the organisation the user belongs to :return Optional[int]: The number of collections in the repository or none. """ try: - n_collections = _get_query(db).count() + if org_id is None: + n_collections = _get_query(db).count() + else: + n_collections = _get_query(db).filter(Organisation.id == org_id).count() except Exception as e: _LOGGER.error(e) return diff --git a/app/repository/document.py b/app/repository/document.py index 92e1453c..7cfd9e77 100644 --- a/app/repository/document.py +++ b/app/repository/document.py @@ -9,6 +9,7 @@ PhysicalDocument, PhysicalDocumentLanguage, ) +from db_client.models.organisation import Organisation from db_client.models.organisation.counters import CountedEntity from sqlalchemy import Column, and_ from sqlalchemy import delete as db_delete @@ -434,15 +435,19 @@ def delete(db: Session, import_id: str) -> bool: return True -def count(db: Session) -> Optional[int]: +def count(db: Session, org_id: Optional[int]) -> Optional[int]: """ Counts the number of documents in the repository. :param db Session: the database connection + :param org_id Optional[int]: the ID of the organisation the user belongs to :return Optional[int]: The number of documents in the repository or none. """ try: - n_documents = db.query(FamilyDocument).count() + if org_id is None: + n_documents = _get_query(db).count() + else: + n_documents = _get_query(db).filter(Organisation.id == org_id).count() except NoResultFound as e: _LOGGER.error(e) return diff --git a/app/repository/event.py b/app/repository/event.py index 7b707b50..d26c90c1 100644 --- a/app/repository/event.py +++ b/app/repository/event.py @@ -5,6 +5,7 @@ from typing import Optional, Tuple, Union, cast from db_client.models.dfce import EventStatus, Family, FamilyDocument, FamilyEvent +from db_client.models.organisation import Organisation from db_client.models.organisation.counters import CountedEntity from sqlalchemy import Column, and_ from sqlalchemy import delete as db_delete @@ -241,16 +242,20 @@ def delete(db: Session, import_id: str) -> bool: return True -def count(db: Session) -> Optional[int]: +def count(db: Session, org_id: Optional[int]) -> Optional[int]: """ Counts the number of family events in the repository. :param db Session: The database connection. + :param org_id Optional[int]: the ID of the organisation the user belongs to :return Optional[int]: The number of family events in the repository or nothing. """ try: - n_events = _get_query(db).count() + if org_id is None: + n_events = _get_query(db).count() + else: + n_events = _get_query(db).filter(Organisation.id == org_id).count() except NoResultFound as e: _LOGGER.error(e) return diff --git a/app/repository/family.py b/app/repository/family.py index 4cc03dc2..05a81687 100644 --- a/app/repository/family.py +++ b/app/repository/family.py @@ -116,14 +116,14 @@ def _update_intention( return update_title, update_basics, update_metadata, update_collections -def all(db: Session, org_id: int, is_superuser: bool) -> list[FamilyReadDTO]: +def all(db: Session, org_id: Optional[int]) -> list[FamilyReadDTO]: """ Returns all the families. :param db Session: the database connection :return Optional[FamilyResponse]: All of things """ - if is_superuser: + if org_id is None: family_geo_metas = _get_query(db).order_by(desc(Family.last_modified)).all() else: family_geo_metas = ( @@ -499,15 +499,19 @@ def get_organisation(db: Session, family_import_id: str) -> Optional[Organisatio ) -def count(db: Session) -> Optional[int]: +def count(db: Session, org_id: Optional[int]) -> Optional[int]: """ Counts the number of families in the repository. :param db Session: the database connection + :param org_id int: the ID of the organisation the user belongs to :return Optional[int]: The number of families in the repository or none. """ try: - n_families = _get_query(db).count() + if org_id is None: + n_families = _get_query(db).count() + else: + n_families = _get_query(db).filter(Organisation.id == org_id).count() except NoResultFound as e: _LOGGER.error(e) return diff --git a/app/repository/protocols.py b/app/repository/protocols.py index 904beb6e..f6122107 100644 --- a/app/repository/protocols.py +++ b/app/repository/protocols.py @@ -12,9 +12,10 @@ class FamilyRepo(Protocol): return_empty: bool = False throw_repository_error: bool = False throw_timeout_error: bool = False + is_superuser: bool = False @staticmethod - def all(db: Session, org_id: int, is_superuser: bool) -> list[FamilyReadDTO]: + def all(db: Session, org_id: Optional[int]) -> list[FamilyReadDTO]: """Returns all the families""" ... @@ -48,6 +49,6 @@ def delete(db: Session, import_id: str) -> bool: ... @staticmethod - def count(db: Session) -> Optional[int]: + def count(db: Session, org_id: Optional[int]) -> Optional[int]: """Counts all the families""" ... diff --git a/app/service/analytics.py b/app/service/analytics.py index 7cc1c586..f9e7c1f4 100644 --- a/app/service/analytics.py +++ b/app/service/analytics.py @@ -10,10 +10,12 @@ from pydantic import ConfigDict, validate_call from sqlalchemy import exc -import app.service.collection as collection_service -import app.service.document as document_service -import app.service.event as event_service -import app.service.family as family_service +import app.clients.db.session as db_session +import app.repository.collection as collection_repo +import app.repository.document as document_repo +import app.repository.event as event_repo +import app.repository.family as family_repo +import app.service.app_user as app_user_service from app.errors import RepositoryError from app.model.analytics import SummaryDTO @@ -21,24 +23,27 @@ @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) -def summary() -> SummaryDTO: +def summary(user_email: str) -> SummaryDTO: """ Gets an analytics summary from the repository. + :param user_email str: The email address of the current user. :return SummaryDTO: The analytics summary found. """ try: - n_collections = collection_service.count() - n_families = family_service.count() - n_documents = document_service.count() - n_events = event_service.count() - - return SummaryDTO( - n_documents=n_documents, - n_families=n_families, - n_collections=n_collections, - n_events=n_events, - ) + with db_session.get_db() as db: + org_id = app_user_service.restrict_entities_to_user_org(db, user_email) + n_collections = collection_repo.count(db, org_id) + n_families = family_repo.count(db, org_id) + n_documents = document_repo.count(db, org_id) + n_events = event_repo.count(db, org_id) + + return SummaryDTO( + n_documents=n_documents, + n_families=n_families, + n_collections=n_collections, + n_events=n_events, + ) except exc.SQLAlchemyError as e: _LOGGER.error(e) diff --git a/app/service/app_user.py b/app/service/app_user.py index 0252e176..e5607278 100644 --- a/app/service/app_user.py +++ b/app/service/app_user.py @@ -1,3 +1,5 @@ +from typing import Optional + from sqlalchemy.orm import Session from app.errors import ValidationError @@ -15,3 +17,11 @@ def get_organisation(db: Session, user_email: str) -> int: def is_superuser(db: Session, user_email: str) -> bool: """Determine a user's superuser status""" return app_user_repo.is_superuser(db, user_email) + + +def restrict_entities_to_user_org(db: Session, user_email: str) -> Optional[int]: + org_id = get_organisation(db, user_email) + superuser: bool = is_superuser(db, user_email) + if superuser: + return None + return org_id diff --git a/app/service/collection.py b/app/service/collection.py index 8c82b006..6c38d047 100644 --- a/app/service/collection.py +++ b/app/service/collection.py @@ -184,21 +184,6 @@ def delete(import_id: str, context=None, db: Session = db_session.get_db()) -> b return collection_repo.delete(db, import_id) -@validate_call(config=ConfigDict(arbitrary_types_allowed=True)) -def count() -> Optional[int]: - """ - Gets a count of collections from the repository. - - :return Optional[int]: The number of collections in the repository or none. - """ - try: - with db_session.get_db() as db: - return collection_repo.count(db) - except exc.SQLAlchemyError as e: - _LOGGER.error(e) - raise RepositoryError(str(e)) - - def get_org_from_id(db: Session, collection_import_id: str) -> Optional[int]: org = collection_repo.get_org_from_collection_id(db, collection_import_id) if org is None: diff --git a/app/service/document.py b/app/service/document.py index c5528db9..cf4151d6 100644 --- a/app/service/document.py +++ b/app/service/document.py @@ -155,18 +155,3 @@ def delete(import_id: str, context=None, db: Session = db_session.get_db()) -> b if context is not None: context.error = f"Could not delete document {import_id}" return document_repo.delete(db, import_id) - - -@validate_call(config=ConfigDict(arbitrary_types_allowed=True)) -def count() -> Optional[int]: - """ - Gets a count of documents from the repository. - - :return Optional[int]: The number of documents in the repository or none. - """ - try: - with db_session.get_db() as db: - return document_repo.count(db) - except exc.SQLAlchemyError as e: - _LOGGER.error(e) - raise RepositoryError(str(e)) diff --git a/app/service/event.py b/app/service/event.py index e6494309..3e72668c 100644 --- a/app/service/event.py +++ b/app/service/event.py @@ -141,18 +141,3 @@ def delete(import_id: str, context=None, db: Session = db_session.get_db()) -> b if context is not None: context.error = f"Could not delete event {import_id}" return event_repo.delete(db, import_id) - - -@validate_call(config=ConfigDict(arbitrary_types_allowed=True)) -def count() -> Optional[int]: - """ - Gets a count of events from the repository. - - :return Optional[int]: A count of events in the repository or none. - """ - try: - with db_session.get_db() as db: - return event_repo.count(db) - except exc.SQLAlchemyError as e: - _LOGGER.error(e) - raise RepositoryError(str(e)) diff --git a/app/service/family.py b/app/service/family.py index e36480dd..29d6af9d 100644 --- a/app/service/family.py +++ b/app/service/family.py @@ -56,9 +56,8 @@ def all(user_email: str) -> list[FamilyReadDTO]: :return list[FamilyDTO]: The list of families. """ with db_session.get_db() as db: - org_id = app_user.get_organisation(db, user_email) - is_superuser: bool = app_user.is_superuser(db, user_email) - return family_repo.all(db, org_id, is_superuser) + org_id = app_user.restrict_entities_to_user_org(db, user_email) + return family_repo.all(db, org_id) @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) @@ -235,18 +234,3 @@ def delete( return None return family_repo.delete(db, import_id) - - -@validate_call(config=ConfigDict(arbitrary_types_allowed=True)) -def count() -> Optional[int]: - """ - Gets a count of families from the repository. - - :return Optional[int]: The number of families in the repository or none. - """ - try: - with db_session.get_db() as db: - return family_repo.count(db) - except exc.SQLAlchemyError as e: - _LOGGER.error(e) - raise RepositoryError(str(e)) diff --git a/tests/integration_tests/analytics/test_get.py b/tests/integration_tests/analytics/test_get.py index 9811b81c..09895a9b 100644 --- a/tests/integration_tests/analytics/test_get.py +++ b/tests/integration_tests/analytics/test_get.py @@ -31,11 +31,13 @@ def test_get_all_analytics_when_not_authenticated(client: TestClient, data_db: S # --- GET SUMMARY -def test_get_analytics_summary(client: TestClient, data_db: Session, user_header_token): +def test_get_analytics_summary_super( + client: TestClient, data_db: Session, superuser_header_token +): setup_db(data_db) response = client.get( "/api/v1/analytics/summary", - headers=user_header_token, + headers=superuser_header_token, ) assert response.status_code == status.HTTP_200_OK @@ -48,6 +50,50 @@ def test_get_analytics_summary(client: TestClient, data_db: Session, user_header ) +def test_get_analytics_summary_cclw( + client: TestClient, data_db: Session, user_header_token +): + setup_db(data_db) + response = client.get( + "/api/v1/analytics/summary", + headers=user_header_token, + ) + assert response.status_code == status.HTTP_200_OK + + data = response.json() + assert isinstance(data, dict) + assert list(data.keys()) == EXPECTED_ANALYTICS_SUMMARY_KEYS + + assert dict(sorted(data.items())) == { + "n_collections": 2, + "n_documents": 2, + "n_events": 3, + "n_families": 2, + } + + +def test_get_analytics_summary_unfccc( + client: TestClient, data_db: Session, non_cclw_user_header_token +): + setup_db(data_db) + response = client.get( + "/api/v1/analytics/summary", + headers=non_cclw_user_header_token, + ) + assert response.status_code == status.HTTP_200_OK + + data = response.json() + assert isinstance(data, dict) + assert list(data.keys()) == EXPECTED_ANALYTICS_SUMMARY_KEYS + + assert dict(sorted(data.items())) == { + "n_collections": 0, + "n_documents": 2, + "n_events": 3, + "n_families": 1, + } + + def test_get_analytics_summary_when_not_authenticated( client: TestClient, data_db: Session ): diff --git a/tests/mocks/repos/app_user_repo.py b/tests/mocks/repos/app_user_repo.py index c0c92e9e..f14493aa 100644 --- a/tests/mocks/repos/app_user_repo.py +++ b/tests/mocks/repos/app_user_repo.py @@ -10,13 +10,14 @@ HASH_PASSWORD = auth_service.get_password_hash(PLAIN_PASSWORD) VALID_USERNAME = "bob@here.com" -ORG_ID = 1234 +INVALID_ORG_ID = 1234 def mock_app_user_repo(app_user_repo, monkeypatch: MonkeyPatch, mocker): app_user_repo.user_active = True app_user_repo.error = False app_user_repo.invalid_org = False + app_user_repo.is_superuser = False def mock_get_app_user_authorisation( _, __ @@ -34,19 +35,28 @@ def mock_get_user_by_email(_, __) -> MaybeAppUser: def mock_get_org_id(_, user_email: str) -> int: if app_user_repo.invalid_org is True: - return ORG_ID + return INVALID_ORG_ID return 1 def mock_is_active(_, email: str) -> bool: return app_user_repo.user_active + def mock_is_superuser(_, email: str) -> bool: + return app_user_repo.is_superuser + monkeypatch.setattr(app_user_repo, "get_user_by_email", mock_get_user_by_email) + mocker.spy(app_user_repo, "get_user_by_email") + monkeypatch.setattr(app_user_repo, "get_org_id", mock_get_org_id) + mocker.spy(app_user_repo, "get_org_id") + monkeypatch.setattr(app_user_repo, "is_active", mock_is_active) + mocker.spy(app_user_repo, "is_active") + monkeypatch.setattr( app_user_repo, "get_app_user_authorisation", mock_get_app_user_authorisation ) - mocker.spy(app_user_repo, "get_user_by_email") - mocker.spy(app_user_repo, "get_org_id") - mocker.spy(app_user_repo, "is_active") mocker.spy(app_user_repo, "get_app_user_authorisation") + + monkeypatch.setattr(app_user_repo, "is_superuser", mock_is_superuser) + mocker.spy(app_user_repo, "is_superuser") diff --git a/tests/mocks/repos/bad_collection_repo.py b/tests/mocks/repos/bad_collection_repo.py index 5972b4c7..6ecacd89 100644 --- a/tests/mocks/repos/bad_collection_repo.py +++ b/tests/mocks/repos/bad_collection_repo.py @@ -27,7 +27,7 @@ def mock_create(_, data: CollectionReadDTO, __) -> Optional[CollectionReadDTO]: def mock_delete(_, import_id: str) -> bool: raise RepositoryError("Bad Repo") - def mock_get_count(_) -> Optional[int]: + def mock_get_count(_, org_id: Optional[int]) -> Optional[int]: raise RepositoryError("Bad Repo") monkeypatch.setattr(repo, "get", mock_get) @@ -53,7 +53,8 @@ def mock_get_count(_) -> Optional[int]: def mock_collection_count_none(repo, monkeypatch: MonkeyPatch, mocker): - def mock_get_count(_) -> Optional[int]: + + def mock_get_count(_, org_id: Optional[int]) -> Optional[int]: return None monkeypatch.setattr(repo, "count", mock_get_count) diff --git a/tests/mocks/repos/bad_document_repo.py b/tests/mocks/repos/bad_document_repo.py index c4519eed..27c4b458 100644 --- a/tests/mocks/repos/bad_document_repo.py +++ b/tests/mocks/repos/bad_document_repo.py @@ -25,7 +25,7 @@ def mock_create(_, data: DocumentCreateDTO) -> Optional[DocumentReadDTO]: def mock_delete(_, import_id: str) -> bool: raise RepositoryError("Bad Repo") - def mock_get_count(_) -> Optional[int]: + def mock_get_count(_, org_id: Optional[int]) -> Optional[int]: raise RepositoryError("Bad Repo") monkeypatch.setattr(repo, "get", mock_get) @@ -51,7 +51,8 @@ def mock_get_count(_) -> Optional[int]: def mock_document_count_none(repo, monkeypatch: MonkeyPatch, mocker): - def mock_get_count(_) -> Optional[int]: + + def mock_get_count(_, org_id: Optional[int]) -> Optional[int]: return None monkeypatch.setattr(repo, "count", mock_get_count) diff --git a/tests/mocks/repos/bad_event_repo.py b/tests/mocks/repos/bad_event_repo.py index 1829637d..6fd9729d 100644 --- a/tests/mocks/repos/bad_event_repo.py +++ b/tests/mocks/repos/bad_event_repo.py @@ -25,7 +25,7 @@ def mock_update(_, import_id, data: EventReadDTO) -> Optional[EventReadDTO]: def mock_delete(_, import_id: str) -> bool: raise RepositoryError("Bad Repo") - def mock_get_count(_) -> Optional[int]: + def mock_get_count(_, org_id: Optional[int]) -> Optional[int]: raise RepositoryError("Bad Repo") monkeypatch.setattr(repo, "get", mock_get) @@ -51,7 +51,8 @@ def mock_get_count(_) -> Optional[int]: def mock_event_count_none(repo, monkeypatch: MonkeyPatch, mocker): - def mock_get_count(_) -> Optional[int]: + + def mock_get_count(_, org_id: Optional[int]) -> Optional[int]: return None monkeypatch.setattr(repo, "count", mock_get_count) diff --git a/tests/mocks/repos/bad_family_repo.py b/tests/mocks/repos/bad_family_repo.py index 562680e4..674fe53a 100644 --- a/tests/mocks/repos/bad_family_repo.py +++ b/tests/mocks/repos/bad_family_repo.py @@ -27,7 +27,7 @@ def mock_create(_, data: FamilyReadDTO, __, ___) -> Optional[FamilyReadDTO]: def mock_delete(_, import_id: str) -> bool: raise RepositoryError("Bad Repo") - def mock_get_count(_) -> Optional[int]: + def mock_get_count(_, org_id: Optional[int]) -> Optional[int]: raise RepositoryError("Bad Repo") monkeypatch.setattr(repo, "get", mock_get) @@ -53,7 +53,8 @@ def mock_get_count(_) -> Optional[int]: def mock_family_count_none(repo, monkeypatch: MonkeyPatch, mocker): - def mock_get_count(_) -> Optional[int]: + + def mock_get_count(_, org_id: Optional[int]) -> Optional[int]: return None monkeypatch.setattr(repo, "count", mock_get_count) diff --git a/tests/mocks/repos/collection_repo.py b/tests/mocks/repos/collection_repo.py index 34fab149..fa887bca 100644 --- a/tests/mocks/repos/collection_repo.py +++ b/tests/mocks/repos/collection_repo.py @@ -18,16 +18,18 @@ def mock_collection_repo(collection_repo, monkeypatch: MonkeyPatch, mocker): collection_repo.throw_repository_error = False collection_repo.throw_timeout_error = False collection_repo.alternative_org = False + collection_repo.is_superuser = False def maybe_throw(): if collection_repo.throw_repository_error: - raise RepositoryError("bad repo") + raise RepositoryError("bad collection repo") def maybe_timeout(): if collection_repo.throw_timeout_error: raise TimeoutError - def mock_get_all(_) -> list[CollectionReadDTO]: + def mock_get_all(_, org_id: Optional[int]) -> list[CollectionReadDTO]: + maybe_throw() return [ create_collection_read_dto(import_id="id1"), create_collection_read_dto(import_id="id2"), @@ -60,10 +62,12 @@ def mock_delete(_, import_id: str) -> bool: maybe_throw() return not collection_repo.return_empty - def mock_get_count(_) -> Optional[int]: + def mock_get_count(_, org_id: Optional[int]) -> Optional[int]: maybe_throw() if collection_repo.return_empty is False: - return 11 + if collection_repo.is_superuser: + return 11 + return 5 return def mock_validate(_, __) -> bool: diff --git a/tests/mocks/repos/config_repo.py b/tests/mocks/repos/config_repo.py index 80c5e071..af0a0920 100644 --- a/tests/mocks/repos/config_repo.py +++ b/tests/mocks/repos/config_repo.py @@ -18,7 +18,6 @@ def mock_get(_) -> Optional[ConfigReadDTO]: maybe_throw() return ConfigReadDTO( geographies=[], - taxonomies={}, corpora=[], languages={}, document=DocumentConfig(roles=[], types=[], variants=[]), diff --git a/tests/mocks/repos/document_repo.py b/tests/mocks/repos/document_repo.py index 973ea9da..b8d3a1e4 100644 --- a/tests/mocks/repos/document_repo.py +++ b/tests/mocks/repos/document_repo.py @@ -12,16 +12,18 @@ def mock_document_repo(document_repo, monkeypatch: MonkeyPatch, mocker): document_repo.return_empty = False document_repo.throw_repository_error = False document_repo.throw_timeout_error = False + document_repo.is_superuser = False def maybe_throw(): if document_repo.throw_repository_error: - raise RepositoryError("bad repo") + raise RepositoryError("bad document repo") def maybe_timeout(): if document_repo.throw_timeout_error: raise TimeoutError - def mock_get_all(_) -> list[DocumentReadDTO]: + def mock_get_all(_, org_id: Optional[int]) -> list[DocumentReadDTO]: + maybe_throw() values = [] for x in range(3): dto = create_document_read_dto(import_id=f"id{x}") @@ -55,10 +57,12 @@ def mock_delete(_, import_id: str) -> bool: maybe_throw() return not document_repo.return_empty - def mock_get_count(_) -> Optional[int]: + def mock_get_count(_, org_id: Optional[int]) -> Optional[int]: maybe_throw() if not document_repo.return_empty: - return 33 + if document_repo.is_superuser: + return 33 + return 11 return monkeypatch.setattr(document_repo, "get", mock_get) diff --git a/tests/mocks/repos/event_repo.py b/tests/mocks/repos/event_repo.py index a2a5922b..3e856d68 100644 --- a/tests/mocks/repos/event_repo.py +++ b/tests/mocks/repos/event_repo.py @@ -12,16 +12,18 @@ def mock_event_repo(event_repo, monkeypatch: MonkeyPatch, mocker): event_repo.return_empty = False event_repo.throw_repository_error = False event_repo.throw_timeout_error = False + event_repo.is_superuser = False def maybe_throw(): if event_repo.throw_repository_error: - raise RepositoryError("bad repo") + raise RepositoryError("bad event repo") def maybe_timeout(): if event_repo.throw_timeout_error: raise TimeoutError def mock_get_all(_) -> list[EventReadDTO]: + maybe_throw() values = [] for x in range(3): dto = create_event_read_dto(import_id=f"id{x}") @@ -55,10 +57,12 @@ def mock_delete(_, import_id: str) -> bool: maybe_throw() return not event_repo.return_empty - def mock_get_count(_) -> Optional[int]: + def mock_get_count(_, org_id: Optional[int]) -> Optional[int]: maybe_throw() if not event_repo.return_empty: - return 5 + if event_repo.is_superuser: + return 5 + return 2 return monkeypatch.setattr(event_repo, "get", mock_get) diff --git a/tests/mocks/repos/family_repo.py b/tests/mocks/repos/family_repo.py index 52511feb..a98bf840 100644 --- a/tests/mocks/repos/family_repo.py +++ b/tests/mocks/repos/family_repo.py @@ -10,7 +10,7 @@ def _maybe_throw(): if family_repo.throw_repository_error is True: - raise RepositoryError("bad repo") + raise RepositoryError("bad family repo") def _maybe_timeout(): @@ -18,7 +18,8 @@ def _maybe_timeout(): raise TimeoutError -def all(db: Session): +def all(db: Session, org_id: Optional[int]): + _maybe_throw() return [create_family_read_dto("test", collections=["x.y.z.1", "x.y.z.2"])] @@ -58,8 +59,10 @@ def delete(db: Session, import_id: str) -> bool: return family_repo.return_empty is False -def count(db: Session) -> Optional[int]: +def count(db: Session, org_id: Optional[int]) -> Optional[int]: _maybe_throw() - if family_repo.return_empty is False: + if family_repo.return_empty: + return + if family_repo.is_superuser: return 22 - return + return 11 diff --git a/tests/mocks/services/analytics_service.py b/tests/mocks/services/analytics_service.py index 0f275695..18d54219 100644 --- a/tests/mocks/services/analytics_service.py +++ b/tests/mocks/services/analytics_service.py @@ -13,7 +13,7 @@ def maybe_throw(): if analytics_service.throw_repository_error: raise RepositoryError("bad repo") - def mock_get_summary() -> SummaryDTO: + def mock_get_summary(user_email: str) -> SummaryDTO: maybe_throw() if analytics_service.return_empty is True: return create_summary_dto( diff --git a/tests/mocks/services/app_user_service.py b/tests/mocks/services/app_user_service.py index 87d1f3cd..f6968af5 100644 --- a/tests/mocks/services/app_user_service.py +++ b/tests/mocks/services/app_user_service.py @@ -1,3 +1,5 @@ +from typing import Optional + from pytest import MonkeyPatch ORG_ID = 1234 @@ -5,11 +7,24 @@ def mock_app_user_service(app_user_service, monkeypatch: MonkeyPatch, mocker): app_user_service.invalid_org = False + app_user_service.is_superuser = False def mock_get_organisation(_, user_email: str) -> int: if app_user_service.invalid_org is True: return ORG_ID return 1 + def mock_restrict_entities_to_user_org(_, user_email: str) -> Optional[int]: + if app_user_service.is_superuser is True: + return None + return 1 + + monkeypatch.setattr( + app_user_service, + "restrict_entities_to_user_org", + mock_restrict_entities_to_user_org, + ) + mocker.spy(app_user_service, "restrict_entities_to_user_org") + monkeypatch.setattr(app_user_service, "get_organisation", mock_get_organisation) mocker.spy(app_user_service, "get_organisation") diff --git a/tests/mocks/services/collection_service.py b/tests/mocks/services/collection_service.py index a6945131..62e31fea 100644 --- a/tests/mocks/services/collection_service.py +++ b/tests/mocks/services/collection_service.py @@ -58,12 +58,6 @@ def mock_delete_collection(import_id: str) -> bool: maybe_throw() return not collection_service.missing - def mock_count_collection() -> Optional[int]: - maybe_throw() - if collection_service.missing: - return None - return 11 - def mock_validate() -> Optional[int]: maybe_throw() if collection_service.missing is False: @@ -98,9 +92,6 @@ def mock_get_org_from_id() -> Optional[int]: monkeypatch.setattr(collection_service, "delete", mock_delete_collection) mocker.spy(collection_service, "delete") - monkeypatch.setattr(collection_service, "count", mock_count_collection) - mocker.spy(collection_service, "count") - monkeypatch.setattr(collection_service, "validate", mock_validate) mocker.spy(collection_service, "validate") diff --git a/tests/mocks/services/config_service.py b/tests/mocks/services/config_service.py index 5a7e5a2d..606925ed 100644 --- a/tests/mocks/services/config_service.py +++ b/tests/mocks/services/config_service.py @@ -15,7 +15,6 @@ def mock_get_config(_) -> ConfigReadDTO: maybe_throw() return ConfigReadDTO( geographies=[], - taxonomies={}, corpora=[], languages={}, document=DocumentConfig(roles=[], types=[], variants=[]), diff --git a/tests/mocks/services/document_service.py b/tests/mocks/services/document_service.py index 09df1557..8724f4f4 100644 --- a/tests/mocks/services/document_service.py +++ b/tests/mocks/services/document_service.py @@ -63,12 +63,6 @@ def mock_delete_document(_) -> bool: maybe_throw() return not document_service.missing - def mock_count_collection() -> Optional[int]: - maybe_throw() - if document_service.missing: - return None - return 33 - monkeypatch.setattr(document_service, "get", mock_get_document) mocker.spy(document_service, "get") @@ -86,6 +80,3 @@ def mock_count_collection() -> Optional[int]: monkeypatch.setattr(document_service, "delete", mock_delete_document) mocker.spy(document_service, "delete") - - monkeypatch.setattr(document_service, "count", mock_count_collection) - mocker.spy(document_service, "count") diff --git a/tests/mocks/services/event_service.py b/tests/mocks/services/event_service.py index e2e8616f..2e8b1ab1 100644 --- a/tests/mocks/services/event_service.py +++ b/tests/mocks/services/event_service.py @@ -57,12 +57,6 @@ def mock_delete_event(_) -> bool: maybe_throw() return not event_service.missing - def mock_count_event() -> Optional[int]: - maybe_throw() - if event_service.missing: - return None - return 5 - monkeypatch.setattr(event_service, "get", mock_get_event) mocker.spy(event_service, "get") @@ -80,6 +74,3 @@ def mock_count_event() -> Optional[int]: monkeypatch.setattr(event_service, "delete", mock_delete_event) mocker.spy(event_service, "delete") - - monkeypatch.setattr(event_service, "count", mock_count_event) - mocker.spy(event_service, "count") diff --git a/tests/mocks/services/family_service.py b/tests/mocks/services/family_service.py index ec80a2e4..0a78bd2e 100644 --- a/tests/mocks/services/family_service.py +++ b/tests/mocks/services/family_service.py @@ -65,12 +65,6 @@ def mock_create_family(data: FamilyCreateDTO, user_email: str) -> str: def mock_delete_family(import_id: str) -> bool: return not family_service.missing - def mock_count_collection() -> Optional[int]: - maybe_throw() - if family_service.missing: - return None - return 22 - monkeypatch.setattr(family_service, "get", mock_get_family) mocker.spy(family_service, "get") @@ -88,6 +82,3 @@ def mock_count_collection() -> Optional[int]: monkeypatch.setattr(family_service, "delete", mock_delete_family) mocker.spy(family_service, "delete") - - monkeypatch.setattr(family_service, "count", mock_count_collection) - mocker.spy(family_service, "count") diff --git a/tests/unit_tests/service/analytics/test_analytics_service.py b/tests/unit_tests/service/analytics/test_analytics_service.py index dad984ff..a1a99342 100644 --- a/tests/unit_tests/service/analytics/test_analytics_service.py +++ b/tests/unit_tests/service/analytics/test_analytics_service.py @@ -13,6 +13,8 @@ create_summary_dto, ) +USER_EMAIL = "test@cpr.org" + def expected_analytics_summary( expected_docs: Optional[int] = EXPECTED_NUM_DOCUMENTS, @@ -31,15 +33,52 @@ def expected_analytics_summary( # --- GET SUMMARY -def test_summary( - collection_repo_mock, document_repo_mock, family_repo_mock, event_repo_mock +def test_summary_superuser( + collection_repo_mock, + document_repo_mock, + family_repo_mock, + event_repo_mock, + app_user_repo_mock, ): - result = analytics_service.summary() + collection_repo_mock.is_superuser = True + document_repo_mock.is_superuser = True + family_repo_mock.is_superuser = True + event_repo_mock.is_superuser = True + + result = analytics_service.summary("superuser@cpr.org") assert result == expected_analytics_summary() assert result is not None # Ensure the analytics service uses the other services to validate. + assert app_user_repo_mock.get_org_id.call_count == 1 + assert app_user_repo_mock.is_superuser.call_count == 1 + assert collection_repo_mock.count.call_count == 1 + assert document_repo_mock.count.call_count == 1 + assert family_repo_mock.count.call_count == 1 + assert event_repo_mock.count.call_count == 1 + + +def test_summary_non_superuser( + collection_repo_mock, + document_repo_mock, + family_repo_mock, + event_repo_mock, + app_user_repo_mock, +): + result = analytics_service.summary("non-superuser@cpr.org") + assert result == expected_analytics_summary( + expected_docs=11, + expected_families=22, + expected_collections=5, + expected_events=2, + ) + + assert result is not None + + # Ensure the analytics service uses the other services to validate. + assert app_user_repo_mock.get_org_id.call_count == 1 + assert app_user_repo_mock.is_superuser.call_count == 1 assert collection_repo_mock.count.call_count == 1 assert document_repo_mock.count.call_count == 1 assert family_repo_mock.count.call_count == 1 @@ -47,13 +86,24 @@ def test_summary( def test_summary_returns_none( - collection_repo_mock, document_repo_mock, family_repo_mock, event_repo_mock + collection_repo_mock, + document_repo_mock, + family_repo_mock, + event_repo_mock, + app_user_repo_mock, ): collection_repo_mock.return_empty = True - result = analytics_service.summary() - assert result == expected_analytics_summary(expected_collections=None) + result = analytics_service.summary(USER_EMAIL) + assert result == expected_analytics_summary( + expected_docs=11, + expected_families=22, + expected_collections=None, + expected_events=2, + ) # Ensure the analytics service uses the other services to validate. + assert app_user_repo_mock.get_org_id.call_count == 1 + assert app_user_repo_mock.is_superuser.call_count == 1 assert collection_repo_mock.count.call_count == 1 assert document_repo_mock.count.call_count == 1 assert family_repo_mock.count.call_count == 1 @@ -61,13 +111,19 @@ def test_summary_returns_none( def test_summary_raises_if_db_error( - collection_repo_mock, document_repo_mock, family_repo_mock, event_repo_mock + collection_repo_mock, + document_repo_mock, + family_repo_mock, + event_repo_mock, + app_user_repo_mock, ): collection_repo_mock.throw_repository_error = True with pytest.raises(RepositoryError): - analytics_service.summary() + analytics_service.summary(USER_EMAIL) # Ensure the analytics service uses the other services to validate. + assert app_user_repo_mock.get_org_id.call_count == 1 + assert app_user_repo_mock.is_superuser.call_count == 1 assert collection_repo_mock.count.call_count == 1 assert document_repo_mock.count.call_count == 0 assert family_repo_mock.count.call_count == 0 diff --git a/tests/unit_tests/service/collection/test_count_collection_service.py b/tests/unit_tests/service/collection/test_count_collection_service.py deleted file mode 100644 index 34758162..00000000 --- a/tests/unit_tests/service/collection/test_count_collection_service.py +++ /dev/null @@ -1,29 +0,0 @@ -import pytest - -import app.service.collection as collection_service -from app.errors import RepositoryError - -# --- COUNT - - -def test_count(collection_repo_mock): - result = collection_service.count() - assert result is not None - assert collection_repo_mock.count.call_count == 1 - - -def test_count_returns_none(collection_repo_mock): - collection_repo_mock.return_empty = True - result = collection_service.count() - assert result is None - assert collection_repo_mock.count.call_count == 1 - - -def test_count_raises_if_db_error(collection_repo_mock): - collection_repo_mock.throw_repository_error = True - with pytest.raises(RepositoryError) as e: - collection_service.count() - - expected_msg = "bad repo" - assert e.value.message == expected_msg - assert collection_repo_mock.count.call_count == 1 diff --git a/tests/unit_tests/service/document/test_count_document_service.py b/tests/unit_tests/service/document/test_count_document_service.py deleted file mode 100644 index 0946c077..00000000 --- a/tests/unit_tests/service/document/test_count_document_service.py +++ /dev/null @@ -1,29 +0,0 @@ -import pytest - -import app.service.document as doc_service -from app.errors import RepositoryError - -# --- COUNT - - -def test_count(document_repo_mock): - result = doc_service.count() - assert result is not None - assert document_repo_mock.count.call_count == 1 - - -def test_count_returns_none(document_repo_mock): - document_repo_mock.return_empty = True - result = doc_service.count() - assert result is None - assert document_repo_mock.count.call_count == 1 - - -def test_count_raises_if_db_error(document_repo_mock): - document_repo_mock.throw_repository_error = True - with pytest.raises(RepositoryError) as e: - doc_service.count() - - expected_msg = "bad repo" - assert e.value.message == expected_msg - assert document_repo_mock.count.call_count == 1 diff --git a/tests/unit_tests/service/event/test_count_event_service.py b/tests/unit_tests/service/event/test_count_event_service.py deleted file mode 100644 index 0752d196..00000000 --- a/tests/unit_tests/service/event/test_count_event_service.py +++ /dev/null @@ -1,29 +0,0 @@ -import pytest - -import app.service.event as event_service -from app.errors import RepositoryError - -# --- COUNT - - -def test_count(event_repo_mock): - result = event_service.count() - assert result is not None - assert event_repo_mock.count.call_count == 1 - - -def test_count_returns_none(event_repo_mock): - event_repo_mock.return_empty = True - result = event_service.count() - assert result is None - assert event_repo_mock.count.call_count == 1 - - -def test_count_raises_if_db_error(event_repo_mock): - event_repo_mock.throw_repository_error = True - with pytest.raises(RepositoryError) as e: - event_service.count() - - expected_msg = "bad repo" - assert e.value.message == expected_msg - assert event_repo_mock.count.call_count == 1 diff --git a/tests/unit_tests/service/family/test_count_family_service.py b/tests/unit_tests/service/family/test_count_family_service.py deleted file mode 100644 index ed65ac8e..00000000 --- a/tests/unit_tests/service/family/test_count_family_service.py +++ /dev/null @@ -1,38 +0,0 @@ -""" -Tests the family service. - -Uses a family repo mock and ensures that the repo is called. -""" - -import pytest - -import app.service.family as family_service -from app.errors import RepositoryError - -USER_EMAIL = "cclw@cpr.org" -ORG_ID = 1 - -# --- COUNT - - -def test_count(family_repo_mock): - result = family_service.count() - assert result is not None - assert family_repo_mock.count.call_count == 1 - - -def test_count_returns_none(family_repo_mock): - family_repo_mock.return_empty = True - result = family_service.count() - assert result is None - assert family_repo_mock.count.call_count == 1 - - -def test_count_raises_if_db_error(family_repo_mock): - family_repo_mock.throw_repository_error = True - with pytest.raises(RepositoryError) as e: - family_service.count() - - expected_msg = "bad repo" - assert e.value.message == expected_msg - assert family_repo_mock.count.call_count == 1 diff --git a/tests/unit_tests/service/family/test_get_family_service.py b/tests/unit_tests/service/family/test_get_family_service.py index 28b1c390..cfb28f0e 100644 --- a/tests/unit_tests/service/family/test_get_family_service.py +++ b/tests/unit_tests/service/family/test_get_family_service.py @@ -9,10 +9,6 @@ import app.service.family as family_service from app.errors import ValidationError -USER_EMAIL = "cclw@cpr.org" -ORG_ID = 1 - - # --- GET