Skip to content

Commit

Permalink
Permissions system update (redux) (#857)
Browse files Browse the repository at this point in the history
Update metamist permissions system

Notable changes:

- Switch from just read/write permissions to defined roles
- Moved permissions from being defined on groups to being defined on projects to better reflect how they are being used.
- Moved permissions check helpers and methods out from the project table code and onto the `connection` so that they are available everywhere needed
- Made permissions checks faster by calculating access at the start of the request and then looking up an in memory map of permissions when checking. 
- Removed lots of slightly risky code that didn't check permissions in certain cases, now that permissions are fast we can always check them

commits: 

* add initial db migration for new project_groups table

* Update project based db connections

* update usage of project and group permission functions

* fix typo

* fix merge problems and circular import problems

* Update GraphQLFilters to fix generics for all_values method

* simplify migration, liquibase rollback support is shaky

So best to limit the destructive updates in the migration and do them manually

* Move project permission checks from project table to connection

This way they are accessible pretty much everywhere, but are only calculated once. The permission checks themselves are now synchronous and should be really fast, so no need for avoiding checking project ids

* Update graphql loaders and schema to work with new permissions

Also some QOL fixes for graphql types so that the context is now properly typed

* Update routes, layers and table files to work with new permissions

* add route to update project members

* update project table methods to incorporate admin group roles

Rather than having these roles separately - incorportate them into
project level roles where it makes sense. That way the same permission
constructs can be used for checking admin roles rather than having to
have separate ones.

* remove allow all access setting, it is better to use real controls

even when running locally, it is better to lean on the actual access
controls rather than allowing all access. This way we can catch issues
with permission checks during development

* change graphql project list to only list projects with certain roles

to avoid listing absolutely every project for users with admin roles

* fix data generation scripts to work with new permissions

* fix tests to work with new permissions structures

* fix merge issues

* update docs

* fix permission checks in project routes

* make query uppercase for consistency

* use test environment for tests

* make connection class variables protected

* add check for default user in generate data scripts

* fix linting errors

* simplify roles to allow better management from cpg-infrastructure

* re-reorder auth checks

* merge cleanup

* update seqr project listing to raise if user can't access all seqr projs

* Bump version: 7.1.1 → 7.2.0
  • Loading branch information
dancoates authored Jul 8, 2024
1 parent fd2c933 commit 66e233d
Show file tree
Hide file tree
Showing 59 changed files with 1,528 additions and 1,506 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 7.1.1
current_version = 7.2.0
commit = True
tag = False
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>[A-z0-9-]+)
Expand Down
4 changes: 2 additions & 2 deletions api/graphql/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ class GraphQLFilter(Generic[T]):
icontains: T | None = None
startswith: T | None = None

def all_values(self):
def all_values(self) -> list[T]:
"""
Get all values used anywhere in a filter, useful for getting values to map later
"""
v = []
v: list[T] = []
if self.eq:
v.append(self.eq)
if self.in_:
Expand Down
88 changes: 46 additions & 42 deletions api/graphql/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import dataclasses
import enum
from collections import defaultdict
from typing import Any
from typing import Any, TypedDict

from fastapi import Request
from strawberry.dataloader import DataLoader

