Skip to content

Commit

Permalink
proof of concept of allowing selecting of slices of meta in graphql
Browse files Browse the repository at this point in the history
  • Loading branch information
dancoates committed Jun 17, 2024
1 parent aa08309 commit 310dd85
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 11 deletions.
10 changes: 8 additions & 2 deletions api/graphql/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,11 @@ async def load_audit_logs_by_analysis_ids(

@connected_data_loader_with_params(LoaderKeys.ASSAYS_FOR_SAMPLES, default_factory=list)
async def load_assays_by_samples(
connection, ids, filter: AssayFilter
connection,
ids,
filter: AssayFilter,
include_meta: bool,
meta_slices: list[dict[str, Any]] | None,
) -> dict[int, list[AssayInternal]]:
"""
DataLoader: get_assays_for_sample_ids
Expand All @@ -191,7 +195,9 @@ async def load_assays_by_samples(
assaylayer = AssayLayer(connection)
# maybe this is dangerous, but I don't think it should matter
filter.sample_id = GenericFilter(in_=ids)
assays = await assaylayer.query(filter)
assays = await assaylayer.query(
filter, include_meta=include_meta, meta_slices=meta_slices
)
assay_map = group_by(assays, lambda a: a.sample_id)
return assay_map

Expand Down
27 changes: 26 additions & 1 deletion api/graphql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""
import datetime
from inspect import isclass
from typing import Any

import strawberry
from strawberry.extensions import QueryDepthLimiter
Expand Down Expand Up @@ -754,13 +755,28 @@ async def assays(
type: GraphQLFilter[str] | None = None,
meta: GraphQLMetaFilter | None = None,
) -> list['GraphQLAssay']:
selected_fields = info.selected_fields[0].selections
has_meta_selected = any(f.name == 'meta' for f in selected_fields)

# Find if there are any slices of meta selected, we can load these as part of the same query
meta_slices = [
{'path': f.arguments['path'], 'alias': f.alias}
for f in selected_fields
if f.name == 'metaValue'
]

loader_assays_for_sample_ids = info.context[LoaderKeys.ASSAYS_FOR_SAMPLES]
filter_ = AssayFilter(
type=type.to_internal_filter() if type else None,
meta=meta,
)
assays = await loader_assays_for_sample_ids.load(
{'id': root.internal_id, 'filter': filter_}
{
'id': root.internal_id,
'filter': filter_,
'include_meta': has_meta_selected,
'meta_slices': meta_slices,
}
)
return [GraphQLAssay.from_internal(assay) for assay in assays]

Expand Down Expand Up @@ -907,6 +923,7 @@ class GraphQLAssay:
external_ids: strawberry.scalars.JSON

sample_id: strawberry.Private[int]
meta_slices: strawberry.Private[dict[str, Any] | None]

@staticmethod
def from_internal(internal: AssayInternal) -> 'GraphQLAssay':
Expand All @@ -917,6 +934,7 @@ def from_internal(internal: AssayInternal) -> 'GraphQLAssay':
id=internal.id,
type=internal.type,
meta=internal.meta,
meta_slices=internal.meta_slices,
external_ids=internal.external_ids or {},
# internal
sample_id=internal.sample_id,
Expand All @@ -928,6 +946,13 @@ async def sample(self, info: Info, root: 'GraphQLAssay') -> GraphQLSample:
sample = await loader.load(root.sample_id)
return GraphQLSample.from_internal(sample)

@strawberry.field
async def metaValue(
self, info: Info, root: 'GraphQLAssay', path: str
) -> strawberry.scalars.JSON:
alias = info.selected_fields[0].alias
return root.meta_slices.get('_meta_' + alias) if root.meta_slices else None


@strawberry.type
class GraphQLAnalysisRunner:
Expand Down
12 changes: 10 additions & 2 deletions db/python/layers/assay.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,18 @@ def __init__(self, connection: Connection):
self.sampt: SampleTable = SampleTable(connection)

# GET
async def query(self, filter_: AssayFilter = None, check_project_id=True):
async def query(
self,
filter_: AssayFilter = None,
check_project_id=True,
include_meta: bool = True,
meta_slices: list[dict[str, Any]] | None = None,
):
"""Query for samples"""

projects, assays = await self.seqt.query(filter_)
projects, assays = await self.seqt.query(
filter_, include_meta=include_meta, meta_slices=meta_slices
)

if not assays:
return []
Expand Down
26 changes: 22 additions & 4 deletions db/python/tables/assay.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ class AssayTable(DbBase):
# region GETS

async def query(
self, filter_: AssayFilter
self,
filter_: AssayFilter,
include_meta: bool = True,
meta_slices: list[dict[str, Any]] | None = None,
) -> tuple[set[ProjectId], list[AssayInternal]]:
"""Query assays"""
sql_overides = {
Expand All @@ -83,15 +86,30 @@ async def query(
raise ValueError('Must provide a project if filtering by external_id')

conditions, values = filter_.to_sql(sql_overides)
keys = ', '.join(self.COMMON_GET_KEYS)
keys = ', '.join(
[k for k in self.COMMON_GET_KEYS if include_meta or k != 'a.meta']
)
meta_slice_keys = (
','.join(
[
f'JSON_VALUE(a.meta, \'{m["path"]}\') as _meta_{m["alias"]}'
for m in meta_slices
]
)
if meta_slices
else None
)

if meta_slice_keys:
meta_slice_keys = ',' + meta_slice_keys

_query = f"""
SELECT {keys}
SELECT {keys} {meta_slice_keys}
FROM assay a
LEFT JOIN sample s ON s.id = a.sample_id
LEFT JOIN assay_external_id aeid ON aeid.assay_id = a.id
WHERE {conditions}
"""

assay_rows = await self.connection.fetch_all(_query, values)

# this will unique on the id, which we want due to joining on 1:many eid table
Expand Down
9 changes: 7 additions & 2 deletions models/models/assay.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class AssayInternal(SMBase):
id: int | None
sample_id: int
meta: dict[str, Any] | None
meta_slices: dict[str, Any] | None
type: str
external_ids: dict[str, str] | None = {}

Expand All @@ -23,16 +24,18 @@ def __eq__(self, other):
return False

@staticmethod
def from_db(d: dict):
def from_db(d: dict[str, Any]):
"""Take DB mapping object, and return SampleSequencing"""
meta = d.pop('meta', None)
keys = [k for k in d]
meta_slices = {k: d.pop(k) for k in keys if k.startswith('_meta')}

if meta:
if isinstance(meta, bytes):
meta = meta.decode()
if isinstance(meta, str):
meta = json.loads(meta)
return AssayInternal(meta=meta, **d)
return AssayInternal(meta=meta, meta_slices=meta_slices, **d)

def to_external(self):
"""Convert to transport model"""
Expand Down Expand Up @@ -72,13 +75,15 @@ class Assay(SMBase):
external_ids: dict[str, str]
sample_id: str
meta: dict[str, Any]

type: str

def to_internal(self):
"""Convert to internal model"""
return AssayInternal(
id=self.id,
type=self.type,
meta_slices=None,
external_ids=self.external_ids,
sample_id=sample_id_transform_to_raw(self.sample_id),
meta=self.meta,
Expand Down

0 comments on commit 310dd85

Please sign in to comment.