Skip to content

Commit

Permalink
Pdct 1110 update search to only search entities in user org (#139)
Browse files Browse the repository at this point in the history
* Add app user mocking for email address to search unit tests

* Update service count unit tests to only count entities in user org

* Only show count of entities in user org unless superuser

* Move is_superuser out of repo layer

* Fix ordering of function calls to fix tests

* Only show entities on search belonging to same org as user unless super

* Add filter by org logic to entity repo

* Add integration tests for super and non super search calls

* Bump to 2.6.4
  • Loading branch information
katybaulch authored Jun 3, 2024
1 parent 6c99026 commit ee51631
Show file tree
Hide file tree
Showing 34 changed files with 259 additions and 114 deletions.
2 changes: 1 addition & 1 deletion app/api/api_v1/routers/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async def search_collection(request: Request) -> list[CollectionReadDTO]:
validate_query_params(query_params, VALID_PARAMS)

try:
collections = collection_service.search(query_params)
collections = collection_service.search(query_params, request.state.user.email)
except ValidationError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.message)
except RepositoryError as e:
Expand Down
2 changes: 1 addition & 1 deletion app/api/api_v1/routers/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async def search_document(request: Request) -> list[DocumentReadDTO]:
validate_query_params(query_params, VALID_PARAMS)

try:
documents = document_service.search(query_params)
documents = document_service.search(query_params, request.state.user.email)
except ValidationError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.message)
except RepositoryError as e:
Expand Down
2 changes: 1 addition & 1 deletion app/api/api_v1/routers/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ async def search_event(request: Request) -> list[EventReadDTO]:
validate_query_params(query_params, VALID_PARAMS)

try:
events_found = event_service.search(query_params)
events_found = event_service.search(query_params, request.state.user.email)
except ValidationError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.message)
except RepositoryError as e:
Expand Down
2 changes: 1 addition & 1 deletion app/api/api_v1/routers/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async def search_family(request: Request) -> list[FamilyReadDTO]:
validate_query_params(query_params, VALID_PARAMS)

try:
families = family_service.search(query_params)
families = family_service.search(query_params, request.state.user.email)
except ValidationError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.message)
except RepositoryError as e:
Expand Down
18 changes: 10 additions & 8 deletions app/repository/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,15 @@ def get(db: Session, import_id: str) -> Optional[CollectionReadDTO]:


def search(
db: Session, query_params: dict[str, Union[str, int]]
db: Session, query_params: dict[str, Union[str, int]], org_id: Optional[int]
) -> list[CollectionReadDTO]:
"""
Gets a list of collections from the repo searching given fields.
:param db Session: the database connection
:param dict query_params: Any search terms to filter on specified
fields (title & summary by default if 'q' specified).
:param org_id Optional[int]: the ID of the organisation the user belongs to
:raises HTTPException: If a DB error occurs a 503 is returned.
:raises HTTPException: If the search request times out a 408 is
returned.
Expand All @@ -156,10 +157,11 @@ def search(

condition = and_(*search) if len(search) > 1 else search[0]
try:
query = _get_query(db).filter(condition)
if org_id is not None:
query = query.filter(Organisation.id == org_id)
found = (
_get_query(db)
.filter(condition)
.order_by(desc(Collection.last_modified))
query.order_by(desc(Collection.last_modified))
.limit(query_params["max_results"])
.all()
)
Expand Down Expand Up @@ -268,10 +270,10 @@ def count(db: Session, org_id: Optional[int]) -> Optional[int]:
:return Optional[int]: The number of collections in the repository or none.
"""
try:
if org_id is None:
n_collections = _get_query(db).count()
else:
n_collections = _get_query(db).filter(Organisation.id == org_id).count()
query = _get_query(db)
if org_id is not None:
query = query.filter(Organisation.id == org_id)
n_collections = query.count()
except Exception as e:
_LOGGER.error(e)
return
Expand Down
18 changes: 10 additions & 8 deletions app/repository/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,15 @@ def get(db: Session, import_id: str) -> Optional[DocumentReadDTO]:


def search(
db: Session, query_params: dict[str, Union[str, int]]
db: Session, query_params: dict[str, Union[str, int]], org_id: Optional[int]
) -> list[DocumentReadDTO]:
"""
Gets a list of documents from the repository searching the title.
:param db Session: the database connection
:param dict query_params: Any search terms to filter on specified
fields (title by default if 'q' specified).
:param org_id Optional[int]: the ID of the organisation the user belongs to
:raises HTTPException: If a DB error occurs a 503 is returned.
:raises HTTPException: If the search request times out a 408 is
returned.
Expand All @@ -183,9 +184,10 @@ def search(
condition = and_(*search) if len(search) > 1 else search[0]
try:
# TODO: Fix order by on search PDCT-672
result = (
_get_query(db).filter(condition).limit(query_params["max_results"]).all()
)
query = _get_query(db).filter(condition)
if org_id is not None:
query = query.filter(Organisation.id == org_id)
result = query.limit(query_params["max_results"]).all()
except OperationalError as e:
if "canceling statement due to statement timeout" in str(e):
raise TimeoutError
Expand Down Expand Up @@ -444,10 +446,10 @@ def count(db: Session, org_id: Optional[int]) -> Optional[int]:
:return Optional[int]: The number of documents in the repository or none.
"""
try:
if org_id is None:
n_documents = _get_query(db).count()
else:
n_documents = _get_query(db).filter(Organisation.id == org_id).count()
query = _get_query(db)
if org_id is not None:
query = query.filter(Organisation.id == org_id)
n_documents = query.count()
except NoResultFound as e:
_LOGGER.error(e)
return
Expand Down
20 changes: 12 additions & 8 deletions app/repository/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,16 @@ def get(db: Session, import_id: str) -> Optional[EventReadDTO]:
return _event_to_dto(family_event_meta)


