From 45d22459cb8ef787a877dc7a5e7ecbf7274557e3 Mon Sep 17 00:00:00 2001 From: Michael Franklin <22381693+illusional@users.noreply.github.com> Date: Sun, 10 Mar 2024 12:42:30 +1100 Subject: [PATCH] Bump testcontainers and test image (#698) * Bump testcontainers and test image * Bump databases + SQLAlchemy * Remove unpacking of new row (as it unpacks keys) * Fix db iterator * More iterator changes to push --------- Co-authored-by: Michael Franklin --- db/python/tables/assay.py | 3 ++- db/python/tables/family_participant.py | 22 +++++++++++++------ db/python/tables/participant_phenotype.py | 10 ++++++--- requirements-dev.txt | 2 +- requirements.txt | 4 ++-- test/test_web.py | 16 ++++++++------ test/testbase.py | 26 +++++++++++++++++------ 7 files changed, 56 insertions(+), 27 deletions(-) diff --git a/db/python/tables/assay.py b/db/python/tables/assay.py index 95837e0f3..1b5645d7a 100644 --- a/db/python/tables/assay.py +++ b/db/python/tables/assay.py @@ -197,7 +197,8 @@ async def get_assay_type_numbers_by_batch_for_project(self, project: ProjectId): """ rows = await self.connection.fetch_all(_query, {'project': project}) batch_result: dict[str, dict[str, str]] = defaultdict(dict) - for batch, seqType, count in rows: + for row in rows: + batch, seqType, count = row['batch'], row['type'], row['n'] batch = str(batch).strip('\"') if batch != 'null' else 'no-batch' batch_result[batch][seqType] = str(count) if len(batch_result) == 1 and 'no-batch' in batch_result: diff --git a/db/python/tables/family_participant.py b/db/python/tables/family_participant.py index 8aab115a8..e782038ab 100644 --- a/db/python/tables/family_participant.py +++ b/db/python/tables/family_participant.py @@ -20,7 +20,7 @@ async def create_row( paternal_id: int, maternal_id: int, affected: int, - notes: str = None, + notes: str | None = None, ) -> Tuple[int, int]: """ Create a new sample, and add it to database @@ -111,8 +111,8 @@ async def get_rows( keys = [ 'fp.family_id', 'p.id as individual_id', - 'fp.paternal_participant_id', - 'fp.maternal_participant_id', + 'fp.paternal_participant_id as paternal_id', + 'fp.maternal_participant_id as maternal_id', 'p.reported_sex as sex', 'fp.affected', ] @@ -153,7 +153,7 @@ async def get_rows( 'sex', 'affected', ] - ds = [dict(zip(ordered_keys, row)) for row in rows] + ds = [{k: row[k] for k in ordered_keys} for row in rows] return ds @@ -161,7 +161,7 @@ async def get_row( self, family_id: int, participant_id: int, - ): + ) -> dict | None: """Get a single row from the family_participant table""" values: Dict[str, Any] = { 'family_id': family_id, @@ -169,7 +169,13 @@ async def get_row( } _query = """ -SELECT fp.family_id, p.id as individual_id, fp.paternal_participant_id, fp.maternal_participant_id, p.reported_sex as sex, fp.affected +SELECT + fp.family_id as family_id, + p.id as individual_id, + fp.paternal_participant_id as paternal_id, + fp.maternal_participant_id as maternal_id, + p.reported_sex as sex, + fp.affected FROM family_participant fp INNER JOIN family f ON f.id = fp.family_id INNER JOIN participant p on fp.participant_id = p.id @@ -177,6 +183,8 @@ async def get_row( """ row = await self.connection.fetch_one(_query, values) + if not row: + return None ordered_keys = [ 'family_id', @@ -186,7 +194,7 @@ async def get_row( 'sex', 'affected', ] - ds = dict(zip(ordered_keys, row)) + ds = {k: row[k] for k in ordered_keys} return ds diff --git a/db/python/tables/participant_phenotype.py b/db/python/tables/participant_phenotype.py index ea010588d..d50b5bbb7 100644 --- a/db/python/tables/participant_phenotype.py +++ b/db/python/tables/participant_phenotype.py @@ -44,7 +44,7 @@ async def add_key_value_rows(self, rows: List[Tuple[int, str, Any]]) -> None: ) async def get_key_value_rows_for_participant_ids( - self, participant_ids=List[int] + self, participant_ids: List[int] ) -> Dict[int, Dict[str, Any]]: """ Get (participant_id, description, value), @@ -64,7 +64,9 @@ async def get_key_value_rows_for_participant_ids( ) formed_key_value_pairs: Dict[int, Dict[str, Any]] = defaultdict(dict) for row in rows: - pid, key, value = row + pid = row['participant_id'] + key = row['description'] + value = row['value'] formed_key_value_pairs[pid][key] = json.loads(value) return formed_key_value_pairs @@ -86,7 +88,9 @@ async def get_key_value_rows_for_all_participants( rows = await self.connection.fetch_all(_query, {'project': project}) formed_key_value_pairs: Dict[int, Dict[str, Any]] = defaultdict(dict) for row in rows: - pid, key, value = row + pid = row['participant_id'] + key = row['description'] + value = row['value'] formed_key_value_pairs[pid][key] = json.loads(value) return formed_key_value_pairs diff --git a/requirements-dev.txt b/requirements-dev.txt index 3bb86aa87..c73f19f70 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,7 +7,7 @@ flake8-bugbear nest-asyncio pre-commit pylint -testcontainers[mariadb]==3.7.1 +testcontainers[mariadb]>=4.0.0 types-PyMySQL # some strawberry dependency strawberry-graphql[debug-server]==0.206.0 diff --git a/requirements.txt b/requirements.txt index 1f9640ea2..172fa8fe0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,8 +17,8 @@ uvicorn==0.18.3 fastapi[all]==0.85.1 strawberry-graphql[fastapi]==0.206.0 python-multipart==0.0.5 -databases[mysql]==0.6.1 -SQLAlchemy==1.4.41 +databases[mysql]==0.9.0 +SQLAlchemy==2.0.28 cryptography>=41.0.0 python-dateutil==2.8.2 slack-sdk==3.20.2 diff --git a/test/test_web.py b/test/test_web.py index cc5a0bb63..cc7ffad99 100644 --- a/test/test_web.py +++ b/test/test_web.py @@ -221,7 +221,7 @@ async def test_project_summary_empty(self): seqr_sync_types=[], ) - self.assertEqual(expected, result) + self.assertDataclassEqual(expected, result) @run_as_sync async def test_project_summary_single_entry(self): @@ -232,7 +232,7 @@ async def test_project_summary_single_entry(self): result = await self.webl.get_project_summary(token=0, grid_filter=[]) result.participants = [] - self.assertEqual(SINGLE_PARTICIPANT_RESULT, result) + self.assertDataclassEqual(SINGLE_PARTICIPANT_RESULT, result) @run_as_sync async def test_project_summary_to_external(self): @@ -289,7 +289,7 @@ async def project_summary_with_filter_with_results(self): ], ) filtered_result_success.participants = [] - self.assertEqual(SINGLE_PARTICIPANT_RESULT, filtered_result_success) + self.assertDataclassEqual(SINGLE_PARTICIPANT_RESULT, filtered_result_success) @run_as_sync async def project_summary_with_filter_no_results(self): @@ -323,7 +323,7 @@ async def project_summary_with_filter_no_results(self): seqr_sync_types=[], ) - self.assertEqual(empty_result, filtered_result_empty) + self.assertDataclassEqual(empty_result, filtered_result_empty) @run_as_sync async def test_project_summary_multiple_participants(self): @@ -376,7 +376,7 @@ async def test_project_summary_multiple_participants(self): two_samples_result.participants = [] - self.assertEqual(expected_data_two_samples, two_samples_result) + self.assertDataclassEqual(expected_data_two_samples, two_samples_result) @run_as_sync async def test_project_summary_multiple_participants_and_filter(self): @@ -436,7 +436,7 @@ async def test_project_summary_multiple_participants_and_filter(self): ) two_samples_result_filtered.participants = [] - self.assertEqual( + self.assertDataclassEqual( expected_data_two_samples_filtered, two_samples_result_filtered ) @@ -499,7 +499,9 @@ async def test_field_with_space(self): seqr_sync_types=[], ) - self.assertEqual(expected_data_two_samples_filtered, test_field_with_space) + self.assertDataclassEqual( + expected_data_two_samples_filtered, test_field_with_space + ) @run_as_sync async def test_project_summary_inactive_sequencing_group(self): diff --git a/test/testbase.py b/test/testbase.py index 04c8e1d2d..fa60ee66f 100644 --- a/test/testbase.py +++ b/test/testbase.py @@ -1,6 +1,7 @@ # pylint: disable=invalid-overridden-method import asyncio +import dataclasses import logging import os import socket @@ -96,10 +97,10 @@ async def setup(): logger = logging.getLogger() try: set_all_access(True) - db = MySqlContainer('mariadb:10.8.3') + db = MySqlContainer('mariadb:11.2.2') port_to_expose = find_free_port() # override the default port to map the container to - db.with_bind_ports(db.port_to_expose, port_to_expose) + db.with_bind_ports(db.port, port_to_expose) logger.disabled = True db.start() logger.disabled = False @@ -111,7 +112,7 @@ async def setup(): con_string = db.get_connection_url() con_string = 'mysql://' + con_string.split('://', maxsplit=1)[1] - lcon_string = f'jdbc:mariadb://{db.get_container_host_ip()}:{port_to_expose}/{db.MYSQL_DATABASE}' + lcon_string = f'jdbc:mariadb://{db.get_container_host_ip()}:{port_to_expose}/{db.dbname}' # apply the liquibase schema command = [ 'liquibase', @@ -120,8 +121,8 @@ async def setup(): *('--url', lcon_string), *('--driver', 'org.mariadb.jdbc.Driver'), *('--classpath', db_prefix + '/mariadb-java-client-3.0.3.jar'), - *('--username', db.MYSQL_USER), - *('--password', db.MYSQL_PASSWORD), + *('--username', db.username), + *('--password', db.password), 'update', ] subprocess.check_output(command, stderr=subprocess.STDOUT) @@ -175,7 +176,7 @@ async def setup(): def tearDownClass(cls) -> None: db = cls.dbs.get(cls.__name__) if db: - db.exec(f'DROP DATABASE {db.MYSQL_DATABASE};') + db.exec(f'DROP DATABASE {db.dbname};') db.stop() def setUp(self) -> None: @@ -224,6 +225,19 @@ async def audit_log_id(self): """Get audit_log_id for the test""" return await self.connection.audit_log_id() + def assertDataclassEqual(self, a, b): + """Assert two dataclasses are equal""" + + def to_dict(obj): + d = dataclasses.asdict(obj) + for k, v in d.items(): + if dataclasses.is_dataclass(v): + d[k] = to_dict(v) + return d + + self.maxDiff = None + self.assertDictEqual(to_dict(a), to_dict(b)) + class DbIsolatedTest(DbTest): """