diff --git a/app/api/api_v1/routers/ingest.py b/app/api/api_v1/routers/ingest.py index 6ca3d943..f6ae7106 100644 --- a/app/api/api_v1/routers/ingest.py +++ b/app/api/api_v1/routers/ingest.py @@ -1,115 +1,31 @@ import json import logging -from enum import Enum -from typing import Any, Optional -from db_client.models.dfce.taxonomy_entry import EntitySpecificTaxonomyKeys -from db_client.models.organisation.counters import CountedEntity -from fastapi import APIRouter, BackgroundTasks, HTTPException, UploadFile, status +from fastapi import ( + APIRouter, + BackgroundTasks, + HTTPException, + Request, + UploadFile, + status, +) -import app.service.taxonomy as taxonomy from app.errors import ValidationError from app.model.general import Json -from app.model.ingest import ( - IngestCollectionDTO, - IngestDocumentDTO, - IngestEventDTO, - IngestFamilyDTO, +from app.service.ingest import ( + get_collection_template, + get_document_template, + get_event_template, + get_family_template, + import_data, + validate_ingest_data, ) -from app.service.ingest import import_data ingest_router = r = APIRouter() _LOGGER = logging.getLogger(__name__) -def _get_collection_template() -> dict: - """ - Gets a collection template. - - :return dict: The collection template. - """ - collection_schema = IngestCollectionDTO.model_json_schema(mode="serialization") - collection_template = collection_schema["properties"] - - return collection_template - - -def _get_event_template(corpus_type: str) -> dict: - """ - Gets an event template. - - :return dict: The event template. - """ - event_schema = IngestEventDTO.model_json_schema(mode="serialization") - event_template = event_schema["properties"] - - event_meta = _get_metadata_template(corpus_type, CountedEntity.Event) - - # TODO: Replace with event_template["metadata"] in PDCT-1622 - if "event_type" not in event_meta: - raise ValidationError("Bad taxonomy in database") - event_template["event_type_value"] = event_meta["event_type"] - - return event_template - - -def _get_document_template(corpus_type: str) -> dict: - """ - Gets a document template for a given corpus type. - - :param str corpus_type: The corpus_type to use to get the document template. - :return dict: The document template. - """ - document_schema = IngestDocumentDTO.model_json_schema(mode="serialization") - document_template = document_schema["properties"] - document_template["metadata"] = _get_metadata_template( - corpus_type, CountedEntity.Document - ) - - return document_template - - -def _get_metadata_template(corpus_type: str, metadata_type: CountedEntity) -> dict: - """ - Gets a metadata template for a given corpus type and entity. - - :param str corpus_type: The corpus_type to use to get the metadata template. - :param str metadata_type: The metadata_type to use to get the metadata template. - :return dict: The metadata template. - """ - metadata = taxonomy.get(corpus_type) - if not metadata: - return {} - if metadata_type == CountedEntity.Document: - return metadata.pop(EntitySpecificTaxonomyKeys.DOCUMENT.value) - elif metadata_type == CountedEntity.Event: - return metadata.pop(EntitySpecificTaxonomyKeys.EVENT.value) - elif metadata_type == CountedEntity.Family: - metadata.pop(EntitySpecificTaxonomyKeys.DOCUMENT.value) - metadata.pop(EntitySpecificTaxonomyKeys.EVENT.value) - metadata.pop("event_type") # TODO: Remove as part of PDCT-1622 - return metadata - - -def _get_family_template(corpus_type: str) -> dict: - """ - Gets a family template for a given corpus type. - - :param str corpus_type: The corpus_type to use to get the family template. - :return dict: The family template. - """ - family_schema = IngestFamilyDTO.model_json_schema(mode="serialization") - family_template = family_schema["properties"] - - del family_template["corpus_import_id"] - - family_metadata = _get_metadata_template(corpus_type, CountedEntity.Family) - family_template["metadata"] = family_metadata - - return family_template - - @r.get( "/ingest/template/{corpus_type}", response_model=Json, @@ -127,132 +43,26 @@ async def get_ingest_template(corpus_type: str) -> Json: try: return { - "collections": [_get_collection_template()], - "families": [_get_family_template(corpus_type)], - "documents": [_get_document_template(corpus_type)], - "events": [_get_event_template(corpus_type)], + "collections": [get_collection_template()], + "families": [get_family_template(corpus_type)], + "documents": [get_document_template(corpus_type)], + "events": [get_event_template(corpus_type)], } except ValidationError as e: _LOGGER.error(e) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.message) -class IngestEntityList(str, Enum): - """Name of the list of entities that can be ingested.""" - - Collections = "collections" - Families = "families" - Documents = "documents" - Events = "events" - - -def _collect_import_ids( - entity_list_name: IngestEntityList, - data: dict[str, Any], - import_id_type_name: Optional[str] = None, -) -> list[str]: - """ - Extracts a list of import_ids (or family_import_ids if specified) for the specified entity list in data. - - :param IngestEntityList entity_list_name: The name of the entity list from which the import_ids are to be extracted. - :param dict[str, Any] data: The data structure containing the entity lists used for extraction. - :param Optional[str] import_id_type_name: the name of the type of import_id to be extracted or None. - :return list[str]: A list of extracted import_ids for the specified entity list. - """ - import_id_key = import_id_type_name or "import_id" - import_ids = [] - if entity_list_name.value in data: - for entity in data[entity_list_name.value]: - import_ids.append(entity[import_id_key]) - return import_ids - - -def _match_import_ids( - parent_references: list[str], parent_import_ids: set[str] -) -> None: - """ - Validates that all the references to parent entities exist in the set of parent import_ids passed in - - :param list[str] parent_references: List of import_ids referencing parent entities to be validated. - :param set[str] parent_import_ids: Set of parent import_ids to validate against. - :raises ValidationError: raised if a parent reference is not found in the parent_import_ids. - """ - for id in parent_references: - if id not in parent_import_ids: - raise ValidationError(f"No entity with id {id} found") - - -def _validate_collections_exist_for_families(data: dict[str, Any]) -> None: - """ - Validates that collections the families are linked to exist based on import_id links in data. - - :param dict[str, Any] data: The data object containing entities to be validated. - """ - collections = _collect_import_ids(IngestEntityList.Collections, data) - collections_set = set(collections) - - family_collection_import_ids = [] - if "families" in data: - for fam in data["families"]: - family_collection_import_ids.extend(fam["collections"]) - - _match_import_ids(family_collection_import_ids, collections_set) - - -def _validate_families_exist_for_events_and_documents(data: dict[str, Any]) -> None: - """ - Validates that families the documents and events are linked to exist - based on import_id links in data. - - :param dict[str, Any] data: The data object containing entities to be validated. - """ - families = _collect_import_ids(IngestEntityList.Families, data) - families_set = set(families) - - document_family_import_ids = _collect_import_ids( - IngestEntityList.Documents, data, "family_import_id" - ) - event_family_import_ids = _collect_import_ids( - IngestEntityList.Events, data, "family_import_id" - ) - - _match_import_ids(document_family_import_ids, families_set) - _match_import_ids(event_family_import_ids, families_set) - - -def validate_entity_relationships(data: dict[str, Any]) -> None: - """ - Validates relationships between entities contained in data. - For documents, it validates that the family the document is linked to exists. - - :param dict[str, Any] data: The data object containing entities to be validated. - """ - - _validate_collections_exist_for_families(data) - _validate_families_exist_for_events_and_documents(data) - - -def _validate_ingest_data(data: dict[str, Any]) -> None: - """ - Validates data to be ingested. - - :param dict[str, Any] data: The data object to be validated. - :raises HTTPException: raised if data is empty or None. - """ - - if not data: - raise HTTPException(status_code=status.HTTP_204_NO_CONTENT) - - validate_entity_relationships(data) - - @r.post( "/ingest/{corpus_import_id}", response_model=Json, status_code=status.HTTP_202_ACCEPTED, ) async def ingest( - new_data: UploadFile, corpus_import_id: str, background_tasks: BackgroundTasks + request: Request, + new_data: UploadFile, + corpus_import_id: str, + background_tasks: BackgroundTasks, ) -> Json: """ Bulk import endpoint. @@ -260,12 +70,14 @@ async def ingest( :param UploadFile new_data: file containing json representation of data to ingest. :return Json: json representation of the data to ingest. """ - _LOGGER.info(f"Received bulk import request for corpus: {corpus_import_id}") + _LOGGER.info( + f"User {request.state.user} triggered bulk import for corpus: {corpus_import_id}" + ) try: content = await new_data.read() data_dict = json.loads(content) - _validate_ingest_data(data_dict) + validate_ingest_data(data_dict) background_tasks.add_task(import_data, data_dict, corpus_import_id) diff --git a/app/model/authorisation.py b/app/model/authorisation.py index 01fcbe23..7c2d14f0 100644 --- a/app/model/authorisation.py +++ b/app/model/authorisation.py @@ -63,8 +63,8 @@ class AuthEndpoint(str, enum.Enum): }, # Ingest AuthEndpoint.INGEST: { - AuthOperation.CREATE: AuthAccess.USER, - AuthOperation.READ: AuthAccess.USER, + AuthOperation.CREATE: AuthAccess.SUPER, + AuthOperation.READ: AuthAccess.SUPER, }, # Corpus AuthEndpoint.CORPUS: { diff --git a/app/service/authorisation.py b/app/service/authorisation.py index 2838ec39..f8eef0b4 100644 --- a/app/service/authorisation.py +++ b/app/service/authorisation.py @@ -70,4 +70,13 @@ def is_authorised(user: UserContext, entity: AuthEndpoint, op: AuthOperation) -> if _has_access(required_access, _get_user_access(user)): return - raise AuthorisationError(f"User {user.email} is not authorised to {op} a {entity}") + raise AuthorisationError( + f"User {user.email} is not authorised to {op} {_get_article(entity.value)} {entity}" + ) + + +def _get_article(word: str) -> str: + vowels = ["a", "e", "i", "o", "u", "y"] + if word.lower()[0] in vowels: + return "an" + return "a" diff --git a/app/service/ingest.py b/app/service/ingest.py index 81bb4ff7..ff3ce4bc 100644 --- a/app/service/ingest.py +++ b/app/service/ingest.py @@ -11,6 +11,9 @@ from db_client.models.dfce.collection import Collection from db_client.models.dfce.family import Family, FamilyDocument, FamilyEvent +from db_client.models.dfce.taxonomy_entry import EntitySpecificTaxonomyKeys +from db_client.models.organisation.counters import CountedEntity +from fastapi import HTTPException, status from pydantic import ConfigDict, validate_call from sqlalchemy.ext.declarative import DeclarativeMeta from sqlalchemy.orm import Session @@ -23,7 +26,9 @@ import app.service.corpus as corpus import app.service.geography as geography import app.service.notification as notification_service +import app.service.taxonomy as taxonomy import app.service.validation as validation +from app.errors import ValidationError from app.model.ingest import ( IngestCollectionDTO, IngestDocumentDTO, @@ -70,6 +75,93 @@ def _exists_in_db(entity: Type[T], import_id: str, db: Session) -> bool: return entity_exists is not None +def get_collection_template() -> dict: + """ + Gets a collection template. + + :return dict: The collection template. + """ + collection_schema = IngestCollectionDTO.model_json_schema(mode="serialization") + collection_template = collection_schema["properties"] + + return collection_template + + +def get_event_template(corpus_type: str) -> dict: + """ + Gets an event template. + + :return dict: The event template. + """ + event_schema = IngestEventDTO.model_json_schema(mode="serialization") + event_template = event_schema["properties"] + + event_meta = get_metadata_template(corpus_type, CountedEntity.Event) + + # TODO: Replace with event_template["metadata"] in PDCT-1622 + if "event_type" not in event_meta: + raise ValidationError("Bad taxonomy in database") + event_template["event_type_value"] = event_meta["event_type"] + + return event_template + + +def get_document_template(corpus_type: str) -> dict: + """ + Gets a document template for a given corpus type. + + :param str corpus_type: The corpus_type to use to get the document template. + :return dict: The document template. + """ + document_schema = IngestDocumentDTO.model_json_schema(mode="serialization") + document_template = document_schema["properties"] + document_template["metadata"] = get_metadata_template( + corpus_type, CountedEntity.Document + ) + + return document_template + + +def get_metadata_template(corpus_type: str, metadata_type: CountedEntity) -> dict: + """ + Gets a metadata template for a given corpus type and entity. + + :param str corpus_type: The corpus_type to use to get the metadata template. + :param str metadata_type: The metadata_type to use to get the metadata template. + :return dict: The metadata template. + """ + metadata = taxonomy.get(corpus_type) + if not metadata: + return {} + if metadata_type == CountedEntity.Document: + return metadata.pop(EntitySpecificTaxonomyKeys.DOCUMENT.value) + elif metadata_type == CountedEntity.Event: + return metadata.pop(EntitySpecificTaxonomyKeys.EVENT.value) + elif metadata_type == CountedEntity.Family: + metadata.pop(EntitySpecificTaxonomyKeys.DOCUMENT.value) + metadata.pop(EntitySpecificTaxonomyKeys.EVENT.value) + metadata.pop("event_type") # TODO: Remove as part of PDCT-1622 + return metadata + + +def get_family_template(corpus_type: str) -> dict: + """ + Gets a family template for a given corpus type. + + :param str corpus_type: The corpus_type to use to get the family template. + :return dict: The family template. + """ + family_schema = IngestFamilyDTO.model_json_schema(mode="serialization") + family_template = family_schema["properties"] + + del family_template["corpus_import_id"] + + family_metadata = get_metadata_template(corpus_type, CountedEntity.Family) + family_template["metadata"] = family_metadata + + return family_template + + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def save_collections( collection_data: list[dict[str, Any]], @@ -284,3 +376,103 @@ def import_data(data: dict[str, Any], corpus_import_id: str) -> None: end_message = f"💥 Bulk import for corpus: {corpus_import_id} has failed." finally: notification_service.send_notification(end_message) + + +def _collect_import_ids( + entity_list_name: IngestEntityList, + data: dict[str, Any], + import_id_type_name: Optional[str] = None, +) -> list[str]: + """ + Extracts a list of import_ids (or family_import_ids if specified) for the specified entity list in data. + + :param IngestEntityList entity_list_name: The name of the entity list from which the import_ids are to be extracted. + :param dict[str, Any] data: The data structure containing the entity lists used for extraction. + :param Optional[str] import_id_type_name: the name of the type of import_id to be extracted or None. + :return list[str]: A list of extracted import_ids for the specified entity list. + """ + import_id_key = import_id_type_name or "import_id" + import_ids = [] + if entity_list_name.value in data: + for entity in data[entity_list_name.value]: + import_ids.append(entity[import_id_key]) + return import_ids + + +def _match_import_ids( + parent_references: list[str], parent_import_ids: set[str] +) -> None: + """ + Validates that all the references to parent entities exist in the set of parent import_ids passed in + + :param list[str] parent_references: List of import_ids referencing parent entities to be validated. + :param set[str] parent_import_ids: Set of parent import_ids to validate against. + :raises ValidationError: raised if a parent reference is not found in the parent_import_ids. + """ + for id in parent_references: + if id not in parent_import_ids: + raise ValidationError(f"No entity with id {id} found") + + +def _validate_collections_exist_for_families(data: dict[str, Any]) -> None: + """ + Validates that collections the families are linked to exist based on import_id links in data. + + :param dict[str, Any] data: The data object containing entities to be validated. + """ + collections = _collect_import_ids(IngestEntityList.Collections, data) + collections_set = set(collections) + + family_collection_import_ids = [] + if "families" in data: + for fam in data["families"]: + family_collection_import_ids.extend(fam["collections"]) + + _match_import_ids(family_collection_import_ids, collections_set) + + +def _validate_families_exist_for_events_and_documents(data: dict[str, Any]) -> None: + """ + Validates that families the documents and events are linked to exist + based on import_id links in data. + + :param dict[str, Any] data: The data object containing entities to be validated. + """ + families = _collect_import_ids(IngestEntityList.Families, data) + families_set = set(families) + + document_family_import_ids = _collect_import_ids( + IngestEntityList.Documents, data, "family_import_id" + ) + event_family_import_ids = _collect_import_ids( + IngestEntityList.Events, data, "family_import_id" + ) + + _match_import_ids(document_family_import_ids, families_set) + _match_import_ids(event_family_import_ids, families_set) + + +def validate_entity_relationships(data: dict[str, Any]) -> None: + """ + Validates relationships between entities contained in data. + For documents, it validates that the family the document is linked to exists. + + :param dict[str, Any] data: The data object containing entities to be validated. + """ + + _validate_collections_exist_for_families(data) + _validate_families_exist_for_events_and_documents(data) + + +def validate_ingest_data(data: dict[str, Any]) -> None: + """ + Validates data to be ingested. + + :param dict[str, Any] data: The data object to be validated. + :raises HTTPException: raised if data is empty or None. + """ + + if not data: + raise HTTPException(status_code=status.HTTP_204_NO_CONTENT) + + validate_entity_relationships(data) diff --git a/pyproject.toml b/pyproject.toml index 07935c2b..618e1324 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "admin_backend" -version = "2.17.8" +version = "2.17.9" description = "" authors = ["CPR-dev-team "] packages = [{ include = "app" }, { include = "tests" }] diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 6f574ffa..11fbb642 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -226,7 +226,7 @@ def user_header_token() -> Dict[str, str]: @pytest.fixture def admin_user_header_token() -> Dict[str, str]: a_token = token_service.encode( - "cclw@cpr.org", CCLW_ORG_ID, False, {"is_admin": False} + "admin@cpr.org", CCLW_ORG_ID, False, {"is_admin": True} ) headers = {"Authorization": f"Bearer {a_token}"} return headers @@ -242,7 +242,7 @@ def non_cclw_user_header_token() -> Dict[str, str]: @pytest.fixture -def non_admin_user_header_token() -> Dict[str, str]: +def invalid_user_header_token() -> Dict[str, str]: a_token = token_service.encode("non-admin@cpr.org", CCLW_ORG_ID, False, {}) headers = {"Authorization": f"Bearer {a_token}"} return headers diff --git a/tests/integration_tests/ingest/test_ingest.py b/tests/integration_tests/ingest/test_ingest.py index daeaa98a..738db92a 100644 --- a/tests/integration_tests/ingest/test_ingest.py +++ b/tests/integration_tests/ingest/test_ingest.py @@ -17,7 +17,7 @@ @patch.dict(os.environ, {"BULK_IMPORT_BUCKET": "test_bucket"}) def test_ingest_when_ok( - data_db: Session, client: TestClient, admin_user_header_token, basic_s3_client + data_db: Session, client: TestClient, superuser_header_token, basic_s3_client ): response = client.post( "/api/v1/ingest/UNFCCC.corpus.i00000001.n0000", @@ -29,7 +29,7 @@ def test_ingest_when_ok( "rb", ) }, - headers=admin_user_header_token, + headers=superuser_header_token, ) expected_collection_import_ids = ["test.new.collection.0", "test.new.collection.1"] @@ -90,7 +90,7 @@ def test_import_data_rollback( caplog, data_db: Session, client: TestClient, - admin_user_header_token, + superuser_header_token, rollback_collection_repo, basic_s3_client, ): @@ -107,7 +107,7 @@ def test_import_data_rollback( "rb", ) }, - headers=admin_user_header_token, + headers=superuser_header_token, ) assert response.status_code == status.HTTP_202_ACCEPTED @@ -129,7 +129,7 @@ def test_ingest_idempotency( caplog, data_db: Session, client: TestClient, - admin_user_header_token, + superuser_header_token, basic_s3_client, ): family_import_id = "test.new.family.0" @@ -182,7 +182,7 @@ def test_ingest_idempotency( first_response = client.post( "/api/v1/ingest/UNFCCC.corpus.i00000001.n0000", files={"new_data": test_data_file}, - headers=admin_user_header_token, + headers=superuser_header_token, ) assert first_response.status_code == status.HTTP_202_ACCEPTED @@ -215,7 +215,7 @@ def test_ingest_idempotency( second_response = client.post( "/api/v1/ingest/UNFCCC.corpus.i00000001.n0000", files={"new_data": test_json}, - headers=admin_user_header_token, + headers=superuser_header_token, ) assert second_response.status_code == status.HTTP_202_ACCEPTED @@ -246,7 +246,7 @@ def test_generates_unique_slugs_for_documents_with_identical_titles( caplog, data_db: Session, client: TestClient, - admin_user_header_token, + superuser_header_token, basic_s3_client, ): """ @@ -289,7 +289,7 @@ def test_generates_unique_slugs_for_documents_with_identical_titles( first_response = client.post( "/api/v1/ingest/UNFCCC.corpus.i00000001.n0000", files={"new_data": test_data_file}, - headers=admin_user_header_token, + headers=superuser_header_token, ) assert first_response.status_code == status.HTTP_202_ACCEPTED @@ -311,7 +311,7 @@ def test_ingest_when_corpus_import_id_invalid( caplog, data_db: Session, client: TestClient, - admin_user_header_token, + superuser_header_token, basic_s3_client, ): invalid_corpus = "test" @@ -327,7 +327,7 @@ def test_ingest_when_corpus_import_id_invalid( "rb", ) }, - headers=admin_user_header_token, + headers=superuser_header_token, ) assert response.status_code == status.HTTP_202_ACCEPTED @@ -343,7 +343,7 @@ def test_ingest_events_when_event_type_invalid( caplog, data_db: Session, client: TestClient, - admin_user_header_token, + superuser_header_token, basic_s3_client, ): with caplog.at_level(logging.ERROR): @@ -360,7 +360,7 @@ def test_ingest_events_when_event_type_invalid( "rb", ) }, - headers=admin_user_header_token, + headers=superuser_header_token, ) assert response.status_code == status.HTTP_202_ACCEPTED @@ -372,3 +372,34 @@ def test_ingest_events_when_event_type_invalid( "Metadata validation failed: Invalid value '['Invalid']' for metadata key 'event_type'" in caplog.text ) + + +def test_ingest_when_not_authorised(client: TestClient, data_db: Session): + response = client.post( + "/api/v1/ingest/UNFCCC.corpus.i00000001.n0000", + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + +def test_ingest_admin_non_super( + client: TestClient, data_db: Session, admin_user_header_token +): + response = client.post( + "/api/v1/ingest/UNFCCC.corpus.i00000001.n0000", + headers=admin_user_header_token, + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + data = response.json() + assert data["detail"] == "User admin@cpr.org is not authorised to CREATE an INGEST" + + +def test_ingest_non_super_non_admin( + client: TestClient, data_db: Session, user_header_token +): + response = client.post( + "/api/v1/ingest/UNFCCC.corpus.i00000001.n0000", + headers=user_header_token, + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + data = response.json() + assert data["detail"] == "User cclw@cpr.org is not authorised to CREATE an INGEST" diff --git a/tests/integration_tests/ingest/test_ingest_template.py b/tests/integration_tests/ingest/test_ingest_template.py index 3fafcf66..8fc6f0b4 100644 --- a/tests/integration_tests/ingest/test_ingest_template.py +++ b/tests/integration_tests/ingest/test_ingest_template.py @@ -3,11 +3,11 @@ from sqlalchemy.orm import Session -def test_get_template_unfcc( - data_db: Session, client: TestClient, non_cclw_user_header_token +def test_get_template_unfccc( + data_db: Session, client: TestClient, superuser_header_token ): response = client.get( - "/api/v1/ingest/template/Intl. agreements", headers=non_cclw_user_header_token + "/api/v1/ingest/template/Intl. agreements", headers=superuser_header_token ) assert response.status_code == status.HTTP_200_OK @@ -218,3 +218,34 @@ def test_get_template_unfcc( } ], } + + +def test_get_template_when_not_authorised(client: TestClient, data_db: Session): + response = client.get( + "/api/v1/ingest/template/Intl. agreements", + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + +def test_get_template_admin_non_super( + data_db: Session, client: TestClient, admin_user_header_token +): + response = client.get( + "/api/v1/ingest/template/Intl. agreements", + headers=admin_user_header_token, + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + data = response.json() + assert data["detail"] == "User admin@cpr.org is not authorised to READ an INGEST" + + +def test_get_template_non_admin_non_super( + data_db: Session, client: TestClient, user_header_token +): + response = client.get( + "/api/v1/ingest/template/Intl. agreements", + headers=user_header_token, + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + data = response.json() + assert data["detail"] == "User cclw@cpr.org is not authorised to READ an INGEST" diff --git a/tests/integration_tests/setup_db.py b/tests/integration_tests/setup_db.py index 4ab31561..9c895951 100644 --- a/tests/integration_tests/setup_db.py +++ b/tests/integration_tests/setup_db.py @@ -398,6 +398,22 @@ def _setup_organisation(test_db: Session) -> tuple[int, int]: cclw.id, hashed_pass="$2b$12$XXMr7xoEY2fzNiMR3hq.PeJBUUchJyiTfJP.Rt2eq9hsPzt9SXzFC", ) + _add_app_user( + test_db, + "non-admin-super@cpr.org", + "Super", + cclw.id, + hashed_pass="$2b$12$XXMr7xoEY2fzNiMR3hq.PeJBUUchJyiTfJP.Rt2eq9hsPzt9SXzFC", + is_super=True, + ) + _add_app_user( + test_db, + "admin@cpr.org", + "Admin", + cclw.id, + hashed_pass="$2b$12$XXMr7xoEY2fzNiMR3hq.PeJBUUchJyiTfJP.Rt2eq9hsPzt9SXzFC", + is_admin=True, + ) _add_app_user( test_db, "super@cpr.org", diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index 41a81785..16ccd375 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -252,20 +252,27 @@ def non_admin_superuser_header_token() -> Dict[str, str]: @pytest.fixture def user_header_token() -> Dict[str, str]: - a_token = token_service.encode("cclw@cpr.org", ORG_ID, False, {"is_admin": True}) + a_token = token_service.encode("cclw@cpr.org", ORG_ID, False, {"is_admin": False}) + headers = {"Authorization": f"Bearer {a_token}"} + return headers + + +@pytest.fixture +def admin_user_header_token() -> Dict[str, str]: + a_token = token_service.encode("admin@cpr.org", ORG_ID, False, {"is_admin": True}) headers = {"Authorization": f"Bearer {a_token}"} return headers @pytest.fixture def non_cclw_user_header_token() -> Dict[str, str]: - a_token = token_service.encode("unfccc@cpr.org", ORG_ID, False, {"is_admin": True}) + a_token = token_service.encode("unfccc@cpr.org", ORG_ID, False, {"is_admin": False}) headers = {"Authorization": f"Bearer {a_token}"} return headers @pytest.fixture -def non_admin_user_header_token() -> Dict[str, str]: +def invalid_user_header_token() -> Dict[str, str]: a_token = token_service.encode("non-admin@cpr.org", ORG_ID, False, {}) headers = {"Authorization": f"Bearer {a_token}"} return headers diff --git a/tests/unit_tests/routers/ingest/test_bulk_ingest.py b/tests/unit_tests/routers/ingest/test_bulk_ingest.py index fc294c58..e36d04f9 100644 --- a/tests/unit_tests/routers/ingest/test_bulk_ingest.py +++ b/tests/unit_tests/routers/ingest/test_bulk_ingest.py @@ -13,8 +13,8 @@ from fastapi import status from fastapi.testclient import TestClient -from app.api.api_v1.routers.ingest import validate_entity_relationships from app.errors import ValidationError +from app.service.ingest import validate_entity_relationships def test_ingest_when_not_authenticated(client: TestClient): @@ -24,7 +24,17 @@ def test_ingest_when_not_authenticated(client: TestClient): assert response.status_code == status.HTTP_401_UNAUTHORIZED -def test_ingest_data_when_ok(client: TestClient, user_header_token): +def test_ingest_when_non_admin_non_super(client: TestClient, user_header_token): + response = client.post("/api/v1/ingest/test", headers=user_header_token) + assert response.status_code == status.HTTP_403_FORBIDDEN + + +def test_ingest_when_admin_non_super(client: TestClient, admin_user_header_token): + response = client.post("/api/v1/ingest/test", headers=admin_user_header_token) + assert response.status_code == status.HTTP_403_FORBIDDEN + + +def test_ingest_data_when_ok(client: TestClient, superuser_header_token): corpus_import_id = "test" with patch("fastapi.BackgroundTasks.add_task") as background_task_mock: @@ -42,7 +52,7 @@ def test_ingest_data_when_ok(client: TestClient, user_header_token): "rb", ) }, - headers=user_header_token, + headers=superuser_header_token, ) background_task_mock.assert_called_once() @@ -55,7 +65,7 @@ def test_ingest_data_when_ok(client: TestClient, user_header_token): def test_ingest_when_no_data( client: TestClient, - user_header_token, + superuser_header_token, collection_repo_mock, corpus_service_mock, basic_s3_client, @@ -65,7 +75,7 @@ def test_ingest_when_no_data( response = client.post( "/api/v1/ingest/test", files={"new_data": test_data_file}, - headers=user_header_token, + headers=superuser_header_token, ) assert collection_repo_mock.create.call_count == 0 @@ -73,7 +83,7 @@ def test_ingest_when_no_data( assert response.status_code == status.HTTP_204_NO_CONTENT -def test_ingest_documents_when_no_family(client: TestClient, user_header_token): +def test_ingest_documents_when_no_family(client: TestClient, superuser_header_token): fam_import_id = "test.new.family.0" test_data = json.dumps( { @@ -87,7 +97,7 @@ def test_ingest_documents_when_no_family(client: TestClient, user_header_token): response = client.post( "/api/v1/ingest/test", files={"new_data": test_data_file}, - headers=user_header_token, + headers=superuser_header_token, ) assert response.status_code == status.HTTP_400_BAD_REQUEST diff --git a/tests/unit_tests/routers/ingest/test_get_ingest_template.py b/tests/unit_tests/routers/ingest/test_get_ingest_template.py index 5bb13774..ea3b7394 100644 --- a/tests/unit_tests/routers/ingest/test_get_ingest_template.py +++ b/tests/unit_tests/routers/ingest/test_get_ingest_template.py @@ -15,12 +15,30 @@ def test_ingest_template_when_not_authenticated(client: TestClient): assert response.status_code == status.HTTP_401_UNAUTHORIZED +def test_ingest_template_when_non_admin_non_super( + client: TestClient, user_header_token +): + response = client.get( + "/api/v1/ingest/template/test_corpus_type", headers=user_header_token + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + + +def test_ingest_template_when_admin_non_super( + client: TestClient, admin_user_header_token +): + response = client.get( + "/api/v1/ingest/template/test_corpus_type", headers=admin_user_header_token + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + + def test_ingest_template_when_ok( - client: TestClient, user_header_token, db_client_corpus_helpers_mock + client: TestClient, superuser_header_token, db_client_corpus_helpers_mock ): response = client.get( "/api/v1/ingest/template/test_corpus_type", - headers=user_header_token, + headers=superuser_header_token, ) assert response.status_code == status.HTTP_200_OK