Skip to content

Commit

Permalink
Refactor families + add family participants to graphql (#740)
Browse files Browse the repository at this point in the history
* Initial commit

* Implement more family_participant + remove schema types

* Genercise the querying for families

* Cleanup for the family refactor

* Fix tests + filters

* Add codecov test value

* Apply suggestions from code review

* Fix minor formatting oddity

* Linting

* Add a GraphQLMetaFilter test + improve filter hashing

* Apply review feedback

* Linting

---------

Co-authored-by: Michael Franklin <[email protected]>
  • Loading branch information
illusional and illusional authored May 1, 2024
1 parent 6f46082 commit 5451b07
Show file tree
Hide file tree
Showing 15 changed files with 1,101 additions and 690 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ jobs:
uses: codecov/codecov-action@v3
with:
files: ./coverage.xml
token: ${{ secrets.CODECOV_TOKEN }}


- name: "build web front-end"
run: |
Expand Down
59 changes: 42 additions & 17 deletions api/graphql/filters.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Any, Callable, Generic, TypeVar
from typing import Callable, Generic, TypeVar

import strawberry

from db.python.utils import GenericFilter
from db.python.utils import GenericFilter, GenericMetaFilter

T = TypeVar('T')
Y = TypeVar('Y')


@strawberry.input(description='Filter for GraphQL queries')
Expand Down Expand Up @@ -47,22 +48,8 @@ def all_values(self):

return v

def to_internal_filter(self, f: Callable[[T], Any] = None):
def to_internal_filter(self) -> GenericFilter[T]:
"""Convert from GraphQL to internal filter model"""

if f:
return GenericFilter(
eq=f(self.eq) if self.eq else None,
in_=list(map(f, self.in_)) if self.in_ else None,
nin=list(map(f, self.nin)) if self.nin else None,
gt=f(self.gt) if self.gt else None,
gte=f(self.gte) if self.gte else None,
lt=f(self.lt) if self.lt else None,
lte=f(self.lte) if self.lte else None,
contains=f(self.contains) if self.contains else None,
icontains=f(self.icontains) if self.icontains else None,
)

return GenericFilter(
eq=self.eq,
in_=self.in_,
Expand All @@ -75,5 +62,43 @@ def to_internal_filter(self, f: Callable[[T], Any] = None):
icontains=self.icontains,
)

def to_internal_filter_mapped(self, f: Callable[[T], Y]) -> GenericFilter[Y]:
"""
To internal filter, but apply a function to all values.
Separate this into a separate function to please linters and type checkers
"""
return GenericFilter(
eq=f(self.eq) if self.eq else None,
in_=list(map(f, self.in_)) if self.in_ else None,
nin=list(map(f, self.nin)) if self.nin else None,
gt=f(self.gt) if self.gt else None,
gte=f(self.gte) if self.gte else None,
lt=f(self.lt) if self.lt else None,
lte=f(self.lte) if self.lte else None,
contains=f(self.contains) if self.contains else None,
icontains=f(self.icontains) if self.icontains else None,
)


GraphQLMetaFilter = strawberry.scalars.JSON


def graphql_meta_filter_to_internal_filter(
f: GraphQLMetaFilter | None,
) -> GenericMetaFilter | None:
"""Convert from GraphQL to internal filter model
Args:
f (GraphQLMetaFilter | None): GraphQL filter
Returns:
GenericMetaFilter | None: internal filter
"""
if not f:
return None

d: GenericMetaFilter = {}
f_to_d: dict[str, GraphQLMetaFilter] = dict(f) # type: ignore
for k, v in f_to_d.items():
d[k] = GenericFilter(**v) if isinstance(v, dict) else GenericFilter(eq=v)
return d
88 changes: 61 additions & 27 deletions api/graphql/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
)
from db.python.tables.analysis import AnalysisFilter
from db.python.tables.assay import AssayFilter
from db.python.tables.family import FamilyFilter
from db.python.tables.project import ProjectPermissionsTable
from db.python.tables.sample import SampleFilter
from db.python.tables.sequencing_group import SequencingGroupFilter
from db.python.utils import GenericFilter, NotFoundError
from db.python.utils import GenericFilter, NotFoundError, get_hashable_value
from models.models import (
AnalysisInternal,
AssayInternal,
Expand All @@ -36,6 +37,7 @@
SequencingGroupInternal,
)
from models.models.audit_log import AuditLogInternal
from models.models.family import PedRowInternal


