diff --git a/api/graphql/loaders.py b/api/graphql/loaders.py index 2b4980a27..8f131ca95 100644 --- a/api/graphql/loaders.py +++ b/api/graphql/loaders.py @@ -182,7 +182,11 @@ async def load_audit_logs_by_analysis_ids( @connected_data_loader_with_params(LoaderKeys.ASSAYS_FOR_SAMPLES, default_factory=list) async def load_assays_by_samples( - connection, ids, filter: AssayFilter + connection, + ids, + filter: AssayFilter, + include_meta: bool, + meta_slices: list[dict[str, Any]] | None, ) -> dict[int, list[AssayInternal]]: """ DataLoader: get_assays_for_sample_ids @@ -191,7 +195,9 @@ async def load_assays_by_samples( assaylayer = AssayLayer(connection) # maybe this is dangerous, but I don't think it should matter filter.sample_id = GenericFilter(in_=ids) - assays = await assaylayer.query(filter) + assays = await assaylayer.query( + filter, include_meta=include_meta, meta_slices=meta_slices + ) assay_map = group_by(assays, lambda a: a.sample_id) return assay_map diff --git a/api/graphql/schema.py b/api/graphql/schema.py index 0fcbe90fb..5f460a6df 100644 --- a/api/graphql/schema.py +++ b/api/graphql/schema.py @@ -7,6 +7,7 @@ """ import datetime from inspect import isclass +from typing import Any import strawberry from strawberry.extensions import QueryDepthLimiter @@ -754,13 +755,28 @@ async def assays( type: GraphQLFilter[str] | None = None, meta: GraphQLMetaFilter | None = None, ) -> list['GraphQLAssay']: + selected_fields = info.selected_fields[0].selections + has_meta_selected = any(f.name == 'meta' for f in selected_fields) + + # Find if there are any slices of meta selected, we can load these as part of the same query + meta_slices = [ + {'path': f.arguments['path'], 'alias': f.alias} + for f in selected_fields + if f.name == 'metaValue' + ] + loader_assays_for_sample_ids = info.context[LoaderKeys.ASSAYS_FOR_SAMPLES] filter_ = AssayFilter( type=type.to_internal_filter() if type else None, meta=meta, ) assays = await loader_assays_for_sample_ids.load( - {'id': root.internal_id, 'filter': filter_} + { + 'id': root.internal_id, + 'filter': filter_, + 'include_meta': has_meta_selected, + 'meta_slices': meta_slices, + } ) return [GraphQLAssay.from_internal(assay) for assay in assays] @@ -907,6 +923,7 @@ class GraphQLAssay: external_ids: strawberry.scalars.JSON sample_id: strawberry.Private[int] + meta_slices: strawberry.Private[dict[str, Any] | None] @staticmethod def from_internal(internal: AssayInternal) -> 'GraphQLAssay': @@ -917,6 +934,7 @@ def from_internal(internal: AssayInternal) -> 'GraphQLAssay': id=internal.id, type=internal.type, meta=internal.meta, + meta_slices=internal.meta_slices, external_ids=internal.external_ids or {}, # internal sample_id=internal.sample_id, @@ -928,6 +946,13 @@ async def sample(self, info: Info, root: 'GraphQLAssay') -> GraphQLSample: sample = await loader.load(root.sample_id) return GraphQLSample.from_internal(sample) + @strawberry.field + async def metaValue( + self, info: Info, root: 'GraphQLAssay', path: str + ) -> strawberry.scalars.JSON: + alias = info.selected_fields[0].alias + return root.meta_slices.get('_meta_' + alias) if root.meta_slices else None + @strawberry.type class GraphQLAnalysisRunner: diff --git a/db/python/layers/assay.py b/db/python/layers/assay.py index 3328acd82..60bb4d2dc 100644 --- a/db/python/layers/assay.py +++ b/db/python/layers/assay.py @@ -17,10 +17,18 @@ def __init__(self, connection: Connection): self.sampt: SampleTable = SampleTable(connection) # GET - async def query(self, filter_: AssayFilter = None, check_project_id=True): + async def query( + self, + filter_: AssayFilter = None, + check_project_id=True, + include_meta: bool = True, + meta_slices: list[dict[str, Any]] | None = None, + ): """Query for samples""" - projects, assays = await self.seqt.query(filter_) + projects, assays = await self.seqt.query( + filter_, include_meta=include_meta, meta_slices=meta_slices + ) if not assays: return [] diff --git a/db/python/tables/assay.py b/db/python/tables/assay.py index 3416f7451..04d17fb45 100644 --- a/db/python/tables/assay.py +++ b/db/python/tables/assay.py @@ -67,7 +67,10 @@ class AssayTable(DbBase): # region GETS async def query( - self, filter_: AssayFilter + self, + filter_: AssayFilter, + include_meta: bool = True, + meta_slices: list[dict[str, Any]] | None = None, ) -> tuple[set[ProjectId], list[AssayInternal]]: """Query assays""" sql_overides = { @@ -83,15 +86,30 @@ async def query( raise ValueError('Must provide a project if filtering by external_id') conditions, values = filter_.to_sql(sql_overides) - keys = ', '.join(self.COMMON_GET_KEYS) + keys = ', '.join( + [k for k in self.COMMON_GET_KEYS if include_meta or k != 'a.meta'] + ) + meta_slice_keys = ( + ','.join( + [ + f'JSON_VALUE(a.meta, \'{m["path"]}\') as _meta_{m["alias"]}' + for m in meta_slices + ] + ) + if meta_slices + else None + ) + + if meta_slice_keys: + meta_slice_keys = ',' + meta_slice_keys + _query = f""" - SELECT {keys} + SELECT {keys} {meta_slice_keys} FROM assay a LEFT JOIN sample s ON s.id = a.sample_id LEFT JOIN assay_external_id aeid ON aeid.assay_id = a.id WHERE {conditions} """ - assay_rows = await self.connection.fetch_all(_query, values) # this will unique on the id, which we want due to joining on 1:many eid table diff --git a/models/models/assay.py b/models/models/assay.py index 139d60327..ad69a28ca 100644 --- a/models/models/assay.py +++ b/models/models/assay.py @@ -11,6 +11,7 @@ class AssayInternal(SMBase): id: int | None sample_id: int meta: dict[str, Any] | None + meta_slices: dict[str, Any] | None type: str external_ids: dict[str, str] | None = {} @@ -23,16 +24,18 @@ def __eq__(self, other): return False @staticmethod - def from_db(d: dict): + def from_db(d: dict[str, Any]): """Take DB mapping object, and return SampleSequencing""" meta = d.pop('meta', None) + keys = [k for k in d] + meta_slices = {k: d.pop(k) for k in keys if k.startswith('_meta')} if meta: if isinstance(meta, bytes): meta = meta.decode() if isinstance(meta, str): meta = json.loads(meta) - return AssayInternal(meta=meta, **d) + return AssayInternal(meta=meta, meta_slices=meta_slices, **d) def to_external(self): """Convert to transport model""" @@ -72,6 +75,7 @@ class Assay(SMBase): external_ids: dict[str, str] sample_id: str meta: dict[str, Any] + type: str def to_internal(self): @@ -79,6 +83,7 @@ def to_internal(self): return AssayInternal( id=self.id, type=self.type, + meta_slices=None, external_ids=self.external_ids, sample_id=sample_id_transform_to_raw(self.sample_id), meta=self.meta,