diff --git a/app/api/api_v1/routers/family.py b/app/api/api_v1/routers/family.py index ab5134f8..2644c8bc 100644 --- a/app/api/api_v1/routers/family.py +++ b/app/api/api_v1/routers/family.py @@ -61,13 +61,18 @@ async def get_family( "/families", response_model=list[FamilyReadDTO], ) -async def get_all_families() -> list[FamilyReadDTO]: +async def get_all_families(request: Request) -> list[FamilyReadDTO]: """ Returns all families :return FamilyDTO: returns a FamilyDTO of the family found. """ - return family_service.all() + try: + return family_service.all(request.state.user.email) + except RepositoryError as e: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=e.message + ) @r.get( diff --git a/app/repository/app_user.py b/app/repository/app_user.py index d2e4f3f2..0d185a8f 100644 --- a/app/repository/app_user.py +++ b/app/repository/app_user.py @@ -10,6 +10,22 @@ def get_user_by_email(db: Session, email: str) -> MaybeAppUser: return db.query(AppUser).filter(AppUser.email == email).one() +def is_superuser(db: Session, email: str) -> bool: + """Check whether user with email address is a superuser. + + :param db Session: DB session to connect use. + :param email str: User email. + :return bool: Whether the user is a superuser or not. + """ + return ( + db.query(AppUser) + .filter(AppUser.email == email) + .filter(AppUser.is_superuser == True) # noqa: E712 + .count() + > 0 + ) + + def is_active(db: Session, email: str) -> bool: # NOTE: DO NOT be tempted to fix the below to "is True" - this breaks things return ( diff --git a/app/repository/family.py b/app/repository/family.py index fe396f36..00b8b673 100644 --- a/app/repository/family.py +++ b/app/repository/family.py @@ -116,14 +116,22 @@ def _update_intention( return update_title, update_basics, update_metadata, update_collections -def all(db: Session) -> list[FamilyReadDTO]: +def all(db: Session, org_id: int, is_superuser: bool) -> list[FamilyReadDTO]: """ Returns all the families. :param db Session: the database connection :return Optional[FamilyResponse]: All of things """ - family_geo_metas = _get_query(db).order_by(desc(Family.last_modified)).all() + if is_superuser: + family_geo_metas = _get_query(db).order_by(desc(Family.last_modified)).all() + else: + family_geo_metas = ( + _get_query(db) + .filter(Organisation.id == org_id) + .order_by(desc(Family.last_modified)) + .all() + ) if not family_geo_metas: return [] diff --git a/app/repository/protocols.py b/app/repository/protocols.py index 0eef3f9d..904beb6e 100644 --- a/app/repository/protocols.py +++ b/app/repository/protocols.py @@ -14,7 +14,7 @@ class FamilyRepo(Protocol): throw_timeout_error: bool = False @staticmethod - def all(db: Session) -> list[FamilyReadDTO]: + def all(db: Session, org_id: int, is_superuser: bool) -> list[FamilyReadDTO]: """Returns all the families""" ... diff --git a/app/service/app_user.py b/app/service/app_user.py index c126ca49..0252e176 100644 --- a/app/service/app_user.py +++ b/app/service/app_user.py @@ -10,3 +10,8 @@ def get_organisation(db: Session, user_email: str) -> int: if org_id is None: raise ValidationError(f"Could not get the organisation for user {user_email}") return org_id + + +def is_superuser(db: Session, user_email: str) -> bool: + """Determine a user's superuser status""" + return app_user_repo.is_superuser(db, user_email) diff --git a/app/service/family.py b/app/service/family.py index 4449ce2b..e36480dd 100644 --- a/app/service/family.py +++ b/app/service/family.py @@ -49,14 +49,16 @@ def get(import_id: str) -> Optional[FamilyReadDTO]: @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) -def all() -> list[FamilyReadDTO]: +def all(user_email: str) -> list[FamilyReadDTO]: """ Gets the entire list of families from the repository. :return list[FamilyDTO]: The list of families. """ with db_session.get_db() as db: - return family_repo.all(db) + org_id = app_user.get_organisation(db, user_email) + is_superuser: bool = app_user.is_superuser(db, user_email) + return family_repo.all(db, org_id, is_superuser) @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) diff --git a/pyproject.toml b/pyproject.toml index 5e6f3a1d..e9eef80a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "admin_backend" -version = "2.6.0" +version = "2.6.1" description = "" authors = ["CPR-dev-team "] packages = [{ include = "app" }, { include = "tests" }] diff --git a/tests/integration_tests/config/test_config.py b/tests/integration_tests/config/test_config.py index afce399f..476ae935 100644 --- a/tests/integration_tests/config/test_config.py +++ b/tests/integration_tests/config/test_config.py @@ -178,7 +178,7 @@ def test_get_config_unfccc_corpora_correct( # Now sanity check the new corpora data unfccc_corporas = data["corpora"] - assert unfccc_corporas[0]["corpus_import_id"] == "UNFCCC.corpus.1.0" + assert unfccc_corporas[0]["corpus_import_id"] == "UNFCCC.corpus.i00000001.n0000" assert unfccc_corporas[0]["corpus_type"] == "Intl. agreements" assert unfccc_corporas[0]["corpus_type_description"] == "Intl. agreements" assert unfccc_corporas[0]["description"] == "UNFCCC Submissions" diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 90dc9555..e24ad8a5 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -211,7 +211,7 @@ def rollback_event_repo(monkeypatch, mocker): @pytest.fixture def superuser_header_token() -> Dict[str, str]: - a_token = token_service.encode("test@cpr.org", True, {}) + a_token = token_service.encode("super@cpr.org", True, {}) headers = {"Authorization": f"Bearer {a_token}"} return headers @@ -232,6 +232,6 @@ def non_cclw_user_header_token() -> Dict[str, str]: @pytest.fixture def admin_user_header_token() -> Dict[str, str]: - a_token = token_service.encode("test@cpr.org", False, {"is_admin": True}) + a_token = token_service.encode("admin@cpr.org", False, {"is_admin": True}) headers = {"Authorization": f"Bearer {a_token}"} return headers diff --git a/tests/integration_tests/family/test_get.py b/tests/integration_tests/family/test_get.py index 6dcf4a14..166acca1 100644 --- a/tests/integration_tests/family/test_get.py +++ b/tests/integration_tests/family/test_get.py @@ -7,11 +7,13 @@ # --- GET ALL -def test_get_all_families(client: TestClient, data_db: Session, user_header_token): +def test_get_all_families_superuser( + client: TestClient, data_db: Session, superuser_header_token +): setup_db(data_db) response = client.get( "/api/v1/families", - headers=user_header_token, + headers=superuser_header_token, ) assert response.status_code == status.HTTP_200_OK data = response.json() @@ -37,6 +39,65 @@ def test_get_all_families(client: TestClient, data_db: Session, user_header_toke assert response_data[2] == EXPECTED_FAMILIES[2] +def test_get_all_families_cclw(client: TestClient, data_db: Session, user_header_token): + setup_db(data_db) + response = client.get( + "/api/v1/families", + headers=user_header_token, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + assert len(data) == 2 + ids_found = set([f["import_id"] for f in data]) + expected_ids = set(["A.0.0.1", "A.0.0.2"]) + + assert ids_found.symmetric_difference(expected_ids) == set([]) + + assert all(field in fam for fam in data for field in ("created", "last_modified")) + sorted_data = sorted(data, key=lambda d: d["import_id"]) + response_data = [ + { + k: v if not isinstance(v, list) else sorted(v) + for k, v in fam.items() + if k not in ("created", "last_modified") + } + for fam in sorted_data + ] + assert response_data[0] == EXPECTED_FAMILIES[0] + assert response_data[1] == EXPECTED_FAMILIES[1] + + +def test_get_all_families_unfccc( + client: TestClient, data_db: Session, non_cclw_user_header_token +): + setup_db(data_db) + response = client.get( + "/api/v1/families", + headers=non_cclw_user_header_token, + ) + assert response.status_code == status.HTTP_200_OK + data = response.json() + assert isinstance(data, list) + assert len(data) == 1 + ids_found = set([f["import_id"] for f in data]) + expected_ids = set(["A.0.0.3"]) + + assert ids_found.symmetric_difference(expected_ids) == set([]) + + assert all(field in fam for fam in data for field in ("created", "last_modified")) + sorted_data = sorted(data, key=lambda d: d["import_id"]) + response_data = [ + { + k: v if not isinstance(v, list) else sorted(v) + for k, v in fam.items() + if k not in ("created", "last_modified") + } + for fam in sorted_data + ] + assert response_data[0] == EXPECTED_FAMILIES[2] + + def test_get_all_families_when_not_authenticated(client: TestClient, data_db: Session): setup_db(data_db) response = client.get( diff --git a/tests/integration_tests/setup_db.py b/tests/integration_tests/setup_db.py index 41e8a26e..d9603076 100644 --- a/tests/integration_tests/setup_db.py +++ b/tests/integration_tests/setup_db.py @@ -72,10 +72,10 @@ "category": "UNFCCC", "status": "Created", "metadata": {"size": [100], "color": ["blue"]}, - "organisation": "CCLW", - "corpus_import_id": "CCLW.corpus.i00000001.n0000", - "corpus_title": "CCLW national policies", - "corpus_type": "Laws and Policies", + "organisation": "UNFCCC", + "corpus_import_id": "UNFCCC.corpus.i00000001.n0000", + "corpus_title": "UNFCCC Submissions", + "corpus_type": "Intl. agreements", "slug": "Slug3", "events": ["E.0.0.3"], "published_date": "2018-12-24T04:59:33Z", @@ -246,7 +246,8 @@ def _get_org_id_from_name(test_db: Session, name: str) -> int: def _setup_organisation(test_db: Session) -> tuple[int, int]: # Now an organisation - org = test_db.query(Organisation).filter(Organisation.name == "CCLW").one() + cclw = test_db.query(Organisation).filter(Organisation.name == "CCLW").one() + unfccc = test_db.query(Organisation).filter(Organisation.name == "UNFCCC").one() another_org = Organisation( name="Another org", @@ -261,13 +262,20 @@ def _setup_organisation(test_db: Session) -> tuple[int, int]: test_db, "test@cpr.org", "CCLWTestUser", - org.id, + cclw.id, "$2b$12$XXMr7xoEY2fzNiMR3hq.PeJBUUchJyiTfJP.Rt2eq9hsPzt9SXzFC", ) _add_app_user( test_db, "unfccc@cpr.org", - "NonCCLWTestUser", + "UNFCCCTestUser", + unfccc.id, + "$2b$12$XXMr7xoEY2fzNiMR3hq.PeJBUUchJyiTfJP.Rt2eq9hsPzt9SXzFC", + ) + _add_app_user( + test_db, + "another@cpr.org", + "AnotherTestUser", another_org.id, "$2b$12$XXMr7xoEY2fzNiMR3hq.PeJBUUchJyiTfJP.Rt2eq9hsPzt9SXzFC", ) @@ -275,25 +283,25 @@ def _setup_organisation(test_db: Session) -> tuple[int, int]: test_db, "test1@cpr.org", "TestInactive", - org.id, + cclw.id, hashed_pass="$2b$12$q.UbWEdeibUuApI2QDbmQeG5WmAPfNmooG1cAoCWjyJXvgiAVVdlK", is_active=False, ) _add_app_user( - test_db, "test2@cpr.org", "TestHashedPassEmpty", org.id, hashed_pass="" + test_db, "test2@cpr.org", "TestHashedPassEmpty", cclw.id, hashed_pass="" ) _add_app_user( test_db, "test3@cpr.org", "TestPassMismatch", - org.id, + cclw.id, hashed_pass="$2b$12$WZq1rRMvU.Tv1VutLw.rju/Ez5ETkYqP3KufdcSFJm3GTRZP8E52C", ) _add_app_user( test_db, "admin@cpr.org", "Admin", - org.id, + cclw.id, hashed_pass="$2b$12$XXMr7xoEY2fzNiMR3hq.PeJBUUchJyiTfJP.Rt2eq9hsPzt9SXzFC", is_admin=True, ) @@ -301,12 +309,12 @@ def _setup_organisation(test_db: Session) -> tuple[int, int]: test_db, "super@cpr.org", "Super", - org.id, + cclw.id, hashed_pass="$2b$12$XXMr7xoEY2fzNiMR3hq.PeJBUUchJyiTfJP.Rt2eq9hsPzt9SXzFC", is_super=True, ) - return cast(int, org.id), cast(int, another_org.id) + return cast(int, cclw.id), cast(int, another_org.id) def _setup_collection_data( diff --git a/tests/mocks/services/family_service.py b/tests/mocks/services/family_service.py index 7e73b5b4..ec80a2e4 100644 --- a/tests/mocks/services/family_service.py +++ b/tests/mocks/services/family_service.py @@ -23,7 +23,7 @@ def maybe_timeout(): if family_service.throw_timeout_error: raise TimeoutError - def mock_get_all_families(): + def mock_get_all_families(_): return [create_family_read_dto("test", collections=["x.y.z.1", "x.y.z.2"])] def mock_get_family(import_id: str) -> Optional[FamilyReadDTO]: