From 699583986aa248c393a238873d25e2f43574c9a2 Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Tue, 20 Sep 2022 17:36:59 +0200 Subject: [PATCH] Check user belongs to dataset organization before creation --- server/api/datasets/routes.py | 12 +- server/application/datasets/commands.py | 7 +- server/application/datasets/exceptions.py | 2 + server/application/datasets/handlers.py | 27 +-- server/application/datasets/specifications.py | 9 + server/domain/common/types.py | 8 + tests/api/test_catalogs.py | 16 +- tests/api/test_datasets.py | 163 +++++++++++++----- tests/api/test_datasets_filters.py | 23 ++- tests/api/test_datasets_search.py | 20 ++- tests/api/test_licenses.py | 4 +- tests/conftest.py | 7 +- tests/factories.py | 17 +- tests/helpers.py | 8 +- tests/infrastructure/test_catalogs.py | 13 +- tests/infrastructure/test_datasets.py | 7 +- tools/addrandomdatasets.py | 5 +- tools/initdata.py | 3 +- 18 files changed, 259 insertions(+), 92 deletions(-) create mode 100644 server/application/datasets/exceptions.py create mode 100644 server/application/datasets/specifications.py diff --git a/server/api/datasets/routes.py b/server/api/datasets/routes.py index bcc616da..d599abf5 100644 --- a/server/api/datasets/routes.py +++ b/server/api/datasets/routes.py @@ -7,18 +7,20 @@ DeleteDataset, UpdateDataset, ) +from server.application.datasets.exceptions import CannotCreateDataset from server.application.datasets.queries import GetAllDatasets, GetDatasetByID from server.application.datasets.views import DatasetView from server.config.di import resolve from server.domain.auth.entities import UserRole +from server.domain.catalogs.exceptions import CatalogDoesNotExist from server.domain.common.pagination import Page, Pagination from server.domain.common.types import ID from server.domain.datasets.exceptions import DatasetDoesNotExist from server.domain.datasets.specifications import DatasetSpec -from server.domain.organizations.exceptions import OrganizationDoesNotExist from server.seedwork.application.messages import MessageBus from ..auth.permissions import HasRole, IsAuthenticated +from ..types import APIRequest from . import filters from .schemas import DatasetCreate, DatasetListParams, DatasetUpdate @@ -78,15 +80,17 @@ async def get_dataset_by_id(id: ID) -> DatasetView: response_model=DatasetView, status_code=201, ) -async def create_dataset(data: DatasetCreate) -> DatasetView: +async def create_dataset(data: DatasetCreate, request: "APIRequest") -> DatasetView: bus = resolve(MessageBus) - command = CreateDataset(**data.dict()) + command = CreateDataset(account=request.user.account, **data.dict()) try: id = await bus.execute(command) - except OrganizationDoesNotExist as exc: + except CatalogDoesNotExist as exc: raise HTTPException(400, detail=str(exc)) + except CannotCreateDataset: + raise HTTPException(403, detail="Permission denied") query = GetDatasetByID(id=id) return await bus.execute(query) diff --git a/server/application/datasets/commands.py b/server/application/datasets/commands.py index 1d4f12c6..4ae0ae54 100644 --- a/server/application/datasets/commands.py +++ b/server/application/datasets/commands.py @@ -1,10 +1,11 @@ import datetime as dt -from typing import List, Optional +from typing import List, Optional, Union from pydantic import EmailStr, Field +from server.domain.auth.entities import Account from server.domain.catalogs.entities import ExtraFieldValue -from server.domain.common.types import ID +from server.domain.common.types import ID, Skip from server.domain.datasets.entities import DataFormat, UpdateFrequency from server.domain.organizations.entities import LEGACY_ORGANIZATION from server.domain.organizations.types import Siret @@ -14,6 +15,8 @@ class CreateDataset(CreateDatasetValidationMixin, Command[ID]): + account: Union[Account, Skip] + organization_siret: Siret = LEGACY_ORGANIZATION.siret title: str description: str diff --git a/server/application/datasets/exceptions.py b/server/application/datasets/exceptions.py new file mode 100644 index 00000000..1c9aa84d --- /dev/null +++ b/server/application/datasets/exceptions.py @@ -0,0 +1,2 @@ +class CannotCreateDataset(Exception): + pass diff --git a/server/application/datasets/handlers.py b/server/application/datasets/handlers.py index 375e4226..b26a0f1f 100644 --- a/server/application/datasets/handlers.py +++ b/server/application/datasets/handlers.py @@ -4,41 +4,48 @@ from server.config.di import resolve from server.domain.catalog_records.entities import CatalogRecord from server.domain.catalog_records.repositories import CatalogRecordRepository +from server.domain.catalogs.exceptions import CatalogDoesNotExist +from server.domain.catalogs.repositories import CatalogRepository from server.domain.common.pagination import Pagination -from server.domain.common.types import ID +from server.domain.common.types import ID, Skip from server.domain.datasets.entities import DataFormat, Dataset from server.domain.datasets.exceptions import DatasetDoesNotExist from server.domain.datasets.repositories import DatasetRepository -from server.domain.organizations.exceptions import OrganizationDoesNotExist -from server.domain.organizations.repositories import OrganizationRepository from server.domain.tags.repositories import TagRepository from server.seedwork.application.messages import MessageBus from .commands import CreateDataset, DeleteDataset, UpdateDataset +from .exceptions import CannotCreateDataset from .queries import GetAllDatasets, GetDatasetByID, GetDatasetFilters +from .specifications import can_create_dataset from .views import DatasetFiltersView, DatasetView async def create_dataset(command: CreateDataset, *, id_: ID = None) -> ID: repository = resolve(DatasetRepository) - organization_repository = resolve(OrganizationRepository) + catalog_repository = resolve(CatalogRepository) catalog_record_repository = resolve(CatalogRecordRepository) tag_repository = resolve(TagRepository) if id_ is None: id_ = repository.make_id() - organization = await organization_repository.get_by_siret( - siret=command.organization_siret - ) + catalog = await catalog_repository.get_by_siret(siret=command.organization_siret) + + if catalog is None: + raise CatalogDoesNotExist(command.organization_siret) - if organization is None: - raise OrganizationDoesNotExist(command.organization_siret) + if not isinstance(command.account, Skip) and not can_create_dataset( + catalog, command.account + ): + raise CannotCreateDataset( + f"{command.account.organization_siret=}, {catalog.organization.siret=}" + ) catalog_record_id = await catalog_record_repository.insert( CatalogRecord( id=catalog_record_repository.make_id(), - organization=organization, + organization=catalog.organization, ) ) catalog_record = await catalog_record_repository.get_by_id(catalog_record_id) diff --git a/server/application/datasets/specifications.py b/server/application/datasets/specifications.py new file mode 100644 index 00000000..a1e47206 --- /dev/null +++ b/server/application/datasets/specifications.py @@ -0,0 +1,9 @@ +from server.domain.auth.entities import Account, UserRole +from server.domain.catalogs.entities import Catalog + + +def can_create_dataset(catalog: Catalog, account: Account) -> bool: + if account.role == UserRole.ADMIN: + return True + + return catalog.organization.siret == account.organization_siret diff --git a/server/domain/common/types.py b/server/domain/common/types.py index 839d4e32..3f7ec230 100644 --- a/server/domain/common/types.py +++ b/server/domain/common/types.py @@ -1,8 +1,16 @@ import uuid from typing import NewType +from pydantic import BaseModel + ID = NewType("ID", uuid.UUID) def id_factory() -> ID: return ID(uuid.uuid4()) + + +class Skip(BaseModel): + """ + A marker class for when an operation should be skipped. + """ diff --git a/tests/api/test_catalogs.py b/tests/api/test_catalogs.py index ad127592..272ada27 100644 --- a/tests/api/test_catalogs.py +++ b/tests/api/test_catalogs.py @@ -11,14 +11,22 @@ from server.domain.organizations.types import Siret from server.seedwork.application.messages import MessageBus -from ..factories import CreateDatasetFactory, CreateOrganizationFactory, fake -from ..helpers import TestPasswordUser, api_key_auth +from ..factories import ( + CreateDatasetFactory, + CreateOrganizationFactory, + CreatePasswordUserFactory, + fake, +) +from ..helpers import TestPasswordUser, api_key_auth, create_test_password_user @pytest.mark.asyncio async def test_catalog_create(client: httpx.AsyncClient) -> None: bus = resolve(MessageBus) siret = await bus.execute(CreateOrganizationFactory.build(name="Org 1")) + user = await create_test_password_user( + CreatePasswordUserFactory.build(organization_siret=siret) + ) response = await client.post( "/catalogs/", json={"organization_siret": str(siret)}, auth=api_key_auth @@ -34,7 +42,9 @@ async def test_catalog_create(client: httpx.AsyncClient) -> None: catalog = await bus.execute(GetCatalogBySiret(siret=siret)) assert catalog.organization.siret == siret - dataset_id = await bus.execute(CreateDatasetFactory.build(organization_siret=siret)) + dataset_id = await bus.execute( + CreateDatasetFactory.build(account=user.account, organization_siret=siret) + ) dataset = await bus.execute(GetDatasetByID(id=dataset_id)) assert dataset.catalog_record.organization.siret == siret diff --git a/tests/api/test_datasets.py b/tests/api/test_datasets.py index 086f1211..88e53b29 100644 --- a/tests/api/test_datasets.py +++ b/tests/api/test_datasets.py @@ -1,5 +1,5 @@ import random -from typing import Any, List +from typing import Any, List, Tuple import httpx import pytest @@ -22,8 +22,14 @@ from server.seedwork.application.messages import MessageBus from tests.factories import CreateDatasetFactory -from ..factories import CreateOrganizationFactory, UpdateDatasetFactory, fake -from ..helpers import TestPasswordUser, to_payload +from ..factories import ( + CreateDatasetPayloadFactory, + CreateOrganizationFactory, + CreatePasswordUserFactory, + UpdateDatasetFactory, + fake, +) +from ..helpers import TestPasswordUser, create_test_password_user, to_payload @pytest.mark.asyncio @@ -99,15 +105,25 @@ async def test_create_dataset_invalid( @pytest.mark.asyncio -async def test_create_dataset_invalid_organization_does_not_exist( - client: httpx.AsyncClient, temp_user: TestPasswordUser +async def test_create_dataset_invalid_catalog_does_not_exist( + client: httpx.AsyncClient, ) -> None: - siret = Siret(fake.siret()) - payload = to_payload(CreateDatasetFactory.build(organization_siret=siret)) - response = await client.post("/datasets/", json=payload, auth=temp_user.auth) + bus = resolve(MessageBus) + siret = await bus.execute(CreateOrganizationFactory.build()) + # Catalog not created... + user = await create_test_password_user( + CreatePasswordUserFactory.build(organization_siret=siret) + ) + + payload = to_payload( + CreateDatasetPayloadFactory.build( + account=user.account, organization_siret=siret + ) + ) + response = await client.post("/datasets/", json=payload, auth=user.auth) assert response.status_code == 400 data = response.json() - assert data["detail"] == f"Organization not found: '{siret}'" + assert data["detail"] == f"Catalog not found: '{siret}'" @pytest.mark.asyncio @@ -117,7 +133,7 @@ async def test_dataset_crud( last_updated_at = fake.date_time_tz() payload = to_payload( - CreateDatasetFactory.build( + CreateDatasetPayloadFactory.build( title="Example title", description="Example description", service="Example service", @@ -193,10 +209,40 @@ class TestDatasetPermissions: async def test_create_not_authenticated(self, client: httpx.AsyncClient) -> None: response = await client.post( "/datasets/", - json=to_payload(CreateDatasetFactory.build()), + json=to_payload(CreateDatasetPayloadFactory.build()), ) assert response.status_code == 401 + async def test_create_in_other_org_denied( + self, client: httpx.AsyncClient, temp_user: TestPasswordUser + ) -> None: + bus = resolve(MessageBus) + + other_org_siret = await bus.execute(CreateOrganizationFactory.build()) + await bus.execute(CreateCatalog(organization_siret=other_org_siret)) + + payload = to_payload( + CreateDatasetPayloadFactory.build(organization_siret=other_org_siret) + ) + response = await client.post("/datasets/", json=payload, auth=temp_user.auth) + + assert response.status_code == 403 + + async def test_create_in_other_org_admin_ok( + self, client: httpx.AsyncClient, admin_user: TestPasswordUser + ) -> None: + bus = resolve(MessageBus) + + other_org_siret = await bus.execute(CreateOrganizationFactory.build()) + assert other_org_siret != admin_user.account.organization_siret + await bus.execute(CreateCatalog(organization_siret=other_org_siret)) + + payload = to_payload( + CreateDatasetPayloadFactory.build(organization_siret=other_org_siret) + ) + response = await client.post("/datasets/", json=payload, auth=admin_user.auth) + assert response.status_code == 201 + async def test_get_not_authenticated(self, client: httpx.AsyncClient) -> None: pk = id_factory() response = await client.get(f"/datasets/{pk}/") @@ -224,13 +270,17 @@ async def test_delete_not_admin( assert response.status_code == 403 -async def add_dataset_pagination_corpus(n: int, tags: list) -> None: +async def add_dataset_pagination_corpus( + user: TestPasswordUser, n: int, tags: list +) -> None: bus = resolve(MessageBus) for k in range(1, n + 1): tag_ids = [tag.id for tag in random.choices(tags, k=random.randint(0, 2))] await bus.execute( - CreateDatasetFactory.build(title=f"Dataset {k}", tag_ids=tag_ids) + CreateDatasetFactory.build( + account=user.account, title=f"Dataset {k}", tag_ids=tag_ids + ) ) @@ -284,7 +334,7 @@ async def test_dataset_pagination( expected_num_items: int, expected_dataset_titles: List[str], ) -> None: - await add_dataset_pagination_corpus(n=13, tags=tags) + await add_dataset_pagination_corpus(temp_user, n=13, tags=tags) response = await client.get("/datasets/", params=params, auth=temp_user.auth) assert response.status_code == 200 @@ -303,9 +353,15 @@ async def test_dataset_get_all_uses_reverse_chronological_order( client: httpx.AsyncClient, temp_user: TestPasswordUser ) -> None: bus = resolve(MessageBus) - await bus.execute(CreateDatasetFactory.build(title="Oldest")) - await bus.execute(CreateDatasetFactory.build(title="Intermediate")) - await bus.execute(CreateDatasetFactory.build(title="Newest")) + await bus.execute( + CreateDatasetFactory.build(account=temp_user.account, title="Oldest") + ) + await bus.execute( + CreateDatasetFactory.build(account=temp_user.account, title="Intermediate") + ) + await bus.execute( + CreateDatasetFactory.build(account=temp_user.account, title="Newest") + ) response = await client.get("/datasets/", auth=temp_user.auth) assert response.status_code == 200 @@ -332,7 +388,7 @@ async def test_optional_fields_missing_uses_defaults( field: str, default: Any, ) -> None: - payload = to_payload(CreateDatasetFactory.build()) + payload = to_payload(CreateDatasetFactory.build(account=temp_user.account)) payload.pop(field) response = await client.post("/datasets/", json=payload, auth=temp_user.auth) assert response.status_code == 201 @@ -345,7 +401,7 @@ async def test_optional_fields_invalid( response = await client.post( "/datasets/", json={ - **to_payload(CreateDatasetFactory.build()), + **to_payload(CreateDatasetFactory.build(account=temp_user.account)), "contact_emails": ["notanemail", "valid@mydomain.org"], "update_frequency": "not_in_enum", "last_updated_at": "not_a_datetime", @@ -383,7 +439,9 @@ async def test_full_entity_expected( self, client: httpx.AsyncClient, temp_user: TestPasswordUser ) -> None: bus = resolve(MessageBus) - dataset_id = await bus.execute(CreateDatasetFactory.build()) + dataset_id = await bus.execute( + CreateDatasetFactory.build(account=temp_user.account) + ) # Apply PUT semantics, which expect a full entity. response = await client.put( @@ -418,7 +476,9 @@ async def test_fields_empty_invalid( bus = resolve(MessageBus) last_updated_at = fake.date_time_tz() - command = CreateDatasetFactory.build(last_updated_at=last_updated_at) + command = CreateDatasetFactory.build( + account=temp_user.account, last_updated_at=last_updated_at + ) dataset_id = await bus.execute(command) @@ -461,7 +521,9 @@ async def test_update( self, client: httpx.AsyncClient, temp_user: TestPasswordUser ) -> None: bus = resolve(MessageBus) - dataset_id = await bus.execute(CreateDatasetFactory.build()) + dataset_id = await bus.execute( + CreateDatasetFactory.build(account=temp_user.account) + ) other_last_updated_at = fake.date_time_tz() @@ -538,7 +600,8 @@ async def test_formats_add( ) -> None: bus = resolve(MessageBus) command = CreateDatasetFactory.build( - formats=[DataFormat.WEBSITE, DataFormat.API] + account=temp_user.account, + formats=[DataFormat.WEBSITE, DataFormat.API], ) dataset_id = await bus.execute(command) @@ -561,7 +624,8 @@ async def test_formats_remove( ) -> None: bus = resolve(MessageBus) command = CreateDatasetFactory.build( - formats=[DataFormat.WEBSITE, DataFormat.API] + account=temp_user.account, + formats=[DataFormat.WEBSITE, DataFormat.API], ) dataset_id = await bus.execute(command) @@ -587,7 +651,7 @@ async def test_tags_add( ) -> None: bus = resolve(MessageBus) - command = CreateDatasetFactory.build() + command = CreateDatasetFactory.build(account=temp_user.account) dataset_id = await bus.execute(command) tag_architecture_id = await bus.execute(CreateTag(name="Architecture")) tag_architecture = await bus.execute(GetTagByID(id=tag_architecture_id)) @@ -616,7 +680,10 @@ async def test_tags_remove( bus = resolve(MessageBus) tag_architecture_id = await bus.execute(CreateTag(name="Architecture")) - command = CreateDatasetFactory.build(tag_ids=[str(tag_architecture_id)]) + command = CreateDatasetFactory.build( + account=temp_user.account, + tag_ids=[str(tag_architecture_id)], + ) dataset_id = await bus.execute(command) response = await client.put( @@ -638,7 +705,9 @@ async def test_tags_remove( @pytest.mark.asyncio class TestExtraFieldValues: - async def _create_extra_field_in_catalog(self) -> ID: + async def _setup( + self, + ) -> Tuple[Siret, TestPasswordUser, ID]: bus = resolve(MessageBus) siret = await bus.execute(CreateOrganizationFactory.build()) @@ -657,21 +726,28 @@ async def _create_extra_field_in_catalog(self) -> ID: ) catalog = await bus.execute(GetCatalogBySiret(siret=siret)) - return catalog.extra_fields[0].id + extra_field_id = catalog.extra_fields[0].id + + user = await create_test_password_user( + CreatePasswordUserFactory.build(organization_siret=siret) + ) + + return siret, user, extra_field_id async def test_create_dataset_with_extra_field_values( - self, client: httpx.AsyncClient, temp_user: TestPasswordUser + self, client: httpx.AsyncClient ) -> None: - extra_field_id = await self._create_extra_field_in_catalog() + siret, user, extra_field_id = await self._setup() payload = to_payload( - CreateDatasetFactory.build( + CreateDatasetPayloadFactory.build( + organization_siret=siret, extra_field_values=[ ExtraFieldValue(extra_field_id=extra_field_id, value="2.4 Go") - ] + ], ) ) - response = await client.post("/datasets/", json=payload, auth=temp_user.auth) + response = await client.post("/datasets/", json=payload, auth=user.auth) assert response.status_code == 201 data = response.json() assert data["extra_field_values"] == [ @@ -685,9 +761,11 @@ async def test_add_extra_field_value( self, client: httpx.AsyncClient, temp_user: TestPasswordUser ) -> None: bus = resolve(MessageBus) - extra_field_id = await self._create_extra_field_in_catalog() + siret, user, extra_field_id = await self._setup() - command = CreateDatasetFactory.build() + command = CreateDatasetFactory.build( + account=user.account, organization_siret=siret + ) dataset_id = await bus.execute(command) dataset = await bus.execute(GetDatasetByID(id=dataset_id)) assert not dataset.extra_field_values @@ -720,15 +798,17 @@ async def test_remove_extra_field_value( self, client: httpx.AsyncClient, temp_user: TestPasswordUser ) -> None: bus = resolve(MessageBus) - extra_field_id = await self._create_extra_field_in_catalog() + siret, user, extra_field_id = await self._setup() command = CreateDatasetFactory.build( + account=user.account, + organization_siret=siret, extra_field_values=[ ExtraFieldValue( extra_field_id=extra_field_id, value="2.4 Go", ) - ] + ], ) dataset_id = await bus.execute(command) dataset = await bus.execute(GetDatasetByID(id=dataset_id)) @@ -761,11 +841,16 @@ async def test_remove_extra_field_value( @pytest.mark.asyncio class TestDeleteDataset: async def test_delete( - self, client: httpx.AsyncClient, admin_user: TestPasswordUser + self, + client: httpx.AsyncClient, + temp_user: TestPasswordUser, + admin_user: TestPasswordUser, ) -> None: bus = resolve(MessageBus) - dataset_id = await bus.execute(CreateDatasetFactory.build()) + dataset_id = await bus.execute( + CreateDatasetFactory.build(account=temp_user.account) + ) response = await client.delete(f"/datasets/{dataset_id}/", auth=admin_user.auth) assert response.status_code == 204 diff --git a/tests/api/test_datasets_filters.py b/tests/api/test_datasets_filters.py index 02b9baaf..6f946b62 100644 --- a/tests/api/test_datasets_filters.py +++ b/tests/api/test_datasets_filters.py @@ -15,9 +15,10 @@ from ..factories import ( CreateDatasetFactory, CreateOrganizationFactory, + CreatePasswordUserFactory, CreateTagFactory, ) -from ..helpers import TestPasswordUser +from ..helpers import TestPasswordUser, create_test_password_user @pytest.mark.asyncio @@ -42,11 +43,16 @@ async def test_dataset_filters_info( CreateOrganizationFactory.build(name="C - Organization without a catalog") ) + user = await create_test_password_user( + CreatePasswordUserFactory.build(organization_siret=siret_non_empty) + ) + tag_id = await bus.execute(CreateTagFactory.build(name="Architecture")) await bus.execute( CreateDatasetFactory.build( organization_siret=siret_non_empty, + account=user.account, geographical_coverage="France métropolitaine", service="Same example service", technical_source="Example database system", @@ -58,6 +64,7 @@ async def test_dataset_filters_info( await bus.execute( CreateDatasetFactory.build( organization_siret=siret_non_empty, + account=user.account, geographical_coverage="Région Nouvelle-Aquitaine", service="Same example service", technical_source=None, @@ -210,7 +217,13 @@ async def test_dataset_filters_apply( kwargs: dict = {"organization_siret": siret_any} kwargs.update(create_kwargs(env)) - dataset_id = await bus.execute(CreateDatasetFactory.build(**kwargs)) + user = await create_test_password_user( + CreatePasswordUserFactory.build(organization_siret=kwargs["organization_siret"]) + ) + + dataset_id = await bus.execute( + CreateDatasetFactory.build(account=user.account, **kwargs) + ) params = {filtername: negative_value(env)} response = await client.get("/datasets/", params=params, auth=temp_user.auth) @@ -233,10 +246,12 @@ async def test_dataset_filters_license_any( bus = resolve(MessageBus) dataset1_id = await bus.execute( - CreateDatasetFactory.build(license="Licence Ouverte") + CreateDatasetFactory.build(account=temp_user.account, license="Licence Ouverte") ) dataset2_id = await bus.execute( - CreateDatasetFactory.build(license="ODC Open Database Licence v1.0") + CreateDatasetFactory.build( + account=temp_user.account, license="ODC Open Database Licence v1.0" + ) ) params = {"license": "*"} diff --git a/tests/api/test_datasets_search.py b/tests/api/test_datasets_search.py index 5e21ad01..150e8475 100644 --- a/tests/api/test_datasets_search.py +++ b/tests/api/test_datasets_search.py @@ -19,14 +19,18 @@ ] -async def add_corpus(items: List[Tuple[str, str]] = None) -> None: +async def add_corpus( + user: TestPasswordUser, *, items: List[Tuple[str, str]] = None +) -> None: if items is None: items = DEFAULT_CORPUS_ITEMS bus = resolve(MessageBus) for title, description in items: - command = CreateDatasetFactory.build(title=title, description=description) + command = CreateDatasetFactory.build( + account=user.account, title=title, description=description + ) pk = await bus.execute(command) query = GetDatasetByID(id=pk) await bus.execute(query) @@ -89,7 +93,7 @@ async def test_search( q: str, expected_titles: List[str], ) -> None: - await add_corpus() + await add_corpus(temp_user) response = await client.get( "/datasets/", @@ -121,7 +125,7 @@ async def test_search( async def test_search_robustness( client: httpx.AsyncClient, temp_user: TestPasswordUser, q_ref: str, q_other: str ) -> None: - await add_corpus() + await add_corpus(temp_user) response = await client.get( "/datasets/", @@ -150,7 +154,7 @@ async def test_search_results_change_when_data_changes( client: httpx.AsyncClient, temp_user: TestPasswordUser, ) -> None: - await add_corpus() + await add_corpus(temp_user) bus = resolve(MessageBus) @@ -165,7 +169,7 @@ async def test_search_results_change_when_data_changes( assert not data["items"] # Add new dataset - command = CreateDatasetFactory.build(title="Titre") + command = CreateDatasetFactory.build(account=temp_user.account, title="Titre") pk = await bus.execute(command) # New dataset is returned in search results response = await client.get( @@ -234,7 +238,7 @@ async def test_search_ranking( random.shuffle(items) # Ensure DB insert order is irrelevant. - await add_corpus(items) + await add_corpus(temp_user, items=items) q = "Forêt ancienne" # Lexemes: forêt, ancien @@ -281,7 +285,7 @@ async def test_search_highlight( q: str, expected_headlines: Optional[dict], ) -> None: - await add_corpus(corpus) + await add_corpus(temp_user, items=corpus) q = "restaurant" diff --git a/tests/api/test_licenses.py b/tests/api/test_licenses.py index 4679bb39..a0b9161f 100644 --- a/tests/api/test_licenses.py +++ b/tests/api/test_licenses.py @@ -18,7 +18,9 @@ async def test_license_list( assert response.status_code == 200 assert response.json() == ["Licence Ouverte", "ODC Open Database License"] - await bus.execute(CreateDatasetFactory.build(license="Autre licence")) + await bus.execute( + CreateDatasetFactory.build(account=temp_user.account, license="Autre licence") + ) response = await client.get("/licenses/", auth=temp_user.auth) assert response.status_code == 200 diff --git a/tests/conftest.py b/tests/conftest.py index bad40c20..5475d238 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,6 +20,7 @@ from server.seedwork.application.messages import MessageBus from tests.factories import CreateTagFactory +from .factories import CreatePasswordUserFactory from .helpers import TestPasswordUser, create_client, create_test_password_user if TYPE_CHECKING: @@ -111,9 +112,11 @@ async def client(app: "App") -> AsyncIterator[httpx.AsyncClient]: @pytest_asyncio.fixture(name="temp_user") async def fixture_temp_user() -> TestPasswordUser: - return await create_test_password_user(UserRole.USER) + command = CreatePasswordUserFactory.build() + return await create_test_password_user(command, role=UserRole.USER) @pytest_asyncio.fixture async def admin_user() -> TestPasswordUser: - return await create_test_password_user(UserRole.ADMIN) + command = CreatePasswordUserFactory.build() + return await create_test_password_user(command, role=UserRole.ADMIN) diff --git a/tests/factories.py b/tests/factories.py index cff89313..a7a5c9d5 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -5,8 +5,9 @@ import faker from faker.providers import BaseProvider from pydantic import BaseModel -from pydantic_factories import ModelFactory, Use +from pydantic_factories import ModelFactory, Require, Use +from server.api.datasets.schemas import DatasetCreate from server.application.auth.commands import CreateDataPassUser, CreatePasswordUser from server.application.datasets.commands import CreateDataset, UpdateDataset from server.application.organizations.commands import CreateOrganization @@ -62,9 +63,7 @@ class CreateTagFactory(Factory[CreateTag]): ] -class CreateDatasetFactory(Factory[CreateDataset]): - __model__ = CreateDataset - +class _BaseCreateDatasetFactory: organization_siret = Use(lambda: LEGACY_ORGANIZATION.siret) title = Use(fake.sentence) description = Use(fake.text) @@ -84,6 +83,16 @@ class CreateDatasetFactory(Factory[CreateDataset]): extra_field_values = Use(lambda: []) +class CreateDatasetFactory(_BaseCreateDatasetFactory, Factory[CreateDataset]): + __model__ = CreateDataset + + account = Require() + + +class CreateDatasetPayloadFactory(_BaseCreateDatasetFactory, Factory[DatasetCreate]): + __model__ = DatasetCreate + + class UpdateDatasetFactory(Factory[UpdateDataset]): __model__ = UpdateDataset diff --git a/tests/helpers.py b/tests/helpers.py index 42e2d340..1073f003 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -4,13 +4,12 @@ import httpx from pydantic import BaseModel +from server.application.auth.commands import CreatePasswordUser from server.config.di import resolve from server.domain.auth.entities import PasswordUser, UserRole from server.domain.auth.repositories import PasswordUserRepository from server.seedwork.application.messages import MessageBus -from .factories import CreatePasswordUserFactory - def create_client(app: Callable) -> httpx.AsyncClient: transport = httpx.ASGITransport( @@ -46,11 +45,12 @@ def auth(self, request: httpx.Request) -> httpx.Request: return request -async def create_test_password_user(role: UserRole) -> TestPasswordUser: +async def create_test_password_user( + command: CreatePasswordUser, *, role: UserRole = UserRole.USER +) -> TestPasswordUser: bus = resolve(MessageBus) password_user_repository = resolve(PasswordUserRepository) - command = CreatePasswordUserFactory.build() await bus.execute(command, role=role) user = await password_user_repository.get_by_email(command.email) diff --git a/tests/infrastructure/test_catalogs.py b/tests/infrastructure/test_catalogs.py index 398c403a..6dee850d 100644 --- a/tests/infrastructure/test_catalogs.py +++ b/tests/infrastructure/test_catalogs.py @@ -1,7 +1,5 @@ import pytest -from pydantic import EmailStr -from server.application.auth.queries import GetAccountByEmail from server.application.datasets.queries import GetDatasetByID from server.config.di import resolve from server.domain.organizations.types import Siret @@ -9,6 +7,7 @@ from server.infrastructure.database import Database from server.infrastructure.organizations.models import OrganizationModel from server.seedwork.application.messages import MessageBus +from tests.helpers import create_test_password_user from ..factories import CreateDatasetFactory, CreatePasswordUserFactory @@ -32,15 +31,15 @@ async def test_catalog_creation_and_relationships() -> None: # Add a user to the organization... email = "test@mydomain.org" - await bus.execute( + user = await create_test_password_user( CreatePasswordUserFactory.build(organization_siret=siret, email=email) ) - - account = await bus.execute(GetAccountByEmail(email=EmailStr("test@mydomain.org"))) - assert account.organization_siret == siret + assert user.account.organization_siret == siret # Add a dataset to the catalog... - dataset_id = await bus.execute(CreateDatasetFactory.build(organization_siret=siret)) + dataset_id = await bus.execute( + CreateDatasetFactory.build(account=user.account, organization_siret=siret) + ) dataset = await bus.execute(GetDatasetByID(id=dataset_id)) assert dataset.catalog_record.organization.siret == siret diff --git a/tests/infrastructure/test_datasets.py b/tests/infrastructure/test_datasets.py index 1aba7fb2..6fb73bc2 100644 --- a/tests/infrastructure/test_datasets.py +++ b/tests/infrastructure/test_datasets.py @@ -10,16 +10,19 @@ from server.infrastructure.datasets.models import DatasetModel from server.infrastructure.tags.models import TagModel, dataset_tag from server.seedwork.application.messages import MessageBus +from tests.helpers import TestPasswordUser from ..factories import CreateDatasetFactory, CreateTagFactory @pytest.mark.asyncio -async def test_dataset_cascades() -> None: +async def test_dataset_cascades(temp_user: TestPasswordUser) -> None: bus = resolve(MessageBus) tag_id = await bus.execute(CreateTagFactory.build(name="Architecture")) - dataset_id = await bus.execute(CreateDatasetFactory.build(tag_ids=[tag_id])) + dataset_id = await bus.execute( + CreateDatasetFactory.build(account=temp_user.account, tag_ids=[tag_id]) + ) dataset = await bus.execute(GetDatasetByID(id=dataset_id)) diff --git a/tools/addrandomdatasets.py b/tools/addrandomdatasets.py index 8023b941..ce5c6b77 100644 --- a/tools/addrandomdatasets.py +++ b/tools/addrandomdatasets.py @@ -8,6 +8,7 @@ from server.application.tags.queries import GetAllTags from server.config.di import bootstrap, resolve +from server.domain.common.types import Skip from server.domain.organizations.types import Siret from server.seedwork.application.messages import MessageBus from tests.factories import CreateDatasetFactory @@ -26,7 +27,9 @@ async def main(n: int, siret: Siret) -> None: tag_id_set, k=random.randint(1, min(3, len(tag_id_set))) ) await bus.execute( - CreateDatasetFactory.build(organization_siret=siret, tag_ids=tag_ids) + CreateDatasetFactory.build( + account=Skip(), organization_siret=siret, tag_ids=tag_ids + ) ) print(f"{success('created')}: {n} datasets") diff --git a/tools/initdata.py b/tools/initdata.py index 5cadbf39..bed226a8 100644 --- a/tools/initdata.py +++ b/tools/initdata.py @@ -20,6 +20,7 @@ from server.domain.auth.entities import UserRole from server.domain.auth.repositories import PasswordUserRepository from server.domain.catalogs.repositories import CatalogRepository +from server.domain.common.types import Skip from server.domain.datasets.entities import Dataset from server.domain.datasets.repositories import DatasetRepository from server.domain.organizations.repositories import OrganizationRepository @@ -175,7 +176,7 @@ def _get_dataset_attr(dataset: Dataset, attr: str) -> Any: print(f"{info('ok')}: {dataset_repr}") return - create_command = CreateDataset(**item["params"]) + create_command = CreateDataset(account=Skip(), **item["params"]) await bus.execute(create_command, id_=id_) print(f"{success('created')}: {create_command!r}")