class LoaderKeys(enum.Enum):
Expand Down Expand Up @@ -65,6 +67,9 @@ class LoaderKeys(enum.Enum):
PARTICIPANTS_FOR_PROJECTS = 'participants_for_projects'

FAMILIES_FOR_PARTICIPANTS = 'families_for_participants'
FAMILY_PARTICIPANTS_FOR_FAMILIES = 'family_participants_for_families'
FAMILY_PARTICIPANTS_FOR_PARTICIPANTS = 'family_participants_for_participants'
FAMILIES_FOR_IDS = 'families_for_ids'

SEQUENCING_GROUPS_FOR_IDS = 'sequencing_groups_for_ids'
SEQUENCING_GROUPS_FOR_SAMPLES = 'sequencing_groups_for_samples'
Expand All @@ -91,31 +96,8 @@ async def wrapped(*args, **kwargs):
return connected_data_loader_caller


def _prepare_partial_value_for_hashing(value):
if value is None:
return None
if isinstance(value, (int, str, float, bool)):
return value
if isinstance(value, enum.Enum):
return value.value
if isinstance(value, list):
# let's see if later we need to prepare the values in the list
return tuple(value)
if isinstance(value, dict):
return tuple(
sorted(
((k, _prepare_partial_value_for_hashing(v)) for k, v in value.items()),
key=lambda x: x[0],
)
)

return hash(value)


def _get_connected_data_loader_partial_key(kwargs):
return _prepare_partial_value_for_hashing(
{k: v for k, v in kwargs.items() if k != 'id'}
)
def _get_connected_data_loader_partial_key(kwargs) -> tuple:
return get_hashable_value({k: v for k, v in kwargs.items() if k != 'id'}) # type: ignore


def connected_data_loader_with_params(
Expand Down Expand Up @@ -381,7 +363,10 @@ async def load_families_for_participants(
Get families of participants, noting a participant can be in multiple families
"""
flayer = FamilyLayer(connection)
fam_map = await flayer.get_families_by_participants(participant_ids=participant_ids)

fam_map = await flayer.get_families_by_participants(
participant_ids=participant_ids, check_project_ids=False
)
return [fam_map.get(p, []) for p in participant_ids]


Expand Down Expand Up @@ -449,6 +434,55 @@ async def load_phenotypes_for_participants(
return [participant_phenotypes.get(pid, {}) for pid in participant_ids]


@connected_data_loader(LoaderKeys.FAMILIES_FOR_IDS)
async def load_families_for_ids(
family_ids: list[int], connection
) -> list[FamilyInternal]:
"""
DataLoader: get_families_for_ids
"""
flayer = FamilyLayer(connection)
families = await flayer.query(FamilyFilter(id=GenericFilter(in_=family_ids)))
f_by_id = {f.id: f for f in families}
return [f_by_id[f] for f in family_ids]


@connected_data_loader(LoaderKeys.FAMILY_PARTICIPANTS_FOR_FAMILIES)
async def load_family_participants_for_families(
family_ids: list[int], connection
) -> list[list[PedRowInternal]]:
"""
DataLoader: get_family_participants_for_families
"""
flayer = FamilyLayer(connection)
fp_map = await flayer.get_family_participants_by_family_ids(family_ids)

return [fp_map.get(fid, []) for fid in family_ids]


@connected_data_loader(LoaderKeys.FAMILY_PARTICIPANTS_FOR_PARTICIPANTS)
async def load_family_participants_for_participants(
participant_ids: list[int], connection
) -> list[list[PedRowInternal]]:
"""data loader for family participants for participants
Args:
participant_ids (list[int]): list of internal participant ids
connection (_type_): (this is automatically filled in by the loader decorator)
Returns:
list[list[PedRowInternal]]: list of family participants for each participant
(in order)
"""
flayer = FamilyLayer(connection)
family_participants = await flayer.get_family_participants_for_participants(
participant_ids
)
fp_map = group_by(family_participants, lambda fp: fp.individual_id)

return [fp_map.get(pid, []) for pid in participant_ids]


async def get_context(
request: Request, connection=get_projectless_db_connection
): # pylint: disable=unused-argument
Expand Down
Loading

0 comments on commit 5451b07

Please sign in to comment.