Skip to content

Commit

Permalink
Pdct 1110 only show entity count per org (#138)
Browse files Browse the repository at this point in the history
* Update service count unit tests to only count entities in user org

* Only show count of entities in user org unless superuser

* Move is_superuser out of repo layer

* Fix ordering of function calls to fix tests

* Only call restrict_entities_to_user_org once & call repo directly

* Remove incorrect param from docstring
  • Loading branch information
katybaulch authored Jun 3, 2024
1 parent 7103c9d commit 6c99026
Show file tree
Hide file tree
Showing 37 changed files with 262 additions and 324 deletions.
15 changes: 7 additions & 8 deletions app/api/api_v1/routers/analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging

from fastapi import APIRouter, HTTPException, status
from fastapi import APIRouter, HTTPException, Request, status

import app.service.analytics as analytics_service
from app.errors import RepositoryError
Expand All @@ -13,28 +13,27 @@
_LOGGER = logging.getLogger(__name__)


@r.get(
"/analytics/summary",
response_model=SummaryDTO,
)
async def get_analytics_summary() -> SummaryDTO:
@r.get("/analytics/summary", response_model=SummaryDTO)
async def get_analytics_summary(request: Request) -> SummaryDTO:
"""
Returns an analytics summary.
:return dict[str, int]: returns a dictionary of the summarised analytics
data in key (str): value (int) form.
"""
try:
summary_dto = analytics_service.summary()
summary_dto = analytics_service.summary(request.state.user.email)
except RepositoryError as e:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=e.message
)

if any(summary_value is None for _, summary_value in summary_dto):
msg = "Analytics summary not found"
_LOGGER.error(msg)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Analytics summary not found",
detail=msg,
)

return summary_dto
23 changes: 5 additions & 18 deletions app/api/api_v1/routers/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@
"/families/{import_id}",
response_model=FamilyReadDTO,
)
async def get_family(
import_id: str,
) -> FamilyReadDTO:
async def get_family(import_id: str) -> FamilyReadDTO:
"""
Returns a specific family given the import id.
Expand All @@ -57,10 +55,7 @@ async def get_family(
return family


@r.get(
"/families",
response_model=list[FamilyReadDTO],
)
@r.get("/families", response_model=list[FamilyReadDTO])
async def get_all_families(request: Request) -> list[FamilyReadDTO]:
"""
Returns all families
Expand All @@ -75,10 +70,7 @@ async def get_all_families(request: Request) -> list[FamilyReadDTO]:
)


