From d2604a81be0047fa44e0276bc1c309ec6c0f0c74 Mon Sep 17 00:00:00 2001 From: Michael Franklin <22381693+illusional@users.noreply.github.com> Date: Wed, 24 Apr 2024 16:13:30 +1000 Subject: [PATCH] Fix GenericFilter creating conditional with empty list (#744) * Test that fails * Add fix to check empty IN in the GenericFilters * Linting * Better logic for FilterModel.is_false --------- Co-authored-by: Michael Franklin --- db/python/tables/sequencing_group.py | 9 ++++++--- db/python/utils.py | 30 +++++++++++++++++++++++++++- test/test_analysis.py | 8 ++++++++ test/test_sequencing_groups.py | 8 ++++++++ 4 files changed, 51 insertions(+), 4 deletions(-) diff --git a/db/python/tables/sequencing_group.py b/db/python/tables/sequencing_group.py index 1b82da962..975d9890d 100644 --- a/db/python/tables/sequencing_group.py +++ b/db/python/tables/sequencing_group.py @@ -61,6 +61,9 @@ async def query( self, filter_: SequencingGroupFilter ) -> tuple[set[ProjectId], list[SequencingGroupInternal]]: """Query samples""" + if filter_.is_false(): + return set(), [] + sql_overrides = { 'project': 's.project', 'sample_id': 'sg.sample_id', @@ -166,9 +169,9 @@ async def get_all_sequencing_group_ids_by_sample_ids_by_type( WHERE project = :project """ rows = await self.connection.fetch_all(_query, {'project': self.project}) - sequencing_group_ids_by_sample_ids_by_type: dict[ - int, dict[str, list[int]] - ] = defaultdict(lambda: defaultdict(list)) + sequencing_group_ids_by_sample_ids_by_type: dict[int, dict[str, list[int]]] = ( + defaultdict(lambda: defaultdict(list)) + ) for row in rows: sample_id = row['sid'] sg_id = row['sgid'] diff --git a/db/python/utils.py b/db/python/utils.py index 0b0846dcb..7ac94e06c 100644 --- a/db/python/utils.py +++ b/db/python/utils.py @@ -146,6 +146,12 @@ def generate_field_name(name): """ return NONFIELD_CHARS_REGEX.sub('_', name) + def is_false(self) -> bool: + """ + The filter will resolve to False (usually because the in_ is an empty list) + """ + return self.in_ is not None and len(self.in_) == 0 + def to_sql( self, column: str, column_name: str = None ) -> tuple[str, dict[str, T | list[T]]]: @@ -161,6 +167,9 @@ def to_sql( conditionals.append(f'{column} = :{k}') values[k] = self._sql_value_prep(self.eq) if self.in_ is not None: + if len(self.in_) == 0: + # in an empty list is always false + return 'FALSE', {} if not isinstance(self.in_, list): raise ValueError('IN filter must be a list') if len(self.in_) == 1: @@ -171,7 +180,8 @@ def to_sql( k = self.generate_field_name(_column_name + '_in') conditionals.append(f'{column} IN :{k}') values[k] = self._sql_value_prep(self.in_) - if self.nin is not None: + if self.nin is not None and len(self.nin) > 0: + # not in an empty list is always true if not isinstance(self.nin, list): raise ValueError('NIN filter must be a list') k = self.generate_field_name(column + '_nin') @@ -223,6 +233,21 @@ def __hash__(self): """Hash the GenericFilterModel, this doesn't override well""" return hash(dataclasses.astuple(self)) + def is_false(self) -> bool: + """ + Returns False if any of the internal filters is FALSE + """ + for field in dataclasses.fields(self): + value = getattr(self, field.name) + if isinstance(value, GenericFilter) and value.is_false(): + return True + + if isinstance(value, dict): + if any(f.is_false() for f in value.values()): + return True + + return False + def __post_init__(self): for field in dataclasses.fields(self): value = getattr(self, field.name) @@ -259,6 +284,9 @@ def to_sql( self, field_overrides: dict[str, str] = None ) -> tuple[str, dict[str, Any]]: """Convert the model to SQL, and avoid SQL injection""" + if self.is_false(): + return 'FALSE', {} + _foverrides = field_overrides or {} # check for bad field_overrides diff --git a/test/test_analysis.py b/test/test_analysis.py index be7a15e90..aadf7fdb0 100644 --- a/test/test_analysis.py +++ b/test/test_analysis.py @@ -95,6 +95,14 @@ async def test_get_analysis_by_id(self): self.assertEqual('cram', analysis.type) self.assertEqual(AnalysisStatus.COMPLETED, analysis.status) + @run_as_sync + async def test_empty_query(self): + """ + Test empty IDs to see the query construction + """ + analyses = await self.al.query(AnalysisFilter(id=GenericFilter(in_=[]))) + self.assertEqual(len(analyses), 0) + @run_as_sync async def test_add_cram(self): """ diff --git a/test/test_sequencing_groups.py b/test/test_sequencing_groups.py index d4a146375..4b32043f2 100644 --- a/test/test_sequencing_groups.py +++ b/test/test_sequencing_groups.py @@ -52,6 +52,14 @@ async def setUp(self) -> None: self.sglayer = SequencingGroupLayer(self.connection) self.slayer = SampleLayer(self.connection) + @run_as_sync + async def test_empty_query(self): + """ + Test empty IDs to see the query construction + """ + sgs = await self.sglayer.query(SequencingGroupFilter(id=GenericFilter(in_=[]))) + self.assertEqual(len(sgs), 0) + @run_as_sync async def test_insert_sequencing_group(self): """Test inserting and fetching a sequencing group"""