Skip to content

Commit

Permalink
Merge branch 'custom-cohorts' of github.com:populationgenomics/sample…
Browse files Browse the repository at this point in the history
…-metadata into custom-cohorts
  • Loading branch information
vivbak committed Jan 31, 2024
2 parents bc5e0b9 + 32fac2b commit 146e6b2
Show file tree
Hide file tree
Showing 12 changed files with 1,300 additions and 255 deletions.
41 changes: 29 additions & 12 deletions api/graphql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,7 @@ async def sample(
samples = await slayer.query(filter_)
return [GraphQLSample.from_internal(sample) for sample in samples]

# pylint: disable=too-many-arguments
@strawberry.field
async def sequencing_groups(
self,
Expand All @@ -749,6 +750,10 @@ async def sequencing_groups(
technology: GraphQLFilter[str] | None = None,
platform: GraphQLFilter[str] | None = None,
active_only: GraphQLFilter[bool] | None = None,
created_on: GraphQLFilter[datetime.date] | None = None,
assay_meta: GraphQLMetaFilter | None = None,
has_cram: bool | None = None,
has_gvcf: bool | None = None,
) -> list[GraphQLSequencingGroup]:
connection = info.context['connection']
sglayer = SequencingGroupLayer(connection)
Expand All @@ -766,21 +771,33 @@ async def sequencing_groups(
project_id_map = {p.name: p.id for p in projects}

filter_ = SequencingGroupFilter(
project=project.to_internal_filter(lambda val: project_id_map[val])
if project
else None,
sample_id=sample_id.to_internal_filter(sample_id_transform_to_raw)
if sample_id
else None,
id=id.to_internal_filter(sequencing_group_id_transform_to_raw)
if id
else None,
project=(
project.to_internal_filter(lambda val: project_id_map[val])
if project
else None
),
sample_id=(
sample_id.to_internal_filter(sample_id_transform_to_raw)
if sample_id
else None
),
id=(
id.to_internal_filter(sequencing_group_id_transform_to_raw)
if id
else None
),
type=type.to_internal_filter() if type else None,
technology=technology.to_internal_filter() if technology else None,
platform=platform.to_internal_filter() if platform else None,
active_only=active_only.to_internal_filter()
if active_only
else GenericFilter(eq=True),
active_only=(
active_only.to_internal_filter()
if active_only
else GenericFilter(eq=True)
),
created_on=created_on.to_internal_filter() if created_on else None,
assay_meta=assay_meta,
has_cram=has_cram,
has_gvcf=has_gvcf,
)
sgs = await sglayer.query(filter_)
return [GraphQLSequencingGroup.from_internal(sg) for sg in sgs]
Expand Down
114 changes: 106 additions & 8 deletions db/python/tables/sequencing_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ class SequencingGroupFilter(GenericFilterModel):
active_only: GenericFilter[bool] | None = GenericFilter(eq=True)
meta: GenericMetaFilter | None = None

# These fields are manually handled in the query to speed things up, because multiple table
# joins and dynamic computation are required.
created_on: GenericFilter[date] | None = None
assay_meta: GenericMetaFilter | None = None
has_cram: bool | None = None
has_gvcf: bool | None = None

def __hash__(self): # pylint: disable=useless-super-delegation
return super().__hash__()

Expand Down Expand Up @@ -69,17 +76,108 @@ async def query(
'platform': 'sg.platform',
'active_only': 'NOT sg.archived',
'external_id': 'sgexid.external_id',
'created_on': 'DATE(row_start)',
'assay_meta': 'meta',
}

wheres, values = filter_.to_sql(sql_overrides)
_query = f"""
SELECT {self.common_get_keys_str}
FROM sequencing_group sg
# Progressively build up the query and query values based on the filters provided to
# avoid uneccessary joins and improve performance.
_query: list[str] = []
query_values: dict[str, Any] = {}
# These fields are manually handled in the query
exclude_fields: list[str] = []

# Base query
_query.append(
f"""
SELECT
{self.common_get_keys_str}
FROM sequencing_group AS sg
LEFT JOIN sample s ON s.id = sg.sample_id
LEFT JOIN sequencing_group_external_id sgexid ON sg.id = sgexid.sequencing_group_id
WHERE {wheres}
"""
rows = await self.connection.fetch_all(_query, values)
LEFT JOIN sequencing_group_external_id sgexid ON sg.id = sgexid.sequencing_group_id"""
)

if filter_.assay_meta is not None:
exclude_fields.append('assay_meta')
wheres, values = filter_.to_sql(sql_overrides, only=['assay_meta'])
query_values.update(values)
_query.append(
f"""
INNER JOIN (
SELECT DISTINCT
sequencing_group_id
FROM
sequencing_group_assay
INNER JOIN (
SELECT
id
FROM
assay
WHERE
{wheres}
) AS assay_subquery ON sequencing_group_assay.assay_id = assay_subquery.id
) AS sga_subquery ON sg.id = sga_subquery.sequencing_group_id
"""
)

if filter_.created_on is not None:
exclude_fields.append('created_on')
wheres, values = filter_.to_sql(sql_overrides, only=['created_on'])
query_values.update(values)
_query.append(
f"""
INNER JOIN (
SELECT
id,
TIMESTAMP(min(row_start)) AS created_on
FROM
sequencing_group FOR SYSTEM_TIME ALL
WHERE
{wheres}
GROUP BY
id
) AS sg_timequery ON sg.id = sg_timequery.id
"""
)

if filter_.has_cram is not None or filter_.has_gvcf is not None:
exclude_fields.extend(['has_cram', 'has_gvcf'])
wheres, values = filter_.to_sql(
sql_overrides, only=['has_cram', 'has_gvcf']
)
query_values.update(values)
_query.append(
f"""
INNER JOIN (
SELECT
sequencing_group_id,
FIND_IN_SET('cram', GROUP_CONCAT(LOWER(anlysis_query.type))) > 0 AS has_cram,
FIND_IN_SET('gvcf', GROUP_CONCAT(LOWER(anlysis_query.type))) > 0 AS has_gvcf
FROM
analysis_sequencing_group
INNER JOIN (
SELECT
id, type
FROM
analysis
) AS anlysis_query ON analysis_sequencing_group.analysis_id = anlysis_query.id
GROUP BY
sequencing_group_id
HAVING
{wheres}
) AS sg_filequery ON sg.id = sg_filequery.sequencing_group_id
"""
)

# Add the rest of the filters
wheres, values = filter_.to_sql(sql_overrides, exclude=exclude_fields)
_query.append(
f"""
WHERE {wheres}"""
)
query_values.update(values)

rows = await self.connection.fetch_all('\n'.join(_query), query_values)
sgs = [SequencingGroupInternal.from_db(**dict(r)) for r in rows]
projects = set(sg.project for sg in sgs)
return projects, sgs
Expand Down
10 changes: 9 additions & 1 deletion db/python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,10 @@ def __post_init__(self):
setattr(self, field.name, GenericFilter(eq=value))

def to_sql(
self, field_overrides: dict[str, str] = None
self,
field_overrides: dict[str, str] = None,
only: list[str] | None = None,
exclude: list[str] | None = None,
) -> tuple[str, dict[str, Any]]:
"""Convert the model to SQL, and avoid SQL injection"""
_foverrides = field_overrides or {}
Expand All @@ -290,6 +293,11 @@ def to_sql(
fields = dataclasses.fields(self)
conditionals, values = [], {}
for field in fields:
if only and field.name not in only:
continue
if exclude and field.name in exclude:
continue

fcolumn = _foverrides.get(field.name, field.name)
if filter_ := getattr(self, field.name):
if isinstance(filter_, dict):
Expand Down
10 changes: 9 additions & 1 deletion models/models/sequencing_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,15 @@ def from_db(cls, **kwargs):

_archived = kwargs.pop('archived', None)
if _archived is not None:
_archived = _archived != b'\x00'
if isinstance(_archived, int):
_archived = _archived != 0
elif isinstance(_archived, bytes):
_archived = _archived != b'\x00'
else:
raise TypeError(
f"Received type '{type(_archived)}' for SequencingGroup column 'archived'. "
+ "Allowed types are either 'int' or 'bytes'."
)

return SequencingGroupInternal(**kwargs, archived=_archived, meta=meta)

Expand Down
16 changes: 16 additions & 0 deletions test/test_generic_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,22 @@ def test_basic_no_override(self):
self.assertEqual('test_string = :test_string_eq', sql)
self.assertDictEqual({'test_string_eq': 'test'}, values)

def test_contains_case_sensitive(self):
"""Test that the basic filter converts to SQL as expected"""
filter_ = GenericFilterTest(test_string=GenericFilter(contains='test'))
sql, values = filter_.to_sql()

self.assertEqual('test_string LIKE :test_string_contains', sql)
self.assertDictEqual({'test_string_contains': '%test%'}, values)

def test_icontains_is_not_case_sensitive(self):
"""Test that the basic filter converts to SQL as expected"""
filter_ = GenericFilterTest(test_string=GenericFilter(icontains='test'))
sql, values = filter_.to_sql()

self.assertEqual('LOWER(test_string) LIKE LOWER(:test_string_icontains)', sql)
self.assertDictEqual({'test_string_icontains': '%test%'}, values)

def test_basic_override(self):
"""Test that the basic filter with an override converts to SQL as expected"""
filter_ = GenericFilterTest(test_string=GenericFilter(eq='test'))
Expand Down
Loading

0 comments on commit 146e6b2

Please sign in to comment.