diff --git a/app/api/api_v1/routers/corpus.py b/app/api/api_v1/routers/corpus.py index f3a2e640..fb9d97d1 100644 --- a/app/api/api_v1/routers/corpus.py +++ b/app/api/api_v1/routers/corpus.py @@ -8,7 +8,7 @@ validate_query_params, ) from app.errors import AuthorisationError, RepositoryError, ValidationError -from app.model.corpus import CorpusReadDTO, CorpusWriteDTO +from app.model.corpus import CorpusCreateDTO, CorpusReadDTO, CorpusWriteDTO from app.service import corpus as corpus_service corpora_router = r = APIRouter() @@ -139,3 +139,27 @@ async def update_corpus( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=detail) return corpus + + +@r.post("/corpora", response_model=str, status_code=status.HTTP_201_CREATED) +async def create_corpus(request: Request, new_corpus: CorpusCreateDTO) -> str: + """Create a specific corpus given the import id. + + :param Request request: Request object. + :param CorpusCreateDTO new_corpus: New corpus data object. + :raises HTTPException: If there is an error raised during validation + or during adding the corpus to the db. + :return str: returns the import id of the newly created corpus. + """ + try: + corpus_id = corpus_service.create(new_corpus, request.state.user) + except AuthorisationError as e: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=e.message) + except ValidationError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.message) + except RepositoryError as e: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=e.message + ) + + return corpus_id diff --git a/app/model/authorisation.py b/app/model/authorisation.py index 7c2d14f0..0db7a1d4 100644 --- a/app/model/authorisation.py +++ b/app/model/authorisation.py @@ -68,8 +68,8 @@ class AuthEndpoint(str, enum.Enum): }, # Corpus AuthEndpoint.CORPUS: { - AuthOperation.CREATE: AuthAccess.ADMIN, - AuthOperation.READ: AuthAccess.ADMIN, - AuthOperation.UPDATE: AuthAccess.ADMIN, + AuthOperation.CREATE: AuthAccess.SUPER, + AuthOperation.READ: AuthAccess.SUPER, + AuthOperation.UPDATE: AuthAccess.SUPER, }, } diff --git a/app/model/corpus.py b/app/model/corpus.py index 0632c586..6c02a3dd 100644 --- a/app/model/corpus.py +++ b/app/model/corpus.py @@ -31,5 +31,15 @@ class CorpusWriteDTO(BaseModel): corpus_text: Optional[str] corpus_image_url: Optional[str] - corpus_type_name: str corpus_type_description: str + + +class CorpusCreateDTO(BaseModel): + """Representation of a Corpus.""" + + title: str + description: str + corpus_text: Optional[str] + corpus_image_url: Optional[str] + organisation_id: int + corpus_type_name: str diff --git a/app/repository/corpus.py b/app/repository/corpus.py index a259a470..20e2d616 100644 --- a/app/repository/corpus.py +++ b/app/repository/corpus.py @@ -2,6 +2,7 @@ from typing import Optional, Union, cast from db_client.models.organisation import Corpus, CorpusType, Organisation +from db_client.models.organisation.counters import CountedEntity from sqlalchemy import and_, asc, or_ from sqlalchemy import update as db_update from sqlalchemy.exc import NoResultFound, OperationalError @@ -9,7 +10,8 @@ from sqlalchemy_utils import escape_like from app.errors import RepositoryError -from app.model.corpus import CorpusReadDTO, CorpusWriteDTO +from app.model.corpus import CorpusCreateDTO, CorpusReadDTO, CorpusWriteDTO +from app.repository.helpers import generate_import_id _LOGGER = logging.getLogger(__name__) @@ -64,6 +66,19 @@ def get_corpus_org_id(db: Session, corpus_id: str) -> Optional[int]: return db.query(Corpus.organisation_id).filter_by(import_id=corpus_id).scalar() +def is_corpus_type_name_valid(db: Session, corpus_type_name: str) -> bool: + """Check whether a corpus type name exists in the DB. + + :param Session db: The DB session to connect to. + :param str corpus_type_name: The corpus type name we want to search + for. + :return bool: Whether the given corpus type exists in the db. + """ + return bool( + db.query(CorpusType.name).filter_by(name=corpus_type_name).scalar() is not None + ) + + def verify_corpus_exists(db: Session, corpus_id: str) -> bool: """Validate whether a corpus with the given ID exists in the DB. @@ -72,6 +87,7 @@ def verify_corpus_exists(db: Session, corpus_id: str) -> bool: :return bool: Return whether or not the corpus exists in the DB. """ corpora = [corpus[0] for corpus in db.query(Corpus.import_id).distinct().all()] + print(corpora) return bool(corpus_id in corpora) @@ -189,7 +205,6 @@ def update(db: Session, import_id: str, corpus: CorpusWriteDTO) -> bool: return False # Check what has changed. - ct_name_has_changed = original_corpus_type.name != new_values["corpus_type_name"] ct_description_has_changed = ( original_corpus_type.name != new_values["corpus_type_description"] ) @@ -202,7 +217,6 @@ def update(db: Session, import_id: str, corpus: CorpusWriteDTO) -> bool: if not any( [ - ct_name_has_changed, ct_description_has_changed, title_has_changed, description_has_changed, @@ -216,12 +230,11 @@ def update(db: Session, import_id: str, corpus: CorpusWriteDTO) -> bool: commands = [] # Update logic to only perform update if not idempotent. - if ct_name_has_changed or ct_description_has_changed: + if ct_description_has_changed: commands.append( db_update(CorpusType) .where(CorpusType.name == original_corpus_type.name) .values( - name=new_values["corpus_type_name"], description=new_values["corpus_type_description"], ) ) @@ -258,3 +271,30 @@ def update(db: Session, import_id: str, corpus: CorpusWriteDTO) -> bool: raise RepositoryError(msg) return True + + +def create(db: Session, corpus: CorpusCreateDTO) -> str: + """Create a new corpus. + + :param db Session: the database connection + :param CorpusCreateDTO corpus: the values for the new corpus + :return str: The ID of the created corpus. + """ + try: + import_id = generate_import_id(db, CountedEntity.Corpus, corpus.organisation_id) + new_corpus = Corpus( + import_id=import_id, + title=corpus.title, + description=corpus.description or "TBD", + corpus_text=corpus.corpus_text, + corpus_image_url=corpus.corpus_image_url, + organisation_id=corpus.organisation_id, + corpus_type_name=corpus.corpus_type_name, + ) + db.add(new_corpus) + db.flush() + except Exception as e: + _LOGGER.exception("Error trying to create Corpus") + raise RepositoryError(e) + + return cast(str, new_corpus.import_id) diff --git a/app/service/corpus.py b/app/service/corpus.py index db001456..52ec3add 100644 --- a/app/service/corpus.py +++ b/app/service/corpus.py @@ -7,8 +7,9 @@ import app.clients.db.session as db_session import app.repository.corpus as corpus_repo +import app.repository.organisation as org_repo from app.errors import RepositoryError, ValidationError -from app.model.corpus import CorpusReadDTO, CorpusWriteDTO +from app.model.corpus import CorpusCreateDTO, CorpusReadDTO, CorpusWriteDTO from app.model.user import UserContext from app.service import app_user, id @@ -149,9 +150,6 @@ def update( if original_corpus is None: return None - entity_org_id: int = get_corpus_org_id(import_id, db) - app_user.raise_if_unauthorised_to_make_changes(user, entity_org_id, import_id) - try: if corpus_repo.update(db, import_id, corpus): db.commit() @@ -161,3 +159,41 @@ def update( db.rollback() raise e return get(import_id) + + +@db_session.with_database() +@validate_call(config=ConfigDict(arbitrary_types_allowed=True)) +def create( + corpus: CorpusCreateDTO, + user: UserContext, + db: Optional[Session] = None, +) -> str: + """Create a new Corpus from the values passed. + + :param CorpusCreateDTO corpus: The values for the new Family. + :param UserContext user: The current user context. + :raises RepositoryError: raised on a database error + :raises ValidationError: raised should the import_id be invalid. + :return str: The new created Corpus or None if unsuccessful. + """ + if db is None: + db = db_session.get_db() + + # Check the corpus type name exists in the database already. + if not corpus_repo.is_corpus_type_name_valid(db, corpus.corpus_type_name): + raise ValidationError("Invalid corpus type name") + + # Check that the organisation ID exists in the database. + if org_repo.get_name_from_id(db, corpus.organisation_id) is None: + raise ValidationError("Invalid organisation") + + try: + import_id = corpus_repo.create(db, corpus) + if len(import_id) == 0: + db.rollback() + return import_id + except Exception as e: + db.rollback() + raise e + finally: + db.commit() diff --git a/poetry.lock b/poetry.lock index ea1f104f..9572437f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "alembic" @@ -170,17 +170,17 @@ uvloop = ["uvloop (>=0.15.2)"] [[package]] name = "boto3" -version = "1.35.59" +version = "1.35.63" description = "The AWS SDK for Python" optional = false python-versions = ">=3.8" files = [ - {file = "boto3-1.35.59-py3-none-any.whl", hash = "sha256:8f8ff97cb9cb2e1ec7374209d0c09c1926b75604d6464c34bafaffd6d6cf0529"}, - {file = "boto3-1.35.59.tar.gz", hash = "sha256:81f4d8d6eff3e26b82cabd42eda816cfac9482821fdef353f18d2ba2f6e75f2d"}, + {file = "boto3-1.35.63-py3-none-any.whl", hash = "sha256:d0f938d4f6f392b6ffc5e75fff14a42e5bbb5228675a0367c8af55398abadbec"}, + {file = "boto3-1.35.63.tar.gz", hash = "sha256:deb593d9a0fb240deb4c43e4da8e6626d7c36be7b2fd2fe28f49d44d395b7de0"}, ] [package.dependencies] -botocore = ">=1.35.59,<1.36.0" +botocore = ">=1.35.63,<1.36.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -189,13 +189,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.35.59" +version = "1.35.63" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.35.59-py3-none-any.whl", hash = "sha256:bcd66d7f55c8d1b6020eb86f2d87893fe591fb4be6a7d2a689c18be586452334"}, - {file = "botocore-1.35.59.tar.gz", hash = "sha256:de0ce655fedfc02c87869dfaa3b622488a17ff37da316ef8106cbe1573b83c98"}, + {file = "botocore-1.35.63-py3-none-any.whl", hash = "sha256:0ca1200694a4c0a3fa846795d8e8a08404c214e21195eb9e010c4b8a4ca78a4a"}, + {file = "botocore-1.35.63.tar.gz", hash = "sha256:2b8196bab0a997d206c3d490b52e779ef47dffb68c57c685443f77293aca1589"}, ] [package.dependencies] @@ -569,7 +569,7 @@ test-randomorder = ["pytest-randomly"] [[package]] name = "db-client" -version = "3.8.25" +version = "3.8.26" description = "All things to do with the datamodel and its storage. Including alembic migrations and datamodel code." optional = false python-versions = "^3.9" @@ -589,8 +589,8 @@ SQLAlchemy-Utils = "^0.38.2" [package.source] type = "git" url = "https://github.com/climatepolicyradar/navigator-db-client.git" -reference = "v3.8.25" -resolved_reference = "654ce6ab5e76e6cfeb71d0a57376bc084d14c694" +reference = "v3.8.26" +resolved_reference = "500b4f81bb3eeee4111a80e2e33b1e6068eb28fe" [[package]] name = "dill" @@ -754,13 +754,13 @@ fastapi = ">=0.63.0" [[package]] name = "fastapi-pagination" -version = "0.12.31" +version = "0.12.32" description = "FastAPI pagination" optional = false python-versions = "<4.0,>=3.8" files = [ - {file = "fastapi_pagination-0.12.31-py3-none-any.whl", hash = "sha256:a985d1f1baca1c42e7bfa51e9c7d6433e30bdca2ae236c3f3fc01bd9e62dbda6"}, - {file = "fastapi_pagination-0.12.31.tar.gz", hash = "sha256:224d6dc2671f95f1d5e467e42898809438570cd662e1008c7b6b91889211d780"}, + {file = "fastapi_pagination-0.12.32-py3-none-any.whl", hash = "sha256:38e7e72abf252cbebbc1beff9081e4929762756c04959c471b2a5866bb7f0aaf"}, + {file = "fastapi_pagination-0.12.32.tar.gz", hash = "sha256:b808b5b8af493c51d96ae0091b60532b25688cbca1350f39cb72f10d4d69a6ab"}, ] [package.dependencies] @@ -1901,13 +1901,13 @@ typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" [[package]] name = "pyjwt" -version = "2.9.0" +version = "2.10.0" description = "JSON Web Token implementation in Python" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "PyJWT-2.9.0-py3-none-any.whl", hash = "sha256:3b02fb0f44517787776cf48f2ae25d8e14f300e6d7545a4315cee571a415e850"}, - {file = "pyjwt-2.9.0.tar.gz", hash = "sha256:7e1e5b56cc735432a7369cbfa0efe50fa113ebecdc04ae6922deba8b84582d0c"}, + {file = "PyJWT-2.10.0-py3-none-any.whl", hash = "sha256:543b77207db656de204372350926bed5a86201c4cbff159f623f79c7bb487a15"}, + {file = "pyjwt-2.10.0.tar.gz", hash = "sha256:7628a7eb7938959ac1b26e819a1df0fd3259505627b575e4bad6d08f76db695c"}, ] [package.extras] @@ -2334,29 +2334,29 @@ tests = ["coverage (>=6.0.0)", "flake8", "mypy", "pytest (>=7.0.0)", "pytest-asy [[package]] name = "ruff" -version = "0.7.3" +version = "0.7.4" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.7.3-py3-none-linux_armv6l.whl", hash = "sha256:34f2339dc22687ec7e7002792d1f50712bf84a13d5152e75712ac08be565d344"}, - {file = "ruff-0.7.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:fb397332a1879b9764a3455a0bb1087bda876c2db8aca3a3cbb67b3dbce8cda0"}, - {file = "ruff-0.7.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:37d0b619546103274e7f62643d14e1adcbccb242efda4e4bdb9544d7764782e9"}, - {file = "ruff-0.7.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d59f0c3ee4d1a6787614e7135b72e21024875266101142a09a61439cb6e38a5"}, - {file = "ruff-0.7.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:44eb93c2499a169d49fafd07bc62ac89b1bc800b197e50ff4633aed212569299"}, - {file = "ruff-0.7.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6d0242ce53f3a576c35ee32d907475a8d569944c0407f91d207c8af5be5dae4e"}, - {file = "ruff-0.7.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:6b6224af8b5e09772c2ecb8dc9f3f344c1aa48201c7f07e7315367f6dd90ac29"}, - {file = "ruff-0.7.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c50f95a82b94421c964fae4c27c0242890a20fe67d203d127e84fbb8013855f5"}, - {file = "ruff-0.7.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7f3eff9961b5d2644bcf1616c606e93baa2d6b349e8aa8b035f654df252c8c67"}, - {file = "ruff-0.7.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8963cab06d130c4df2fd52c84e9f10d297826d2e8169ae0c798b6221be1d1d2"}, - {file = "ruff-0.7.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:61b46049d6edc0e4317fb14b33bd693245281a3007288b68a3f5b74a22a0746d"}, - {file = "ruff-0.7.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:10ebce7696afe4644e8c1a23b3cf8c0f2193a310c18387c06e583ae9ef284de2"}, - {file = "ruff-0.7.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:3f36d56326b3aef8eeee150b700e519880d1aab92f471eefdef656fd57492aa2"}, - {file = "ruff-0.7.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5d024301109a0007b78d57ab0ba190087b43dce852e552734ebf0b0b85e4fb16"}, - {file = "ruff-0.7.3-py3-none-win32.whl", hash = "sha256:4ba81a5f0c5478aa61674c5a2194de8b02652f17addf8dfc40c8937e6e7d79fc"}, - {file = "ruff-0.7.3-py3-none-win_amd64.whl", hash = "sha256:588a9ff2fecf01025ed065fe28809cd5a53b43505f48b69a1ac7707b1b7e4088"}, - {file = "ruff-0.7.3-py3-none-win_arm64.whl", hash = "sha256:1713e2c5545863cdbfe2cbce21f69ffaf37b813bfd1fb3b90dc9a6f1963f5a8c"}, - {file = "ruff-0.7.3.tar.gz", hash = "sha256:e1d1ba2e40b6e71a61b063354d04be669ab0d39c352461f3d789cac68b54a313"}, + {file = "ruff-0.7.4-py3-none-linux_armv6l.whl", hash = "sha256:a4919925e7684a3f18e18243cd6bea7cfb8e968a6eaa8437971f681b7ec51478"}, + {file = "ruff-0.7.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:cfb365c135b830778dda8c04fb7d4280ed0b984e1aec27f574445231e20d6c63"}, + {file = "ruff-0.7.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:63a569b36bc66fbadec5beaa539dd81e0527cb258b94e29e0531ce41bacc1f20"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d06218747d361d06fd2fdac734e7fa92df36df93035db3dc2ad7aa9852cb109"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e0cea28d0944f74ebc33e9f934238f15c758841f9f5edd180b5315c203293452"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:80094ecd4793c68b2571b128f91754d60f692d64bc0d7272ec9197fdd09bf9ea"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:997512325c6620d1c4c2b15db49ef59543ef9cd0f4aa8065ec2ae5103cedc7e7"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:00b4cf3a6b5fad6d1a66e7574d78956bbd09abfd6c8a997798f01f5da3d46a05"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7dbdc7d8274e1422722933d1edddfdc65b4336abf0b16dfcb9dedd6e6a517d06"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e92dfb5f00eaedb1501b2f906ccabfd67b2355bdf117fea9719fc99ac2145bc"}, + {file = "ruff-0.7.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:3bd726099f277d735dc38900b6a8d6cf070f80828877941983a57bca1cd92172"}, + {file = "ruff-0.7.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2e32829c429dd081ee5ba39aef436603e5b22335c3d3fff013cd585806a6486a"}, + {file = "ruff-0.7.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:662a63b4971807623f6f90c1fb664613f67cc182dc4d991471c23c541fee62dd"}, + {file = "ruff-0.7.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:876f5e09eaae3eb76814c1d3b68879891d6fde4824c015d48e7a7da4cf066a3a"}, + {file = "ruff-0.7.4-py3-none-win32.whl", hash = "sha256:75c53f54904be42dd52a548728a5b572344b50d9b2873d13a3f8c5e3b91f5cac"}, + {file = "ruff-0.7.4-py3-none-win_amd64.whl", hash = "sha256:745775c7b39f914238ed1f1b0bebed0b9155a17cd8bc0b08d3c87e4703b990d6"}, + {file = "ruff-0.7.4-py3-none-win_arm64.whl", hash = "sha256:11bff065102c3ae9d3ea4dc9ecdfe5a5171349cdd0787c1fc64761212fc9cf1f"}, + {file = "ruff-0.7.4.tar.gz", hash = "sha256:cd12e35031f5af6b9b93715d8c4f40360070b2041f81273d0527683d5708fce2"}, ] [[package]] @@ -3046,4 +3046,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "025e56816bd1ef6e016bfb3d2f6a4b32be305281c2ca3c4530d8cdda297d3dc2" +content-hash = "0fa8d1bb12d2d1b8e3ae9582afee475cf8e095957153d24cad49c1b29d539a7d" diff --git a/pyproject.toml b/pyproject.toml index 6ffd8a79..d591d95a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "admin_backend" -version = "2.17.13" +version = "2.17.14" description = "" authors = ["CPR-dev-team "] packages = [{ include = "app" }, { include = "tests" }] @@ -29,7 +29,7 @@ boto3 = "^1.34.151" moto = "^4.2.2" types-sqlalchemy = "^1.4.53.38" urllib3 = "^1.26.17" -db-client = { git = "https://github.com/climatepolicyradar/navigator-db-client.git", tag = "v3.8.25" } +db-client = { git = "https://github.com/climatepolicyradar/navigator-db-client.git", rev = "v3.8.26" } navigator-notify = { git = "https://github.com/climatepolicyradar/navigator-notify.git", tag = "v0.0.2-beta" } bcrypt = "4.0.1" diff --git a/tests/helpers/corpus.py b/tests/helpers/corpus.py new file mode 100644 index 00000000..b50fceeb --- /dev/null +++ b/tests/helpers/corpus.py @@ -0,0 +1,37 @@ +from typing import Optional + +from app.model.corpus import CorpusCreateDTO, CorpusWriteDTO + + +def create_corpus_write_dto( + title: str = "title", + description: str = "description", + corpus_text: Optional[str] = "corpus_text", + image_url: Optional[str] = "some-picture.png", + corpus_type_description: str = "some description", +) -> CorpusWriteDTO: + return CorpusWriteDTO( + title=title, + description=description, + corpus_text=corpus_text, + corpus_image_url=image_url, + corpus_type_description=corpus_type_description, + ) + + +def create_corpus_create_dto( + corpus_type: str, + title: str = "title", + description: str = "description", + corpus_text: Optional[str] = "corpus_text", + image_url: Optional[str] = "some-picture.png", + org_id: int = 1, +) -> CorpusCreateDTO: + return CorpusCreateDTO( + title=title, + description=description, + corpus_text=corpus_text, + corpus_image_url=image_url, + organisation_id=org_id, + corpus_type_name=corpus_type, + ) diff --git a/tests/integration_tests/collection/test_create.py b/tests/integration_tests/collection/test_create.py index f98c1b2c..97192909 100644 --- a/tests/integration_tests/collection/test_create.py +++ b/tests/integration_tests/collection/test_create.py @@ -20,7 +20,7 @@ def test_create_collection(client: TestClient, data_db: Session, user_header_tok ) assert response.status_code == status.HTTP_201_CREATED data = response.json() - assert data == "CCLW.collection.i00000001.n0000" + assert data == "CCLW.collection.i00000002.n0000" actual_collection = ( data_db.query(Collection).filter(Collection.import_id == data).one() ) diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 11fbb642..f904d31a 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -16,11 +16,18 @@ import app.service.token as token_service from app.config import SQLALCHEMY_DATABASE_URI from app.main import app -from app.repository import collection_repo, document_repo, event_repo, family_repo +from app.repository import ( + collection_repo, + corpus_repo, + document_repo, + event_repo, + family_repo, +) from tests.mocks.repos.bad_collection_repo import ( mock_bad_collection_repo, mock_collection_count_none, ) +from tests.mocks.repos.bad_corpus_repo import mock_bad_corpus_repo from tests.mocks.repos.bad_document_repo import ( mock_bad_document_repo, mock_document_count_none, @@ -31,6 +38,7 @@ mock_family_count_none, ) from tests.mocks.repos.rollback_collection_repo import mock_rollback_collection_repo +from tests.mocks.repos.rollback_corpus_repo import mock_rollback_corpus_repo from tests.mocks.repos.rollback_document_repo import mock_rollback_document_repo from tests.mocks.repos.rollback_event_repo import mock_rollback_event_repo from tests.mocks.repos.rollback_family_repo import mock_rollback_family_repo @@ -142,6 +150,13 @@ def bad_event_repo(monkeypatch, mocker): yield event_repo +@pytest.fixture +def bad_corpus_repo(monkeypatch, mocker): + """Mocks the repository for a single test.""" + mock_bad_corpus_repo(corpus_repo, monkeypatch, mocker) + yield corpus_repo + + @pytest.fixture def collection_count_none(monkeypatch, mocker): """Mocks the service for a single test.""" @@ -198,6 +213,13 @@ def rollback_event_repo(monkeypatch, mocker): yield event_repo +@pytest.fixture +def rollback_corpus_repo(monkeypatch, mocker): + """Mocks the repository for a single test.""" + mock_rollback_corpus_repo(corpus_repo, monkeypatch, mocker) + yield corpus_repo + + @pytest.fixture def superuser_header_token() -> Dict[str, str]: a_token = token_service.encode( diff --git a/tests/integration_tests/corpus/test_create.py b/tests/integration_tests/corpus/test_create.py new file mode 100644 index 00000000..92c0fc78 --- /dev/null +++ b/tests/integration_tests/corpus/test_create.py @@ -0,0 +1,211 @@ +from db_client.models.organisation.corpus import Corpus +from fastapi import status +from fastapi.testclient import TestClient +from sqlalchemy.orm import Session + +from tests.helpers.corpus import create_corpus_create_dto +from tests.integration_tests.setup_db import setup_db + + +def test_create_corpus(client: TestClient, data_db: Session, superuser_header_token): + setup_db(data_db) + new_corpus = create_corpus_create_dto("Laws and Policies") + response = client.post( + "/api/v1/corpora", + json=new_corpus.model_dump(), + headers=superuser_header_token, + ) + assert response.status_code == status.HTTP_201_CREATED + + data = response.json() + assert data == "CCLW.corpus.i00000002.n0000" + actual_corpus = data_db.query(Corpus).filter(Corpus.import_id == data).one() + + assert actual_corpus.import_id == "CCLW.corpus.i00000002.n0000" + assert actual_corpus.title == "title" + assert actual_corpus.description == "description" + assert actual_corpus.corpus_text == "corpus_text" + assert actual_corpus.corpus_type_name == "Laws and Policies" + assert actual_corpus.corpus_image_url == "some-picture.png" + assert actual_corpus.organisation_id == 1 + + ct: int = ( + data_db.query(Corpus) + .filter(Corpus.corpus_type_name == "Laws and Policies") + .count() + ) + assert ct > 1 + + +def test_create_corpus_allows_none_corpus_text( + client: TestClient, data_db: Session, superuser_header_token +): + setup_db(data_db) + new_corpus = create_corpus_create_dto("Laws and Policies", corpus_text=None) + response = client.post( + "/api/v1/corpora", + json=new_corpus.model_dump(), + headers=superuser_header_token, + ) + assert response.status_code == status.HTTP_201_CREATED + + data = response.json() + assert data == "CCLW.corpus.i00000002.n0000" + actual_corpus = data_db.query(Corpus).filter(Corpus.import_id == data).one() + + assert actual_corpus.import_id == "CCLW.corpus.i00000002.n0000" + assert actual_corpus.title == "title" + assert actual_corpus.description == "description" + assert actual_corpus.corpus_text is None + assert actual_corpus.corpus_type_name == "Laws and Policies" + assert actual_corpus.corpus_image_url == "some-picture.png" + assert actual_corpus.organisation_id == 1 + + ct: int = ( + data_db.query(Corpus) + .filter(Corpus.corpus_type_name == "Laws and Policies") + .count() + ) + assert ct > 1 + + +def test_create_corpus_allows_none_corpus_image_url( + client: TestClient, data_db: Session, superuser_header_token +): + setup_db(data_db) + new_corpus = create_corpus_create_dto("Laws and Policies", image_url=None) + response = client.post( + "/api/v1/corpora", + json=new_corpus.model_dump(), + headers=superuser_header_token, + ) + assert response.status_code == status.HTTP_201_CREATED + + data = response.json() + assert data == "CCLW.corpus.i00000002.n0000" + actual_corpus = data_db.query(Corpus).filter(Corpus.import_id == data).one() + + assert actual_corpus.import_id == "CCLW.corpus.i00000002.n0000" + assert actual_corpus.title == "title" + assert actual_corpus.description == "description" + assert actual_corpus.corpus_text == "corpus_text" + assert actual_corpus.corpus_type_name == "Laws and Policies" + assert actual_corpus.corpus_image_url is None + assert actual_corpus.organisation_id == 1 + + ct: int = ( + data_db.query(Corpus) + .filter(Corpus.corpus_type_name == "Laws and Policies") + .count() + ) + assert ct > 1 + + +def test_create_corpus_when_corpus_type_not_exist( + client: TestClient, data_db: Session, superuser_header_token +): + setup_db(data_db) + new_corpus = create_corpus_create_dto( + "some corpus type", title="test corpus", description="test test test" + ) + response = client.post( + "/api/v1/corpora", + json=new_corpus.model_dump(), + headers=superuser_header_token, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST + + data = response.json() + assert data["detail"] == "Invalid corpus type name" + + actual_corpus = ( + data_db.query(Corpus) + .filter(Corpus.corpus_type_name == "some corpus type") + .one_or_none() + ) + assert actual_corpus is None + + +def test_create_corpus_when_org_id_not_exist( + client: TestClient, data_db: Session, superuser_header_token +): + setup_db(data_db) + new_corpus = create_corpus_create_dto("Laws and Policies", org_id=100) + response = client.post( + "/api/v1/corpora", + json=new_corpus.model_dump(), + headers=superuser_header_token, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST + + data = response.json() + assert data["detail"] == "Invalid organisation" + + actual_corpus = ( + data_db.query(Corpus).filter(Corpus.organisation_id == 100).one_or_none() + ) + assert actual_corpus is None + + +def test_create_corpus_when_not_authenticated(client: TestClient, data_db: Session): + setup_db(data_db) + new_corpus = create_corpus_create_dto("some-corpus-type") + response = client.post( + "/api/v1/corpora", + json=new_corpus.model_dump(), + ) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + +def test_create_corpus_non_admin_non_super(client: TestClient, user_header_token): + new_corpus = create_corpus_create_dto("some-corpus-type") + response = client.post( + "/api/v1/corpora", json=new_corpus.model_dump(), headers=user_header_token + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + data = response.json() + assert data["detail"] == "User cclw@cpr.org is not authorised to CREATE a CORPORA" + + +def test_create_corpus_admin_non_super(client: TestClient, admin_user_header_token): + new_corpus = create_corpus_create_dto("some-corpus-type") + response = client.post( + "/api/v1/corpora", json=new_corpus.model_dump(), headers=admin_user_header_token + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + data = response.json() + assert data["detail"] == "User admin@cpr.org is not authorised to CREATE a CORPORA" + + +def test_create_corpus_rollback( + client: TestClient, data_db: Session, rollback_corpus_repo, superuser_header_token +): + setup_db(data_db) + new_corpus = create_corpus_create_dto("Laws and Policies") + response = client.post( + "/api/v1/corpora", + json=new_corpus.model_dump(), + headers=superuser_header_token, + ) + assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE + actual_corpus = ( + data_db.query(Corpus).filter(Corpus.import_id == "A.0.0.9").one_or_none() + ) + assert actual_corpus is None + assert rollback_corpus_repo.create.call_count == 1 + + +def test_create_corpus_when_db_error( + client: TestClient, data_db: Session, bad_corpus_repo, superuser_header_token +): + setup_db(data_db) + new_corpus = create_corpus_create_dto("Laws and Policies") + response = client.post( + "/api/v1/corpora", + json=new_corpus.model_dump(), + headers=superuser_header_token, + ) + assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE + data = response.json() + assert data["detail"] == "Bad Repo" + assert bad_corpus_repo.create.call_count == 1 diff --git a/tests/integration_tests/corpus/test_update.py b/tests/integration_tests/corpus/test_update.py new file mode 100644 index 00000000..d16e885f --- /dev/null +++ b/tests/integration_tests/corpus/test_update.py @@ -0,0 +1,269 @@ +from typing import cast + +from db_client.models.organisation.corpus import Corpus, CorpusType +from fastapi import status +from fastapi.testclient import TestClient +from sqlalchemy.orm import Session + +from tests.helpers.corpus import create_corpus_write_dto +from tests.integration_tests.setup_db import setup_db + + +def test_update_corpus(client: TestClient, data_db: Session, superuser_header_token): + setup_db(data_db) + old_ct = ( + data_db.query(CorpusType).filter(CorpusType.name == "Laws and Policies").one() + ) + new_corpus = create_corpus_write_dto() + response = client.put( + "/api/v1/corpora/CCLW.corpus.i00000001.n0000", + json=new_corpus.model_dump(), + headers=superuser_header_token, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["title"] == "title" + assert data["description"] == "description" + assert data["corpus_text"] == "corpus_text" + assert data["corpus_image_url"] == "some-picture.png" + assert data["organisation_id"] == 1 + assert data["organisation_name"] == "CCLW" + assert data["corpus_type_name"] == old_ct.name + assert data["corpus_type_description"] == "some description" + assert isinstance(data["metadata"], dict) + assert data["metadata"] == old_ct.valid_metadata + + db_corpus: Corpus = ( + data_db.query(Corpus) + .filter(Corpus.import_id == "CCLW.corpus.i00000001.n0000") + .one() + ) + assert db_corpus.import_id == "CCLW.corpus.i00000001.n0000" + assert db_corpus.title == "title" + assert db_corpus.description == "description" + assert db_corpus.corpus_text == "corpus_text" + assert db_corpus.corpus_type_name == "Laws and Policies" + assert db_corpus.corpus_image_url == "some-picture.png" + assert db_corpus.organisation_id == 1 + + ct: CorpusType = ( + data_db.query(CorpusType).filter(CorpusType.name == "Laws and Policies").one() + ) + assert ct.name == old_ct.name + assert ct.description == "some description" + assert ct.valid_metadata == old_ct.valid_metadata + + +def test_update_corpus_allows_none_corpus_image_url( + client: TestClient, data_db: Session, superuser_header_token +): + setup_db(data_db) + old_ct = ( + data_db.query(CorpusType).filter(CorpusType.name == "Laws and Policies").one() + ) + new_corpus = create_corpus_write_dto(image_url=None) + response = client.put( + "/api/v1/corpora/CCLW.corpus.i00000001.n0000", + json=new_corpus.model_dump(), + headers=superuser_header_token, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["title"] == "title" + assert data["description"] == "description" + assert data["corpus_text"] == "corpus_text" + assert data["corpus_image_url"] is None + assert data["organisation_id"] == 1 + assert data["organisation_name"] == "CCLW" + assert data["corpus_type_name"] == old_ct.name + assert data["corpus_type_description"] == "some description" + assert isinstance(data["metadata"], dict) + assert data["metadata"] == old_ct.valid_metadata + + db_corpus: Corpus = ( + data_db.query(Corpus) + .filter(Corpus.import_id == "CCLW.corpus.i00000001.n0000") + .one() + ) + assert db_corpus.import_id == "CCLW.corpus.i00000001.n0000" + assert db_corpus.title == "title" + assert db_corpus.description == "description" + assert db_corpus.corpus_text == "corpus_text" + assert db_corpus.corpus_type_name == "Laws and Policies" + assert db_corpus.corpus_image_url is None + assert db_corpus.organisation_id == 1 + + ct: CorpusType = ( + data_db.query(CorpusType).filter(CorpusType.name == "Laws and Policies").one() + ) + assert ct.name == old_ct.name + assert ct.description == "some description" + assert ct.valid_metadata == old_ct.valid_metadata + + +def test_update_corpus_allows_none_corpus_text( + client: TestClient, data_db: Session, superuser_header_token +): + setup_db(data_db) + old_ct = ( + data_db.query(CorpusType).filter(CorpusType.name == "Laws and Policies").one() + ) + new_corpus = create_corpus_write_dto(corpus_text=None) + response = client.put( + "/api/v1/corpora/CCLW.corpus.i00000001.n0000", + json=new_corpus.model_dump(), + headers=superuser_header_token, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert data["import_id"] == "CCLW.corpus.i00000001.n0000" + assert data["title"] == "title" + assert data["description"] == "description" + assert data["corpus_text"] is None + assert data["corpus_image_url"] == "some-picture.png" + assert data["organisation_id"] == 1 + assert data["organisation_name"] == "CCLW" + assert data["corpus_type_name"] == old_ct.name + assert data["corpus_type_description"] == "some description" + assert isinstance(data["metadata"], dict) + assert data["metadata"] == old_ct.valid_metadata + + db_corpus: Corpus = ( + data_db.query(Corpus) + .filter(Corpus.import_id == "CCLW.corpus.i00000001.n0000") + .one() + ) + assert db_corpus.import_id == "CCLW.corpus.i00000001.n0000" + assert db_corpus.title == "title" + assert db_corpus.description == "description" + assert db_corpus.corpus_text is None + assert db_corpus.corpus_type_name == "Laws and Policies" + assert db_corpus.corpus_image_url == "some-picture.png" + assert db_corpus.organisation_id == 1 + + ct: CorpusType = ( + data_db.query(CorpusType).filter(CorpusType.name == "Laws and Policies").one() + ) + assert ct.name == old_ct.name + assert ct.description == "some description" + assert ct.valid_metadata == old_ct.valid_metadata + + +def test_update_corpus_when_not_authorised(client: TestClient, data_db: Session): + setup_db(data_db) + new_corpus = create_corpus_write_dto() + response = client.put("/api/v1/corpora/C.0.0.2", json=new_corpus.model_dump()) + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + +def test_update_corpus_non_super_non_admin(client: TestClient, user_header_token): + new_corpus = create_corpus_write_dto() + response = client.put( + "/api/v1/corpora/C.0.0.2", + json=new_corpus.model_dump(), + headers=user_header_token, + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + data = response.json() + assert data["detail"] == "User cclw@cpr.org is not authorised to UPDATE a CORPORA" + + +def test_update_corpus_non_super_admin(client: TestClient, admin_user_header_token): + new_corpus = create_corpus_write_dto() + response = client.put( + "/api/v1/corpora/C.0.0.2", + json=new_corpus.model_dump(), + headers=admin_user_header_token, + ) + assert response.status_code == status.HTTP_403_FORBIDDEN + data = response.json() + assert data["detail"] == "User admin@cpr.org is not authorised to UPDATE a CORPORA" + + +def test_update_corpus_idempotent( + client: TestClient, data_db: Session, superuser_header_token +): + setup_db(data_db) + old_corpus, old_ct = ( + data_db.query(Corpus, CorpusType) + .join(Corpus, Corpus.corpus_type_name == CorpusType.name) + .filter(CorpusType.name == "Laws and Policies") + .one() + ) + + new_corpus = create_corpus_write_dto( + title=old_corpus.title, + description=old_corpus.description, + corpus_text=old_corpus.corpus_text, + image_url=old_corpus.corpus_image_url, + corpus_type_description="Laws and policies", + ) + old_corpus = cast(Corpus, old_corpus) + response = client.put( + "/api/v1/corpora/CCLW.corpus.i00000001.n0000", + json=new_corpus.model_dump(), + headers=superuser_header_token, + ) + assert response.status_code == status.HTTP_200_OK + + data = response.json() + assert data["import_id"] == old_corpus.import_id + assert data["title"] == old_corpus.title + assert data["description"] == old_corpus.description + assert data["corpus_text"] == old_corpus.corpus_text + assert data["corpus_image_url"] == old_corpus.corpus_image_url + assert data["organisation_id"] == old_corpus.organisation_id + assert data["corpus_type_name"] == old_ct.name + assert data["corpus_type_description"] == old_ct.description + assert isinstance(data["metadata"], dict) + assert data["metadata"] == old_ct.valid_metadata + + db_corpus: Corpus = ( + data_db.query(Corpus) + .filter(Corpus.import_id == "CCLW.corpus.i00000001.n0000") + .one() + ) + assert db_corpus.import_id == old_corpus.import_id + assert db_corpus.title == old_corpus.title + assert db_corpus.description == old_corpus.description + assert db_corpus.corpus_text == old_corpus.corpus_text + assert db_corpus.corpus_image_url == old_corpus.corpus_image_url + assert db_corpus.organisation_id == old_corpus.organisation_id + assert db_corpus.corpus_type_name == old_corpus.corpus_type_name + + +def test_update_corpus_rollback_when_db_error( + client: TestClient, data_db: Session, rollback_corpus_repo, superuser_header_token +): + setup_db(data_db) + new_corpus = create_corpus_write_dto(title="Updated Title") + response = client.put( + "/api/v1/corpora/CCLW.corpus.i00000001.n0000", + json=new_corpus.model_dump(), + headers=superuser_header_token, + ) + assert response.status_code == status.HTTP_503_SERVICE_UNAVAILABLE + + db_corpus: Corpus = ( + data_db.query(Corpus) + .filter(Corpus.import_id == "CCLW.corpus.i00000001.n0000") + .one() + ) + assert db_corpus.title != "Updated Title" + assert db_corpus.description != "description" + assert rollback_corpus_repo.update.call_count == 1 + + +def test_update_corpus_when_not_found( + client: TestClient, data_db: Session, superuser_header_token +): + setup_db(data_db) + new_corpus = create_corpus_write_dto(title="Updated Title") + response = client.put( + "/api/v1/corpora/C.0.0.22", + json=new_corpus.model_dump(), + headers=superuser_header_token, + ) + assert response.status_code == status.HTTP_404_NOT_FOUND + data = response.json() + assert data["detail"] == "Corpus not updated: C.0.0.22" diff --git a/tests/integration_tests/family/test_create.py b/tests/integration_tests/family/test_create.py index 60d84f73..3ff785ad 100644 --- a/tests/integration_tests/family/test_create.py +++ b/tests/integration_tests/family/test_create.py @@ -29,7 +29,7 @@ def test_create_family(client: TestClient, data_db: Session, user_header_token): "/api/v1/families", json=new_family.model_dump(), headers=user_header_token ) assert response.status_code == status.HTTP_201_CREATED - expected_import_id = "CCLW.family.i00000001.n0000" + expected_import_id = "CCLW.family.i00000002.n0000" assert response.json() == expected_import_id actual_family = ( data_db.query(Family).filter(Family.import_id == expected_import_id).one() diff --git a/tests/integration_tests/setup_db.py b/tests/integration_tests/setup_db.py index 9c895951..ad112ee2 100644 --- a/tests/integration_tests/setup_db.py +++ b/tests/integration_tests/setup_db.py @@ -21,8 +21,9 @@ PhysicalDocument, PhysicalDocumentLanguage, ) -from db_client.models.organisation import Corpus, CorpusType +from db_client.models.organisation import Corpus, EntityCounter from db_client.models.organisation.users import AppUser, Organisation, OrganisationUser +from sqlalchemy import update from sqlalchemy.orm import Session EXPECTED_FAMILIES = [ @@ -302,6 +303,8 @@ def setup_test_data(test_db: Session, configure_empty: bool = False): _setup_event_data(test_db) test_db.commit() + setup_corpus(test_db) + def _add_app_user( test_db: Session, @@ -427,32 +430,15 @@ def _setup_organisation(test_db: Session) -> tuple[int, int]: def setup_corpus(test_db: Session) -> None: - org_id = _setup_organisation(test_db) - test_db.add( - Corpus( - import_id="1", - title="Test Title", - description="Test Description", - corpus_text="Test Text", - corpus_image_url="Test Image Url", - organisation_id=org_id, - corpus_type_name="Test Corpus", + test_db.execute( + update(EntityCounter).values( + counter=1, ) ) + test_db.commit() - test_db.add( - CorpusType( - name="Test Corpus", - description="Test Description", - valid_metadata={ - "test": { - "allow_any": "true", - "allow_blanks": "false", - "allowed_values": [], - } - }, - ) - ) + for item in test_db.query(EntityCounter.counter).all(): + assert item[0] == 1 def _setup_collection_data( diff --git a/tests/mocks/repos/bad_corpus_repo.py b/tests/mocks/repos/bad_corpus_repo.py new file mode 100644 index 00000000..a34c4718 --- /dev/null +++ b/tests/mocks/repos/bad_corpus_repo.py @@ -0,0 +1,38 @@ +from typing import Optional + +from pytest import MonkeyPatch + +from app.errors import RepositoryError +from app.model.corpus import CorpusCreateDTO, CorpusReadDTO, CorpusWriteDTO + + +def mock_bad_corpus_repo(repo, monkeypatch: MonkeyPatch, mocker): + def mock_get_all(_): + raise RepositoryError("Bad Repo") + + def mock_get(_, import_id: str) -> Optional[CorpusReadDTO]: + raise RepositoryError("Bad Repo") + + def mock_search(_, q: str, org_id: Optional[int]) -> list[CorpusReadDTO]: + raise RepositoryError("Bad Repo") + + def mock_update(_, import_id: str, data: CorpusWriteDTO) -> Optional[CorpusReadDTO]: + raise RepositoryError("Bad Repo") + + def mock_create(_, data: CorpusCreateDTO) -> str: + raise RepositoryError("Bad Repo") + + monkeypatch.setattr(repo, "get", mock_get) + mocker.spy(repo, "get") + + monkeypatch.setattr(repo, "all", mock_get_all) + mocker.spy(repo, "all") + + monkeypatch.setattr(repo, "search", mock_search) + mocker.spy(repo, "search") + + monkeypatch.setattr(repo, "update", mock_update) + mocker.spy(repo, "update") + + monkeypatch.setattr(repo, "create", mock_create) + mocker.spy(repo, "create") diff --git a/tests/mocks/repos/rollback_corpus_repo.py b/tests/mocks/repos/rollback_corpus_repo.py new file mode 100644 index 00000000..91519677 --- /dev/null +++ b/tests/mocks/repos/rollback_corpus_repo.py @@ -0,0 +1,27 @@ +from typing import Optional + +from pytest import MonkeyPatch +from sqlalchemy.exc import NoResultFound + +from app.model.corpus import CorpusCreateDTO, CorpusReadDTO, CorpusWriteDTO + + +def mock_rollback_corpus_repo(corpus_repo, monkeypatch: MonkeyPatch, mocker): + actual_update = corpus_repo.update + actual_create = corpus_repo.create + + def mock_update_corpus( + db, import_id: str, data: CorpusWriteDTO + ) -> Optional[CorpusReadDTO]: + actual_update(db, import_id, data) + raise NoResultFound() + + def mock_create_corpus(db, data: CorpusCreateDTO) -> str: + actual_create(db, data) + raise NoResultFound() + + monkeypatch.setattr(corpus_repo, "update", mock_update_corpus) + mocker.spy(corpus_repo, "update") + + monkeypatch.setattr(corpus_repo, "create", mock_create_corpus) + mocker.spy(corpus_repo, "create")