Skip to content

Commit

Permalink
Check user belongs to dataset organization before creation
Browse files Browse the repository at this point in the history
  • Loading branch information
florimondmanca committed Sep 21, 2022
1 parent 6d25a22 commit 26bffcb
Show file tree
Hide file tree
Showing 24 changed files with 257 additions and 128 deletions.
14 changes: 5 additions & 9 deletions server/api/auth/backends/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@
)
from starlette.requests import HTTPConnection

from server.application.auth.queries import GetAccountByAPIToken
from server.config.di import resolve
from server.domain.auth.exceptions import AccountDoesNotExist
from server.seedwork.application.messages import MessageBus
from server.domain.auth.repositories import AccountRepository

from ..models import ApiUser

Expand All @@ -26,6 +24,8 @@ class TokenAuthBackend(AuthenticationBackend):
async def authenticate(
self, conn: HTTPConnection
) -> Optional[Tuple[AuthCredentials, ApiUser]]:
account_repository = resolve(AccountRepository)

# NOTE: we don't reuse fastapi.security.HTTPBearer as it does not distinguish
# "no Authorization" / "scheme is not Bearer" and "malformed Authorization".

Expand All @@ -42,13 +42,9 @@ async def authenticate(
if scheme.lower() != "bearer":
return AuthCredentials(), ApiUser(None)

bus = resolve(MessageBus)

query = GetAccountByAPIToken(api_token=api_token)
account = await account_repository.get_by_api_token(api_token)

try:
account = await bus.execute(query)
except AccountDoesNotExist:
if account is None:
raise AuthenticationError()

return AuthCredentials(scopes=["authenticated"]), ApiUser(account)
6 changes: 3 additions & 3 deletions server/api/auth/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

from starlette.authentication import BaseUser

from server.application.auth.views import AccountView
from server.domain.auth.entities import Account


class ApiUser(BaseUser):
def __init__(self, account: Optional[AccountView]) -> None:
def __init__(self, account: Optional[Account]) -> None:
self._account = account

@property
def account(self) -> AccountView:
def account(self) -> Account:
if self._account is None:
raise RuntimeError(
"Cannot access .account, as the user is anonymous. "
Expand Down
2 changes: 1 addition & 1 deletion server/api/auth/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def login_password_user(data: PasswordUserLogin) -> AuthenticatedAccountVi
dependencies=[Depends(IsAuthenticated())],
)
async def get_connected_user(request: APIRequest) -> AccountView:
return request.user.account
return AccountView(**request.user.account.dict())


@router.delete(
Expand Down
12 changes: 8 additions & 4 deletions server/api/datasets/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,20 @@
DeleteDataset,
UpdateDataset,
)
from server.application.datasets.exceptions import CannotCreateDataset
from server.application.datasets.queries import GetAllDatasets, GetDatasetByID
from server.application.datasets.views import DatasetView
from server.config.di import resolve
from server.domain.auth.entities import UserRole
from server.domain.catalogs.exceptions import CatalogDoesNotExist
from server.domain.common.pagination import Page, Pagination
from server.domain.common.types import ID
from server.domain.datasets.exceptions import DatasetDoesNotExist
from server.domain.datasets.specifications import DatasetSpec
from server.domain.organizations.exceptions import OrganizationDoesNotExist
from server.seedwork.application.messages import MessageBus

from ..auth.permissions import HasRole, IsAuthenticated
from ..types import APIRequest
from . import filters
from .schemas import DatasetCreate, DatasetListParams, DatasetUpdate

Expand Down Expand Up @@ -77,15 +79,17 @@ async def get_dataset_by_id(id: ID) -> DatasetView:
response_model=DatasetView,
status_code=201,
)
async def create_dataset(data: DatasetCreate) -> DatasetView:
async def create_dataset(data: DatasetCreate, request: "APIRequest") -> DatasetView:
bus = resolve(MessageBus)

