diff --git a/db/python/tables/bq/generic_bq_filter.py b/db/python/tables/bq/generic_bq_filter.py index 7e75b7d00..8aeabd729 100644 --- a/db/python/tables/bq/generic_bq_filter.py +++ b/db/python/tables/bq/generic_bq_filter.py @@ -38,13 +38,17 @@ def to_sql( values[k] = self._sql_value_prep(k, self.in_[0]) else: k = self.generate_field_name(_column_name + '_in') - conditionals.append(f'{column} IN ({self._sql_cond_prep(k, self.in_)})') + conditionals.append( + f'{column} IN UNNEST({self._sql_cond_prep(k, self.in_)})' + ) values[k] = self._sql_value_prep(k, self.in_) if self.nin is not None: if not isinstance(self.nin, list): raise ValueError('NIN filter must be a list') k = self.generate_field_name(column + '_nin') - conditionals.append(f'{column} NOT IN ({self._sql_cond_prep(k, self.nin)})') + conditionals.append( + f'{column} NOT IN UNNEST({self._sql_cond_prep(k, self.nin)})' + ) values[k] = self._sql_value_prep(k, self.nin) if self.gt is not None: k = self.generate_field_name(column + '_gt') @@ -83,9 +87,14 @@ def _sql_value_prep(key, value): Overrides the default _sql_value_prep to handle BQ parameters """ if isinstance(value, list): - return bigquery.ArrayQueryParameter( - key, 'STRING', ','.join([str(v) for v in value]) - ) + if value and isinstance(value[0], int): + return bigquery.ArrayQueryParameter(key, 'INT64', value) + if value and isinstance(value[0], float): + return bigquery.ArrayQueryParameter(key, 'FLOAT64', value) + + # otherwise all list records as string + return bigquery.ArrayQueryParameter(key, 'STRING', [str(v) for v in value]) + if isinstance(value, Enum): return bigquery.ScalarQueryParameter(key, 'STRING', value.value) if isinstance(value, int): diff --git a/test/test_bq_generic_filters.py b/test/test_bq_generic_filters.py new file mode 100644 index 000000000..38719573b --- /dev/null +++ b/test/test_bq_generic_filters.py @@ -0,0 +1,294 @@ +import dataclasses +import unittest +from datetime import datetime +from enum import Enum + +from google.cloud import bigquery + +from db.python.tables.bq.generic_bq_filter import GenericBQFilter +from db.python.tables.bq.generic_bq_filter_model import GenericBQFilterModel + + +@dataclasses.dataclass(kw_only=True) +class GenericBQFilterTest(GenericBQFilterModel): + """Test model for GenericBQFilter""" + + test_string: GenericBQFilter[str] | None = None + test_int: GenericBQFilter[int] | None = None + test_float: GenericBQFilter[float] | None = None + test_dt: GenericBQFilter[datetime] | None = None + test_dict: dict[str, GenericBQFilter[str]] | None = None + test_enum: GenericBQFilter[Enum] | None = None + + +class BGFilterTestEnum(str, Enum): + """Simple Enum classs""" + + ID = 'id' + VALUE = 'value' + + +class TestGenericBQFilters(unittest.TestCase): + """Test generic filters SQL generation""" + + def test_basic_no_override(self): + """Test that the basic filter converts to SQL as expected""" + filter_ = GenericBQFilterTest(test_string=GenericBQFilter(eq='test')) + sql, values = filter_.to_sql() + + self.assertEqual('test_string = @test_string_eq', sql) + self.assertDictEqual( + { + 'test_string_eq': bigquery.ScalarQueryParameter( + 'test_string_eq', 'STRING', 'test' + ) + }, + values, + ) + + def test_basic_override(self): + """Test that the basic filter with an override converts to SQL as expected""" + filter_ = GenericBQFilterTest(test_string=GenericBQFilter(eq='test')) + sql, values = filter_.to_sql({'test_string': 't.string'}) + + self.assertEqual('t.string = @t_string_eq', sql) + self.assertDictEqual( + { + 't_string_eq': bigquery.ScalarQueryParameter( + 't_string_eq', 'STRING', 'test' + ) + }, + values, + ) + + def test_single_string(self): + """ + Test that a single value filtered using the "in" operator + gets converted to an eq operation + """ + filter_ = GenericBQFilterTest(test_string=GenericBQFilter(in_=['test'])) + sql, values = filter_.to_sql() + + self.assertEqual('test_string = @test_string_in_eq', sql) + self.assertDictEqual( + { + 'test_string_in_eq': bigquery.ScalarQueryParameter( + 'test_string_in_eq', 'STRING', 'test' + ) + }, + values, + ) + + def test_single_int(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = 123 + filter_ = GenericBQFilterTest(test_int=GenericBQFilter(gt=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int > @test_int_gt', sql) + self.assertDictEqual( + { + 'test_int_gt': bigquery.ScalarQueryParameter( + 'test_int_gt', 'INT64', value + ) + }, + values, + ) + + def test_single_float(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = 123.456 + filter_ = GenericBQFilterTest(test_float=GenericBQFilter(gte=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_float >= @test_float_gte', sql) + self.assertDictEqual( + { + 'test_float_gte': bigquery.ScalarQueryParameter( + 'test_float_gte', 'FLOAT64', value + ) + }, + values, + ) + + def test_single_datetime(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + datetime_str = '2021-10-08 01:02:03' + value = datetime.strptime(datetime_str, '%Y-%m-%d %H:%M:%S') + filter_ = GenericBQFilterTest(test_dt=GenericBQFilter(lt=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_dt < TIMESTAMP(@test_dt_lt)', sql) + self.assertDictEqual( + { + 'test_dt_lt': bigquery.ScalarQueryParameter( + 'test_dt_lt', 'STRING', datetime_str + ) + }, + values, + ) + + def test_single_enum(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = BGFilterTestEnum.ID + filter_ = GenericBQFilterTest(test_enum=GenericBQFilter(lte=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_enum <= @test_enum_lte', sql) + self.assertDictEqual( + { + 'test_enum_lte': bigquery.ScalarQueryParameter( + 'test_enum_lte', 'STRING', value.value + ) + }, + values, + ) + + def test_in_multiple_int(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = [1, 2] + filter_ = GenericBQFilterTest(test_int=GenericBQFilter(in_=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int IN UNNEST(@test_int_in)', sql) + self.assertDictEqual( + { + 'test_int_in': bigquery.ArrayQueryParameter( + 'test_int_in', 'INT64', value + ) + }, + values, + ) + + def test_in_multiple_float(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = [1.0, 2.0] + filter_ = GenericBQFilterTest(test_float=GenericBQFilter(in_=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_float IN UNNEST(@test_float_in)', sql) + self.assertDictEqual( + { + 'test_float_in': bigquery.ArrayQueryParameter( + 'test_float_in', 'FLOAT64', value + ) + }, + values, + ) + + def test_in_multiple_str(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = ['A', 'B'] + filter_ = GenericBQFilterTest(test_string=GenericBQFilter(in_=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_string IN UNNEST(@test_string_in)', sql) + self.assertDictEqual( + { + 'test_string_in': bigquery.ArrayQueryParameter( + 'test_string_in', 'STRING', value + ) + }, + values, + ) + + def test_nin_multiple_str(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = ['A', 'B'] + filter_ = GenericBQFilterTest(test_string=GenericBQFilter(nin=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_string NOT IN UNNEST(@test_string_nin)', sql) + self.assertDictEqual( + { + 'test_string_nin': bigquery.ArrayQueryParameter( + 'test_string_nin', 'STRING', value + ) + }, + values, + ) + + def test_in_and_eq_multiple_str(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = ['A'] + filter_ = GenericBQFilterTest(test_string=GenericBQFilter(in_=value, eq='B')) + sql, values = filter_.to_sql() + + self.assertEqual( + 'test_string = @test_string_eq AND test_string = @test_string_in_eq', + sql, + ) + self.assertDictEqual( + { + 'test_string_eq': bigquery.ScalarQueryParameter( + 'test_string_eq', 'STRING', 'B' + ), + 'test_string_in_eq': bigquery.ScalarQueryParameter( + 'test_string_in_eq', 'STRING', 'A' + ), + }, + values, + ) + + def test_failed_in_multiple_str(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = 'Not a list' + filter_ = GenericBQFilterTest(test_string=GenericBQFilter(in_=value)) + + # check if ValueError is raised + with self.assertRaises(ValueError) as context: + filter_.to_sql() + + self.assertTrue('IN filter must be a list' in str(context.exception)) + + def test_failed_not_in_multiple_str(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = 'Not a list' + filter_ = GenericBQFilterTest(test_string=GenericBQFilter(nin=value)) + + # check if ValueError is raised + with self.assertRaises(ValueError) as context: + filter_.to_sql() + + self.assertTrue('NIN filter must be a list' in str(context.exception)) + + def test_fail_none_in_tuple(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = (None,) + + # check if ValueError is raised + with self.assertRaises(ValueError) as context: + filter_ = GenericBQFilterTest(test_string=value) + filter_.to_sql() + + self.assertTrue( + 'There is very likely a trailing comma on the end of ' + 'GenericBQFilterTest.test_string. ' + 'If you actually want a tuple of length one with the value = (None,), ' + 'then use dataclasses.field(default_factory=lambda: (None,))' + in str(context.exception) + ) diff --git a/test/test_generic_filters.py b/test/test_generic_filters.py index 2c1348076..b5598be54 100644 --- a/test/test_generic_filters.py +++ b/test/test_generic_filters.py @@ -53,3 +53,54 @@ def test_in_multiple(self): self.assertEqual('test_int IN :test_int_in', sql) self.assertDictEqual({'test_int_in': value}, values) + + def test_gt_single(self): + """ + Test that a single value filtered using the "gt" operator + """ + filter_ = GenericFilterTest(test_int=GenericFilter(gt=123)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int > :test_int_gt', sql) + self.assertDictEqual({'test_int_gt': 123}, values) + + def test_gte_single(self): + """ + Test that a single value filtered using the "gte" operator + """ + filter_ = GenericFilterTest(test_int=GenericFilter(gte=123)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int >= :test_int_gte', sql) + self.assertDictEqual({'test_int_gte': 123}, values) + + def test_lt_single(self): + """ + Test that a single value filtered using the "lt" operator + """ + filter_ = GenericFilterTest(test_int=GenericFilter(lt=123)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int < :test_int_lt', sql) + self.assertDictEqual({'test_int_lt': 123}, values) + + def test_lte_single(self): + """ + Test that a single value filtered using the "lte" operator + """ + filter_ = GenericFilterTest(test_int=GenericFilter(lte=123)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int <= :test_int_lte', sql) + self.assertDictEqual({'test_int_lte': 123}, values) + + def test_not_in_multiple(self): + """ + Test that values filtered using the "nin" operator convert as expected + """ + value = [1, 2] + filter_ = GenericFilterTest(test_int=GenericFilter(nin=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int NOT IN :test_int_nin', sql) + self.assertDictEqual({'test_int_nin': value}, values)