Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of contributor role to metamist permissions #703

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions api/graphql/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ class GraphQLFilter(Generic[T]):
contains: T | None = None
icontains: T | None = None

def all_values(self):
def all_values(self) -> list[T]:
"""
Get all values used anywhere in a filter, useful for getting values to map later
"""
v = []
v: list[T] = []
if self.eq:
v.append(self.eq)
if self.in_:
Expand Down
88 changes: 46 additions & 42 deletions api/graphql/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
import dataclasses
import enum
from collections import defaultdict
from typing import Any
from typing import Any, TypedDict

from fastapi import Request
from strawberry.dataloader import DataLoader

from api.utils import get_projectless_db_connection, group_by
from api.utils import group_by
from api.utils.db import get_projectless_db_connection
from db.python.connect import Connection
from db.python.layers import (
AnalysisLayer,
AssayLayer,
Expand All @@ -22,7 +24,6 @@
from db.python.tables.analysis import AnalysisFilter
from db.python.tables.assay import AssayFilter
from db.python.tables.family import FamilyFilter
from db.python.tables.project import ProjectPermissionsTable
from db.python.tables.sample import SampleFilter
from db.python.tables.sequencing_group import SequencingGroupFilter
from db.python.utils import GenericFilter, NotFoundError, get_hashable_value
Expand Down Expand Up @@ -77,14 +78,14 @@ class LoaderKeys(enum.Enum):
SEQUENCING_GROUPS_FOR_ANALYSIS = 'sequencing_groups_for_analysis'


loaders = {}
loaders: dict[LoaderKeys, Any] = {}


def connected_data_loader(id_: LoaderKeys, cache=True):
def connected_data_loader(id_: LoaderKeys, cache: bool = True):
"""Provide connection to a data loader"""

def connected_data_loader_caller(fn):
def inner(connection):
def inner(connection: Connection):
async def wrapped(*args, **kwargs):
return await fn(*args, **kwargs, connection=connection)

Expand All @@ -108,7 +109,7 @@ def connected_data_loader_with_params(
"""

def connected_data_loader_caller(fn):
def inner(connection):
def inner(connection: Connection):
async def wrapped(query: list[dict[str, Any]]) -> list[Any]:
by_key: dict[tuple, Any] = {}

Expand Down Expand Up @@ -157,7 +158,7 @@ async def wrapped(query: list[dict[str, Any]]) -> list[Any]:

@connected_data_loader(LoaderKeys.AUDIT_LOGS_BY_IDS)
async def load_audit_logs_by_ids(
audit_log_ids: list[int], connection
audit_log_ids: list[int], connection: Connection
) -> list[AuditLogInternal | None]:
"""
DataLoader: get_audit_logs_by_ids
Expand All @@ -170,7 +171,7 @@ async def load_audit_logs_by_ids(

@connected_data_loader(LoaderKeys.AUDIT_LOGS_BY_ANALYSIS_IDS)
async def load_audit_logs_by_analysis_ids(
analysis_ids: list[int], connection
analysis_ids: list[int], connection: Connection
) -> list[list[AuditLogInternal]]:
"""
DataLoader: get_audit_logs_by_analysis_ids
Expand All @@ -182,7 +183,7 @@ async def load_audit_logs_by_analysis_ids(

@connected_data_loader_with_params(LoaderKeys.ASSAYS_FOR_SAMPLES, default_factory=list)
async def load_assays_by_samples(
connection, ids, filter: AssayFilter
connection: Connection, ids, filter: AssayFilter
) -> dict[int, list[AssayInternal]]:
"""
DataLoader: get_assays_for_sample_ids
Expand All @@ -198,7 +199,7 @@ async def load_assays_by_samples(

@connected_data_loader(LoaderKeys.ASSAYS_FOR_SEQUENCING_GROUPS)
async def load_assays_by_sequencing_groups(
sequencing_group_ids: list[int], connection
sequencing_group_ids: list[int], connection: Connection
) -> list[list[AssayInternal]]:
"""
Get all assays belong to the sequencing groups
Expand All @@ -207,7 +208,7 @@ async def load_assays_by_sequencing_groups(

# group by all last fields, in case we add more
assays = await assaylayer.get_assays_for_sequencing_group_ids(
sequencing_group_ids=sequencing_group_ids, check_project_ids=False
sequencing_group_ids=sequencing_group_ids
)

return [assays.get(sg, []) for sg in sequencing_group_ids]
Expand All @@ -217,7 +218,7 @@ async def load_assays_by_sequencing_groups(
LoaderKeys.SAMPLES_FOR_PARTICIPANTS, default_factory=list
)
async def load_samples_for_participant_ids(
ids: list[int], filter: SampleFilter, connection
ids: list[int], filter: SampleFilter, connection: Connection
) -> dict[int, list[SampleInternal]]:
"""
DataLoader: get_samples_for_participant_ids
Expand All @@ -230,7 +231,7 @@ async def load_samples_for_participant_ids(

@connected_data_loader(LoaderKeys.SEQUENCING_GROUPS_FOR_IDS)
async def load_sequencing_groups_for_ids(
sequencing_group_ids: list[int], connection
sequencing_group_ids: list[int], connection: Connection
) -> list[SequencingGroupInternal]:
"""
DataLoader: get_sequencing_groups_by_ids
Expand All @@ -247,7 +248,7 @@ async def load_sequencing_groups_for_ids(
LoaderKeys.SEQUENCING_GROUPS_FOR_SAMPLES, default_factory=list
)
async def load_sequencing_groups_for_samples(
connection, ids: list[int], filter: SequencingGroupFilter
connection: Connection, ids: list[int], filter: SequencingGroupFilter
) -> dict[int, list[SequencingGroupInternal]]:
"""
Has format [(sample_id: int, sequencing_type?: string)]
Expand All @@ -263,7 +264,7 @@ async def load_sequencing_groups_for_samples(

@connected_data_loader(LoaderKeys.SAMPLES_FOR_IDS)
async def load_samples_for_ids(
sample_ids: list[int], connection
sample_ids: list[int], connection: Connection
) -> list[SampleInternal]:
"""
DataLoader: get_samples_for_ids
Expand All @@ -279,7 +280,7 @@ async def load_samples_for_ids(
LoaderKeys.SAMPLES_FOR_PROJECTS, default_factory=list
)
async def load_samples_for_projects(
connection, ids: list[ProjectId], filter: SampleFilter
connection: Connection, ids: list[ProjectId], filter: SampleFilter
):
"""
DataLoader: get_samples_for_project_ids
Expand All @@ -293,7 +294,7 @@ async def load_samples_for_projects(

@connected_data_loader(LoaderKeys.PARTICIPANTS_FOR_IDS)
async def load_participants_for_ids(
participant_ids: list[int], connection
participant_ids: list[int], connection: Connection
) -> list[ParticipantInternal]:
"""
DataLoader: get_participants_by_ids
Expand All @@ -311,7 +312,7 @@ async def load_participants_for_ids(

@connected_data_loader(LoaderKeys.SEQUENCING_GROUPS_FOR_ANALYSIS)
async def load_sequencing_groups_for_analysis_ids(
analysis_ids: list[int], connection
analysis_ids: list[int], connection: Connection
) -> list[list[SequencingGroupInternal]]:
"""
DataLoader: get_samples_for_analysis_ids
Expand All @@ -326,7 +327,7 @@ async def load_sequencing_groups_for_analysis_ids(
LoaderKeys.SEQUENCING_GROUPS_FOR_PROJECTS, default_factory=list
)
async def load_sequencing_groups_for_project_ids(
ids: list[int], filter: SequencingGroupFilter, connection
ids: list[int], filter: SequencingGroupFilter, connection: Connection
) -> dict[int, list[SequencingGroupInternal]]:
"""
DataLoader: get_sequencing_groups_for_project_ids
Expand All @@ -340,39 +341,33 @@ async def load_sequencing_groups_for_project_ids(


@connected_data_loader(LoaderKeys.PROJECTS_FOR_IDS)
async def load_projects_for_ids(project_ids: list[int], connection) -> list[Project]:
async def load_projects_for_ids(
project_ids: list[int], connection: Connection
) -> list[Project]:
"""
Get projects by IDs
"""
pttable = ProjectPermissionsTable(connection)
projects = await pttable.get_and_check_access_to_projects_for_ids(
user=connection.author, project_ids=project_ids, readonly=True
)

p_by_id = {p.id: p for p in projects}
projects = [p_by_id.get(p) for p in project_ids]
projects = [connection.project_id_map.get(p) for p in project_ids]

return [p for p in projects if p is not None]


@connected_data_loader(LoaderKeys.FAMILIES_FOR_PARTICIPANTS)
async def load_families_for_participants(
participant_ids: list[int], connection
participant_ids: list[int], connection: Connection
) -> list[list[FamilyInternal]]:
"""
Get families of participants, noting a participant can be in multiple families
"""
flayer = FamilyLayer(connection)

fam_map = await flayer.get_families_by_participants(
participant_ids=participant_ids, check_project_ids=False
)
fam_map = await flayer.get_families_by_participants(participant_ids=participant_ids)
return [fam_map.get(p, []) for p in participant_ids]


@connected_data_loader(LoaderKeys.PARTICIPANTS_FOR_FAMILIES)
async def load_participants_for_families(
family_ids: list[int], connection
family_ids: list[int], connection: Connection
) -> list[list[ParticipantInternal]]:
"""Get all participants in a family, doesn't include affected statuses"""
player = ParticipantLayer(connection)
Expand All @@ -382,7 +377,7 @@ async def load_participants_for_families(

@connected_data_loader(LoaderKeys.PARTICIPANTS_FOR_PROJECTS)
async def load_participants_for_projects(
project_ids: list[ProjectId], connection
project_ids: list[ProjectId], connection: Connection
) -> list[list[ParticipantInternal]]:
"""
Get all participants in a project
Expand All @@ -404,7 +399,7 @@ async def load_participants_for_projects(
async def load_analyses_for_sequencing_groups(
ids: list[int],
filter_: AnalysisFilter,
connection,
connection: Connection,
) -> dict[int, list[AnalysisInternal]]:
"""
Type: (sequencing_group_id: int, status?: AnalysisStatus, type?: str)
Expand All @@ -422,7 +417,7 @@ async def load_analyses_for_sequencing_groups(

@connected_data_loader(LoaderKeys.PHENOTYPES_FOR_PARTICIPANTS)
async def load_phenotypes_for_participants(
participant_ids: list[int], connection
participant_ids: list[int], connection: Connection
) -> list[dict]:
"""
Data loader for phenotypes for participants
Expand All @@ -436,7 +431,7 @@ async def load_phenotypes_for_participants(

@connected_data_loader(LoaderKeys.FAMILIES_FOR_IDS)
async def load_families_for_ids(
family_ids: list[int], connection
family_ids: list[int], connection: Connection
) -> list[FamilyInternal]:
"""
DataLoader: get_families_for_ids
Expand All @@ -449,7 +444,7 @@ async def load_families_for_ids(

@connected_data_loader(LoaderKeys.FAMILY_PARTICIPANTS_FOR_FAMILIES)
async def load_family_participants_for_families(
family_ids: list[int], connection
family_ids: list[int], connection: Connection
) -> list[list[PedRowInternal]]:
"""
DataLoader: get_family_participants_for_families
Expand All @@ -462,7 +457,7 @@ async def load_family_participants_for_families(

@connected_data_loader(LoaderKeys.FAMILY_PARTICIPANTS_FOR_PARTICIPANTS)
async def load_family_participants_for_participants(
participant_ids: list[int], connection
participant_ids: list[int], connection: Connection
) -> list[list[PedRowInternal]]:
"""data loader for family participants for participants

Expand All @@ -483,12 +478,21 @@ async def load_family_participants_for_participants(
return [fp_map.get(pid, []) for pid in participant_ids]


class GraphQLContext(TypedDict):
"""Basic dict type for GraphQL context to be passed to resolvers"""

loaders: dict[LoaderKeys, Any]
connection: Connection


async def get_context(
request: Request, connection=get_projectless_db_connection
): # pylint: disable=unused-argument
request: Request, # pylint: disable=unused-argument
connection: Connection = get_projectless_db_connection,
) -> GraphQLContext:
"""Get loaders / cache context for strawberyy GraphQL"""
mapped_loaders = {k: fn(connection) for k, fn in loaders.items()}

return {
'connection': connection,
**mapped_loaders,
'loaders': mapped_loaders,
}
Loading
Loading