command = CreateDataset(**data.dict())
command = CreateDataset(account=request.user.account, **data.dict())

try:
id = await bus.execute(command)
except OrganizationDoesNotExist as exc:
except CatalogDoesNotExist as exc:
raise HTTPException(400, detail=str(exc))
except CannotCreateDataset:
raise HTTPException(403, detail="Permission denied")

query = GetDatasetByID(id=id)
return await bus.execute(query)
Expand Down
18 changes: 1 addition & 17 deletions server/application/auth/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,7 @@
DeletePasswordUser,
)
from .passwords import PasswordEncoder, generate_api_token
from .queries import (
GetAccountByAPIToken,
GetAccountByEmail,
LoginDataPassUser,
LoginPasswordUser,
)
from .queries import GetAccountByEmail, LoginDataPassUser, LoginPasswordUser


async def create_password_user(
Expand Down Expand Up @@ -157,17 +152,6 @@ async def get_account_by_email(query: GetAccountByEmail) -> AccountView:
return AccountView(**account.dict())


async def get_account_by_api_token(query: GetAccountByAPIToken) -> AccountView:
repository = resolve(AccountRepository)

account = await repository.get_by_api_token(query.api_token)

if account is None:
raise AccountDoesNotExist("__token__")

return AccountView(**account.dict())


async def change_password(command: ChangePassword) -> None:
repository = resolve(PasswordUserRepository)
password_encoder = resolve(PasswordEncoder)
Expand Down
4 changes: 0 additions & 4 deletions server/application/auth/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,3 @@ class LoginDataPassUser(Query[AuthenticatedAccountView]):

class GetAccountByEmail(Query[AccountView]):
email: EmailStr


class GetAccountByAPIToken(Query[AccountView]):
api_token: str
7 changes: 5 additions & 2 deletions server/application/datasets/commands.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import datetime as dt
from typing import List, Optional
from typing import List, Optional, Union

from pydantic import EmailStr, Field

from server.domain.auth.entities import Account
from server.domain.catalogs.entities import ExtraFieldValue
from server.domain.common.types import ID
from server.domain.common.types import ID, Skip
from server.domain.datasets.entities import DataFormat, UpdateFrequency
from server.domain.organizations.entities import LEGACY_ORGANIZATION
from server.domain.organizations.types import Siret
Expand All @@ -14,6 +15,8 @@


class CreateDataset(CreateDatasetValidationMixin, Command[ID]):
account: Union[Account, Skip]

organization_siret: Siret = LEGACY_ORGANIZATION.siret
title: str
description: str
Expand Down
2 changes: 2 additions & 0 deletions server/application/datasets/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class CannotCreateDataset(Exception):
pass
25 changes: 15 additions & 10 deletions server/application/datasets/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,46 @@
from server.config.di import resolve
from server.domain.catalog_records.entities import CatalogRecord
from server.domain.catalog_records.repositories import CatalogRecordRepository
from server.domain.catalogs.exceptions import CatalogDoesNotExist
from server.domain.catalogs.repositories import CatalogRepository
from server.domain.common.pagination import Pagination
from server.domain.common.types import ID
from server.domain.common.types import ID, Skip
from server.domain.datasets.entities import DataFormat, Dataset
from server.domain.datasets.exceptions import DatasetDoesNotExist
from server.domain.datasets.repositories import DatasetRepository
from server.domain.organizations.exceptions import OrganizationDoesNotExist
from server.domain.organizations.repositories import OrganizationRepository
from server.domain.tags.repositories import TagRepository
from server.seedwork.application.messages import MessageBus

from .commands import CreateDataset, DeleteDataset, UpdateDataset
from .exceptions import CannotCreateDataset
from .queries import GetAllDatasets, GetDatasetByID, GetDatasetFilters
from .specifications import can_dataset_be_created
from .views import DatasetFiltersView, DatasetView


async def create_dataset(command: CreateDataset, *, id_: ID = None) -> ID:
repository = resolve(DatasetRepository)
organization_repository = resolve(OrganizationRepository)
catalog_repository = resolve(CatalogRepository)
catalog_record_repository = resolve(CatalogRecordRepository)
tag_repository = resolve(TagRepository)

if id_ is None:
id_ = repository.make_id()

organization = await organization_repository.get_by_siret(
siret=command.organization_siret
)
catalog = await catalog_repository.get_by_siret(siret=command.organization_siret)

if catalog is None:
raise CatalogDoesNotExist(command.organization_siret)

if organization is None:
raise OrganizationDoesNotExist(command.organization_siret)
if not isinstance(command.account, Skip) and not can_dataset_be_created(
catalog, command.account
):
raise CannotCreateDataset()

catalog_record_id = await catalog_record_repository.insert(
CatalogRecord(
id=catalog_record_repository.make_id(),
organization=organization,
organization=catalog.organization,
)
)
catalog_record = await catalog_record_repository.get_by_id(catalog_record_id)
Expand Down
9 changes: 9 additions & 0 deletions server/application/datasets/specifications.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from server.domain.auth.entities import Account, UserRole
from server.domain.catalogs.entities import Catalog


def can_dataset_be_created(catalog: Catalog, account: Account) -> bool:
if account.role == UserRole.ADMIN:
return True

return catalog.organization.siret == account.organization_siret
8 changes: 8 additions & 0 deletions server/domain/common/types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import uuid
from typing import NewType

from pydantic import BaseModel

ID = NewType("ID", uuid.UUID)


def id_factory() -> ID:
return ID(uuid.uuid4())


class Skip(BaseModel):
"""
A marker class for when an operation should be skipped.
"""
3 changes: 0 additions & 3 deletions server/infrastructure/auth/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@
create_datapass_user,
create_password_user,
delete_password_user,
get_account_by_api_token,
get_account_by_email,
login_datapass_user,
login_password_user,
)
from server.application.auth.queries import (
GetAccountByAPIToken,
GetAccountByEmail,
LoginDataPassUser,
LoginPasswordUser,
Expand All @@ -35,5 +33,4 @@ class AuthModule(Module):
LoginPasswordUser: login_password_user,
LoginDataPassUser: login_datapass_user,
GetAccountByEmail: get_account_by_email,
GetAccountByAPIToken: get_account_by_api_token,
}
16 changes: 13 additions & 3 deletions tests/api/test_catalogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,22 @@
from server.domain.organizations.types import Siret
from server.seedwork.application.messages import MessageBus