@r.get(
"/families/",
response_model=list[FamilyReadDTO],
)
@r.get("/families/", response_model=list[FamilyReadDTO])
async def search_family(request: Request) -> list[FamilyReadDTO]:
"""
Searches for families matching URL parameters ("q" by default).
Expand Down Expand Up @@ -122,10 +114,7 @@ async def search_family(request: Request) -> list[FamilyReadDTO]:
return families


@r.put(
"/families/{import_id}",
response_model=FamilyReadDTO,
)
@r.put("/families/{import_id}", response_model=FamilyReadDTO)
async def update_family(
request: Request,
import_id: str,
Expand Down Expand Up @@ -179,9 +168,7 @@ async def create_family(
return family


@r.delete(
"/families/{import_id}",
)
@r.delete("/families/{import_id}")
async def delete_family(
import_id: str,
) -> None:
Expand Down
8 changes: 6 additions & 2 deletions app/repository/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,15 +259,19 @@ def delete(db: Session, import_id: str) -> bool:
return result.rowcount > 0 # type: ignore


def count(db: Session) -> Optional[int]:
def count(db: Session, org_id: Optional[int]) -> Optional[int]:
"""
Counts the number of collections in the repository.
:param db Session: the database connection
:param org_id Optional[int]: the ID of the organisation the user belongs to
:return Optional[int]: The number of collections in the repository or none.
"""
try:
n_collections = _get_query(db).count()
if org_id is None:
n_collections = _get_query(db).count()
else:
n_collections = _get_query(db).filter(Organisation.id == org_id).count()
except Exception as e:
_LOGGER.error(e)
return
Expand Down
9 changes: 7 additions & 2 deletions app/repository/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
PhysicalDocument,
PhysicalDocumentLanguage,
)
from db_client.models.organisation import Organisation
from db_client.models.organisation.counters import CountedEntity
from sqlalchemy import Column, and_
from sqlalchemy import delete as db_delete
Expand Down Expand Up @@ -434,15 +435,19 @@ def delete(db: Session, import_id: str) -> bool:
return True


def count(db: Session) -> Optional[int]:
def count(db: Session, org_id: Optional[int]) -> Optional[int]:
"""
Counts the number of documents in the repository.
:param db Session: the database connection
:param org_id Optional[int]: the ID of the organisation the user belongs to
:return Optional[int]: The number of documents in the repository or none.
"""
try:
n_documents = db.query(FamilyDocument).count()
if org_id is None:
n_documents = _get_query(db).count()
else:
n_documents = _get_query(db).filter(Organisation.id == org_id).count()
except NoResultFound as e:
_LOGGER.error(e)
return
Expand Down
9 changes: 7 additions & 2 deletions app/repository/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional, Tuple, Union, cast

from db_client.models.dfce import EventStatus, Family, FamilyDocument, FamilyEvent
from db_client.models.organisation import Organisation
from db_client.models.organisation.counters import CountedEntity
from sqlalchemy import Column, and_
from sqlalchemy import delete as db_delete
Expand Down Expand Up @@ -241,16 +242,20 @@ def delete(db: Session, import_id: str) -> bool:
return True


def count(db: Session) -> Optional[int]:
def count(db: Session, org_id: Optional[int]) -> Optional[int]:
"""
Counts the number of family events in the repository.
:param db Session: The database connection.
:param org_id Optional[int]: the ID of the organisation the user belongs to
:return Optional[int]: The number of family events in the repository
or nothing.
"""
try:
n_events = _get_query(db).count()
if org_id is None:
n_events = _get_query(db).count()
else:
n_events = _get_query(db).filter(Organisation.id == org_id).count()
except NoResultFound as e:
_LOGGER.error(e)
return
Expand Down
12 changes: 8 additions & 4 deletions app/repository/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,14 @@ def _update_intention(
return update_title, update_basics, update_metadata, update_collections


def all(db: Session, org_id: int, is_superuser: bool) -> list[FamilyReadDTO]:
def all(db: Session, org_id: Optional[int]) -> list[FamilyReadDTO]:
"""
Returns all the families.
:param db Session: the database connection
:return Optional[FamilyResponse]: All of things
"""
if is_superuser:
if org_id is None:
family_geo_metas = _get_query(db).order_by(desc(Family.last_modified)).all()
else:
family_geo_metas = (
Expand Down Expand Up @@ -499,15 +499,19 @@ def get_organisation(db: Session, family_import_id: str) -> Optional[Organisatio
)


def count(db: Session) -> Optional[int]:
def count(db: Session, org_id: Optional[int]) -> Optional[int]:
"""
Counts the number of families in the repository.
:param db Session: the database connection
:param org_id int: the ID of the organisation the user belongs to
:return Optional[int]: The number of families in the repository or none.
"""
try:
n_families = _get_query(db).count()
if org_id is None:
n_families = _get_query(db).count()
else:
n_families = _get_query(db).filter(Organisation.id == org_id).count()
except NoResultFound as e:
_LOGGER.error(e)
return
Expand Down
5 changes: 3 additions & 2 deletions app/repository/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ class FamilyRepo(Protocol):
return_empty: bool = False
throw_repository_error: bool = False
throw_timeout_error: bool = False
is_superuser: bool = False

@staticmethod
def all(db: Session, org_id: int, is_superuser: bool) -> list[FamilyReadDTO]:
def all(db: Session, org_id: Optional[int]) -> list[FamilyReadDTO]:
"""Returns all the families"""
...

Expand Down Expand Up @@ -48,6 +49,6 @@ def delete(db: Session, import_id: str) -> bool:
...

@staticmethod
def count(db: Session) -> Optional[int]:
def count(db: Session, org_id: Optional[int]) -> Optional[int]:
"""Counts all the families"""
...
37 changes: 21 additions & 16 deletions app/service/analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,40 @@
from pydantic import ConfigDict, validate_call
from sqlalchemy import exc

import app.service.collection as collection_service
import app.service.document as document_service
import app.service.event as event_service
import app.service.family as family_service
import app.clients.db.session as db_session
import app.repository.collection as collection_repo
import app.repository.document as document_repo
import app.repository.event as event_repo
import app.repository.family as family_repo
import app.service.app_user as app_user_service
from app.errors import RepositoryError
from app.model.analytics import SummaryDTO

_LOGGER = logging.getLogger(__name__)


@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def summary() -> SummaryDTO:
def summary(user_email: str) -> SummaryDTO:
"""
Gets an analytics summary from the repository.
:param user_email str: The email address of the current user.
:return SummaryDTO: The analytics summary found.
"""
try:
n_collections = collection_service.count()
n_families = family_service.count()
n_documents = document_service.count()
n_events = event_service.count()

return SummaryDTO(
n_documents=n_documents,
n_families=n_families,
n_collections=n_collections,
n_events=n_events,
)
with db_session.get_db() as db:
org_id = app_user_service.restrict_entities_to_user_org(db, user_email)
n_collections = collection_repo.count(db, org_id)
n_families = family_repo.count(db, org_id)
n_documents = document_repo.count(db, org_id)
n_events = event_repo.count(db, org_id)

return SummaryDTO(
n_documents=n_documents,
n_families=n_families,
n_collections=n_collections,
n_events=n_events,
)

except exc.SQLAlchemyError as e:
_LOGGER.error(e)
Expand Down
10 changes: 10 additions & 0 deletions app/service/app_user.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from sqlalchemy.orm import Session

from app.errors import ValidationError
Expand All @@ -15,3 +17,11 @@ def get_organisation(db: Session, user_email: str) -> int:
def is_superuser(db: Session, user_email: str) -> bool:
"""Determine a user's superuser status"""
return app_user_repo.is_superuser(db, user_email)


