Skip to content

Commit

Permalink
Feature/pdct 1110 only show entities belonging to user orgcorpus (#135)
Browse files Browse the repository at this point in the history
* Only show families belonging to user org unless super

* Remove debug

* Placate pyright

* Bump to 2.6.1
  • Loading branch information
katybaulch authored May 21, 2024
1 parent 7959e55 commit ad1a161
Show file tree
Hide file tree
Showing 12 changed files with 132 additions and 27 deletions.
9 changes: 7 additions & 2 deletions app/api/api_v1/routers/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,18 @@ async def get_family(
"/families",
response_model=list[FamilyReadDTO],
)
async def get_all_families() -> list[FamilyReadDTO]:
async def get_all_families(request: Request) -> list[FamilyReadDTO]:
"""
Returns all families
:return FamilyDTO: returns a FamilyDTO of the family found.
"""
return family_service.all()
try:
return family_service.all(request.state.user.email)
except RepositoryError as e:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=e.message
)


@r.get(
Expand Down
16 changes: 16 additions & 0 deletions app/repository/app_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,22 @@ def get_user_by_email(db: Session, email: str) -> MaybeAppUser:
return db.query(AppUser).filter(AppUser.email == email).one()


def is_superuser(db: Session, email: str) -> bool:
"""Check whether user with email address is a superuser.
:param db Session: DB session to connect use.
:param email str: User email.
:return bool: Whether the user is a superuser or not.
"""
return (
db.query(AppUser)
.filter(AppUser.email == email)
.filter(AppUser.is_superuser == True) # noqa: E712
.count()
> 0
)


def is_active(db: Session, email: str) -> bool:
# NOTE: DO NOT be tempted to fix the below to "is True" - this breaks things
return (
Expand Down
12 changes: 10 additions & 2 deletions app/repository/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,22 @@ def _update_intention(
return update_title, update_basics, update_metadata, update_collections


def all(db: Session) -> list[FamilyReadDTO]:
def all(db: Session, org_id: int, is_superuser: bool) -> list[FamilyReadDTO]:
"""
Returns all the families.
:param db Session: the database connection
:return Optional[FamilyResponse]: All of things
"""
family_geo_metas = _get_query(db).order_by(desc(Family.last_modified)).all()
if is_superuser:
family_geo_metas = _get_query(db).order_by(desc(Family.last_modified)).all()
else:
family_geo_metas = (
_get_query(db)
.filter(Organisation.id == org_id)
.order_by(desc(Family.last_modified))
.all()
)

if not family_geo_metas:
return []
Expand Down
2 changes: 1 addition & 1 deletion app/repository/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class FamilyRepo(Protocol):
throw_timeout_error: bool = False

@staticmethod
def all(db: Session) -> list[FamilyReadDTO]:
def all(db: Session, org_id: int, is_superuser: bool) -> list[FamilyReadDTO]:
"""Returns all the families"""
...

Expand Down
5 changes: 5 additions & 0 deletions app/service/app_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,8 @@ def get_organisation(db: Session, user_email: str) -> int:
if org_id is None:
raise ValidationError(f"Could not get the organisation for user {user_email}")
return org_id


def is_superuser(db: Session, user_email: str) -> bool:
"""Determine a user's superuser status"""
return app_user_repo.is_superuser(db, user_email)
6 changes: 4 additions & 2 deletions app/service/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,16 @@ def get(import_id: str) -> Optional[FamilyReadDTO]:


@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
def all() -> list[FamilyReadDTO]:
def all(user_email: str) -> list[FamilyReadDTO]:
"""
Gets the entire list of families from the repository.
:return list[FamilyDTO]: The list of families.
"""
with db_session.get_db() as db:
return family_repo.all(db)
org_id = app_user.get_organisation(db, user_email)
is_superuser: bool = app_user.is_superuser(db, user_email)
return family_repo.all(db, org_id, is_superuser)


@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "admin_backend"
version = "2.6.0"
version = "2.6.1"
description = ""
authors = ["CPR-dev-team <[email protected]>"]
packages = [{ include = "app" }, { include = "tests" }]
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/config/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def test_get_config_unfccc_corpora_correct(
# Now sanity check the new corpora data
unfccc_corporas = data["corpora"]

assert unfccc_corporas[0]["corpus_import_id"] == "UNFCCC.corpus.1.0"
assert unfccc_corporas[0]["corpus_import_id"] == "UNFCCC.corpus.i00000001.n0000"
assert unfccc_corporas[0]["corpus_type"] == "Intl. agreements"
assert unfccc_corporas[0]["corpus_type_description"] == "Intl. agreements"
assert unfccc_corporas[0]["description"] == "UNFCCC Submissions"
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def rollback_event_repo(monkeypatch, mocker):

@pytest.fixture
def superuser_header_token() -> Dict[str, str]:
a_token = token_service.encode("test@cpr.org", True, {})
a_token = token_service.encode("super@cpr.org", True, {})
headers = {"Authorization": f"Bearer {a_token}"}
return headers

Expand All @@ -232,6 +232,6 @@ def non_cclw_user_header_token() -> Dict[str, str]:

@pytest.fixture
def admin_user_header_token() -> Dict[str, str]:
a_token = token_service.encode("test@cpr.org", False, {"is_admin": True})
a_token = token_service.encode("admin@cpr.org", False, {"is_admin": True})
headers = {"Authorization": f"Bearer {a_token}"}
return headers
65 changes: 63 additions & 2 deletions tests/integration_tests/family/test_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
# --- GET ALL


def test_get_all_families(client: TestClient, data_db: Session, user_header_token):
def test_get_all_families_superuser(
client: TestClient, data_db: Session, superuser_header_token
):
setup_db(data_db)
response = client.get(
"/api/v1/families",
headers=user_header_token,
headers=superuser_header_token,
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
Expand All @@ -37,6 +39,65 @@ def test_get_all_families(client: TestClient, data_db: Session, user_header_toke
assert response_data[2] == EXPECTED_FAMILIES[2]


def test_get_all_families_cclw(client: TestClient, data_db: Session, user_header_token):
setup_db(data_db)
response = client.get(
"/api/v1/families",
headers=user_header_token,
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert isinstance(data, list)
assert len(data) == 2
ids_found = set([f["import_id"] for f in data])
expected_ids = set(["A.0.0.1", "A.0.0.2"])

assert ids_found.symmetric_difference(expected_ids) == set([])

assert all(field in fam for fam in data for field in ("created", "last_modified"))
sorted_data = sorted(data, key=lambda d: d["import_id"])
response_data = [
{
k: v if not isinstance(v, list) else sorted(v)
for k, v in fam.items()
if k not in ("created", "last_modified")
}
for fam in sorted_data
]
assert response_data[0] == EXPECTED_FAMILIES[0]
assert response_data[1] == EXPECTED_FAMILIES[1]


def test_get_all_families_unfccc(
client: TestClient, data_db: Session, non_cclw_user_header_token
):
setup_db(data_db)
response = client.get(
"/api/v1/families",
headers=non_cclw_user_header_token,
)
assert response.status_code == status.HTTP_200_OK
data = response.json()
assert isinstance(data, list)
assert len(data) == 1
ids_found = set([f["import_id"] for f in data])
expected_ids = set(["A.0.0.3"])

assert ids_found.symmetric_difference(expected_ids) == set([])

assert all(field in fam for fam in data for field in ("created", "last_modified"))
sorted_data = sorted(data, key=lambda d: d["import_id"])
response_data = [
{
k: v if not isinstance(v, list) else sorted(v)
for k, v in fam.items()
if k not in ("created", "last_modified")
}
for fam in sorted_data
]
assert response_data[0] == EXPECTED_FAMILIES[2]


def test_get_all_families_when_not_authenticated(client: TestClient, data_db: Session):
setup_db(data_db)
response = client.get(
Expand Down
34 changes: 21 additions & 13 deletions tests/integration_tests/setup_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@
"category": "UNFCCC",
"status": "Created",
"metadata": {"size": [100], "color": ["blue"]},
"organisation": "CCLW",
"corpus_import_id": "CCLW.corpus.i00000001.n0000",
"corpus_title": "CCLW national policies",
"corpus_type": "Laws and Policies",
"organisation": "UNFCCC",
"corpus_import_id": "UNFCCC.corpus.i00000001.n0000",
"corpus_title": "UNFCCC Submissions",
"corpus_type": "Intl. agreements",
"slug": "Slug3",
"events": ["E.0.0.3"],
"published_date": "2018-12-24T04:59:33Z",
Expand Down Expand Up @@ -246,7 +246,8 @@ def _get_org_id_from_name(test_db: Session, name: str) -> int:

def _setup_organisation(test_db: Session) -> tuple[int, int]:
# Now an organisation
org = test_db.query(Organisation).filter(Organisation.name == "CCLW").one()
cclw = test_db.query(Organisation).filter(Organisation.name == "CCLW").one()
unfccc = test_db.query(Organisation).filter(Organisation.name == "UNFCCC").one()

another_org = Organisation(
name="Another org",
Expand All @@ -261,52 +262,59 @@ def _setup_organisation(test_db: Session) -> tuple[int, int]:
test_db,
"[email protected]",
"CCLWTestUser",
org.id,
cclw.id,
"$2b$12$XXMr7xoEY2fzNiMR3hq.PeJBUUchJyiTfJP.Rt2eq9hsPzt9SXzFC",
)
_add_app_user(
test_db,
"[email protected]",
"NonCCLWTestUser",
"UNFCCCTestUser",
unfccc.id,
"$2b$12$XXMr7xoEY2fzNiMR3hq.PeJBUUchJyiTfJP.Rt2eq9hsPzt9SXzFC",
)
_add_app_user(
test_db,
"[email protected]",
"AnotherTestUser",
another_org.id,
"$2b$12$XXMr7xoEY2fzNiMR3hq.PeJBUUchJyiTfJP.Rt2eq9hsPzt9SXzFC",
)
_add_app_user(
test_db,
"[email protected]",
"TestInactive",
org.id,
cclw.id,
hashed_pass="$2b$12$q.UbWEdeibUuApI2QDbmQeG5WmAPfNmooG1cAoCWjyJXvgiAVVdlK",
is_active=False,
)
_add_app_user(
test_db, "[email protected]", "TestHashedPassEmpty", org.id, hashed_pass=""
test_db, "[email protected]", "TestHashedPassEmpty", cclw.id, hashed_pass=""
)
_add_app_user(
test_db,
"[email protected]",
"TestPassMismatch",
org.id,
cclw.id,
hashed_pass="$2b$12$WZq1rRMvU.Tv1VutLw.rju/Ez5ETkYqP3KufdcSFJm3GTRZP8E52C",
)
_add_app_user(
test_db,
"[email protected]",
"Admin",
org.id,
cclw.id,
hashed_pass="$2b$12$XXMr7xoEY2fzNiMR3hq.PeJBUUchJyiTfJP.Rt2eq9hsPzt9SXzFC",
is_admin=True,
)
_add_app_user(
test_db,
"[email protected]",
"Super",
org.id,
cclw.id,
hashed_pass="$2b$12$XXMr7xoEY2fzNiMR3hq.PeJBUUchJyiTfJP.Rt2eq9hsPzt9SXzFC",
is_super=True,
)

return cast(int, org.id), cast(int, another_org.id)
return cast(int, cclw.id), cast(int, another_org.id)


def _setup_collection_data(
Expand Down
2 changes: 1 addition & 1 deletion tests/mocks/services/family_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def maybe_timeout():
if family_service.throw_timeout_error:
raise TimeoutError

def mock_get_all_families():
def mock_get_all_families(_):
return [create_family_read_dto("test", collections=["x.y.z.1", "x.y.z.2"])]

def mock_get_family(import_id: str) -> Optional[FamilyReadDTO]:
Expand Down

0 comments on commit ad1a161

Please sign in to comment.