from api.utils import get_projectless_db_connection, group_by
from api.utils import group_by
from api.utils.db import get_projectless_db_connection
from db.python.connect import Connection
from db.python.filters import GenericFilter, get_hashable_value
from db.python.layers import (
AnalysisLayer,
Expand All @@ -24,7 +26,6 @@
from db.python.tables.assay import AssayFilter
from db.python.tables.family import FamilyFilter
from db.python.tables.participant import ParticipantFilter
from db.python.tables.project import ProjectPermissionsTable
from db.python.tables.sample import SampleFilter
from db.python.tables.sequencing_group import SequencingGroupFilter
from db.python.utils import NotFoundError
Expand Down Expand Up @@ -79,14 +80,14 @@ class LoaderKeys(enum.Enum):
SEQUENCING_GROUPS_FOR_ANALYSIS = 'sequencing_groups_for_analysis'


loaders = {}
loaders: dict[LoaderKeys, Any] = {}


def connected_data_loader(id_: LoaderKeys, cache=True):
def connected_data_loader(id_: LoaderKeys, cache: bool = True):
"""Provide connection to a data loader"""

def connected_data_loader_caller(fn):
def inner(connection):
def inner(connection: Connection):
async def wrapped(*args, **kwargs):
return await fn(*args, **kwargs, connection=connection)

Expand All @@ -110,7 +111,7 @@ def connected_data_loader_with_params(
"""

def connected_data_loader_caller(fn):
def inner(connection):
def inner(connection: Connection):
async def wrapped(query: list[dict[str, Any]]) -> list[Any]:
by_key: dict[tuple, Any] = {}

Expand Down Expand Up @@ -159,7 +160,7 @@ async def wrapped(query: list[dict[str, Any]]) -> list[Any]:

@connected_data_loader(LoaderKeys.AUDIT_LOGS_BY_IDS)
async def load_audit_logs_by_ids(
audit_log_ids: list[int], connection
audit_log_ids: list[int], connection: Connection
) -> list[AuditLogInternal | None]:
"""
DataLoader: get_audit_logs_by_ids
Expand All @@ -172,7 +173,7 @@ async def load_audit_logs_by_ids(

@connected_data_loader(LoaderKeys.AUDIT_LOGS_BY_ANALYSIS_IDS)
async def load_audit_logs_by_analysis_ids(
analysis_ids: list[int], connection
analysis_ids: list[int], connection: Connection
) -> list[list[AuditLogInternal]]:
"""
DataLoader: get_audit_logs_by_analysis_ids
Expand All @@ -184,7 +185,7 @@ async def load_audit_logs_by_analysis_ids(

@connected_data_loader_with_params(LoaderKeys.ASSAYS_FOR_SAMPLES, default_factory=list)
async def load_assays_by_samples(
connection, ids, filter: AssayFilter
connection: Connection, ids, filter: AssayFilter
) -> dict[int, list[AssayInternal]]:
"""
DataLoader: get_assays_for_sample_ids
Expand All @@ -200,7 +201,7 @@ async def load_assays_by_samples(

@connected_data_loader(LoaderKeys.ASSAYS_FOR_SEQUENCING_GROUPS)
async def load_assays_by_sequencing_groups(
sequencing_group_ids: list[int], connection
sequencing_group_ids: list[int], connection: Connection
) -> list[list[AssayInternal]]:
"""
Get all assays belong to the sequencing groups
Expand All @@ -209,7 +210,7 @@ async def load_assays_by_sequencing_groups(

# group by all last fields, in case we add more
assays = await assaylayer.get_assays_for_sequencing_group_ids(
sequencing_group_ids=sequencing_group_ids, check_project_ids=False
sequencing_group_ids=sequencing_group_ids
)

return [assays.get(sg, []) for sg in sequencing_group_ids]
Expand All @@ -219,7 +220,7 @@ async def load_assays_by_sequencing_groups(
LoaderKeys.SAMPLES_FOR_PARTICIPANTS, default_factory=list
)
async def load_samples_for_participant_ids(
ids: list[int], filter: SampleFilter, connection
ids: list[int], filter: SampleFilter, connection: Connection
) -> dict[int, list[SampleInternal]]:
"""
DataLoader: get_samples_for_participant_ids
Expand All @@ -232,7 +233,7 @@ async def load_samples_for_participant_ids(

@connected_data_loader(LoaderKeys.SEQUENCING_GROUPS_FOR_IDS)
async def load_sequencing_groups_for_ids(
sequencing_group_ids: list[int], connection
sequencing_group_ids: list[int], connection: Connection
) -> list[SequencingGroupInternal]:
"""
DataLoader: get_sequencing_groups_by_ids
Expand All @@ -249,7 +250,7 @@ async def load_sequencing_groups_for_ids(
LoaderKeys.SEQUENCING_GROUPS_FOR_SAMPLES, default_factory=list
)
async def load_sequencing_groups_for_samples(
connection, ids: list[int], filter: SequencingGroupFilter
connection: Connection, ids: list[int], filter: SequencingGroupFilter
) -> dict[int, list[SequencingGroupInternal]]:
"""
Has format [(sample_id: int, sequencing_type?: string)]
Expand All @@ -270,7 +271,7 @@ async def load_sequencing_groups_for_samples(

@connected_data_loader(LoaderKeys.SAMPLES_FOR_IDS)
async def load_samples_for_ids(
sample_ids: list[int], connection
sample_ids: list[int], connection: Connection
) -> list[SampleInternal]:
"""
DataLoader: get_samples_for_ids
Expand All @@ -286,7 +287,7 @@ async def load_samples_for_ids(
LoaderKeys.SAMPLES_FOR_PROJECTS, default_factory=list
)
async def load_samples_for_projects(
connection, ids: list[ProjectId], filter: SampleFilter
connection: Connection, ids: list[ProjectId], filter: SampleFilter
):
"""
DataLoader: get_samples_for_project_ids
Expand All @@ -300,7 +301,7 @@ async def load_samples_for_projects(

@connected_data_loader(LoaderKeys.PARTICIPANTS_FOR_IDS)
async def load_participants_for_ids(
participant_ids: list[int], connection
participant_ids: list[int], connection: Connection
) -> list[ParticipantInternal]:
"""
DataLoader: get_participants_by_ids
Expand All @@ -318,7 +319,7 @@ async def load_participants_for_ids(

@connected_data_loader(LoaderKeys.SEQUENCING_GROUPS_FOR_ANALYSIS)
async def load_sequencing_groups_for_analysis_ids(
analysis_ids: list[int], connection
analysis_ids: list[int], connection: Connection
) -> list[list[SequencingGroupInternal]]:
"""
DataLoader: get_samples_for_analysis_ids
Expand All @@ -333,7 +334,7 @@ async def load_sequencing_groups_for_analysis_ids(
LoaderKeys.SEQUENCING_GROUPS_FOR_PROJECTS, default_factory=list
)
async def load_sequencing_groups_for_project_ids(
ids: list[int], filter: SequencingGroupFilter, connection
ids: list[int], filter: SequencingGroupFilter, connection: Connection
) -> dict[int, list[SequencingGroupInternal]]:
"""
DataLoader: get_sequencing_groups_for_project_ids
Expand All @@ -347,39 +348,33 @@ async def load_sequencing_groups_for_project_ids(


@connected_data_loader(LoaderKeys.PROJECTS_FOR_IDS)
async def load_projects_for_ids(project_ids: list[int], connection) -> list[Project]:
async def load_projects_for_ids(
project_ids: list[int], connection: Connection
) -> list[Project]:
"""
Get projects by IDs
"""
pttable = ProjectPermissionsTable(connection)
projects = await pttable.get_and_check_access_to_projects_for_ids(
user=connection.author, project_ids=project_ids, readonly=True
)

p_by_id = {p.id: p for p in projects}
projects = [p_by_id.get(p) for p in project_ids]
projects = [connection.project_id_map.get(p) for p in project_ids]

return [p for p in projects if p is not None]


@connected_data_loader(LoaderKeys.FAMILIES_FOR_PARTICIPANTS)
async def load_families_for_participants(
participant_ids: list[int], connection
participant_ids: list[int], connection: Connection
) -> list[list[FamilyInternal]]:
"""
Get families of participants, noting a participant can be in multiple families
"""
flayer = FamilyLayer(connection)

fam_map = await flayer.get_families_by_participants(
participant_ids=participant_ids, check_project_ids=False
)
fam_map = await flayer.get_families_by_participants(participant_ids=participant_ids)
return [fam_map.get(p, []) for p in participant_ids]


@connected_data_loader(LoaderKeys.PARTICIPANTS_FOR_FAMILIES)
async def load_participants_for_families(
family_ids: list[int], connection
family_ids: list[int], connection: Connection
) -> list[list[ParticipantInternal]]:
"""Get all participants in a family, doesn't include affected statuses"""
player = ParticipantLayer(connection)
Expand All @@ -391,7 +386,7 @@ async def load_participants_for_families(
LoaderKeys.PARTICIPANTS_FOR_PROJECTS, default_factory=list
)
async def load_participants_for_projects(
ids: list[ProjectId], filter_: ParticipantFilter, connection
ids: list[ProjectId], filter_: ParticipantFilter, connection: Connection
) -> dict[ProjectId, list[ParticipantInternal]]:
"""
Get all participants in a project
Expand All @@ -411,7 +406,7 @@ async def load_participants_for_projects(
async def load_analyses_for_sequencing_groups(
ids: list[int],
filter_: AnalysisFilter,
connection,
connection: Connection,
) -> dict[int, list[AnalysisInternal]]:
"""
Type: (sequencing_group_id: int, status?: AnalysisStatus, type?: str)
Expand All @@ -429,7 +424,7 @@ async def load_analyses_for_sequencing_groups(

@connected_data_loader(LoaderKeys.PHENOTYPES_FOR_PARTICIPANTS)
async def load_phenotypes_for_participants(
participant_ids: list[int], connection
participant_ids: list[int], connection: Connection
) -> list[dict]:
"""
Data loader for phenotypes for participants
Expand All @@ -443,7 +438,7 @@ async def load_phenotypes_for_participants(

@connected_data_loader(LoaderKeys.FAMILIES_FOR_IDS)
async def load_families_for_ids(
family_ids: list[int], connection
family_ids: list[int], connection: Connection
) -> list[FamilyInternal]:
"""
DataLoader: get_families_for_ids
Expand All @@ -456,7 +451,7 @@ async def load_families_for_ids(

@connected_data_loader(LoaderKeys.FAMILY_PARTICIPANTS_FOR_FAMILIES)
async def load_family_participants_for_families(
family_ids: list[int], connection
family_ids: list[int], connection: Connection
) -> list[list[PedRowInternal]]:
"""
DataLoader: get_family_participants_for_families
Expand All @@ -469,7 +464,7 @@ async def load_family_participants_for_families(

@connected_data_loader(LoaderKeys.FAMILY_PARTICIPANTS_FOR_PARTICIPANTS)
async def load_family_participants_for_participants(
participant_ids: list[int], connection
participant_ids: list[int], connection: Connection
) -> list[list[PedRowInternal]]:
"""data loader for family participants for participants
Expand All @@ -490,12 +485,21 @@ async def load_family_participants_for_participants(
return [fp_map.get(pid, []) for pid in participant_ids]


class GraphQLContext(TypedDict):
"""Basic dict type for GraphQL context to be passed to resolvers"""

loaders: dict[LoaderKeys, Any]
connection: Connection


async def get_context(
request: Request, connection=get_projectless_db_connection
): # pylint: disable=unused-argument
request: Request, # pylint: disable=unused-argument
connection: Connection = get_projectless_db_connection,
) -> GraphQLContext:
"""Get loaders / cache context for strawberyy GraphQL"""
mapped_loaders = {k: fn(connection) for k, fn in loaders.items()}

return {
'connection': connection,
**mapped_loaders,
'loaders': mapped_loaders,
}
Loading

0 comments on commit 66e233d

Please sign in to comment.