Skip to content

Commit 146e6b2

Browse files
committed
Merge branch 'custom-cohorts' of github.com:populationgenomics/sample-metadata into custom-cohorts
2 parents bc5e0b9 + 32fac2b commit 146e6b2

12 files changed

+1300
-255
lines changed

api/graphql/schema.py

+29-12
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,7 @@ async def sample(
738738
samples = await slayer.query(filter_)
739739
return [GraphQLSample.from_internal(sample) for sample in samples]
740740

741+
# pylint: disable=too-many-arguments
741742
@strawberry.field
742743
async def sequencing_groups(
743744
self,
@@ -749,6 +750,10 @@ async def sequencing_groups(
749750
technology: GraphQLFilter[str] | None = None,
750751
platform: GraphQLFilter[str] | None = None,
751752
active_only: GraphQLFilter[bool] | None = None,
753+
created_on: GraphQLFilter[datetime.date] | None = None,
754+
assay_meta: GraphQLMetaFilter | None = None,
755+
has_cram: bool | None = None,
756+
has_gvcf: bool | None = None,
752757
) -> list[GraphQLSequencingGroup]:
753758
connection = info.context['connection']
754759
sglayer = SequencingGroupLayer(connection)
@@ -766,21 +771,33 @@ async def sequencing_groups(
766771
project_id_map = {p.name: p.id for p in projects}
767772

768773
filter_ = SequencingGroupFilter(
769-
project=project.to_internal_filter(lambda val: project_id_map[val])
770-
if project
771-
else None,
772-
sample_id=sample_id.to_internal_filter(sample_id_transform_to_raw)
773-
if sample_id
774-
else None,
775-
id=id.to_internal_filter(sequencing_group_id_transform_to_raw)
776-
if id
777-
else None,
774+
project=(
775+
project.to_internal_filter(lambda val: project_id_map[val])
776+
if project
777+
else None
778+
),
779+
sample_id=(
780+
sample_id.to_internal_filter(sample_id_transform_to_raw)
781+
if sample_id
782+
else None
783+
),
784+
id=(
785+
id.to_internal_filter(sequencing_group_id_transform_to_raw)
786+
if id
787+
else None
788+
),
778789
type=type.to_internal_filter() if type else None,
779790
technology=technology.to_internal_filter() if technology else None,
780791
platform=platform.to_internal_filter() if platform else None,
781-
active_only=active_only.to_internal_filter()
782-
if active_only
783-
else GenericFilter(eq=True),
792+
active_only=(
793+
active_only.to_internal_filter()
794+
if active_only
795+
else GenericFilter(eq=True)
796+
),
797+
created_on=created_on.to_internal_filter() if created_on else None,
798+
assay_meta=assay_meta,
799+
has_cram=has_cram,
800+
has_gvcf=has_gvcf,
784801
)
785802
sgs = await sglayer.query(filter_)
786803
return [GraphQLSequencingGroup.from_internal(sg) for sg in sgs]

db/python/tables/sequencing_group.py

+106-8
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ class SequencingGroupFilter(GenericFilterModel):
3232
active_only: GenericFilter[bool] | None = GenericFilter(eq=True)
3333
meta: GenericMetaFilter | None = None
3434

35+
# These fields are manually handled in the query to speed things up, because multiple table
36+
# joins and dynamic computation are required.
37+
created_on: GenericFilter[date] | None = None
38+
assay_meta: GenericMetaFilter | None = None
39+
has_cram: bool | None = None
40+
has_gvcf: bool | None = None
41+
3542
def __hash__(self): # pylint: disable=useless-super-delegation
3643
return super().__hash__()
3744

@@ -69,17 +76,108 @@ async def query(
6976
'platform': 'sg.platform',
7077
'active_only': 'NOT sg.archived',
7178
'external_id': 'sgexid.external_id',
79+
'created_on': 'DATE(row_start)',
80+
'assay_meta': 'meta',
7281
}
7382

74-
wheres, values = filter_.to_sql(sql_overrides)
75-
_query = f"""
76-
SELECT {self.common_get_keys_str}
77-
FROM sequencing_group sg
83+
# Progressively build up the query and query values based on the filters provided to
84+
# avoid uneccessary joins and improve performance.
85+
_query: list[str] = []
86+
query_values: dict[str, Any] = {}
87+
# These fields are manually handled in the query
88+
exclude_fields: list[str] = []
89+
90+
# Base query
91+
_query.append(
92+
f"""
93+
SELECT
94+
{self.common_get_keys_str}
95+
FROM sequencing_group AS sg
7896
LEFT JOIN sample s ON s.id = sg.sample_id
79-
LEFT JOIN sequencing_group_external_id sgexid ON sg.id = sgexid.sequencing_group_id
80-
WHERE {wheres}
81-
"""
82-
rows = await self.connection.fetch_all(_query, values)
97+
LEFT JOIN sequencing_group_external_id sgexid ON sg.id = sgexid.sequencing_group_id"""
98+
)
99+
100+
if filter_.assay_meta is not None:
101+
exclude_fields.append('assay_meta')
102+
wheres, values = filter_.to_sql(sql_overrides, only=['assay_meta'])
103+
query_values.update(values)
104+
_query.append(
105+
f"""
106+
INNER JOIN (
107+
SELECT DISTINCT
108+
sequencing_group_id
109+
FROM
110+
sequencing_group_assay
111+
INNER JOIN (
112+
SELECT
113+
id
114+
FROM
115+
assay
116+
WHERE
117+
{wheres}
118+
) AS assay_subquery ON sequencing_group_assay.assay_id = assay_subquery.id
119+
) AS sga_subquery ON sg.id = sga_subquery.sequencing_group_id
120+
"""
121+
)
122+
123+
if filter_.created_on is not None:
124+
exclude_fields.append('created_on')
125+
wheres, values = filter_.to_sql(sql_overrides, only=['created_on'])
126+
query_values.update(values)
127+
_query.append(
128+
f"""
129+
INNER JOIN (
130+
SELECT
131+
id,
132+
TIMESTAMP(min(row_start)) AS created_on
133+
FROM
134+
sequencing_group FOR SYSTEM_TIME ALL
135+
WHERE
136+
{wheres}
137+
GROUP BY
138+
id
139+
) AS sg_timequery ON sg.id = sg_timequery.id
140+
"""
141+
)
142+
143+
if filter_.has_cram is not None or filter_.has_gvcf is not None:
144+
exclude_fields.extend(['has_cram', 'has_gvcf'])
145+
wheres, values = filter_.to_sql(
146+
sql_overrides, only=['has_cram', 'has_gvcf']
147+
)
148+
query_values.update(values)
149+
_query.append(
150+
f"""
151+
INNER JOIN (
152+
SELECT
153+
sequencing_group_id,
154+
FIND_IN_SET('cram', GROUP_CONCAT(LOWER(anlysis_query.type))) > 0 AS has_cram,
155+
FIND_IN_SET('gvcf', GROUP_CONCAT(LOWER(anlysis_query.type))) > 0 AS has_gvcf
156+
FROM
157+
analysis_sequencing_group
158+
INNER JOIN (
159+
SELECT
160+
id, type
161+
FROM
162+
analysis
163+
) AS anlysis_query ON analysis_sequencing_group.analysis_id = anlysis_query.id
164+
GROUP BY
165+
sequencing_group_id
166+
HAVING
167+
{wheres}
168+
) AS sg_filequery ON sg.id = sg_filequery.sequencing_group_id
169+
"""
170+
)
171+
172+
# Add the rest of the filters
173+
wheres, values = filter_.to_sql(sql_overrides, exclude=exclude_fields)
174+
_query.append(
175+
f"""
176+
WHERE {wheres}"""
177+
)
178+
query_values.update(values)
179+
180+
rows = await self.connection.fetch_all('\n'.join(_query), query_values)
83181
sgs = [SequencingGroupInternal.from_db(**dict(r)) for r in rows]
84182
projects = set(sg.project for sg in sgs)
85183
return projects, sgs

db/python/utils.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,10 @@ def __post_init__(self):
273273
setattr(self, field.name, GenericFilter(eq=value))
274274

275275
def to_sql(
276-
self, field_overrides: dict[str, str] = None
276+
self,
277+
field_overrides: dict[str, str] = None,
278+
only: list[str] | None = None,
279+
exclude: list[str] | None = None,
277280
) -> tuple[str, dict[str, Any]]:
278281
"""Convert the model to SQL, and avoid SQL injection"""
279282
_foverrides = field_overrides or {}
@@ -290,6 +293,11 @@ def to_sql(
290293
fields = dataclasses.fields(self)
291294
conditionals, values = [], {}
292295
for field in fields:
296+
if only and field.name not in only:
297+
continue
298+
if exclude and field.name in exclude:
299+
continue
300+
293301
fcolumn = _foverrides.get(field.name, field.name)
294302
if filter_ := getattr(self, field.name):
295303
if isinstance(filter_, dict):

models/models/sequencing_group.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,15 @@ def from_db(cls, **kwargs):
5353

5454
_archived = kwargs.pop('archived', None)
5555
if _archived is not None:
56-
_archived = _archived != b'\x00'
56+
if isinstance(_archived, int):
57+
_archived = _archived != 0
58+
elif isinstance(_archived, bytes):
59+
_archived = _archived != b'\x00'
60+
else:
61+
raise TypeError(
62+
f"Received type '{type(_archived)}' for SequencingGroup column 'archived'. "
63+
+ "Allowed types are either 'int' or 'bytes'."
64+
)
5765

5866
return SequencingGroupInternal(**kwargs, archived=_archived, meta=meta)
5967

test/test_generic_filters.py

+16
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,22 @@ def test_basic_no_override(self):
2424
self.assertEqual('test_string = :test_string_eq', sql)
2525
self.assertDictEqual({'test_string_eq': 'test'}, values)
2626

27+
def test_contains_case_sensitive(self):
28+
"""Test that the basic filter converts to SQL as expected"""
29+
filter_ = GenericFilterTest(test_string=GenericFilter(contains='test'))
30+
sql, values = filter_.to_sql()
31+
32+
self.assertEqual('test_string LIKE :test_string_contains', sql)
33+
self.assertDictEqual({'test_string_contains': '%test%'}, values)
34+
35+
def test_icontains_is_not_case_sensitive(self):
36+
"""Test that the basic filter converts to SQL as expected"""
37+
filter_ = GenericFilterTest(test_string=GenericFilter(icontains='test'))
38+
sql, values = filter_.to_sql()
39+
40+
self.assertEqual('LOWER(test_string) LIKE LOWER(:test_string_icontains)', sql)
41+
self.assertDictEqual({'test_string_icontains': '%test%'}, values)
42+
2743
def test_basic_override(self):
2844
"""Test that the basic filter with an override converts to SQL as expected"""
2945
filter_ = GenericFilterTest(test_string=GenericFilter(eq='test'))

0 commit comments

Comments
 (0)