def search(db: Session, query_params: dict[str, Union[str, int]]) -> list[EventReadDTO]:
def search(
db: Session, query_params: dict[str, Union[str, int]], org_id: Optional[int]
) -> list[EventReadDTO]:
"""
Get family events matching a search term on the event title or type.
:param db Session: The database connection.
:param dict query_params: Any search terms to filter on specified
fields (title & event type name by default if 'q' specified).
:param org_id Optional[int]: the ID of the organisation the user belongs to
:raises HTTPException: If a DB error occurs a 503 is returned.
:raises HTTPException: If the search request times out a 408 is
returned.
Expand All @@ -132,9 +135,10 @@ def search(db: Session, query_params: dict[str, Union[str, int]]) -> list[EventR

condition = and_(*search) if len(search) > 1 else search[0]
try:
found = (
_get_query(db).filter(condition).limit(query_params["max_results"]).all()
)
query = _get_query(db).filter(condition)
if org_id is not None:
query = query.filter(Organisation.id == org_id)
found = query.limit(query_params["max_results"]).all()
except OperationalError as e:
if "canceling statement due to statement timeout" in str(e):
raise TimeoutError
Expand Down Expand Up @@ -252,10 +256,10 @@ def count(db: Session, org_id: Optional[int]) -> Optional[int]:
or nothing.
"""
try:
if org_id is None:
n_events = _get_query(db).count()
else:
n_events = _get_query(db).filter(Organisation.id == org_id).count()
query = _get_query(db)
if org_id is not None:
query = query.filter(Organisation.id == org_id)
n_events = query.count()
except NoResultFound as e:
_LOGGER.error(e)
return
Expand Down
32 changes: 15 additions & 17 deletions app/repository/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,13 @@ def all(db: Session, org_id: Optional[int]) -> list[FamilyReadDTO]:
Returns all the families.
:param db Session: the database connection
:param org_id int: the ID of the organisation the user belongs to
:return Optional[FamilyResponse]: All of things
"""
if org_id is None:
family_geo_metas = _get_query(db).order_by(desc(Family.last_modified)).all()
else:
family_geo_metas = (
_get_query(db)
.filter(Organisation.id == org_id)
.order_by(desc(Family.last_modified))
.all()
)
query = _get_query(db)
if org_id is not None:
query = query.filter(Organisation.id == org_id)
family_geo_metas = query.order_by(desc(Family.last_modified)).all()

if not family_geo_metas:
return []
Expand Down Expand Up @@ -159,14 +155,15 @@ def get(db: Session, import_id: str) -> Optional[FamilyReadDTO]:


def search(
db: Session, query_params: dict[str, Union[str, int]]
db: Session, query_params: dict[str, Union[str, int]], org_id: Optional[int]
) -> list[FamilyReadDTO]:
"""
Gets a list of families from the repository searching given fields.
:param db Session: the database connection
:param dict query_params: Any search terms to filter on specified
fields (title & summary by default if 'q' specified).
:param org_id Optional[int]: the ID of the organisation the user belongs to
:raises HTTPException: If a DB error occurs a 503 is returned.
:raises HTTPException: If the search request times out a 408 is
returned.
Expand Down Expand Up @@ -200,10 +197,11 @@ def search(

condition = and_(*search) if len(search) > 1 else search[0]
try:
query = _get_query(db).filter(condition)
if org_id is not None:
query = query.filter(Organisation.id == org_id)
found = (
_get_query(db)
.filter(condition)
.order_by(desc(Family.last_modified))
query.order_by(desc(Family.last_modified))
.limit(query_params["max_results"])
.all()
)
Expand Down Expand Up @@ -508,10 +506,10 @@ def count(db: Session, org_id: Optional[int]) -> Optional[int]:
:return Optional[int]: The number of families in the repository or none.
"""
try:
if org_id is None:
n_families = _get_query(db).count()
else:
n_families = _get_query(db).filter(Organisation.id == org_id).count()
query = _get_query(db)
if org_id is not None:
query = query.filter(Organisation.id == org_id)
n_families = query.count()
except NoResultFound as e:
_LOGGER.error(e)
return
Expand Down
2 changes: 1 addition & 1 deletion app/repository/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def get(db: Session, import_id: str) -> Optional[FamilyReadDTO]:

@staticmethod
def search(
db: Session, query_params: dict[str, Union[str, int]]
db: Session, query_params: dict[str, Union[str, int]], org_id: Optional[int]
) -> list[FamilyReadDTO]:
"""Searches the families"""
...
Expand Down
8 changes: 6 additions & 2 deletions app/service/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def all() -> list[CollectionReadDTO]:


@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def search(query_params: dict[str, Union[str, int]]) -> list[CollectionReadDTO]:
def search(
query_params: dict[str, Union[str, int]], user_email: str
) -> list[CollectionReadDTO]:
"""
Searches for the search term against collections on specified fields.
Expand All @@ -70,11 +72,13 @@ def search(query_params: dict[str, Union[str, int]]) -> list[CollectionReadDTO]:
:param dict query_params: Search patterns to match against specified
fields, given as key value pairs in a dictionary.
:param str user_email: The email address of the current user.
:return list[CollectionReadDTO]: The list of collections matching
the given search terms.
"""
with db_session.get_db() as db:
return collection_repo.search(db, query_params)
org_id = app_user.restrict_entities_to_user_org(db, user_email)
return collection_repo.search(db, query_params, org_id)


@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
Expand Down
10 changes: 7 additions & 3 deletions app/service/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from app.clients.aws.client import get_s3_client
from app.errors import RepositoryError, ValidationError
from app.model.document import DocumentCreateDTO, DocumentReadDTO, DocumentWriteDTO
from app.service import id
from app.service import app_user, id

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -55,7 +55,9 @@ def all() -> list[DocumentReadDTO]:


@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def search(query_params: dict[str, Union[str, int]]) -> list[DocumentReadDTO]:
def search(
query_params: dict[str, Union[str, int]], user_email: str
) -> list[DocumentReadDTO]:
"""
Searches for the search term against documents on specified fields.
Expand All @@ -64,11 +66,13 @@ def search(query_params: dict[str, Union[str, int]]) -> list[DocumentReadDTO]:
:param dict query_params: Search patterns to match against specified
fields, given as key value pairs in a dictionary.
:param str user_email: The email address of the current user.
:return list[DocumentReadDTO]: The list of documents matching the
given search terms.
"""
with db_session.get_db() as db:
return document_repo.search(db, query_params)
org_id = app_user.restrict_entities_to_user_org(db, user_email)
return document_repo.search(db, query_params, org_id)


@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
Expand Down
10 changes: 7 additions & 3 deletions app/service/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import app.service.family as family_service
from app.errors import RepositoryError, ValidationError
from app.model.event import EventCreateDTO, EventReadDTO, EventWriteDTO
from app.service import id
from app.service import app_user, id

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -46,7 +46,9 @@ def all() -> list[EventReadDTO]:


@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def search(query_params: dict[str, Union[str, int]]) -> list[EventReadDTO]:
def search(
query_params: dict[str, Union[str, int]], user_email: str
) -> list[EventReadDTO]:
"""
Searches for the search term against events on specified fields.
Expand All @@ -56,11 +58,13 @@ def search(query_params: dict[str, Union[str, int]]) -> list[EventReadDTO]:
:param dict query_params: Search patterns to match against specified
fields, given as key value pairs in a dictionary.
:param str user_email: The email address of the current user.
:return list[EventReadDTO]: The list of events matching the given
search terms.
"""
with db_session.get_db() as db:
return event_repo.search(db, query_params)
org_id = app_user.restrict_entities_to_user_org(db, user_email)
return event_repo.search(db, query_params, org_id)


@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
Expand Down
8 changes: 6 additions & 2 deletions app/service/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def all(user_email: str) -> list[FamilyReadDTO]:


@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def search(query_params: dict[str, Union[str, int]]) -> list[FamilyReadDTO]:
def search(
query_params: dict[str, Union[str, int]], user_email: str
) -> list[FamilyReadDTO]:
"""
Searches for the search term against families on specified fields.
Expand All @@ -71,11 +73,13 @@ def search(query_params: dict[str, Union[str, int]]) -> list[FamilyReadDTO]:
:param dict query_params: Search patterns to match against specified
fields, given as key value pairs in a dictionary.
:param str user_email: The email address of the current user.
:return list[FamilyDTO]: The list of families matching the given
search terms.
"""
with db_session.get_db() as db:
return family_repo.search(db, query_params)
org_id = app_user.restrict_entities_to_user_org(db, user_email)
return family_repo.search(db, query_params, org_id)


@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "admin_backend"
version = "2.6.3"
version = "2.6.4"
description = ""
authors = ["CPR-dev-team <[email protected]>"]
packages = [{ include = "app" }, { include = "tests" }]
Expand Down
Loading

0 comments on commit ee51631

Please sign in to comment.