Skip to content

Commit

Permalink
Merge pull request #829 from populationgenomics/permissions-system-up…
Browse files Browse the repository at this point in the history
…date

Metamist permissions system update
  • Loading branch information
dancoates authored Jul 8, 2024
2 parents fd2c933 + 56054fa commit ec54f43
Show file tree
Hide file tree
Showing 59 changed files with 1,528 additions and 1,506 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 7.1.1
current_version = 7.2.0
commit = True
tag = False
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>[A-z0-9-]+)
Expand Down
4 changes: 2 additions & 2 deletions api/graphql/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ class GraphQLFilter(Generic[T]):
icontains: T | None = None
startswith: T | None = None

def all_values(self):
def all_values(self) -> list[T]:
"""
Get all values used anywhere in a filter, useful for getting values to map later
"""
v = []
v: list[T] = []
if self.eq:
v.append(self.eq)
if self.in_:
Expand Down
88 changes: 46 additions & 42 deletions api/graphql/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import dataclasses
import enum
from collections import defaultdict
from typing import Any
from typing import Any, TypedDict

from fastapi import Request
from strawberry.dataloader import DataLoader

from api.utils import get_projectless_db_connection, group_by
from api.utils import group_by
from api.utils.db import get_projectless_db_connection
from db.python.connect import Connection
from db.python.filters import GenericFilter, get_hashable_value
from db.python.layers import (
AnalysisLayer,
Expand All @@ -24,7 +26,6 @@
from db.python.tables.assay import AssayFilter
from db.python.tables.family import FamilyFilter
from db.python.tables.participant import ParticipantFilter
from db.python.tables.project import ProjectPermissionsTable
from db.python.tables.sample import SampleFilter
from db.python.tables.sequencing_group import SequencingGroupFilter
from db.python.utils import NotFoundError
Expand Down Expand Up @@ -79,14 +80,14 @@ class LoaderKeys(enum.Enum):
SEQUENCING_GROUPS_FOR_ANALYSIS = 'sequencing_groups_for_analysis'


loaders = {}
loaders: dict[LoaderKeys, Any] = {}


def connected_data_loader(id_: LoaderKeys, cache=True):
def connected_data_loader(id_: LoaderKeys, cache: bool = True):
"""Provide connection to a data loader"""

def connected_data_loader_caller(fn):
def inner(connection):
def inner(connection: Connection):
async def wrapped(*args, **kwargs):
return await fn(*args, **kwargs, connection=connection)

Expand All @@ -110,7 +111,7 @@ def connected_data_loader_with_params(
"""

def connected_data_loader_caller(fn):
def inner(connection):
def inner(connection: Connection):
async def wrapped(query: list[dict[str, Any]]) -> list[Any]:
by_key: dict[tuple, Any] = {}

Expand Down Expand Up @@ -159,7 +160,7 @@ async def wrapped(query: list[dict[str, Any]]) -> list[Any]:

@connected_data_loader(LoaderKeys.AUDIT_LOGS_BY_IDS)
async def load_audit_logs_by_ids(
audit_log_ids: list[int], connection
audit_log_ids: list[int], connection: Connection
) -> list[AuditLogInternal | None]:
"""
DataLoader: get_audit_logs_by_ids
Expand All @@ -172,7 +173,7 @@ async def load_audit_logs_by_ids(

@connected_data_loader(LoaderKeys.AUDIT_LOGS_BY_ANALYSIS_IDS)
async def load_audit_logs_by_analysis_ids(
analysis_ids: list[int], connection
analysis_ids: list[int], connection: Connection
) -> list[list[AuditLogInternal]]:
"""
DataLoader: get_audit_logs_by_analysis_ids
Expand All @@ -184,7 +185,7 @@ async def load_audit_logs_by_analysis_ids(

@connected_data_loader_with_params(LoaderKeys.ASSAYS_FOR_SAMPLES, default_factory=list)
async def load_assays_by_samples(
connection, ids, filter: AssayFilter
connection: Connection, ids, filter: AssayFilter
) -> dict[int, list[AssayInternal]]:
"""
DataLoader: get_assays_for_sample_ids
Expand All @@ -200,7 +201,7 @@ async def load_assays_by_samples(

@connected_data_loader(LoaderKeys.ASSAYS_FOR_SEQUENCING_GROUPS)
async def load_assays_by_sequencing_groups(
sequencing_group_ids: list[int], connection
sequencing_group_ids: list[int], connection: Connection
) -> list[list[AssayInternal]]:
"""
Get all assays belong to the sequencing groups
Expand All @@ -209,7 +210,7 @@ async def load_assays_by_sequencing_groups(

# group by all last fields, in case we add more
assays = await assaylayer.get_assays_for_sequencing_group_ids(
sequencing_group_ids=sequencing_group_ids, check_project_ids=False
sequencing_group_ids=sequencing_group_ids
)

return [assays.get(sg, []) for sg in sequencing_group_ids]
Expand All @@ -219,7 +220,7 @@ async def load_assays_by_sequencing_groups(
LoaderKeys.SAMPLES_FOR_PARTICIPANTS, default_factory=list
)
async def load_samples_for_participant_ids(
ids: list[int], filter: SampleFilter, connection
ids: list[int], filter: SampleFilter, connection: Connection
) -> dict[int, list[SampleInternal]]:
"""
DataLoader: get_samples_for_participant_ids
Expand All @@ -232,7 +233,7 @@ async def load_samples_for_participant_ids(

@connected_data_loader(LoaderKeys.SEQUENCING_GROUPS_FOR_IDS)
async def load_sequencing_groups_for_ids(
sequencing_group_ids: list[int], connection
sequencing_group_ids: list[int], connection: Connection
) -> list[SequencingGroupInternal]:
"""
DataLoader: get_sequencing_groups_by_ids
Expand All @@ -249,7 +250,7 @@ async def load_sequencing_groups_for_ids(
LoaderKeys.SEQUENCING_GROUPS_FOR_SAMPLES, default_factory=list
)
async def load_sequencing_groups_for_samples(
connection, ids: list[int], filter: SequencingGroupFilter
connection: Connection, ids: list[int], filter: SequencingGroupFilter
) -> dict[int, list[SequencingGroupInternal]]:
"""
Has format [(sample_id: int, sequencing_type?: string)]
Expand All @@ -270,7 +271,7 @@ async def load_sequencing_groups_for_samples(

@connected_data_loader(LoaderKeys.SAMPLES_FOR_IDS)
async def load_samples_for_ids(
sample_ids: list[int], connection
sample_ids: list[int], connection: Connection
) -> list[SampleInternal]:
"""
DataLoader: get_samples_for_ids
Expand All @@ -286,7 +287,7 @@ async def load_samples_for_ids(
LoaderKeys.SAMPLES_FOR_PROJECTS, default_factory=list
)
async def load_samples_for_projects(
connection, ids: list[ProjectId], filter: SampleFilter
connection: Connection, ids: list[ProjectId], filter: SampleFilter
):
"""
DataLoader: get_samples_for_project_ids
Expand All @@ -300,7 +301,7 @@ async def load_samples_for_projects(

@connected_data_loader(LoaderKeys.PARTICIPANTS_FOR_IDS)
async def load_participants_for_ids(
participant_ids: list[int], connection
participant_ids: list[int], connection: Connection
) -> list[ParticipantInternal]:
"""
DataLoader: get_participants_by_ids
Expand All @@ -318,7 +319,7 @@ async def load_participants_for_ids(

@connected_data_loader(LoaderKeys.SEQUENCING_GROUPS_FOR_ANALYSIS)
async def load_sequencing_groups_for_analysis_ids(
analysis_ids: list[int], connection
analysis_ids: list[int], connection: Connection
) -> list[list[SequencingGroupInternal]]:
"""
DataLoader: get_samples_for_analysis_ids
Expand All @@ -333,7 +334,7 @@ async def load_sequencing_groups_for_analysis_ids(
LoaderKeys.SEQUENCING_GROUPS_FOR_PROJECTS, default_factory=list
)
async def load_sequencing_groups_for_project_ids(
ids: list[int], filter: SequencingGroupFilter, connection
ids: list[int], filter: SequencingGroupFilter, connection: Connection
) -> dict[int, list[SequencingGroupInternal]]:
"""
DataLoader: get_sequencing_groups_for_project_ids
Expand All @@ -347,39 +348,33 @@ async def load_sequencing_groups_for_project_ids(


@connected_data_loader(LoaderKeys.PROJECTS_FOR_IDS)
async def load_projects_for_ids(project_ids: list[int], connection) -> list[Project]:
async def load_projects_for_ids(
project_ids: list[int], connection: Connection
) -> list[Project]:
"""
Get projects by IDs
"""
pttable = ProjectPermissionsTable(connection)
projects = await pttable.get_and_check_access_to_projects_for_ids(
user=connection.author, project_ids=project_ids, readonly=True
)

p_by_id = {p.id: p for p in projects}
projects = [p_by_id.get(p) for p in project_ids]
projects = [connection.project_id_map.get(p) for p in project_ids]

return [p for p in projects if p is not None]


@connected_data_loader(LoaderKeys.FAMILIES_FOR_PARTICIPANTS)
async def load_families_for_participants(
participant_ids: list[int], connection
participant_ids: list[int], connection: Connection
) -> list[list[FamilyInternal]]:
"""
Get families of participants, noting a participant can be in multiple families
"""
flayer = FamilyLayer(connection)

fam_map = await flayer.get_families_by_participants(
participant_ids=participant_ids, check_project_ids=False
)
fam_map = await flayer.get_families_by_participants(participant_ids=participant_ids)
return [fam_map.get(p, []) for p in participant_ids]


@connected_data_loader(LoaderKeys.PARTICIPANTS_FOR_FAMILIES)
async def load_participants_for_families(
family_ids: list[int], connection
family_ids: list[int], connection: Connection
) -> list[list[ParticipantInternal]]:
"""Get all participants in a family, doesn't include affected statuses"""
player = ParticipantLayer(connection)
Expand All @@ -391,7 +386,7 @@ async def load_participants_for_families(
LoaderKeys.PARTICIPANTS_FOR_PROJECTS, default_factory=list
)
async def load_participants_for_projects(
ids: list[ProjectId], filter_: ParticipantFilter, connection
ids: list[ProjectId], filter_: ParticipantFilter, connection: Connection
) -> dict[ProjectId, list[ParticipantInternal]]:
"""
Get all participants in a project
Expand All @@ -411,7 +406,7 @@ async def load_participants_for_projects(
async def load_analyses_for_sequencing_groups(
ids: list[int],
filter_: AnalysisFilter,
connection,
connection: Connection,
) -> dict[int, list[AnalysisInternal]]:
"""
Type: (sequencing_group_id: int, status?: AnalysisStatus, type?: str)
Expand All @@ -429,7 +424,7 @@ async def load_analyses_for_sequencing_groups(

@connected_data_loader(LoaderKeys.PHENOTYPES_FOR_PARTICIPANTS)
async def load_phenotypes_for_participants(
participant_ids: list[int], connection
participant_ids: list[int], connection: Connection
) -> list[dict]:
"""
Data loader for phenotypes for participants
Expand All @@ -443,7 +438,7 @@ async def load_phenotypes_for_participants(

@connected_data_loader(LoaderKeys.FAMILIES_FOR_IDS)
async def load_families_for_ids(
family_ids: list[int], connection
family_ids: list[int], connection: Connection
) -> list[FamilyInternal]:
"""
DataLoader: get_families_for_ids
Expand All @@ -456,7 +451,7 @@ async def load_families_for_ids(

@connected_data_loader(LoaderKeys.FAMILY_PARTICIPANTS_FOR_FAMILIES)
async def load_family_participants_for_families(
family_ids: list[int], connection
family_ids: list[int], connection: Connection
) -> list[list[PedRowInternal]]:
"""
DataLoader: get_family_participants_for_families
Expand All @@ -469,7 +464,7 @@ async def load_family_participants_for_families(

@connected_data_loader(LoaderKeys.FAMILY_PARTICIPANTS_FOR_PARTICIPANTS)
async def load_family_participants_for_participants(
participant_ids: list[int], connection
participant_ids: list[int], connection: Connection
) -> list[list[PedRowInternal]]:
"""data loader for family participants for participants
Expand All @@ -490,12 +485,21 @@ async def load_family_participants_for_participants(
return [fp_map.get(pid, []) for pid in participant_ids]


class GraphQLContext(TypedDict):
"""Basic dict type for GraphQL context to be passed to resolvers"""

loaders: dict[LoaderKeys, Any]
connection: Connection


async def get_context(
request: Request, connection=get_projectless_db_connection
): # pylint: disable=unused-argument
request: Request, # pylint: disable=unused-argument
connection: Connection = get_projectless_db_connection,
) -> GraphQLContext:
"""Get loaders / cache context for strawberyy GraphQL"""
mapped_loaders = {k: fn(connection) for k, fn in loaders.items()}

return {
'connection': connection,
**mapped_loaders,
'loaders': mapped_loaders,
}
Loading

0 comments on commit ec54f43

Please sign in to comment.