from ..factories import CreateDatasetFactory, CreateOrganizationFactory, fake
from ..helpers import TestPasswordUser, api_key_auth
from ..factories import (
CreateDatasetFactory,
CreateOrganizationFactory,
CreatePasswordUserFactory,
fake,
)
from ..helpers import TestPasswordUser, api_key_auth, create_test_password_user


@pytest.mark.asyncio
async def test_catalog_create(client: httpx.AsyncClient) -> None:
bus = resolve(MessageBus)
siret = await bus.execute(CreateOrganizationFactory.build(name="Org 1"))
user = await create_test_password_user(
CreatePasswordUserFactory.build(organization_siret=siret)
)

response = await client.post(
"/catalogs/", json={"organization_siret": str(siret)}, auth=api_key_auth
Expand All @@ -34,7 +42,9 @@ async def test_catalog_create(client: httpx.AsyncClient) -> None:
catalog = await bus.execute(GetCatalogBySiret(siret=siret))
assert catalog.organization.siret == siret

dataset_id = await bus.execute(CreateDatasetFactory.build(organization_siret=siret))
dataset_id = await bus.execute(
CreateDatasetFactory.build(account=user.account, organization_siret=siret)
)
dataset = await bus.execute(GetDatasetByID(id=dataset_id))
assert dataset.catalog_record.organization.siret == siret

Expand Down
Loading

0 comments on commit 26bffcb

Please sign in to comment.