Skip to content

Commit

Permalink
Add participant phenotypes to graphql (#545)
Browse files Browse the repository at this point in the history
* Add participant phenotypes to graphql

* Add test for graphql phenotypes

* Fix unrelated linting issues

* Slight linting updates

* PR cleanup

---------

Co-authored-by: Michael Franklin <[email protected]>
  • Loading branch information
illusional and illusional authored Sep 18, 2023
1 parent 57ca4d3 commit 78cafc7
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 31 deletions.
38 changes: 27 additions & 11 deletions api/graphql/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,26 @@
from db.python.connect import NotFoundError
from db.python.layers import (
AnalysisLayer,
SampleLayer,
AssayLayer,
FamilyLayer,
ParticipantLayer,
SampleLayer,
SequencingGroupLayer,
FamilyLayer,
)
from db.python.tables.analysis import AnalysisFilter
from db.python.tables.assay import AssayFilter
from db.python.tables.project import ProjectPermissionsTable
from db.python.tables.sample import SampleFilter
from db.python.tables.sequencing_group import SequencingGroupFilter
from db.python.utils import ProjectId, GenericFilter
from db.python.utils import GenericFilter, ProjectId
from models.models import (
AssayInternal,
SampleInternal,
SequencingGroupInternal,
AnalysisInternal,
ParticipantInternal,
AssayInternal,
FamilyInternal,
ParticipantInternal,
Project,
SampleInternal,
SequencingGroupInternal,
)


Expand All @@ -53,6 +53,8 @@ class LoaderKeys(enum.Enum):
SAMPLES_FOR_PARTICIPANTS = 'samples_for_participants'
SAMPLES_FOR_PROJECTS = 'samples_for_projects'

PHENOTYPES_FOR_PARTICIPANTS = 'phenotypes_for_participants'

PARTICIPANTS_FOR_IDS = 'participants_for_ids'
PARTICIPANTS_FOR_FAMILIES = 'participants_for_families'
PARTICIPANTS_FOR_PROJECTS = 'participants_for_projects'
Expand Down Expand Up @@ -291,9 +293,7 @@ async def load_participants_for_ids(
p_by_id = {p.id: p for p in persons}
missing_pids = set(participant_ids) - set(p_by_id.keys())
if missing_pids:
raise NotFoundError(
f'Could not find participants with ids {missing_pids}'
)
raise NotFoundError(f'Could not find participants with ids {missing_pids}')
return [p_by_id.get(p) for p in participant_ids]


Expand Down Expand Up @@ -400,7 +400,23 @@ async def load_analyses_for_sequencing_groups(
return by_sg_id


async def get_context(request: Request, connection=get_projectless_db_connection): # pylint: disable=unused-argument
@connected_data_loader(LoaderKeys.PHENOTYPES_FOR_PARTICIPANTS)
async def load_phenotypes_for_participants(
participant_ids: list[int], connection
) -> list[dict]:
"""
Data loader for phenotypes for participants
"""
player = ParticipantLayer(connection)
participant_phenotypes = await player.get_phenotypes_for_participants(
participant_ids=participant_ids
)
return [participant_phenotypes.get(pid, {}) for pid in participant_ids]


async def get_context(
request: Request, connection=get_projectless_db_connection
): # pylint: disable=unused-argument
"""Get loaders / cache context for strawberyy GraphQL"""
mapped_loaders = {k: fn(connection) for k, fn in loaders.items()}
return {
Expand Down
33 changes: 16 additions & 17 deletions api/graphql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,10 @@
from strawberry.fastapi import GraphQLRouter
from strawberry.types import Info

from api.graphql.filters import (
GraphQLFilter,
GraphQLMetaFilter,
)
from api.graphql.loaders import (
get_context,
LoaderKeys,
)
from api.graphql.filters import GraphQLFilter, GraphQLMetaFilter
from api.graphql.loaders import LoaderKeys, get_context
from db.python import enum_tables
from db.python.layers import AnalysisLayer, SequencingGroupLayer, SampleLayer
from db.python.layers import AnalysisLayer, SampleLayer, SequencingGroupLayer
from db.python.layers.assay import AssayLayer
from db.python.layers.family import FamilyLayer
from db.python.tables.analysis import AnalysisFilter
Expand All @@ -34,21 +28,19 @@
from db.python.utils import GenericFilter
from models.enums import AnalysisStatus
from models.models import (
SampleInternal,
ParticipantInternal,
Project,
AnalysisInternal,
AssayInternal,
FamilyInternal,
ParticipantInternal,
Project,
SampleInternal,
SequencingGroupInternal,
AssayInternal,
)
from models.models.sample import sample_id_transform_to_raw
from models.utils.sample_id_format import (
sample_id_format,
)
from models.utils.sample_id_format import sample_id_format
from models.utils.sequencing_group_id_format import (
sequencing_group_id_transform_to_raw,
sequencing_group_id_format,
sequencing_group_id_transform_to_raw,
)

enum_methods = {}
Expand Down Expand Up @@ -336,6 +328,13 @@ async def samples(
samples = await info.context[LoaderKeys.SAMPLES_FOR_PARTICIPANTS].load(q)
return [GraphQLSample.from_internal(s) for s in samples]

@strawberry.field
async def phenotypes(
self, info: Info, root: 'GraphQLParticipant'
) -> strawberry.scalars.JSON:
loader = info.context[LoaderKeys.PHENOTYPES_FOR_PARTICIPANTS]
return await loader.load(root.id)

@strawberry.field
async def families(
self, info: Info, root: 'GraphQLParticipant'
Expand Down
30 changes: 28 additions & 2 deletions db/python/layers/participant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import re
from collections import defaultdict
from enum import Enum
from typing import Dict, List, Tuple, Optional, Any
from typing import Any, Dict, List, Optional, Tuple

from db.python.connect import NotFoundError, NoOpAenter
from db.python.connect import NoOpAenter, NotFoundError
from db.python.layers.base import BaseLayer
from db.python.layers.sample import SampleLayer
from db.python.tables.family import FamilyTable
Expand Down Expand Up @@ -335,6 +335,21 @@ async def fill_in_missing_participants(self):

return f'Updated {len(sample_ids_to_update)} records'

async def insert_participant_phenotypes(
self, participant_phenotypes: dict[int, dict]
):
"""
Insert participant phenotypes, with format: {pid: {key: value}}
"""
ppttable = ParticipantPhenotypeTable(self.connection)
return await ppttable.add_key_value_rows(
[
(pid, pk, pv)
for pid, phenotypes in participant_phenotypes.items()
for pk, pv in phenotypes.items()
]
)

async def generic_individual_metadata_importer(
self,
headers: List[str],
Expand Down Expand Up @@ -653,6 +668,17 @@ async def update_many_participant_external_ids(

# region PHENOTYPES / SEQR

async def get_phenotypes_for_participants(
self, participant_ids: list[int]
) -> dict[int, dict[str, Any]]:
"""
Get phenotypes for participants keyed by by pid
"""
ppttable = ParticipantPhenotypeTable(self.connection)
return await ppttable.get_key_value_rows_for_participant_ids(
participant_ids=participant_ids
)

async def get_seqr_individual_template(
self,
project: int,
Expand Down
2 changes: 1 addition & 1 deletion scripts/parse_ped.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def main(ped_file_path: str, project: str):
fapi = FamilyApi()

# pylint: disable=no-member
with AnyPath(ped_file_path).open() as ped_file: # type: ignore
with AnyPath(ped_file_path).open() as ped_file:
fapi.import_pedigree(
file=ped_file,
has_header=True,
Expand Down
27 changes: 27 additions & 0 deletions test/test_graphql.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,30 @@ async def test_sg_analyses_query(self):
self.assertIn('id', analyses[0])
self.assertIn('meta', analyses[0])
self.assertIn('output', analyses[0])

@run_as_sync
async def test_participant_phenotypes(self):
"""
Test getting participant phentypes in graphql
"""
# insert participant
p = await self.player.upsert_participant(
ParticipantUpsertInternal(external_id='Demeter', meta={}, samples=[])
)

phenotypes = {'phenotype1': 'value1', 'phenotype2': {'number': 123}}
# insert participant_phenotypes
await self.player.insert_participant_phenotypes({p.id: phenotypes})

q = """
query MyQuery($pid: Int!) {
participant(id: $pid) {
phenotypes
}
}"""

resp = await self.run_graphql_query_async(q, {'pid': p.id})

self.assertIn('participant', resp)
self.assertIn('phenotypes', resp['participant'])
self.assertDictEqual(phenotypes, resp['participant']['phenotypes'])

0 comments on commit 78cafc7

Please sign in to comment.