Skip to content

Commit

Permalink
Bump testcontainers and test image (#698)
Browse files Browse the repository at this point in the history
* Bump testcontainers and test image

* Bump databases + SQLAlchemy

* Remove unpacking of new row (as it unpacks keys)

* Fix db iterator

* More iterator changes to push

---------

Co-authored-by: Michael Franklin <[email protected]>
  • Loading branch information
illusional and illusional authored Mar 10, 2024
1 parent 8b7b69e commit 45d2245
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 27 deletions.
3 changes: 2 additions & 1 deletion db/python/tables/assay.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,8 @@ async def get_assay_type_numbers_by_batch_for_project(self, project: ProjectId):
"""
rows = await self.connection.fetch_all(_query, {'project': project})
batch_result: dict[str, dict[str, str]] = defaultdict(dict)
for batch, seqType, count in rows:
for row in rows:
batch, seqType, count = row['batch'], row['type'], row['n']
batch = str(batch).strip('\"') if batch != 'null' else 'no-batch'
batch_result[batch][seqType] = str(count)
if len(batch_result) == 1 and 'no-batch' in batch_result:
Expand Down
22 changes: 15 additions & 7 deletions db/python/tables/family_participant.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def create_row(
paternal_id: int,
maternal_id: int,
affected: int,
notes: str = None,
notes: str | None = None,
) -> Tuple[int, int]:
"""
Create a new sample, and add it to database
Expand Down Expand Up @@ -111,8 +111,8 @@ async def get_rows(
keys = [
'fp.family_id',
'p.id as individual_id',
'fp.paternal_participant_id',
'fp.maternal_participant_id',
'fp.paternal_participant_id as paternal_id',
'fp.maternal_participant_id as maternal_id',
'p.reported_sex as sex',
'fp.affected',
]
Expand Down Expand Up @@ -153,30 +153,38 @@ async def get_rows(
'sex',
'affected',
]
ds = [dict(zip(ordered_keys, row)) for row in rows]
ds = [{k: row[k] for k in ordered_keys} for row in rows]

return ds

async def get_row(
self,
family_id: int,
participant_id: int,
):
) -> dict | None:
"""Get a single row from the family_participant table"""
values: Dict[str, Any] = {
'family_id': family_id,
'participant_id': participant_id,
}

_query = """
SELECT fp.family_id, p.id as individual_id, fp.paternal_participant_id, fp.maternal_participant_id, p.reported_sex as sex, fp.affected
SELECT
fp.family_id as family_id,
p.id as individual_id,
fp.paternal_participant_id as paternal_id,
fp.maternal_participant_id as maternal_id,
p.reported_sex as sex,
fp.affected
FROM family_participant fp
INNER JOIN family f ON f.id = fp.family_id
INNER JOIN participant p on fp.participant_id = p.id
WHERE f.id = :family_id AND p.id = :participant_id
"""

row = await self.connection.fetch_one(_query, values)
if not row:
return None

ordered_keys = [
'family_id',
Expand All @@ -186,7 +194,7 @@ async def get_row(
'sex',
'affected',
]
ds = dict(zip(ordered_keys, row))
ds = {k: row[k] for k in ordered_keys}

return ds

Expand Down
10 changes: 7 additions & 3 deletions db/python/tables/participant_phenotype.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async def add_key_value_rows(self, rows: List[Tuple[int, str, Any]]) -> None:
)

async def get_key_value_rows_for_participant_ids(
self, participant_ids=List[int]
self, participant_ids: List[int]
) -> Dict[int, Dict[str, Any]]:
"""
Get (participant_id, description, value),
Expand All @@ -64,7 +64,9 @@ async def get_key_value_rows_for_participant_ids(
)
formed_key_value_pairs: Dict[int, Dict[str, Any]] = defaultdict(dict)
for row in rows:
pid, key, value = row
pid = row['participant_id']
key = row['description']
value = row['value']
formed_key_value_pairs[pid][key] = json.loads(value)

return formed_key_value_pairs
Expand All @@ -86,7 +88,9 @@ async def get_key_value_rows_for_all_participants(
rows = await self.connection.fetch_all(_query, {'project': project})
formed_key_value_pairs: Dict[int, Dict[str, Any]] = defaultdict(dict)
for row in rows:
pid, key, value = row
pid = row['participant_id']
key = row['description']
value = row['value']
formed_key_value_pairs[pid][key] = json.loads(value)

return formed_key_value_pairs
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ flake8-bugbear
nest-asyncio
pre-commit
pylint
testcontainers[mariadb]==3.7.1
testcontainers[mariadb]>=4.0.0
types-PyMySQL
# some strawberry dependency
strawberry-graphql[debug-server]==0.206.0
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ uvicorn==0.18.3
fastapi[all]==0.85.1
strawberry-graphql[fastapi]==0.206.0
python-multipart==0.0.5
databases[mysql]==0.6.1
SQLAlchemy==1.4.41
databases[mysql]==0.9.0
SQLAlchemy==2.0.28
cryptography>=41.0.0
python-dateutil==2.8.2
slack-sdk==3.20.2
16 changes: 9 additions & 7 deletions test/test_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ async def test_project_summary_empty(self):
seqr_sync_types=[],
)

self.assertEqual(expected, result)
self.assertDataclassEqual(expected, result)

@run_as_sync
async def test_project_summary_single_entry(self):
Expand All @@ -232,7 +232,7 @@ async def test_project_summary_single_entry(self):
result = await self.webl.get_project_summary(token=0, grid_filter=[])

result.participants = []
self.assertEqual(SINGLE_PARTICIPANT_RESULT, result)
self.assertDataclassEqual(SINGLE_PARTICIPANT_RESULT, result)

@run_as_sync
async def test_project_summary_to_external(self):
Expand Down Expand Up @@ -289,7 +289,7 @@ async def project_summary_with_filter_with_results(self):
],
)
filtered_result_success.participants = []
self.assertEqual(SINGLE_PARTICIPANT_RESULT, filtered_result_success)
self.assertDataclassEqual(SINGLE_PARTICIPANT_RESULT, filtered_result_success)

@run_as_sync
async def project_summary_with_filter_no_results(self):
Expand Down Expand Up @@ -323,7 +323,7 @@ async def project_summary_with_filter_no_results(self):
seqr_sync_types=[],
)

self.assertEqual(empty_result, filtered_result_empty)
self.assertDataclassEqual(empty_result, filtered_result_empty)

@run_as_sync
async def test_project_summary_multiple_participants(self):
Expand Down Expand Up @@ -376,7 +376,7 @@ async def test_project_summary_multiple_participants(self):

two_samples_result.participants = []

self.assertEqual(expected_data_two_samples, two_samples_result)
self.assertDataclassEqual(expected_data_two_samples, two_samples_result)

@run_as_sync
async def test_project_summary_multiple_participants_and_filter(self):
Expand Down Expand Up @@ -436,7 +436,7 @@ async def test_project_summary_multiple_participants_and_filter(self):
)
two_samples_result_filtered.participants = []

self.assertEqual(
self.assertDataclassEqual(
expected_data_two_samples_filtered, two_samples_result_filtered
)

Expand Down Expand Up @@ -499,7 +499,9 @@ async def test_field_with_space(self):
seqr_sync_types=[],
)

self.assertEqual(expected_data_two_samples_filtered, test_field_with_space)
self.assertDataclassEqual(
expected_data_two_samples_filtered, test_field_with_space
)

@run_as_sync
async def test_project_summary_inactive_sequencing_group(self):
Expand Down
26 changes: 20 additions & 6 deletions test/testbase.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=invalid-overridden-method

import asyncio
import dataclasses
import logging
import os
import socket
Expand Down Expand Up @@ -96,10 +97,10 @@ async def setup():
logger = logging.getLogger()
try:
set_all_access(True)
db = MySqlContainer('mariadb:10.8.3')
db = MySqlContainer('mariadb:11.2.2')
port_to_expose = find_free_port()
# override the default port to map the container to
db.with_bind_ports(db.port_to_expose, port_to_expose)
db.with_bind_ports(db.port, port_to_expose)
logger.disabled = True
db.start()
logger.disabled = False
Expand All @@ -111,7 +112,7 @@ async def setup():

con_string = db.get_connection_url()
con_string = 'mysql://' + con_string.split('://', maxsplit=1)[1]
lcon_string = f'jdbc:mariadb://{db.get_container_host_ip()}:{port_to_expose}/{db.MYSQL_DATABASE}'
lcon_string = f'jdbc:mariadb://{db.get_container_host_ip()}:{port_to_expose}/{db.dbname}'
# apply the liquibase schema
command = [
'liquibase',
Expand All @@ -120,8 +121,8 @@ async def setup():
*('--url', lcon_string),
*('--driver', 'org.mariadb.jdbc.Driver'),
*('--classpath', db_prefix + '/mariadb-java-client-3.0.3.jar'),
*('--username', db.MYSQL_USER),
*('--password', db.MYSQL_PASSWORD),
*('--username', db.username),
*('--password', db.password),
'update',
]
subprocess.check_output(command, stderr=subprocess.STDOUT)
Expand Down Expand Up @@ -175,7 +176,7 @@ async def setup():
def tearDownClass(cls) -> None:
db = cls.dbs.get(cls.__name__)
if db:
db.exec(f'DROP DATABASE {db.MYSQL_DATABASE};')
db.exec(f'DROP DATABASE {db.dbname};')
db.stop()

def setUp(self) -> None:
Expand Down Expand Up @@ -224,6 +225,19 @@ async def audit_log_id(self):
"""Get audit_log_id for the test"""
return await self.connection.audit_log_id()

def assertDataclassEqual(self, a, b):
"""Assert two dataclasses are equal"""

def to_dict(obj):
d = dataclasses.asdict(obj)
for k, v in d.items():
if dataclasses.is_dataclass(v):
d[k] = to_dict(v)
return d

self.maxDiff = None
self.assertDictEqual(to_dict(a), to_dict(b))


class DbIsolatedTest(DbTest):
"""
Expand Down

0 comments on commit 45d2245

Please sign in to comment.