Skip to content

Commit

Permalink
Added unit tests for BQ filters.
Browse files Browse the repository at this point in the history
  • Loading branch information
milo-hyben committed Jan 19, 2024
1 parent 3f298cc commit 41e2ba6
Show file tree
Hide file tree
Showing 3 changed files with 359 additions and 5 deletions.
19 changes: 14 additions & 5 deletions db/python/tables/bq/generic_bq_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,17 @@ def to_sql(
values[k] = self._sql_value_prep(k, self.in_[0])
else:
k = self.generate_field_name(_column_name + '_in')
conditionals.append(f'{column} IN ({self._sql_cond_prep(k, self.in_)})')
conditionals.append(
f'{column} IN UNNEST({self._sql_cond_prep(k, self.in_)})'
)
values[k] = self._sql_value_prep(k, self.in_)
if self.nin is not None:
if not isinstance(self.nin, list):
raise ValueError('NIN filter must be a list')
k = self.generate_field_name(column + '_nin')
conditionals.append(f'{column} NOT IN ({self._sql_cond_prep(k, self.nin)})')
conditionals.append(
f'{column} NOT IN UNNEST({self._sql_cond_prep(k, self.nin)})'
)
values[k] = self._sql_value_prep(k, self.nin)
if self.gt is not None:
k = self.generate_field_name(column + '_gt')
Expand Down Expand Up @@ -83,9 +87,14 @@ def _sql_value_prep(key, value):
Overrides the default _sql_value_prep to handle BQ parameters
"""
if isinstance(value, list):
return bigquery.ArrayQueryParameter(
key, 'STRING', ','.join([str(v) for v in value])
)
if value and isinstance(value[0], int):
return bigquery.ArrayQueryParameter(key, 'INT64', value)
if value and isinstance(value[0], float):
return bigquery.ArrayQueryParameter(key, 'FLOAT64', value)

# otherwise all list records as string
return bigquery.ArrayQueryParameter(key, 'STRING', [str(v) for v in value])

if isinstance(value, Enum):
return bigquery.ScalarQueryParameter(key, 'STRING', value.value)
if isinstance(value, int):
Expand Down
294 changes: 294 additions & 0 deletions test/test_bq_generic_filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
import dataclasses
import unittest
from datetime import datetime
from enum import Enum

from google.cloud import bigquery

from db.python.tables.bq.generic_bq_filter import GenericBQFilter
from db.python.tables.bq.generic_bq_filter_model import GenericBQFilterModel


@dataclasses.dataclass(kw_only=True)
class GenericBQFilterTest(GenericBQFilterModel):
"""Test model for GenericBQFilter"""

test_string: GenericBQFilter[str] | None = None
test_int: GenericBQFilter[int] | None = None
test_float: GenericBQFilter[float] | None = None
test_dt: GenericBQFilter[datetime] | None = None
test_dict: dict[str, GenericBQFilter[str]] | None = None
test_enum: GenericBQFilter[Enum] | None = None


class BGFilterTestEnum(str, Enum):
"""Simple Enum classs"""

ID = 'id'
VALUE = 'value'


class TestGenericBQFilters(unittest.TestCase):
"""Test generic filters SQL generation"""

def test_basic_no_override(self):
"""Test that the basic filter converts to SQL as expected"""
filter_ = GenericBQFilterTest(test_string=GenericBQFilter(eq='test'))
sql, values = filter_.to_sql()

self.assertEqual('test_string = @test_string_eq', sql)
self.assertDictEqual(
{
'test_string_eq': bigquery.ScalarQueryParameter(
'test_string_eq', 'STRING', 'test'
)
},
values,
)

def test_basic_override(self):
"""Test that the basic filter with an override converts to SQL as expected"""
filter_ = GenericBQFilterTest(test_string=GenericBQFilter(eq='test'))
sql, values = filter_.to_sql({'test_string': 't.string'})

self.assertEqual('t.string = @t_string_eq', sql)
self.assertDictEqual(
{
't_string_eq': bigquery.ScalarQueryParameter(
't_string_eq', 'STRING', 'test'
)
},
values,
)

def test_single_string(self):
"""
Test that a single value filtered using the "in" operator
gets converted to an eq operation
"""
filter_ = GenericBQFilterTest(test_string=GenericBQFilter(in_=['test']))
sql, values = filter_.to_sql()

self.assertEqual('test_string = @test_string_in_eq', sql)
self.assertDictEqual(
{
'test_string_in_eq': bigquery.ScalarQueryParameter(
'test_string_in_eq', 'STRING', 'test'
)
},
values,
)

def test_single_int(self):
"""
Test that values filtered using the "in" operator convert as expected
"""
value = 123
filter_ = GenericBQFilterTest(test_int=GenericBQFilter(gt=value))
sql, values = filter_.to_sql()

self.assertEqual('test_int > @test_int_gt', sql)
self.assertDictEqual(
{
'test_int_gt': bigquery.ScalarQueryParameter(
'test_int_gt', 'INT64', value
)
},
values,
)

def test_single_float(self):
"""
Test that values filtered using the "in" operator convert as expected
"""
value = 123.456
filter_ = GenericBQFilterTest(test_float=GenericBQFilter(gte=value))
sql, values = filter_.to_sql()

self.assertEqual('test_float >= @test_float_gte', sql)
self.assertDictEqual(
{
'test_float_gte': bigquery.ScalarQueryParameter(
'test_float_gte', 'FLOAT64', value
)
},
values,
)

def test_single_datetime(self):
"""
Test that values filtered using the "in" operator convert as expected
"""
datetime_str = '2021-10-08 01:02:03'
value = datetime.strptime(datetime_str, '%Y-%m-%d %H:%M:%S')
filter_ = GenericBQFilterTest(test_dt=GenericBQFilter(lt=value))
sql, values = filter_.to_sql()

self.assertEqual('test_dt < TIMESTAMP(@test_dt_lt)', sql)
self.assertDictEqual(
{
'test_dt_lt': bigquery.ScalarQueryParameter(
'test_dt_lt', 'STRING', datetime_str
)
},
values,
)

def test_single_enum(self):
"""
Test that values filtered using the "in" operator convert as expected
"""
value = BGFilterTestEnum.ID
filter_ = GenericBQFilterTest(test_enum=GenericBQFilter(lte=value))
sql, values = filter_.to_sql()

self.assertEqual('test_enum <= @test_enum_lte', sql)
self.assertDictEqual(
{
'test_enum_lte': bigquery.ScalarQueryParameter(
'test_enum_lte', 'STRING', value.value
)
},
values,
)

def test_in_multiple_int(self):
"""
Test that values filtered using the "in" operator convert as expected
"""
value = [1, 2]
filter_ = GenericBQFilterTest(test_int=GenericBQFilter(in_=value))
sql, values = filter_.to_sql()

self.assertEqual('test_int IN UNNEST(@test_int_in)', sql)
self.assertDictEqual(
{
'test_int_in': bigquery.ArrayQueryParameter(
'test_int_in', 'INT64', value
)
},
values,
)

def test_in_multiple_float(self):
"""
Test that values filtered using the "in" operator convert as expected
"""
value = [1.0, 2.0]
filter_ = GenericBQFilterTest(test_float=GenericBQFilter(in_=value))
sql, values = filter_.to_sql()

self.assertEqual('test_float IN UNNEST(@test_float_in)', sql)
self.assertDictEqual(
{
'test_float_in': bigquery.ArrayQueryParameter(
'test_float_in', 'FLOAT64', value
)
},
values,
)

def test_in_multiple_str(self):
"""
Test that values filtered using the "in" operator convert as expected
"""
value = ['A', 'B']
filter_ = GenericBQFilterTest(test_string=GenericBQFilter(in_=value))
sql, values = filter_.to_sql()

self.assertEqual('test_string IN UNNEST(@test_string_in)', sql)
self.assertDictEqual(
{
'test_string_in': bigquery.ArrayQueryParameter(
'test_string_in', 'STRING', value
)
},
values,
)

def test_nin_multiple_str(self):
"""
Test that values filtered using the "in" operator convert as expected
"""
value = ['A', 'B']
filter_ = GenericBQFilterTest(test_string=GenericBQFilter(nin=value))
sql, values = filter_.to_sql()

self.assertEqual('test_string NOT IN UNNEST(@test_string_nin)', sql)
self.assertDictEqual(
{
'test_string_nin': bigquery.ArrayQueryParameter(
'test_string_nin', 'STRING', value
)
},
values,
)

def test_in_and_eq_multiple_str(self):
"""
Test that values filtered using the "in" operator convert as expected
"""
value = ['A']
filter_ = GenericBQFilterTest(test_string=GenericBQFilter(in_=value, eq='B'))
sql, values = filter_.to_sql()

self.assertEqual(
'test_string = @test_string_eq AND test_string = @test_string_in_eq',
sql,
)
self.assertDictEqual(
{
'test_string_eq': bigquery.ScalarQueryParameter(
'test_string_eq', 'STRING', 'B'
),
'test_string_in_eq': bigquery.ScalarQueryParameter(
'test_string_in_eq', 'STRING', 'A'
),
},
values,
)

def test_failed_in_multiple_str(self):
"""
Test that values filtered using the "in" operator convert as expected
"""
value = 'Not a list'
filter_ = GenericBQFilterTest(test_string=GenericBQFilter(in_=value))

# check if ValueError is raised
with self.assertRaises(ValueError) as context:
filter_.to_sql()

self.assertTrue('IN filter must be a list' in str(context.exception))

def test_failed_not_in_multiple_str(self):
"""
Test that values filtered using the "in" operator convert as expected
"""
value = 'Not a list'
filter_ = GenericBQFilterTest(test_string=GenericBQFilter(nin=value))

# check if ValueError is raised
with self.assertRaises(ValueError) as context:
filter_.to_sql()

self.assertTrue('NIN filter must be a list' in str(context.exception))

def test_fail_none_in_tuple(self):
"""
Test that values filtered using the "in" operator convert as expected
"""
value = (None,)

# check if ValueError is raised
with self.assertRaises(ValueError) as context:
filter_ = GenericBQFilterTest(test_string=value)
filter_.to_sql()

self.assertTrue(
'There is very likely a trailing comma on the end of '
'GenericBQFilterTest.test_string. '
'If you actually want a tuple of length one with the value = (None,), '
'then use dataclasses.field(default_factory=lambda: (None,))'
in str(context.exception)
)
51 changes: 51 additions & 0 deletions test/test_generic_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,54 @@ def test_in_multiple(self):

self.assertEqual('test_int IN :test_int_in', sql)
self.assertDictEqual({'test_int_in': value}, values)

def test_gt_single(self):
"""
Test that a single value filtered using the "gt" operator
"""
filter_ = GenericFilterTest(test_int=GenericFilter(gt=123))
sql, values = filter_.to_sql()

self.assertEqual('test_int > :test_int_gt', sql)
self.assertDictEqual({'test_int_gt': 123}, values)

def test_gte_single(self):
"""
Test that a single value filtered using the "gte" operator
"""
filter_ = GenericFilterTest(test_int=GenericFilter(gte=123))
sql, values = filter_.to_sql()

self.assertEqual('test_int >= :test_int_gte', sql)
self.assertDictEqual({'test_int_gte': 123}, values)

def test_lt_single(self):
"""
Test that a single value filtered using the "lt" operator
"""
filter_ = GenericFilterTest(test_int=GenericFilter(lt=123))
sql, values = filter_.to_sql()

self.assertEqual('test_int < :test_int_lt', sql)
self.assertDictEqual({'test_int_lt': 123}, values)

def test_lte_single(self):
"""
Test that a single value filtered using the "lte" operator
"""
filter_ = GenericFilterTest(test_int=GenericFilter(lte=123))
sql, values = filter_.to_sql()

self.assertEqual('test_int <= :test_int_lte', sql)
self.assertDictEqual({'test_int_lte': 123}, values)

def test_not_in_multiple(self):
"""
Test that values filtered using the "nin" operator convert as expected
"""
value = [1, 2]
filter_ = GenericFilterTest(test_int=GenericFilter(nin=value))
sql, values = filter_.to_sql()

self.assertEqual('test_int NOT IN :test_int_nin', sql)
self.assertDictEqual({'test_int_nin': value}, values)

0 comments on commit 41e2ba6

Please sign in to comment.