diff --git a/api/graphql/loaders.py b/api/graphql/loaders.py index 2a54fc514..905297009 100644 --- a/api/graphql/loaders.py +++ b/api/graphql/loaders.py @@ -13,26 +13,26 @@ from db.python.connect import NotFoundError from db.python.layers import ( AnalysisLayer, - SampleLayer, AssayLayer, + FamilyLayer, ParticipantLayer, + SampleLayer, SequencingGroupLayer, - FamilyLayer, ) from db.python.tables.analysis import AnalysisFilter from db.python.tables.assay import AssayFilter from db.python.tables.project import ProjectPermissionsTable from db.python.tables.sample import SampleFilter from db.python.tables.sequencing_group import SequencingGroupFilter -from db.python.utils import ProjectId, GenericFilter +from db.python.utils import GenericFilter, ProjectId from models.models import ( - AssayInternal, - SampleInternal, - SequencingGroupInternal, AnalysisInternal, - ParticipantInternal, + AssayInternal, FamilyInternal, + ParticipantInternal, Project, + SampleInternal, + SequencingGroupInternal, ) @@ -53,6 +53,8 @@ class LoaderKeys(enum.Enum): SAMPLES_FOR_PARTICIPANTS = 'samples_for_participants' SAMPLES_FOR_PROJECTS = 'samples_for_projects' + PHENOTYPES_FOR_PARTICIPANTS = 'phenotypes_for_participants' + PARTICIPANTS_FOR_IDS = 'participants_for_ids' PARTICIPANTS_FOR_FAMILIES = 'participants_for_families' PARTICIPANTS_FOR_PROJECTS = 'participants_for_projects' @@ -291,9 +293,7 @@ async def load_participants_for_ids( p_by_id = {p.id: p for p in persons} missing_pids = set(participant_ids) - set(p_by_id.keys()) if missing_pids: - raise NotFoundError( - f'Could not find participants with ids {missing_pids}' - ) + raise NotFoundError(f'Could not find participants with ids {missing_pids}') return [p_by_id.get(p) for p in participant_ids] @@ -400,7 +400,23 @@ async def load_analyses_for_sequencing_groups( return by_sg_id -async def get_context(request: Request, connection=get_projectless_db_connection): # pylint: disable=unused-argument +@connected_data_loader(LoaderKeys.PHENOTYPES_FOR_PARTICIPANTS) +async def load_phenotypes_for_participants( + participant_ids: list[int], connection +) -> list[dict]: + """ + Data loader for phenotypes for participants + """ + player = ParticipantLayer(connection) + participant_phenotypes = await player.get_phenotypes_for_participants( + participant_ids=participant_ids + ) + return [participant_phenotypes.get(pid, {}) for pid in participant_ids] + + +async def get_context( + request: Request, connection=get_projectless_db_connection +): # pylint: disable=unused-argument """Get loaders / cache context for strawberyy GraphQL""" mapped_loaders = {k: fn(connection) for k, fn in loaders.items()} return { diff --git a/api/graphql/schema.py b/api/graphql/schema.py index 8befa8867..7255821ed 100644 --- a/api/graphql/schema.py +++ b/api/graphql/schema.py @@ -14,16 +14,10 @@ from strawberry.fastapi import GraphQLRouter from strawberry.types import Info -from api.graphql.filters import ( - GraphQLFilter, - GraphQLMetaFilter, -) -from api.graphql.loaders import ( - get_context, - LoaderKeys, -) +from api.graphql.filters import GraphQLFilter, GraphQLMetaFilter +from api.graphql.loaders import LoaderKeys, get_context from db.python import enum_tables -from db.python.layers import AnalysisLayer, SequencingGroupLayer, SampleLayer +from db.python.layers import AnalysisLayer, SampleLayer, SequencingGroupLayer from db.python.layers.assay import AssayLayer from db.python.layers.family import FamilyLayer from db.python.tables.analysis import AnalysisFilter @@ -34,21 +28,19 @@ from db.python.utils import GenericFilter from models.enums import AnalysisStatus from models.models import ( - SampleInternal, - ParticipantInternal, - Project, AnalysisInternal, + AssayInternal, FamilyInternal, + ParticipantInternal, + Project, + SampleInternal, SequencingGroupInternal, - AssayInternal, ) from models.models.sample import sample_id_transform_to_raw -from models.utils.sample_id_format import ( - sample_id_format, -) +from models.utils.sample_id_format import sample_id_format from models.utils.sequencing_group_id_format import ( - sequencing_group_id_transform_to_raw, sequencing_group_id_format, + sequencing_group_id_transform_to_raw, ) enum_methods = {} @@ -336,6 +328,13 @@ async def samples( samples = await info.context[LoaderKeys.SAMPLES_FOR_PARTICIPANTS].load(q) return [GraphQLSample.from_internal(s) for s in samples] + @strawberry.field + async def phenotypes( + self, info: Info, root: 'GraphQLParticipant' + ) -> strawberry.scalars.JSON: + loader = info.context[LoaderKeys.PHENOTYPES_FOR_PARTICIPANTS] + return await loader.load(root.id) + @strawberry.field async def families( self, info: Info, root: 'GraphQLParticipant' diff --git a/db/python/layers/participant.py b/db/python/layers/participant.py index 216427e31..44d6d4db2 100644 --- a/db/python/layers/participant.py +++ b/db/python/layers/participant.py @@ -2,9 +2,9 @@ import re from collections import defaultdict from enum import Enum -from typing import Dict, List, Tuple, Optional, Any +from typing import Any, Dict, List, Optional, Tuple -from db.python.connect import NotFoundError, NoOpAenter +from db.python.connect import NoOpAenter, NotFoundError from db.python.layers.base import BaseLayer from db.python.layers.sample import SampleLayer from db.python.tables.family import FamilyTable @@ -335,6 +335,21 @@ async def fill_in_missing_participants(self): return f'Updated {len(sample_ids_to_update)} records' + async def insert_participant_phenotypes( + self, participant_phenotypes: dict[int, dict] + ): + """ + Insert participant phenotypes, with format: {pid: {key: value}} + """ + ppttable = ParticipantPhenotypeTable(self.connection) + return await ppttable.add_key_value_rows( + [ + (pid, pk, pv) + for pid, phenotypes in participant_phenotypes.items() + for pk, pv in phenotypes.items() + ] + ) + async def generic_individual_metadata_importer( self, headers: List[str], @@ -653,6 +668,17 @@ async def update_many_participant_external_ids( # region PHENOTYPES / SEQR + async def get_phenotypes_for_participants( + self, participant_ids: list[int] + ) -> dict[int, dict[str, Any]]: + """ + Get phenotypes for participants keyed by by pid + """ + ppttable = ParticipantPhenotypeTable(self.connection) + return await ppttable.get_key_value_rows_for_participant_ids( + participant_ids=participant_ids + ) + async def get_seqr_individual_template( self, project: int, diff --git a/scripts/parse_ped.py b/scripts/parse_ped.py index f587e483b..831ddd402 100644 --- a/scripts/parse_ped.py +++ b/scripts/parse_ped.py @@ -15,7 +15,7 @@ def main(ped_file_path: str, project: str): fapi = FamilyApi() # pylint: disable=no-member - with AnyPath(ped_file_path).open() as ped_file: # type: ignore + with AnyPath(ped_file_path).open() as ped_file: fapi.import_pedigree( file=ped_file, has_header=True, diff --git a/test/test_graphql.py b/test/test_graphql.py index c61eed238..a52a5f947 100644 --- a/test/test_graphql.py +++ b/test/test_graphql.py @@ -229,3 +229,30 @@ async def test_sg_analyses_query(self): self.assertIn('id', analyses[0]) self.assertIn('meta', analyses[0]) self.assertIn('output', analyses[0]) + + @run_as_sync + async def test_participant_phenotypes(self): + """ + Test getting participant phentypes in graphql + """ + # insert participant + p = await self.player.upsert_participant( + ParticipantUpsertInternal(external_id='Demeter', meta={}, samples=[]) + ) + + phenotypes = {'phenotype1': 'value1', 'phenotype2': {'number': 123}} + # insert participant_phenotypes + await self.player.insert_participant_phenotypes({p.id: phenotypes}) + + q = """ +query MyQuery($pid: Int!) { + participant(id: $pid) { + phenotypes + } +}""" + + resp = await self.run_graphql_query_async(q, {'pid': p.id}) + + self.assertIn('participant', resp) + self.assertIn('phenotypes', resp['participant']) + self.assertDictEqual(phenotypes, resp['participant']['phenotypes'])