Skip to content

Commit

Permalink
Fix GenericFilter creating conditional with empty list (#744)
Browse files Browse the repository at this point in the history
* Test that fails

* Add fix to check empty IN in the GenericFilters

* Linting

* Better logic for FilterModel.is_false

---------

Co-authored-by: Michael Franklin <[email protected]>
  • Loading branch information
illusional and illusional authored Apr 24, 2024
1 parent 8ee989f commit d2604a8
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 4 deletions.
9 changes: 6 additions & 3 deletions db/python/tables/sequencing_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ async def query(
self, filter_: SequencingGroupFilter
) -> tuple[set[ProjectId], list[SequencingGroupInternal]]:
"""Query samples"""
if filter_.is_false():
return set(), []

sql_overrides = {
'project': 's.project',
'sample_id': 'sg.sample_id',
Expand Down Expand Up @@ -166,9 +169,9 @@ async def get_all_sequencing_group_ids_by_sample_ids_by_type(
WHERE project = :project
"""
rows = await self.connection.fetch_all(_query, {'project': self.project})
sequencing_group_ids_by_sample_ids_by_type: dict[
int, dict[str, list[int]]
] = defaultdict(lambda: defaultdict(list))
sequencing_group_ids_by_sample_ids_by_type: dict[int, dict[str, list[int]]] = (
defaultdict(lambda: defaultdict(list))
)
for row in rows:
sample_id = row['sid']
sg_id = row['sgid']
Expand Down
30 changes: 29 additions & 1 deletion db/python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ def generate_field_name(name):
"""
return NONFIELD_CHARS_REGEX.sub('_', name)

def is_false(self) -> bool:
"""
The filter will resolve to False (usually because the in_ is an empty list)
"""
return self.in_ is not None and len(self.in_) == 0

def to_sql(
self, column: str, column_name: str = None
) -> tuple[str, dict[str, T | list[T]]]:
Expand All @@ -161,6 +167,9 @@ def to_sql(
conditionals.append(f'{column} = :{k}')
values[k] = self._sql_value_prep(self.eq)
if self.in_ is not None:
if len(self.in_) == 0:
# in an empty list is always false
return 'FALSE', {}
if not isinstance(self.in_, list):
raise ValueError('IN filter must be a list')
if len(self.in_) == 1:
Expand All @@ -171,7 +180,8 @@ def to_sql(
k = self.generate_field_name(_column_name + '_in')
conditionals.append(f'{column} IN :{k}')
values[k] = self._sql_value_prep(self.in_)
if self.nin is not None:
if self.nin is not None and len(self.nin) > 0:
# not in an empty list is always true
if not isinstance(self.nin, list):
raise ValueError('NIN filter must be a list')
k = self.generate_field_name(column + '_nin')
Expand Down Expand Up @@ -223,6 +233,21 @@ def __hash__(self):
"""Hash the GenericFilterModel, this doesn't override well"""
return hash(dataclasses.astuple(self))

def is_false(self) -> bool:
"""
Returns False if any of the internal filters is FALSE
"""
for field in dataclasses.fields(self):
value = getattr(self, field.name)
if isinstance(value, GenericFilter) and value.is_false():
return True

if isinstance(value, dict):
if any(f.is_false() for f in value.values()):
return True

return False

def __post_init__(self):
for field in dataclasses.fields(self):
value = getattr(self, field.name)
Expand Down Expand Up @@ -259,6 +284,9 @@ def to_sql(
self, field_overrides: dict[str, str] = None
) -> tuple[str, dict[str, Any]]:
"""Convert the model to SQL, and avoid SQL injection"""
if self.is_false():
return 'FALSE', {}

_foverrides = field_overrides or {}

# check for bad field_overrides
Expand Down
8 changes: 8 additions & 0 deletions test/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ async def test_get_analysis_by_id(self):
self.assertEqual('cram', analysis.type)
self.assertEqual(AnalysisStatus.COMPLETED, analysis.status)

@run_as_sync
async def test_empty_query(self):
"""
Test empty IDs to see the query construction
"""
analyses = await self.al.query(AnalysisFilter(id=GenericFilter(in_=[])))
self.assertEqual(len(analyses), 0)

@run_as_sync
async def test_add_cram(self):
"""
Expand Down
8 changes: 8 additions & 0 deletions test/test_sequencing_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ async def setUp(self) -> None:
self.sglayer = SequencingGroupLayer(self.connection)
self.slayer = SampleLayer(self.connection)

@run_as_sync
async def test_empty_query(self):
"""
Test empty IDs to see the query construction
"""
sgs = await self.sglayer.query(SequencingGroupFilter(id=GenericFilter(in_=[])))
self.assertEqual(len(sgs), 0)

@run_as_sync
async def test_insert_sequencing_group(self):
"""Test inserting and fetching a sequencing group"""
Expand Down

0 comments on commit d2604a8

Please sign in to comment.