Skip to content

Commit

Permalink
Feature/pdct 1055 frontend create family enable the user to select th…
Browse files Browse the repository at this point in the history
…e corpus (#129)

* Enable the user to select the corpus on create

* Add docstrings and add function to validate corpus exists in DB.

* Update existing family tests and helpers to include corpus_import_id field

* Added test to check family create fails when org mismatch between user and corpus

* Raise auth error instead of validation error when org mismatch & fix service family layer

* Add new tests for validating corpus on family create

* Generalise error text and check error text in test

* Add default value for family service update context

* Bump package to 2.5.1

* Bump to 2.5.0
  • Loading branch information
katybaulch authored May 2, 2024
1 parent f2ceb5f commit a33f00e
Show file tree
Hide file tree
Showing 19 changed files with 319 additions and 30 deletions.
4 changes: 3 additions & 1 deletion app/api/api_v1/routers/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
set_default_query_params,
validate_query_params,
)
from app.errors import RepositoryError, ValidationError
from app.errors import AuthorisationError, RepositoryError, ValidationError
from app.model.family import FamilyCreateDTO, FamilyReadDTO, FamilyWriteDTO

families_router = r = APIRouter()
Expand Down Expand Up @@ -162,6 +162,8 @@ async def create_family(
"""
try:
family = family_service.create(new_family, request.state.user.email)
except AuthorisationError as e:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=e.message)
except ValidationError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.message)
except RepositoryError as e:
Expand Down
3 changes: 2 additions & 1 deletion app/model/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class FamilyReadDTO(BaseModel):
documents: list[str]
collections: list[str]
organisation: str
corpus_id: str
corpus_import_id: str
corpus_title: str
corpus_type: str
created: datetime
Expand Down Expand Up @@ -63,3 +63,4 @@ class FamilyCreateDTO(BaseModel):
category: str
metadata: Json
collections: list[str]
corpus_import_id: str
17 changes: 16 additions & 1 deletion app/repository/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,27 @@
import app.repository.app_user as app_user_repo
import app.repository.collection as collection_repo
import app.repository.config as config_repo
import app.repository.corpus as corpus_repo
import app.repository.document as document_repo
import app.repository.event as event_repo
import app.repository.family as family_repo
import app.repository.family as family_repo # type: ignore
import app.repository.geography as geography_repo
import app.repository.metadata as metadata_repo
import app.repository.organisation as organisation_repo
from app.repository.protocols import FamilyRepo

family_repo: FamilyRepo

__all__ = (
"s3bucket_repo",
"app_user_repo",
"collection_repo",
"config_repo",
"corpus_repo",
"document_repo",
"event_repo",
"family_repo",
"geography_repo",
"metadata_repo",
"organisation_repo",
)
32 changes: 32 additions & 0 deletions app/repository/corpus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import logging
from typing import Optional

from db_client.models.organisation.corpus import Corpus
from sqlalchemy.orm import Session

_LOGGER = logging.getLogger(__name__)


def get_corpus_org_id(db: Session, corpus_id: str) -> Optional[int]:
"""Get the organisation ID a corpus belongs to.
TODO: Will need to review as part of PDCT-1011.
:param Session db: The DB session to connect to.
:param str corpus_id: The corpus import ID we want to get the org
for.
:return Optional[int]: Return the organisation ID the given corpus
belongs to or None.
"""
return db.query(Corpus.organisation_id).filter_by(import_id=corpus_id).scalar()


def validate(db: Session, corpus_id: str) -> bool:
"""Validate whether a corpus with the given ID exists in the DB.
:param Session db: The DB session to connect to.
:param str corpus_id: The corpus import ID we want to validate.
:return bool: Return whether or not the corpus exists in the DB.
"""
corpora = [corpus[0] for corpus in db.query(Corpus.import_id).distinct().all()]
return bool(corpus_id in corpora)
7 changes: 3 additions & 4 deletions app/repository/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _family_to_dto(
)
],
organisation=org,
corpus_id=cast(str, corpus.import_id),
corpus_import_id=cast(str, corpus.import_id),
corpus_title=cast(str, corpus.title),
corpus_type=cast(str, corpus.corpus_type_name),
created=cast(datetime, fam.created),
Expand Down Expand Up @@ -343,12 +343,11 @@ def create(db: Session, family: FamilyCreateDTO, geo_id: int, org_id: int) -> st
)
db.add(new_family)

# New schema.
new_fam_corpus = db.query(Corpus).filter(Corpus.organisation_id == org_id).one()
# Add corpus - family link.
db.add(
FamilyCorpus(
family_import_id=new_family.import_id,
corpus_import_id=new_fam_corpus.import_id,
corpus_import_id=family.corpus_import_id,
)
)

Expand Down
11 changes: 11 additions & 0 deletions app/service/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,17 @@ def validate_import_id(import_id: str) -> None:
id.validate(import_id)


@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def validate_multiple_ids(import_ids: set[str]) -> None:
"""
Validates a set of import ids for a collection.
:param set[str] import_ids: A set of import ids to check.
:raises ValidationError: raised if any of the import_ids are invalid.
"""
id.validate_multiple_ids(import_ids)


@db_session.with_transaction(__name__)
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def update(
Expand Down
47 changes: 47 additions & 0 deletions app/service/corpus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import logging
from typing import Optional

from sqlalchemy.orm import Session

from app.errors import RepositoryError, ValidationError
from app.repository import corpus_repo

_LOGGER = logging.getLogger(__name__)


def get_corpus_org_id(db: Session, corpus_import_id: str) -> Optional[int]:
"""Get the organisation ID(s) a corpus belongs to.
TODO: Will need to review as part of PDCT-1011.
:param Session db: The DB session to connect to.
:param str corpus_id: The corpus import ID we want to get the org
for.
:return Optional[int]: Return the organisation ID the given corpus
belongs to or None.
"""
org_id = corpus_repo.get_corpus_org_id(db, corpus_import_id)
return org_id


def validate(db: Session, corpus_import_id: str) -> bool:
"""Validate whether a corpus with the given ID exists in the DB.
:param Session db: The DB session to connect to.
:param str corpus_id: The corpus import ID we want to validate.
:raises ValidationError: When the corpus ID is not found in the DB.
:raises RepositoryError: When an error occurs.
:return bool: Return whether or not the corpus exists in the DB.
"""
try:
is_valid = corpus_repo.validate(db, corpus_import_id)
if is_valid:
return is_valid

except Exception as e:
_LOGGER.error(e)
raise RepositoryError(e)

msg = f"Corpus '{corpus_import_id}' not found"
_LOGGER.error(msg)
raise ValidationError(msg)
17 changes: 13 additions & 4 deletions app/service/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
from sqlalchemy.orm import Session

import app.clients.db.session as db_session
from app.errors import RepositoryError, ValidationError
from app.errors import AuthorisationError, RepositoryError, ValidationError
from app.model.family import FamilyCreateDTO, FamilyReadDTO, FamilyWriteDTO
from app.repository import family_repo
from app.service import (
app_user,
category,
collection,
corpus,
geography,
id,
metadata,
Expand Down Expand Up @@ -93,7 +94,7 @@ def update(
import_id: str,
user_email: str,
family_dto: FamilyWriteDTO,
context,
context=None,
db: Session = db_session.get_db(),
) -> Optional[FamilyReadDTO]:
"""
Expand Down Expand Up @@ -185,7 +186,7 @@ def create(

# Validate collection ids.
collections = set(family.collections)
id.validate_multiple_ids(collections)
collection.validate_multiple_ids(collections)

# Validate that the collections we want to update are from the same organisation as
# the current user.
Expand All @@ -195,7 +196,15 @@ def create(
if len(collections_not_in_user_org) > 0 and any(collections_not_in_user_org):
msg = "Organisation mismatch between some collections and the current user"
_LOGGER.error(msg)
raise ValidationError(msg)
raise AuthorisationError(msg)

# Validate that the corpus we want to add the new family to exists and is from the
# same organisation as the user.
corpus.validate(db, family.corpus_import_id)
if corpus.get_corpus_org_id(db, family.corpus_import_id) != org_id:
msg = "Organisation mismatch between selected corpus and the current user"
_LOGGER.error(msg)
raise AuthorisationError(msg)

return family_repo.create(db, family, geo_id, org_id)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "admin_backend"
version = "2.4.1"
version = "2.5.0"
description = ""
authors = ["CPR-dev-team <[email protected]>"]
packages = [{ include = "app" }, { include = "tests" }]
Expand Down
4 changes: 3 additions & 1 deletion tests/helpers/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def create_family_read_dto(
documents=["doc1", "doc2"],
collections=collections,
organisation="CCLW",
corpus_id="CCLW.corpus.i00000001.n0000",
corpus_import_id="CCLW.corpus.i00000001.n0000",
corpus_title="CCLW national policies",
corpus_type="Laws and Policies",
created=datetime.now(),
Expand All @@ -52,6 +52,7 @@ def create_family_create_dto(
category: str = FamilyCategory.LEGISLATIVE.value,
metadata: Optional[dict] = None,
collections: Optional[list[str]] = None,
corpus_import_id: str = "CCLW.corpus.i00000001.n0000",
) -> FamilyCreateDTO:
if metadata is None:
metadata = {}
Expand All @@ -64,6 +65,7 @@ def create_family_create_dto(
category=category,
metadata=metadata,
collections=collections,
corpus_import_id=corpus_import_id,
)


Expand Down
31 changes: 25 additions & 6 deletions tests/integration_tests/family/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_create_family_when_invalid_geo(
response = client.post(
"/api/v1/families", json=new_family.model_dump(), headers=user_header_token
)
assert response.status_code == 400
assert response.status_code == status.HTTP_400_BAD_REQUEST
data = response.json()
assert data["detail"] == "The geography value UK is invalid!"

Expand All @@ -152,7 +152,7 @@ def test_create_family_when_invalid_category(
response = client.post(
"/api/v1/families", json=new_family.model_dump(), headers=user_header_token
)
assert response.status_code == 400
assert response.status_code == status.HTTP_400_BAD_REQUEST
data = response.json()
assert data["detail"] == "Invalid is not a valid FamilyCategory"

Expand All @@ -167,11 +167,10 @@ def test_create_family_when_invalid_collection_id(
metadata={"color": ["pink"], "size": [0]},
collections=["col1"],
)
# new_family.category = "invalid"
response = client.post(
"/api/v1/families", json=new_family.model_dump(), headers=user_header_token
)
assert response.status_code == 400
assert response.status_code == status.HTTP_400_BAD_REQUEST
data = response.json()
assert data["detail"] == "The import ids are invalid: ['col1']"

Expand All @@ -186,13 +185,33 @@ def test_create_family_when_invalid_collection_org(
metadata={"color": ["pink"], "size": [0]},
collections=["C.0.0.1"],
)
# new_family.category = "invalid"
response = client.post(
"/api/v1/families", json=new_family.model_dump(), headers=user_header_token
)
assert response.status_code == 400
assert response.status_code == status.HTTP_403_FORBIDDEN
data = response.json()
assert (
data["detail"]
== "Organisation mismatch between some collections and the current user"
)


def test_create_family_when_invalid_corpus_org(
client: TestClient, data_db: Session, user_header_token
):
setup_db(data_db)
new_family = create_family_create_dto(
title="Title",
summary="test test test",
metadata={"color": ["pink"], "size": [0]},
corpus_import_id="UNFCCC.corpus.i00000001.n0000",
)
response = client.post(
"/api/v1/families", json=new_family.model_dump(), headers=user_header_token
)
assert response.status_code == status.HTTP_403_FORBIDDEN
data = response.json()
assert (
data["detail"]
== "Organisation mismatch between selected corpus and the current user"
)
6 changes: 3 additions & 3 deletions tests/integration_tests/setup_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
"status": "Created",
"metadata": {"size": [3], "color": ["red"]},
"organisation": "CCLW",
"corpus_id": "CCLW.corpus.i00000001.n0000",
"corpus_import_id": "CCLW.corpus.i00000001.n0000",
"corpus_title": "CCLW national policies",
"corpus_type": "Laws and Policies",
"slug": "Slug1",
Expand All @@ -54,7 +54,7 @@
"status": "Created",
"metadata": {"size": [4], "color": ["green"]},
"organisation": "CCLW",
"corpus_id": "CCLW.corpus.i00000001.n0000",
"corpus_import_id": "CCLW.corpus.i00000001.n0000",
"corpus_title": "CCLW national policies",
"corpus_type": "Laws and Policies",
"slug": "Slug2",
Expand All @@ -73,7 +73,7 @@
"status": "Created",
"metadata": {"size": [100], "color": ["blue"]},
"organisation": "CCLW",
"corpus_id": "CCLW.corpus.i00000001.n0000",
"corpus_import_id": "CCLW.corpus.i00000001.n0000",
"corpus_title": "CCLW national policies",
"corpus_type": "Laws and Policies",
"slug": "Slug3",
Expand Down
Loading

0 comments on commit a33f00e

Please sign in to comment.