-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'custom-cohorts' of github.com:populationgenomics/sample…
…-metadata into custom-cohorts
- Loading branch information
Showing
1 changed file
with
73 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,11 @@ | ||
import datetime | ||
from test.testbase import DbIsolatedTest, run_as_sync | ||
|
||
from pymysql.err import IntegrityError | ||
|
||
from db.python.layers import CohortLayer, SampleLayer | ||
from db.python.utils import Forbidden, NotFoundError | ||
from db.python.tables.cohort import CohortFilter | ||
from db.python.utils import Forbidden, GenericFilter, NotFoundError | ||
from models.models import SampleUpsertInternal, SequencingGroupUpsertInternal | ||
from models.models.cohort import CohortCriteria, CohortTemplate | ||
from models.utils.sequencing_group_id_format import sequencing_group_id_format | ||
|
@@ -124,6 +126,52 @@ async def test_create_template_then_cohorts(self): | |
) | ||
|
||
|
||
class TestCohortQueries(DbIsolatedTest): | ||
"""Test query-related custom cohort layer functions""" | ||
|
||
@run_as_sync | ||
async def setUp(self): | ||
super().setUp() | ||
self.cohortl = CohortLayer(self.connection) | ||
|
||
@run_as_sync | ||
async def test_id_query(self): | ||
"""Exercise querying id against an empty database""" | ||
result = await self.cohortl.query(CohortFilter(id=GenericFilter(eq=42))) | ||
self.assertEqual([], result) | ||
|
||
@run_as_sync | ||
async def test_name_query(self): | ||
"""Exercise querying name against an empty database""" | ||
result = await self.cohortl.query(CohortFilter(name=GenericFilter(eq='Unknown cohort'))) | ||
self.assertEqual([], result) | ||
|
||
@run_as_sync | ||
async def test_author_query(self): | ||
"""Exercise querying author against an empty database""" | ||
result = await self.cohortl.query(CohortFilter(author=GenericFilter(eq='Alan Smithee'))) | ||
self.assertEqual([], result) | ||
|
||
@run_as_sync | ||
async def test_template_id_query(self): | ||
"""Exercise querying template_id against an empty database""" | ||
result = await self.cohortl.query(CohortFilter(template_id=GenericFilter(eq=28))) | ||
self.assertEqual([], result) | ||
|
||
@run_as_sync | ||
async def test_timestamp_query(self): | ||
"""Exercise querying timestamp against an empty database""" | ||
new_years_day = datetime.datetime(2024, 1, 1) | ||
result = await self.cohortl.query(CohortFilter(timestamp=GenericFilter(eq=new_years_day))) | ||
self.assertEqual([], result) | ||
|
||
@run_as_sync | ||
async def test_project_query(self): | ||
"""Exercise querying project against an empty database""" | ||
result = await self.cohortl.query(CohortFilter(project=GenericFilter(eq=37))) | ||
self.assertEqual([], result) | ||
|
||
|
||
def get_sample_model(eid, s_type='blood', sg_type='genome', tech='short-read', plat='illumina'): | ||
"""Create a minimal sample""" | ||
return SampleUpsertInternal( | ||
|
@@ -322,3 +370,27 @@ async def test_reevaluate_cohort(self): | |
|
||
self.assertNotIn(sgD, coh1['sequencing_group_ids']) | ||
self.assertIn(sgD, coh2['sequencing_group_ids']) | ||
|
||
@run_as_sync | ||
async def test_query_cohort(self): | ||
"""Create a cohort and test that it is populated when queried""" | ||
created = await self.cohortl.create_cohort_from_criteria( | ||
project_to_write=self.project_id, | ||
author='[email protected]', | ||
description='Cohort with two samples', | ||
cohort_name='Duo cohort', | ||
dry_run=False, | ||
cohort_criteria=CohortCriteria( | ||
projects=['test'], | ||
sg_ids_internal=[self.sgA, self.sgB], | ||
), | ||
) | ||
self.assertEqual(2, len(created['sequencing_group_ids'])) | ||
|
||
queried = await self.cohortl.query(CohortFilter(name=GenericFilter(eq='Duo cohort'))) | ||
self.assertEqual(1, len(queried)) | ||
|
||
result = await self.cohortl.get_cohort_sequencing_group_ids(int(queried[0].id)) | ||
self.assertEqual(2, len(result)) | ||
self.assertIn(self.sA.sequencing_groups[0].id, result) | ||
self.assertIn(self.sB.sequencing_groups[0].id, result) |