diff --git a/app/api/api_v1/routers/collection.py b/app/api/api_v1/routers/collection.py index 6fddcadc..8e8ba84e 100644 --- a/app/api/api_v1/routers/collection.py +++ b/app/api/api_v1/routers/collection.py @@ -99,7 +99,7 @@ async def search_collection(request: Request) -> list[CollectionReadDTO]: validate_query_params(query_params, VALID_PARAMS) try: - collections = collection_service.search(query_params) + collections = collection_service.search(query_params, request.state.user.email) except ValidationError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.message) except RepositoryError as e: diff --git a/app/api/api_v1/routers/document.py b/app/api/api_v1/routers/document.py index c699fe78..cc3fec5d 100644 --- a/app/api/api_v1/routers/document.py +++ b/app/api/api_v1/routers/document.py @@ -93,7 +93,7 @@ async def search_document(request: Request) -> list[DocumentReadDTO]: validate_query_params(query_params, VALID_PARAMS) try: - documents = document_service.search(query_params) + documents = document_service.search(query_params, request.state.user.email) except ValidationError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.message) except RepositoryError as e: diff --git a/app/api/api_v1/routers/event.py b/app/api/api_v1/routers/event.py index 2abeb5e5..462ae65f 100644 --- a/app/api/api_v1/routers/event.py +++ b/app/api/api_v1/routers/event.py @@ -65,7 +65,7 @@ async def search_event(request: Request) -> list[EventReadDTO]: validate_query_params(query_params, VALID_PARAMS) try: - events_found = event_service.search(query_params) + events_found = event_service.search(query_params, request.state.user.email) except ValidationError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.message) except RepositoryError as e: diff --git a/app/api/api_v1/routers/family.py b/app/api/api_v1/routers/family.py index 8606086a..42898a35 100644 --- a/app/api/api_v1/routers/family.py +++ b/app/api/api_v1/routers/family.py @@ -93,7 +93,7 @@ async def search_family(request: Request) -> list[FamilyReadDTO]: validate_query_params(query_params, VALID_PARAMS) try: - families = family_service.search(query_params) + families = family_service.search(query_params, request.state.user.email) except ValidationError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.message) except RepositoryError as e: diff --git a/app/repository/collection.py b/app/repository/collection.py index 57acfa15..4cdd3401 100644 --- a/app/repository/collection.py +++ b/app/repository/collection.py @@ -134,7 +134,7 @@ def get(db: Session, import_id: str) -> Optional[CollectionReadDTO]: def search( - db: Session, query_params: dict[str, Union[str, int]] + db: Session, query_params: dict[str, Union[str, int]], org_id: Optional[int] ) -> list[CollectionReadDTO]: """ Gets a list of collections from the repo searching given fields. @@ -142,6 +142,7 @@ def search( :param db Session: the database connection :param dict query_params: Any search terms to filter on specified fields (title & summary by default if 'q' specified). + :param org_id Optional[int]: the ID of the organisation the user belongs to :raises HTTPException: If a DB error occurs a 503 is returned. :raises HTTPException: If the search request times out a 408 is returned. @@ -156,10 +157,11 @@ def search( condition = and_(*search) if len(search) > 1 else search[0] try: + query = _get_query(db).filter(condition) + if org_id is not None: + query = query.filter(Organisation.id == org_id) found = ( - _get_query(db) - .filter(condition) - .order_by(desc(Collection.last_modified)) + query.order_by(desc(Collection.last_modified)) .limit(query_params["max_results"]) .all() ) @@ -268,10 +270,10 @@ def count(db: Session, org_id: Optional[int]) -> Optional[int]: :return Optional[int]: The number of collections in the repository or none. """ try: - if org_id is None: - n_collections = _get_query(db).count() - else: - n_collections = _get_query(db).filter(Organisation.id == org_id).count() + query = _get_query(db) + if org_id is not None: + query = query.filter(Organisation.id == org_id) + n_collections = query.count() except Exception as e: _LOGGER.error(e) return diff --git a/app/repository/document.py b/app/repository/document.py index 7cfd9e77..3d4094e7 100644 --- a/app/repository/document.py +++ b/app/repository/document.py @@ -162,7 +162,7 @@ def get(db: Session, import_id: str) -> Optional[DocumentReadDTO]: def search( - db: Session, query_params: dict[str, Union[str, int]] + db: Session, query_params: dict[str, Union[str, int]], org_id: Optional[int] ) -> list[DocumentReadDTO]: """ Gets a list of documents from the repository searching the title. @@ -170,6 +170,7 @@ def search( :param db Session: the database connection :param dict query_params: Any search terms to filter on specified fields (title by default if 'q' specified). + :param org_id Optional[int]: the ID of the organisation the user belongs to :raises HTTPException: If a DB error occurs a 503 is returned. :raises HTTPException: If the search request times out a 408 is returned. @@ -183,9 +184,10 @@ def search( condition = and_(*search) if len(search) > 1 else search[0] try: # TODO: Fix order by on search PDCT-672 - result = ( - _get_query(db).filter(condition).limit(query_params["max_results"]).all() - ) + query = _get_query(db).filter(condition) + if org_id is not None: + query = query.filter(Organisation.id == org_id) + result = query.limit(query_params["max_results"]).all() except OperationalError as e: if "canceling statement due to statement timeout" in str(e): raise TimeoutError @@ -444,10 +446,10 @@ def count(db: Session, org_id: Optional[int]) -> Optional[int]: :return Optional[int]: The number of documents in the repository or none. """ try: - if org_id is None: - n_documents = _get_query(db).count() - else: - n_documents = _get_query(db).filter(Organisation.id == org_id).count() + query = _get_query(db) + if org_id is not None: + query = query.filter(Organisation.id == org_id) + n_documents = query.count() except NoResultFound as e: _LOGGER.error(e) return diff --git a/app/repository/event.py b/app/repository/event.py index d26c90c1..e154b378 100644 --- a/app/repository/event.py +++ b/app/repository/event.py @@ -111,13 +111,16 @@ def get(db: Session, import_id: str) -> Optional[EventReadDTO]: return _event_to_dto(family_event_meta) -def search(db: Session, query_params: dict[str, Union[str, int]]) -> list[EventReadDTO]: +def search( + db: Session, query_params: dict[str, Union[str, int]], org_id: Optional[int] +) -> list[EventReadDTO]: """ Get family events matching a search term on the event title or type. :param db Session: The database connection. :param dict query_params: Any search terms to filter on specified fields (title & event type name by default if 'q' specified). + :param org_id Optional[int]: the ID of the organisation the user belongs to :raises HTTPException: If a DB error occurs a 503 is returned. :raises HTTPException: If the search request times out a 408 is returned. @@ -132,9 +135,10 @@ def search(db: Session, query_params: dict[str, Union[str, int]]) -> list[EventR condition = and_(*search) if len(search) > 1 else search[0] try: - found = ( - _get_query(db).filter(condition).limit(query_params["max_results"]).all() - ) + query = _get_query(db).filter(condition) + if org_id is not None: + query = query.filter(Organisation.id == org_id) + found = query.limit(query_params["max_results"]).all() except OperationalError as e: if "canceling statement due to statement timeout" in str(e): raise TimeoutError @@ -252,10 +256,10 @@ def count(db: Session, org_id: Optional[int]) -> Optional[int]: or nothing. """ try: - if org_id is None: - n_events = _get_query(db).count() - else: - n_events = _get_query(db).filter(Organisation.id == org_id).count() + query = _get_query(db) + if org_id is not None: + query = query.filter(Organisation.id == org_id) + n_events = query.count() except NoResultFound as e: _LOGGER.error(e) return diff --git a/app/repository/family.py b/app/repository/family.py index 05a81687..0a56fd17 100644 --- a/app/repository/family.py +++ b/app/repository/family.py @@ -121,17 +121,13 @@ def all(db: Session, org_id: Optional[int]) -> list[FamilyReadDTO]: Returns all the families. :param db Session: the database connection + :param org_id int: the ID of the organisation the user belongs to :return Optional[FamilyResponse]: All of things """ - if org_id is None: - family_geo_metas = _get_query(db).order_by(desc(Family.last_modified)).all() - else: - family_geo_metas = ( - _get_query(db) - .filter(Organisation.id == org_id) - .order_by(desc(Family.last_modified)) - .all() - ) + query = _get_query(db) + if org_id is not None: + query = query.filter(Organisation.id == org_id) + family_geo_metas = query.order_by(desc(Family.last_modified)).all() if not family_geo_metas: return [] @@ -159,7 +155,7 @@ def get(db: Session, import_id: str) -> Optional[FamilyReadDTO]: def search( - db: Session, query_params: dict[str, Union[str, int]] + db: Session, query_params: dict[str, Union[str, int]], org_id: Optional[int] ) -> list[FamilyReadDTO]: """ Gets a list of families from the repository searching given fields. @@ -167,6 +163,7 @@ def search( :param db Session: the database connection :param dict query_params: Any search terms to filter on specified fields (title & summary by default if 'q' specified). + :param org_id Optional[int]: the ID of the organisation the user belongs to :raises HTTPException: If a DB error occurs a 503 is returned. :raises HTTPException: If the search request times out a 408 is returned. @@ -200,10 +197,11 @@ def search( condition = and_(*search) if len(search) > 1 else search[0] try: + query = _get_query(db).filter(condition) + if org_id is not None: + query = query.filter(Organisation.id == org_id) found = ( - _get_query(db) - .filter(condition) - .order_by(desc(Family.last_modified)) + query.order_by(desc(Family.last_modified)) .limit(query_params["max_results"]) .all() ) @@ -508,10 +506,10 @@ def count(db: Session, org_id: Optional[int]) -> Optional[int]: :return Optional[int]: The number of families in the repository or none. """ try: - if org_id is None: - n_families = _get_query(db).count() - else: - n_families = _get_query(db).filter(Organisation.id == org_id).count() + query = _get_query(db) + if org_id is not None: + query = query.filter(Organisation.id == org_id) + n_families = query.count() except NoResultFound as e: _LOGGER.error(e) return diff --git a/app/repository/protocols.py b/app/repository/protocols.py index f6122107..7e091040 100644 --- a/app/repository/protocols.py +++ b/app/repository/protocols.py @@ -26,7 +26,7 @@ def get(db: Session, import_id: str) -> Optional[FamilyReadDTO]: @staticmethod def search( - db: Session, query_params: dict[str, Union[str, int]] + db: Session, query_params: dict[str, Union[str, int]], org_id: Optional[int] ) -> list[FamilyReadDTO]: """Searches the families""" ... diff --git a/app/service/collection.py b/app/service/collection.py index 6c38d047..6d4ab22b 100644 --- a/app/service/collection.py +++ b/app/service/collection.py @@ -60,7 +60,9 @@ def all() -> list[CollectionReadDTO]: @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) -def search(query_params: dict[str, Union[str, int]]) -> list[CollectionReadDTO]: +def search( + query_params: dict[str, Union[str, int]], user_email: str +) -> list[CollectionReadDTO]: """ Searches for the search term against collections on specified fields. @@ -70,11 +72,13 @@ def search(query_params: dict[str, Union[str, int]]) -> list[CollectionReadDTO]: :param dict query_params: Search patterns to match against specified fields, given as key value pairs in a dictionary. + :param str user_email: The email address of the current user. :return list[CollectionReadDTO]: The list of collections matching the given search terms. """ with db_session.get_db() as db: - return collection_repo.search(db, query_params) + org_id = app_user.restrict_entities_to_user_org(db, user_email) + return collection_repo.search(db, query_params, org_id) @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) diff --git a/app/service/document.py b/app/service/document.py index cf4151d6..60899b51 100644 --- a/app/service/document.py +++ b/app/service/document.py @@ -12,7 +12,7 @@ from app.clients.aws.client import get_s3_client from app.errors import RepositoryError, ValidationError from app.model.document import DocumentCreateDTO, DocumentReadDTO, DocumentWriteDTO -from app.service import id +from app.service import app_user, id _LOGGER = logging.getLogger(__name__) @@ -55,7 +55,9 @@ def all() -> list[DocumentReadDTO]: @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) -def search(query_params: dict[str, Union[str, int]]) -> list[DocumentReadDTO]: +def search( + query_params: dict[str, Union[str, int]], user_email: str +) -> list[DocumentReadDTO]: """ Searches for the search term against documents on specified fields. @@ -64,11 +66,13 @@ def search(query_params: dict[str, Union[str, int]]) -> list[DocumentReadDTO]: :param dict query_params: Search patterns to match against specified fields, given as key value pairs in a dictionary. + :param str user_email: The email address of the current user. :return list[DocumentReadDTO]: The list of documents matching the given search terms. """ with db_session.get_db() as db: - return document_repo.search(db, query_params) + org_id = app_user.restrict_entities_to_user_org(db, user_email) + return document_repo.search(db, query_params, org_id) @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) diff --git a/app/service/event.py b/app/service/event.py index 3e72668c..1b1898f6 100644 --- a/app/service/event.py +++ b/app/service/event.py @@ -10,7 +10,7 @@ import app.service.family as family_service from app.errors import RepositoryError, ValidationError from app.model.event import EventCreateDTO, EventReadDTO, EventWriteDTO -from app.service import id +from app.service import app_user, id _LOGGER = logging.getLogger(__name__) @@ -46,7 +46,9 @@ def all() -> list[EventReadDTO]: @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) -def search(query_params: dict[str, Union[str, int]]) -> list[EventReadDTO]: +def search( + query_params: dict[str, Union[str, int]], user_email: str +) -> list[EventReadDTO]: """ Searches for the search term against events on specified fields. @@ -56,11 +58,13 @@ def search(query_params: dict[str, Union[str, int]]) -> list[EventReadDTO]: :param dict query_params: Search patterns to match against specified fields, given as key value pairs in a dictionary. + :param str user_email: The email address of the current user. :return list[EventReadDTO]: The list of events matching the given search terms. """ with db_session.get_db() as db: - return event_repo.search(db, query_params) + org_id = app_user.restrict_entities_to_user_org(db, user_email) + return event_repo.search(db, query_params, org_id) @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) diff --git a/app/service/family.py b/app/service/family.py index 29d6af9d..8cca73d0 100644 --- a/app/service/family.py +++ b/app/service/family.py @@ -61,7 +61,9 @@ def all(user_email: str) -> list[FamilyReadDTO]: @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) -def search(query_params: dict[str, Union[str, int]]) -> list[FamilyReadDTO]: +def search( + query_params: dict[str, Union[str, int]], user_email: str +) -> list[FamilyReadDTO]: """ Searches for the search term against families on specified fields. @@ -71,11 +73,13 @@ def search(query_params: dict[str, Union[str, int]]) -> list[FamilyReadDTO]: :param dict query_params: Search patterns to match against specified fields, given as key value pairs in a dictionary. + :param str user_email: The email address of the current user. :return list[FamilyDTO]: The list of families matching the given search terms. """ with db_session.get_db() as db: - return family_repo.search(db, query_params) + org_id = app_user.restrict_entities_to_user_org(db, user_email) + return family_repo.search(db, query_params, org_id) @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) diff --git a/pyproject.toml b/pyproject.toml index 8dcee850..def40b26 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "admin_backend" -version = "2.6.3" +version = "2.6.4" description = "" authors = ["CPR-dev-team "] packages = [{ include = "app" }, { include = "tests" }] diff --git a/tests/integration_tests/collection/test_search.py b/tests/integration_tests/collection/test_search.py index a5fb1cec..1d21886f 100644 --- a/tests/integration_tests/collection/test_search.py +++ b/tests/integration_tests/collection/test_search.py @@ -7,11 +7,13 @@ from tests.integration_tests.setup_db import setup_db -def test_search_collection(client: TestClient, data_db: Session, user_header_token): +def test_search_collection_super( + client: TestClient, data_db: Session, superuser_header_token +): setup_db(data_db) response = client.get( "/api/v1/collections/?q=description", - headers=user_header_token, + headers=superuser_header_token, ) assert response.status_code == status.HTTP_200_OK data = response.json() @@ -24,6 +26,25 @@ def test_search_collection(client: TestClient, data_db: Session, user_header_tok assert ids_found.symmetric_difference(expected_ids) == set([]) +def test_search_collection_non_super( + client: TestClient, data_db: Session, user_header_token +): + setup_db(data_db) + response = client.get( + "/api/v1/collections/?q=description", + headers=user_header_token, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + + ids_found = set([f["import_id"] for f in data]) + assert len(ids_found) == 2 + + expected_ids = set(["C.0.0.3", "C.0.0.2"]) + assert ids_found.symmetric_difference(expected_ids) == set([]) + + def test_search_collection_when_not_authorised(client: TestClient, data_db: Session): setup_db(data_db) response = client.get( @@ -63,12 +84,12 @@ def test_search_collection_when_db_error( def test_search_collections_with_max_results( - client: TestClient, data_db: Session, user_header_token + client: TestClient, data_db: Session, superuser_header_token ): setup_db(data_db) response = client.get( "/api/v1/collections/?q=description&max_results=1", - headers=user_header_token, + headers=superuser_header_token, ) assert response.status_code == status.HTTP_200_OK data = response.json() @@ -77,7 +98,7 @@ def test_search_collections_with_max_results( ids_found = set([f["import_id"] for f in data]) assert len(ids_found) == 1 - expected_ids = set(["C.0.0.1"]) + expected_ids = set(["C.0.0.2"]) assert ids_found.symmetric_difference(expected_ids) == set([]) diff --git a/tests/integration_tests/document/test_search.py b/tests/integration_tests/document/test_search.py index 55811f30..7639f6b5 100644 --- a/tests/integration_tests/document/test_search.py +++ b/tests/integration_tests/document/test_search.py @@ -7,7 +7,28 @@ from tests.integration_tests.setup_db import setup_db -def test_search_document(client: TestClient, data_db: Session, user_header_token): +def test_search_document_super( + client: TestClient, data_db: Session, superuser_header_token +): + setup_db(data_db) + response = client.get( + "/api/v1/documents/?q=title", + headers=superuser_header_token, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + + ids_found = set([f["import_id"] for f in data]) + assert len(ids_found) == 2 + + expected_ids = set(["D.0.0.1", "D.0.0.2"]) + assert ids_found.symmetric_difference(expected_ids) == set([]) + + +def test_search_document_non_super( + client: TestClient, data_db: Session, user_header_token +): setup_db(data_db) response = client.get( "/api/v1/documents/?q=title", diff --git a/tests/integration_tests/event/test_search.py b/tests/integration_tests/event/test_search.py index 9f5a1a55..54fc2784 100644 --- a/tests/integration_tests/event/test_search.py +++ b/tests/integration_tests/event/test_search.py @@ -7,7 +7,28 @@ from tests.integration_tests.setup_db import setup_db -def test_search_event(client: TestClient, data_db: Session, user_header_token): +def test_search_event_super( + client: TestClient, data_db: Session, superuser_header_token +): + setup_db(data_db) + response = client.get( + "/api/v1/events/?q=Amended", + headers=superuser_header_token, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + + ids_found = set([f["import_id"] for f in data]) + assert len(ids_found) == 2 + + expected_ids = set(["E.0.0.2", "E.0.0.3"]) + assert ids_found.symmetric_difference(expected_ids) == set([]) + + +def test_search_event_non_super( + client: TestClient, data_db: Session, user_header_token +): setup_db(data_db) response = client.get( "/api/v1/events/?q=Amended", diff --git a/tests/integration_tests/family/test_search.py b/tests/integration_tests/family/test_search.py index 055d49d5..c3e7d361 100644 --- a/tests/integration_tests/family/test_search.py +++ b/tests/integration_tests/family/test_search.py @@ -7,11 +7,13 @@ from tests.integration_tests.setup_db import setup_db -def test_search_family_using_q(client: TestClient, data_db: Session, user_header_token): +def test_search_family_super( + client: TestClient, data_db: Session, superuser_header_token +): setup_db(data_db) response = client.get( "/api/v1/families/?q=orange", - headers=user_header_token, + headers=superuser_header_token, ) assert response.status_code == status.HTTP_200_OK data = response.json() @@ -24,6 +26,25 @@ def test_search_family_using_q(client: TestClient, data_db: Session, user_header assert ids_found.symmetric_difference(expected_ids) == set([]) +def test_search_family_non_super( + client: TestClient, data_db: Session, user_header_token +): + setup_db(data_db) + response = client.get( + "/api/v1/families/?q=orange", + headers=user_header_token, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + + ids_found = set([f["import_id"] for f in data]) + assert len(ids_found) == 1 + + expected_ids = set(["A.0.0.2"]) + assert ids_found.symmetric_difference(expected_ids) == set([]) + + def test_search_family_with_specific_param( client: TestClient, data_db: Session, user_header_token ): diff --git a/tests/mocks/repos/bad_collection_repo.py b/tests/mocks/repos/bad_collection_repo.py index 6ecacd89..84febbeb 100644 --- a/tests/mocks/repos/bad_collection_repo.py +++ b/tests/mocks/repos/bad_collection_repo.py @@ -13,7 +13,7 @@ def mock_get_all(_): def mock_get(_, import_id: str) -> Optional[CollectionReadDTO]: raise RepositoryError("Bad Repo") - def mock_search(_, q: str) -> list[CollectionReadDTO]: + def mock_search(_, q: str, org_id: Optional[int]) -> list[CollectionReadDTO]: raise RepositoryError("Bad Repo") def mock_update( diff --git a/tests/mocks/repos/bad_document_repo.py b/tests/mocks/repos/bad_document_repo.py index 27c4b458..9b5e0736 100644 --- a/tests/mocks/repos/bad_document_repo.py +++ b/tests/mocks/repos/bad_document_repo.py @@ -13,7 +13,7 @@ def mock_get_all(_): def mock_get(_, import_id: str) -> Optional[DocumentReadDTO]: raise RepositoryError("Bad Repo") - def mock_search(_, q: str) -> list[DocumentReadDTO]: + def mock_search(_, q: str, org_id: Optional[int]) -> list[DocumentReadDTO]: raise RepositoryError("Bad Repo") def mock_update(_, import_id, data: DocumentReadDTO) -> Optional[DocumentReadDTO]: diff --git a/tests/mocks/repos/bad_event_repo.py b/tests/mocks/repos/bad_event_repo.py index 6fd9729d..5c3f9416 100644 --- a/tests/mocks/repos/bad_event_repo.py +++ b/tests/mocks/repos/bad_event_repo.py @@ -13,7 +13,7 @@ def mock_get_all(_): def mock_get(_, import_id: str) -> Optional[EventReadDTO]: raise RepositoryError("Bad Repo") - def mock_search(_, q: str) -> list[EventReadDTO]: + def mock_search(_, q: str, org_id: Optional[int]) -> list[EventReadDTO]: raise RepositoryError("Bad Repo") def mock_create(_, data: EventCreateDTO) -> Optional[EventReadDTO]: diff --git a/tests/mocks/repos/bad_family_repo.py b/tests/mocks/repos/bad_family_repo.py index 674fe53a..376e7f2c 100644 --- a/tests/mocks/repos/bad_family_repo.py +++ b/tests/mocks/repos/bad_family_repo.py @@ -13,7 +13,7 @@ def mock_get_all(_): def mock_get(_, import_id: str) -> Optional[FamilyReadDTO]: raise RepositoryError("Bad Repo") - def mock_search(_, q: str) -> list[FamilyReadDTO]: + def mock_search(_, q: str, org_id: Optional[int]) -> list[FamilyReadDTO]: raise RepositoryError("Bad Repo") def mock_update( diff --git a/tests/mocks/repos/collection_repo.py b/tests/mocks/repos/collection_repo.py index fa887bca..2408bc4a 100644 --- a/tests/mocks/repos/collection_repo.py +++ b/tests/mocks/repos/collection_repo.py @@ -39,7 +39,7 @@ def mock_get_all(_, org_id: Optional[int]) -> list[CollectionReadDTO]: def mock_get(_, import_id: str) -> Optional[CollectionReadDTO]: return create_collection_read_dto(import_id=import_id) - def mock_search(_, q: str) -> list[CollectionReadDTO]: + def mock_search(_, q: str, org_id: Optional[int]) -> list[CollectionReadDTO]: maybe_throw() maybe_timeout() if not collection_repo.return_empty: diff --git a/tests/mocks/repos/document_repo.py b/tests/mocks/repos/document_repo.py index b8d3a1e4..b1302fe3 100644 --- a/tests/mocks/repos/document_repo.py +++ b/tests/mocks/repos/document_repo.py @@ -34,7 +34,7 @@ def mock_get(_, import_id: str) -> Optional[DocumentReadDTO]: dto = create_document_read_dto(import_id) return dto - def mock_search(_, q: str) -> list[DocumentReadDTO]: + def mock_search(_, q: str, org_id: Optional[int]) -> list[DocumentReadDTO]: maybe_throw() maybe_timeout() if not document_repo.return_empty: diff --git a/tests/mocks/repos/event_repo.py b/tests/mocks/repos/event_repo.py index 3e856d68..009d59d5 100644 --- a/tests/mocks/repos/event_repo.py +++ b/tests/mocks/repos/event_repo.py @@ -34,7 +34,7 @@ def mock_get(_, import_id: str) -> Optional[EventReadDTO]: dto = create_event_read_dto(import_id) return dto - def mock_search(_, q: dict) -> list[EventReadDTO]: + def mock_search(_, q: dict, org_id: Optional[int]) -> list[EventReadDTO]: maybe_throw() maybe_timeout() if not event_repo.return_empty: diff --git a/tests/mocks/repos/family_repo.py b/tests/mocks/repos/family_repo.py index a98bf840..e2b2a268 100644 --- a/tests/mocks/repos/family_repo.py +++ b/tests/mocks/repos/family_repo.py @@ -30,7 +30,7 @@ def get(db: Session, import_id: str) -> Optional[FamilyReadDTO]: def search( - db: Session, query_params: dict[str, Union[str, int]] + db: Session, query_params: dict[str, Union[str, int]], org_id: Optional[int] ) -> list[FamilyReadDTO]: _maybe_throw() _maybe_timeout() diff --git a/tests/mocks/services/collection_service.py b/tests/mocks/services/collection_service.py index 62e31fea..c8d193d9 100644 --- a/tests/mocks/services/collection_service.py +++ b/tests/mocks/services/collection_service.py @@ -34,7 +34,9 @@ def mock_get_collection(import_id: str) -> Optional[CollectionReadDTO]: if not collection_service.missing: return create_collection_read_dto(import_id) - def mock_search_collections(q_params: dict) -> list[CollectionReadDTO]: + def mock_search_collections( + q_params: dict, user_email: str + ) -> list[CollectionReadDTO]: maybe_throw() maybe_timeout() if collection_service.missing: diff --git a/tests/mocks/services/document_service.py b/tests/mocks/services/document_service.py index 8724f4f4..3a49a5bc 100644 --- a/tests/mocks/services/document_service.py +++ b/tests/mocks/services/document_service.py @@ -30,7 +30,7 @@ def mock_get_document(import_id: str) -> Optional[DocumentReadDTO]: if not document_service.missing: return create_document_read_dto(import_id) - def mock_search_documents(q_params: dict) -> list[DocumentReadDTO]: + def mock_search_documents(q_params: dict, user_email: str) -> list[DocumentReadDTO]: if document_service.missing: return [] diff --git a/tests/mocks/services/event_service.py b/tests/mocks/services/event_service.py index 2e8b1ab1..8787d579 100644 --- a/tests/mocks/services/event_service.py +++ b/tests/mocks/services/event_service.py @@ -29,7 +29,7 @@ def mock_get_event(import_id: str) -> Optional[EventReadDTO]: if not event_service.missing: return create_event_read_dto(import_id) - def mock_search_events(q: dict) -> list[EventReadDTO]: + def mock_search_events(q: dict, user_email: str) -> list[EventReadDTO]: maybe_throw() maybe_timeout() if event_service.missing: diff --git a/tests/mocks/services/family_service.py b/tests/mocks/services/family_service.py index 0a78bd2e..fade37b0 100644 --- a/tests/mocks/services/family_service.py +++ b/tests/mocks/services/family_service.py @@ -23,14 +23,14 @@ def maybe_timeout(): if family_service.throw_timeout_error: raise TimeoutError - def mock_get_all_families(_): + def mock_get_all_families(user_email: str): return [create_family_read_dto("test", collections=["x.y.z.1", "x.y.z.2"])] def mock_get_family(import_id: str) -> Optional[FamilyReadDTO]: if not family_service.missing: return create_family_read_dto(import_id, collections=["x.y.z.1", "x.y.z.2"]) - def mock_search_families(q_params: dict) -> list[FamilyReadDTO]: + def mock_search_families(q_params: dict, user_email: str) -> list[FamilyReadDTO]: if q_params["q"] == "empty": return [] diff --git a/tests/unit_tests/service/collection/test_search_collection_service.py b/tests/unit_tests/service/collection/test_search_collection_service.py index 05114415..d1231118 100644 --- a/tests/unit_tests/service/collection/test_search_collection_service.py +++ b/tests/unit_tests/service/collection/test_search_collection_service.py @@ -3,33 +3,42 @@ import app.service.collection as collection_service from app.errors import RepositoryError +USER_EMAIL = "test@cpr.org" # --- SEARCH -def test_search(collection_repo_mock): - result = collection_service.search({"q": "two"}) +def test_search(collection_repo_mock, app_user_repo_mock): + result = collection_service.search({"q": "two"}, USER_EMAIL) assert result is not None assert len(result) == 1 + assert app_user_repo_mock.get_org_id.call_count == 1 + assert app_user_repo_mock.is_superuser.call_count == 1 assert collection_repo_mock.search.call_count == 1 -def test_search_db_error(collection_repo_mock): +def test_search_db_error(collection_repo_mock, app_user_repo_mock): collection_repo_mock.throw_repository_error = True with pytest.raises(RepositoryError): - collection_service.search({"q": "error"}) + collection_service.search({"q": "error"}, USER_EMAIL) + assert app_user_repo_mock.get_org_id.call_count == 1 + assert app_user_repo_mock.is_superuser.call_count == 1 assert collection_repo_mock.search.call_count == 1 -def test_search_request_timeout(collection_repo_mock): +def test_search_request_timeout(collection_repo_mock, app_user_repo_mock): collection_repo_mock.throw_timeout_error = True with pytest.raises(TimeoutError): - collection_service.search({"q": "timeout"}) + collection_service.search({"q": "timeout"}, USER_EMAIL) + assert app_user_repo_mock.get_org_id.call_count == 1 + assert app_user_repo_mock.is_superuser.call_count == 1 assert collection_repo_mock.search.call_count == 1 -def test_search_missing(collection_repo_mock): +def test_search_missing(collection_repo_mock, app_user_repo_mock): collection_repo_mock.return_empty = True - result = collection_service.search({"q": "empty"}) + result = collection_service.search({"q": "empty"}, USER_EMAIL) assert result is not None assert len(result) == 0 + assert app_user_repo_mock.get_org_id.call_count == 1 + assert app_user_repo_mock.is_superuser.call_count == 1 assert collection_repo_mock.search.call_count == 1 diff --git a/tests/unit_tests/service/document/test_search_document_service.py b/tests/unit_tests/service/document/test_search_document_service.py index eb65808a..1fb24b86 100644 --- a/tests/unit_tests/service/document/test_search_document_service.py +++ b/tests/unit_tests/service/document/test_search_document_service.py @@ -3,33 +3,42 @@ import app.service.document as doc_service from app.errors import RepositoryError +USER_EMAIL = "test@cpr.org" # --- SEARCH -def test_search(document_repo_mock): - result = doc_service.search({"q": "two"}) +def test_search(document_repo_mock, app_user_repo_mock): + result = doc_service.search({"q": "two"}, USER_EMAIL) assert result is not None assert len(result) == 1 + assert app_user_repo_mock.get_org_id.call_count == 1 + assert app_user_repo_mock.is_superuser.call_count == 1 assert document_repo_mock.search.call_count == 1 -def test_search_db_error(document_repo_mock): +def test_search_db_error(document_repo_mock, app_user_repo_mock): document_repo_mock.throw_repository_error = True with pytest.raises(RepositoryError): - doc_service.search({"q": "error"}) + doc_service.search({"q": "error"}, USER_EMAIL) + assert app_user_repo_mock.get_org_id.call_count == 1 + assert app_user_repo_mock.is_superuser.call_count == 1 assert document_repo_mock.search.call_count == 1 -def test_search_request_timeout(document_repo_mock): +def test_search_request_timeout(document_repo_mock, app_user_repo_mock): document_repo_mock.throw_timeout_error = True with pytest.raises(TimeoutError): - doc_service.search({"q": "timeout"}) + doc_service.search({"q": "timeout"}, USER_EMAIL) + assert app_user_repo_mock.get_org_id.call_count == 1 + assert app_user_repo_mock.is_superuser.call_count == 1 assert document_repo_mock.search.call_count == 1 -def test_search_missing(document_repo_mock): +def test_search_missing(document_repo_mock, app_user_repo_mock): document_repo_mock.return_empty = True - result = doc_service.search({"q": "empty"}) + result = doc_service.search({"q": "empty"}, USER_EMAIL) assert result is not None assert len(result) == 0 + assert app_user_repo_mock.get_org_id.call_count == 1 + assert app_user_repo_mock.is_superuser.call_count == 1 assert document_repo_mock.search.call_count == 1 diff --git a/tests/unit_tests/service/event/test_search_event_service.py b/tests/unit_tests/service/event/test_search_event_service.py index 6cf714c5..ebef74a3 100644 --- a/tests/unit_tests/service/event/test_search_event_service.py +++ b/tests/unit_tests/service/event/test_search_event_service.py @@ -3,33 +3,42 @@ import app.service.event as event_service from app.errors import RepositoryError +USER_EMAIL = "test@cpr.org" # --- SEARCH -def test_search(event_repo_mock): - result = event_service.search({"q": "two"}) +def test_search(event_repo_mock, app_user_repo_mock): + result = event_service.search({"q": "two"}, USER_EMAIL) assert result is not None assert len(result) == 1 + assert app_user_repo_mock.get_org_id.call_count == 1 + assert app_user_repo_mock.is_superuser.call_count == 1 assert event_repo_mock.search.call_count == 1 -def test_search_db_error(event_repo_mock): +def test_search_db_error(event_repo_mock, app_user_repo_mock): event_repo_mock.throw_repository_error = True with pytest.raises(RepositoryError): - event_service.search({"q": "error"}) + event_service.search({"q": "error"}, USER_EMAIL) + assert app_user_repo_mock.get_org_id.call_count == 1 + assert app_user_repo_mock.is_superuser.call_count == 1 assert event_repo_mock.search.call_count == 1 -def test_search_request_timeout(event_repo_mock): +def test_search_request_timeout(event_repo_mock, app_user_repo_mock): event_repo_mock.throw_timeout_error = True with pytest.raises(TimeoutError): - event_service.search({"q": "timeout"}) + event_service.search({"q": "timeout"}, USER_EMAIL) + assert app_user_repo_mock.get_org_id.call_count == 1 + assert app_user_repo_mock.is_superuser.call_count == 1 assert event_repo_mock.search.call_count == 1 -def test_search_missing(event_repo_mock): +def test_search_missing(event_repo_mock, app_user_repo_mock): event_repo_mock.return_empty = True - result = event_service.search({"q": "empty"}) + result = event_service.search({"q": "empty"}, USER_EMAIL) assert result is not None assert len(result) == 0 + assert app_user_repo_mock.get_org_id.call_count == 1 + assert app_user_repo_mock.is_superuser.call_count == 1 assert event_repo_mock.search.call_count == 1 diff --git a/tests/unit_tests/service/family/test_search_family_service.py b/tests/unit_tests/service/family/test_search_family_service.py index be644785..772d3e85 100644 --- a/tests/unit_tests/service/family/test_search_family_service.py +++ b/tests/unit_tests/service/family/test_search_family_service.py @@ -9,44 +9,54 @@ import app.service.family as family_service from app.errors import RepositoryError -USER_EMAIL = "cclw@cpr.org" +USER_EMAIL = "test@cpr.org" ORG_ID = 1 # --- SEARCH -def test_search(family_repo_mock): - result = family_service.search({"q": "two"}) +def test_search(family_repo_mock, app_user_repo_mock): + result = family_service.search({"q": "two"}, USER_EMAIL) assert result is not None assert len(result) == 2 + assert app_user_repo_mock.get_org_id.call_count == 1 + assert app_user_repo_mock.is_superuser.call_count == 1 assert family_repo_mock.search.call_count == 1 -def test_search_on_specific_field(family_repo_mock): - result = family_service.search({"title": "one"}) +def test_search_on_specific_field(family_repo_mock, app_user_repo_mock): + result = family_service.search({"title": "one"}, USER_EMAIL) assert result is not None assert len(result) == 1 + assert app_user_repo_mock.get_org_id.call_count == 1 + assert app_user_repo_mock.is_superuser.call_count == 1 assert family_repo_mock.search.call_count == 1 -def test_search_db_error(family_repo_mock): +def test_search_db_error(family_repo_mock, app_user_repo_mock): family_repo_mock.throw_repository_error = True with pytest.raises(RepositoryError): - family_service.search({"q": "error"}) + family_service.search({"q": "error"}, USER_EMAIL) + assert app_user_repo_mock.get_org_id.call_count == 1 + assert app_user_repo_mock.is_superuser.call_count == 1 assert family_repo_mock.search.call_count == 1 -def test_search_request_timeout(family_repo_mock): +def test_search_request_timeout(family_repo_mock, app_user_repo_mock): family_repo_mock.throw_timeout_error = True with pytest.raises(TimeoutError): - family_service.search({"q": "timeout"}) + family_service.search({"q": "timeout"}, USER_EMAIL) + assert app_user_repo_mock.get_org_id.call_count == 1 + assert app_user_repo_mock.is_superuser.call_count == 1 assert family_repo_mock.search.call_count == 1 -def test_search_missing(family_repo_mock): +def test_search_missing(family_repo_mock, app_user_repo_mock): family_repo_mock.return_empty = True - result = family_service.search({"q": "empty"}) + result = family_service.search({"q": "empty"}, USER_EMAIL) assert result is not None assert len(result) == 0 + assert app_user_repo_mock.get_org_id.call_count == 1 + assert app_user_repo_mock.is_superuser.call_count == 1 assert family_repo_mock.search.call_count == 1