def restrict_entities_to_user_org(db: Session, user_email: str) -> Optional[int]:
org_id = get_organisation(db, user_email)
superuser: bool = is_superuser(db, user_email)
if superuser:
return None
return org_id
15 changes: 0 additions & 15 deletions app/service/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,21 +184,6 @@ def delete(import_id: str, context=None, db: Session = db_session.get_db()) -> b
return collection_repo.delete(db, import_id)


@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def count() -> Optional[int]:
"""
Gets a count of collections from the repository.
:return Optional[int]: The number of collections in the repository or none.
"""
try:
with db_session.get_db() as db:
return collection_repo.count(db)
except exc.SQLAlchemyError as e:
_LOGGER.error(e)
raise RepositoryError(str(e))


def get_org_from_id(db: Session, collection_import_id: str) -> Optional[int]:
org = collection_repo.get_org_from_collection_id(db, collection_import_id)
if org is None:
Expand Down
15 changes: 0 additions & 15 deletions app/service/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,3 @@ def delete(import_id: str, context=None, db: Session = db_session.get_db()) -> b
if context is not None:
context.error = f"Could not delete document {import_id}"
return document_repo.delete(db, import_id)


@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def count() -> Optional[int]:
"""
Gets a count of documents from the repository.
:return Optional[int]: The number of documents in the repository or none.
"""
try:
with db_session.get_db() as db:
return document_repo.count(db)
except exc.SQLAlchemyError as e:
_LOGGER.error(e)
raise RepositoryError(str(e))
15 changes: 0 additions & 15 deletions app/service/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,18 +141,3 @@ def delete(import_id: str, context=None, db: Session = db_session.get_db()) -> b
if context is not None:
context.error = f"Could not delete event {import_id}"
return event_repo.delete(db, import_id)


@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def count() -> Optional[int]:
"""
Gets a count of events from the repository.
:return Optional[int]: A count of events in the repository or none.
"""
try:
with db_session.get_db() as db:
return event_repo.count(db)
except exc.SQLAlchemyError as e:
_LOGGER.error(e)
raise RepositoryError(str(e))
Loading

0 comments on commit 6c99026

Please sign in to comment.