From 35f051d214cdebc9f343cd52ebddeb9d29d99025 Mon Sep 17 00:00:00 2001 From: Milo Hyben Date: Thu, 22 Feb 2024 15:51:52 +1100 Subject: [PATCH] Billing unit tests (#656) * Billing api extra labels (#619) * Added compute_category, cromwell_sub_workflow_name, cromwell_workflow_id, goog_pipelines_worker and wdl_task_name to extended view and created relevant filters and API points. * Added labels to all BQ queries, refactoring billing layer. * Added examples to billing-total-cost API regarding the new filters. * Billing - fixing styling issues after the first Billing release (#624) * Temporarily disable seqr and hail from /topics API. * Autoselect 1st topic / 1st project value from the DDL. * Merging Billing.css into index.css * Small fix - reusing extRecords in FieldSelector component. * Refactoring duplicated code in FieldSelector. * Added Stages to the Group by DDL. * Billing API IsBillingEnabled (#626) * Added API point to check if billing is enabled. * Added simple Total Cost By Batch Page. (#627) * Added simple Total Cost By Batch Page. * Billing cost by category (#629) * Added simple Total Cost By Batch Page. * Fixed autoselect day format. * Fixing day format for autoselect (missing leading 0) * Added first draft of billing page to show detail SKU per selected cost category over selected time periods (day, week, month or invoice month) * Small fix for BillingCostByBatch page, disable search if searchBy is empty or < 6 chars. * New: Billing API GET namespaces, added namespace to allowed fields for total cost. * Implemented HorizontalStackedBarChart, updated Billing By Invoice Month page to enable toggle between chart and table view. * Stacked Bars Chart with option to accumulate data. (#634) * Implemented Stacked bars with option to accumulate data. * Added budget bar to billing horizontal bar chart, added background color for the billing table to reflect the chart colours. * Added simple prediction of billing stacked bar chart. * Billing hail batch layout (#633) * Added simple Total Cost By Batch Page. * Removing debug prints. * Fixed autoselect day format. * Fixing day format for autoselect (missing leading 0) * Added first draft of billing page to show detail SKU per selected cost category over selected time periods (day, week, month or invoice month) * Small fix for BillingCostByBatch page, disable search if searchBy is empty or < 6 chars. * New: Billing API GET namespaces, added namespace to allowed fields for total cost. * Implemented HorizontalStackedBarChart, updated Billing By Invoice Month page to enable toggle between chart and table view. * ADD: Cost by Analysis page * ADD: add start of Analysis grid * ADD: add start of Analysis grid * FIX: table fixes for the HailBatchGrid * API: api changes to enable query of the raw table * API: fixed and working with updated get_total_cost endpoint * API: fix typing of get_total_cost (default return is now a list[dict] and can be converted in the layer/route to a specific output type * API: add endpoint to get costs by batch_id * API: done * IN PROGRESS: modifying Cost By Analysis to use new endpoints * IN PROGRESS: changes to Cost By Analysis, linking with backend API. * IN PROGRESS: changes to Cost By Analysis, grid grouping by ar/batch/job. * NEW: finalising Cost By Analysis page * ADD: durations to Cost By Analysis page --------- Co-authored-by: Milo Hyben * FIX: Billing - fixing time_column condition. * Removing draft billing page. * Remove unused API point & cleanup, changes as per code review. * Small Frontend refactoring, reflecting PR review. * Updating billing style for dark mode. * Optimised Frontend, replacing reduce with forEach where possible. * Refactoring Billing DB structures. * Cleaning up unused dependencies. * FIX: replaced button 'color=red' with 'negative' property. * FIX: replace HEX color for pattern with CSS var. * FIX: replace async call with sync for a simple function. * FIX: dark mode for Horizontal Stacked Bar. * FIX: billing cost by analysis page, esp. search control resizing and functionality. * FIX: duplicated keys in the grid on Billing Cost By Analysis page. * FIX: refactoring BQ tables, small fixes for billing pages. * FIX: BillingCostPageAnalysis, keeping the old record until loading of data finishes. * FIX: Billing StackedChart various issues. * Linting * FIX: missing filters checks, updating charts when loading. * FIX: silenece linting no attribute msg for Middleware. * Refactoring filters, implemented first Billing GraphQL integration. * Fixing linting. * Added unit tests for BQ filters. * Fixing linting. * Added tests for billing routes. * Removing billing GraphQL, there will be another PR for this. * Adding doc strings to tests. * Changing to staticmethod where relevant, added more unitests. * Linting * Refactoring string constant so both pylint and unittest are happy. * Unittests for BQ Function filter. * Unittests for BillingArBatchTable. * Added pytz dependency to dev so we can write unit test with timezone aware datetime. * Fixing timezone aware unit test for billing function filter. * Refactoring BQ Unittests, added more for BillingBaseTable class. * Linting * Added more unit tests for BillingBaseTable. * Add last missing unittest for BilingBaseTable. * Added unittests for BillingLayer. * Linting. * Merge dev try 2. * More unit tests for BillingLayer. * More unit tests for Billing routes functions. * More unit tests for Billing APi routes. * Billing unit tests - use mockup author. * Added BillingLayer muckup to unit tests. * Billing unit tests refactoring, implementing feedback from PR. * Fixing billing unit tests. * More fixes for billing unit tests. * Removing tests related to old getLabelValue functions. * Update test/test_api_billing.py Co-authored-by: Michael Franklin <22381693+illusional@users.noreply.github.com> * Setting up mock of is_billing_enabled in test class setUp function. --------- Co-authored-by: Sabrina Yan <9669990+violetbrina@users.noreply.github.com> Co-authored-by: Michael Franklin <22381693+illusional@users.noreply.github.com> --- db/python/layers/billing.py | 2 +- db/python/tables/bq/billing_base.py | 29 +- db/python/tables/bq/billing_filter.py | 15 + db/python/tables/bq/billing_gcp_daily.py | 10 +- db/python/tables/bq/billing_raw.py | 3 +- db/python/tables/bq/function_bq_filter.py | 8 +- db/python/tables/bq/generic_bq_filter.py | 32 +- requirements-dev.txt | 2 + test/test_api_billing.py | 269 +++++++ test/test_api_utils.py | 77 ++ test/test_bq_billing_ar_batch.py | 188 +++++ test/test_bq_billing_base.py | 911 ++++++++++++++++++++++ test/test_bq_billing_daily.py | 255 ++++++ test/test_bq_billing_daily_extended.py | 116 +++ test/test_bq_billing_gcp_daily.py | 197 +++++ test/test_bq_billing_raw.py | 75 ++ test/test_bq_function_filter.py | 155 ++++ test/test_bq_generic_filter.py | 270 +++++++ test/test_generic_filters.py | 51 ++ test/test_layers_billing.py | 470 +++++++++++ test/testbqbase.py | 44 ++ 21 files changed, 3153 insertions(+), 26 deletions(-) create mode 100644 test/test_api_billing.py create mode 100644 test/test_api_utils.py create mode 100644 test/test_bq_billing_ar_batch.py create mode 100644 test/test_bq_billing_base.py create mode 100644 test/test_bq_billing_daily.py create mode 100644 test/test_bq_billing_daily_extended.py create mode 100644 test/test_bq_billing_gcp_daily.py create mode 100644 test/test_bq_billing_raw.py create mode 100644 test/test_bq_function_filter.py create mode 100644 test/test_bq_generic_filter.py create mode 100644 test/test_layers_billing.py create mode 100644 test/testbqbase.py diff --git a/db/python/layers/billing.py b/db/python/layers/billing.py index 737be6857..1122259ba 100644 --- a/db/python/layers/billing.py +++ b/db/python/layers/billing.py @@ -18,7 +18,7 @@ class BillingLayer(BqBaseLayer): def table_factory( self, - source: BillingSource, + source: BillingSource | None = None, fields: list[BillingColumn] | None = None, filters: dict[BillingColumn, str | list | dict] | None = None, ) -> ( diff --git a/db/python/tables/bq/billing_base.py b/db/python/tables/bq/billing_base.py index 335603c3b..350dfecab 100644 --- a/db/python/tables/bq/billing_base.py +++ b/db/python/tables/bq/billing_base.py @@ -117,8 +117,9 @@ def _execute_query( # otherwise return as BQ iterator return self._connection.connection.query(query, job_config=job_config) + @staticmethod def _query_to_partitioned_filter( - self, query: BillingTotalCostQueryModel + query: BillingTotalCostQueryModel, ) -> BillingFilter: """ By default views are partitioned by 'day', @@ -137,15 +138,18 @@ def _query_to_partitioned_filter( ) return billing_filter - def _filter_to_optimise_query(self) -> str: + @staticmethod + def _filter_to_optimise_query() -> str: """Filter string to optimise BQ query""" return 'day >= TIMESTAMP(@start_day) AND day <= TIMESTAMP(@last_day)' - def _last_loaded_day_filter(self) -> str: + @staticmethod + def _last_loaded_day_filter() -> str: """Last Loaded day filter string""" return 'day = TIMESTAMP(@last_loaded_day)' - def _convert_output(self, query_job_result): + @staticmethod + def _convert_output(query_job_result): """Convert query result to json""" if not query_job_result or query_job_result.result().total_rows == 0: # return empty list if no record found @@ -325,8 +329,8 @@ async def _execute_running_cost_query( self._execute_query(_query, query_params), ) + @staticmethod async def _append_total_running_cost( - self, field: BillingColumn, is_current_month: bool, last_loaded_day: str | None, @@ -437,9 +441,8 @@ async def _append_running_cost_records( return results - def _prepare_order_by_string( - self, order_by: dict[BillingColumn, bool] | None - ) -> str: + @staticmethod + def _prepare_order_by_string(order_by: dict[BillingColumn, bool] | None) -> str: """Prepare order by string""" if not order_by: return '' @@ -452,9 +455,8 @@ def _prepare_order_by_string( return f'ORDER BY {",".join(order_by_cols)}' if order_by_cols else '' - def _prepare_aggregation( - self, query: BillingTotalCostQueryModel - ) -> tuple[str, str]: + @staticmethod + def _prepare_aggregation(query: BillingTotalCostQueryModel) -> tuple[str, str]: """Prepare both fields for aggregation and group by string""" # Get columns to group by @@ -479,7 +481,8 @@ def _prepare_aggregation( return fields_selected, group_by - def _prepare_labels_function(self, query: BillingTotalCostQueryModel): + @staticmethod + def _prepare_labels_function(query: BillingTotalCostQueryModel): if not query.filters: return None @@ -558,7 +561,7 @@ async def get_total_cost( where_str = f'WHERE {where_str}' _query = f""" - {func_filter.fun_implementation if func_filter else ''} + {func_filter.func_implementation if func_filter else ''} WITH t AS ( SELECT {time_group.field}{time_group.separator} {fields_selected}, diff --git a/db/python/tables/bq/billing_filter.py b/db/python/tables/bq/billing_filter.py index 9a379817f..b78a8c960 100644 --- a/db/python/tables/bq/billing_filter.py +++ b/db/python/tables/bq/billing_filter.py @@ -2,6 +2,7 @@ import dataclasses import datetime +from typing import Any from db.python.tables.bq.generic_bq_filter import GenericBQFilter from db.python.tables.bq.generic_bq_filter_model import GenericBQFilterModel @@ -46,3 +47,17 @@ class BillingFilter(GenericBQFilterModel): goog_pipelines_worker: GenericBQFilter[str] = None wdl_task_name: GenericBQFilter[str] = None namespace: GenericBQFilter[str] = None + + def __eq__(self, other: Any) -> bool: + """Equality operator""" + result = super().__eq__(other) + if not result or not isinstance(other, BillingFilter): + return False + + # compare all attributes + for att in self.__dict__: + if getattr(self, att) != getattr(other, att): + return False + + # all attributes are equal + return True diff --git a/db/python/tables/bq/billing_gcp_daily.py b/db/python/tables/bq/billing_gcp_daily.py index b765547c3..0b691bc6f 100644 --- a/db/python/tables/bq/billing_gcp_daily.py +++ b/db/python/tables/bq/billing_gcp_daily.py @@ -1,4 +1,5 @@ from datetime import datetime, timedelta +from typing import Any from google.cloud import bigquery @@ -9,7 +10,7 @@ ) from db.python.tables.bq.billing_filter import BillingFilter from db.python.tables.bq.generic_bq_filter import GenericBQFilter -from models.models import BillingTotalCostQueryModel +from models.models import BillingColumn, BillingTotalCostQueryModel class BillingGcpDailyTable(BillingBaseTable): @@ -21,8 +22,9 @@ def get_table_name(self): """Get table name""" return self.table_name + @staticmethod def _query_to_partitioned_filter( - self, query: BillingTotalCostQueryModel + query: BillingTotalCostQueryModel, ) -> BillingFilter: """ add extra filter to limit materialized view partition @@ -77,7 +79,9 @@ async def _last_loaded_day(self): return None - def _prepare_daily_cost_subquery(self, field, query_params, last_loaded_day): + def _prepare_daily_cost_subquery( + self, field: BillingColumn, query_params: list[Any], last_loaded_day: str + ): """prepare daily cost subquery""" # add extra filter to limit materialized view partition diff --git a/db/python/tables/bq/billing_raw.py b/db/python/tables/bq/billing_raw.py index a82fa4eec..dde90388d 100644 --- a/db/python/tables/bq/billing_raw.py +++ b/db/python/tables/bq/billing_raw.py @@ -16,8 +16,9 @@ def get_table_name(self): """Get table name""" return self.table_name + @staticmethod def _query_to_partitioned_filter( - self, query: BillingTotalCostQueryModel + query: BillingTotalCostQueryModel, ) -> BillingFilter: """ Raw BQ billing table is partitioned by usage_end_time diff --git a/db/python/tables/bq/function_bq_filter.py b/db/python/tables/bq/function_bq_filter.py index f18f60211..5c8ac21d2 100644 --- a/db/python/tables/bq/function_bq_filter.py +++ b/db/python/tables/bq/function_bq_filter.py @@ -27,14 +27,14 @@ class FunctionBQFilter: def __init__(self, name: str, implementation: str): self.func_name = name - self.fun_implementation = implementation + self.func_implementation = implementation # param_id is a counter for parameterised values self._param_id = 0 def to_sql( self, column_name: BillingColumn, - func_params: str | list[Any] | dict[Any, Any], + func_params: str | list[Any] | dict[Any, Any] | None = None, func_operator: str = None, ) -> tuple[str, list[bigquery.ScalarQueryParameter | bigquery.ArrayQueryParameter]]: """ @@ -103,7 +103,9 @@ def _sql_value_prep(key: str, value: Any) -> bigquery.ScalarQueryParameter: if isinstance(value, float): return bigquery.ScalarQueryParameter(key, 'FLOAT64', value) if isinstance(value, datetime): - return bigquery.ScalarQueryParameter(key, 'STRING', value) + return bigquery.ScalarQueryParameter( + key, 'STRING', value.isoformat(timespec='seconds') + ) # otherwise as string parameter return bigquery.ScalarQueryParameter(key, 'STRING', value) diff --git a/db/python/tables/bq/generic_bq_filter.py b/db/python/tables/bq/generic_bq_filter.py index b0bfba973..ce54ac885 100644 --- a/db/python/tables/bq/generic_bq_filter.py +++ b/db/python/tables/bq/generic_bq_filter.py @@ -12,6 +12,19 @@ class GenericBQFilter(GenericFilter[T]): Generic BigQuery filter is BQ specific filter class, based on GenericFilter """ + def __eq__(self, other): + """Equality operator""" + if not isinstance(other, GenericBQFilter): + return False + + keys = ['eq', 'in_', 'nin', 'gt', 'gte', 'lt', 'lte'] + for att in keys: + if getattr(self, att) != getattr(other, att): + return False + + # all attributes are equal + return True + def to_sql( self, column: str, column_name: str = None ) -> tuple[str, dict[str, T | list[T] | Any | list[Any]]]: @@ -38,13 +51,17 @@ def to_sql( values[k] = self._sql_value_prep(k, self.in_[0]) else: k = self.generate_field_name(_column_name + '_in') - conditionals.append(f'{column} IN ({self._sql_cond_prep(k, self.in_)})') + conditionals.append( + f'{column} IN UNNEST({self._sql_cond_prep(k, self.in_)})' + ) values[k] = self._sql_value_prep(k, self.in_) if self.nin is not None: if not isinstance(self.nin, list): raise ValueError('NIN filter must be a list') k = self.generate_field_name(column + '_nin') - conditionals.append(f'{column} NOT IN ({self._sql_cond_prep(k, self.nin)})') + conditionals.append( + f'{column} NOT IN UNNEST({self._sql_cond_prep(k, self.nin)})' + ) values[k] = self._sql_value_prep(k, self.nin) if self.gt is not None: k = self.generate_field_name(column + '_gt') @@ -83,9 +100,14 @@ def _sql_value_prep(key, value): Overrides the default _sql_value_prep to handle BQ parameters """ if isinstance(value, list): - return bigquery.ArrayQueryParameter( - key, 'STRING', ','.join([str(v) for v in value]) - ) + if value and isinstance(value[0], int): + return bigquery.ArrayQueryParameter(key, 'INT64', value) + if value and isinstance(value[0], float): + return bigquery.ArrayQueryParameter(key, 'FLOAT64', value) + + # otherwise all list records as string + return bigquery.ArrayQueryParameter(key, 'STRING', [str(v) for v in value]) + if isinstance(value, Enum): return GenericBQFilter._sql_value_prep(key, value.value) if isinstance(value, int): diff --git a/requirements-dev.txt b/requirements-dev.txt index e0bfc432f..554e52241 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -15,3 +15,5 @@ strawberry-graphql[debug-server]==0.206.0 functions_framework google-cloud-bigquery google-cloud-pubsub +# following required to unit test some billing functions +pytz diff --git a/test/test_api_billing.py b/test/test_api_billing.py new file mode 100644 index 000000000..4811187a7 --- /dev/null +++ b/test/test_api_billing.py @@ -0,0 +1,269 @@ +# pylint: disable=protected-access too-many-public-methods +from test.testbase import run_as_sync +from test.testbqbase import BqTest +from unittest.mock import patch + +from api.routes import billing +from models.models import ( + BillingColumn, + BillingCostBudgetRecord, + BillingHailBatchCostRecord, + BillingTotalCostQueryModel, + BillingTotalCostRecord, +) + +TEST_API_BILLING_USER = 'test_user' + + +class TestApiBilling(BqTest): + """ + Test API Billing routes + Billing routes are only calling layer functions and returning data it received + This set of tests only checks if all routes code has been called + and if data returned is the same as data received from layer functions + It does not check all possible combination of parameters as + Billing Layer and BQ Tables should be testing those + """ + + def setUp(self): + super().setUp() + + # make billing enabled by default for all the calls + patcher = patch('api.routes.billing.is_billing_enabled', return_value=True) + self.mockup_is_billing_enabled = patcher.start() + + @run_as_sync + @patch('api.routes.billing.is_billing_enabled', return_value=False) + async def test_get_gcp_projects_no_billing(self, _mockup_is_billing_enabled): + """ + Test get_gcp_projects function + This function should raise ValueError if billing is not enabled + """ + with self.assertRaises(ValueError) as context: + await billing.get_gcp_projects('test_user') + + self.assertTrue('Billing is not enabled' in str(context.exception)) + + @run_as_sync + @patch('api.routes.billing._get_billing_layer_from') + @patch('db.python.layers.billing.BillingLayer.get_cost_by_ar_guid') + async def test_get_cost_by_ar_guid( + self, mock_get_cost_by_ar_guid, mock_get_billing_layer + ): + """ + Test get_cost_by_ar_guid function + """ + ar_guid = 'test_ar_guid' + mockup_record = BillingHailBatchCostRecord( + ar_guid=ar_guid, batch_ids=None, costs=None + ) + mock_get_billing_layer.return_value = self.layer + mock_get_cost_by_ar_guid.return_value = mockup_record + records = await billing.get_cost_by_ar_guid( + ar_guid, author=TEST_API_BILLING_USER + ) + self.assertEqual(mockup_record, records) + + @run_as_sync + @patch('api.routes.billing._get_billing_layer_from') + @patch('db.python.layers.billing.BillingLayer.get_cost_by_batch_id') + async def test_get_cost_by_batch_id( + self, mock_get_cost_by_batch_id, mock_get_billing_layer + ): + """ + Test get_cost_by_batch_id function + """ + ar_guid = 'test_ar_guid' + batch_id = 'test_batch_id' + mockup_record = BillingHailBatchCostRecord( + ar_guid=ar_guid, batch_ids=[batch_id], costs=None + ) + mock_get_billing_layer.return_value = self.layer + mock_get_cost_by_batch_id.return_value = mockup_record + records = await billing.get_cost_by_batch_id( + batch_id, author=TEST_API_BILLING_USER + ) + self.assertEqual(mockup_record, records) + + @run_as_sync + @patch('api.routes.billing._get_billing_layer_from') + @patch('db.python.layers.billing.BillingLayer.get_total_cost') + async def test_get_total_cost(self, mock_get_total_cost, mock_get_billing_layer): + """ + Test get_total_cost function + """ + query = BillingTotalCostQueryModel(fields=[], start_date='', end_date='') + mockup_record = [{'cost': 123.45}, {'cost': 123}] + expected = [BillingTotalCostRecord.from_json(r) for r in mockup_record] + mock_get_billing_layer.return_value = self.layer + mock_get_total_cost.return_value = mockup_record + records = await billing.get_total_cost(query, author=TEST_API_BILLING_USER) + self.assertEqual(expected, records) + + @run_as_sync + @patch('api.routes.billing._get_billing_layer_from') + @patch('db.python.layers.billing.BillingLayer.get_running_cost') + async def test_get_running_cost( + self, mock_get_running_cost, mock_get_billing_layer + ): + """ + Test get_running_cost function + """ + mockup_record = [ + BillingCostBudgetRecord.from_json( + {'field': 'TOPIC1', 'total_monthly': 999.99, 'details': []} + ), + BillingCostBudgetRecord.from_json( + {'field': 'TOPIC2', 'total_monthly': 123, 'details': []} + ), + ] + mock_get_billing_layer.return_value = self.layer + mock_get_running_cost.return_value = mockup_record + records = await billing.get_running_costs( + field=BillingColumn.TOPIC, + invoice_month=None, + source=None, + author=TEST_API_BILLING_USER, + ) + self.assertEqual(mockup_record, records) + + @patch('api.routes.billing._get_billing_layer_from') + async def call_api_function( + self, + api_function, + mock_layer_function, + mock_get_billing_layer=None, + ): + """ + Common wrapper for all API calls, to avoid code duplication + API function is called with author=TEST_API_BILLING_USER + get_author function will be tested separately + We only testing if routes are calling layer functions and + returning data it received + """ + mock_get_billing_layer.return_value = self.layer + mockup_records = ['RECORD1', 'RECORD2'] + mock_layer_function.return_value = mockup_records + records = await api_function(author=TEST_API_BILLING_USER) + self.assertEqual(mockup_records, records) + + @run_as_sync + @patch('db.python.layers.billing.BillingLayer.get_gcp_projects') + async def test_get_gcp_projects(self, mock_get_gcp_projects): + """ + Test get_gcp_projects function + """ + await self.call_api_function(billing.get_gcp_projects, mock_get_gcp_projects) + + @run_as_sync + @patch('db.python.layers.billing.BillingLayer.get_topics') + async def test_get_topics(self, mock_get_topics): + """ + Test get_topics function + """ + await self.call_api_function(billing.get_topics, mock_get_topics) + + @run_as_sync + @patch('db.python.layers.billing.BillingLayer.get_cost_categories') + async def test_get_cost_categories(self, mock_get_cost_categories): + """ + Test get_cost_categories function + """ + await self.call_api_function( + billing.get_cost_categories, mock_get_cost_categories + ) + + @run_as_sync + @patch('db.python.layers.billing.BillingLayer.get_skus') + async def test_get_skus(self, mock_get_skus): + """ + Test get_skus function + """ + await self.call_api_function(billing.get_skus, mock_get_skus) + + @run_as_sync + @patch('db.python.layers.billing.BillingLayer.get_datasets') + async def test_get_datasets(self, mock_get_datasets): + """ + Test get_datasets function + """ + await self.call_api_function(billing.get_datasets, mock_get_datasets) + + @run_as_sync + @patch('db.python.layers.billing.BillingLayer.get_sequencing_types') + async def test_get_sequencing_types(self, mock_get_sequencing_types): + """ + Test get_sequencing_types function + """ + await self.call_api_function( + billing.get_sequencing_types, mock_get_sequencing_types + ) + + @run_as_sync + @patch('db.python.layers.billing.BillingLayer.get_stages') + async def test_get_stages(self, mock_get_stages): + """ + Test get_stages function + """ + await self.call_api_function(billing.get_stages, mock_get_stages) + + @run_as_sync + @patch('db.python.layers.billing.BillingLayer.get_sequencing_groups') + async def test_get_sequencing_groups(self, mock_get_sequencing_groups): + """ + Test get_sequencing_groups function + """ + await self.call_api_function( + billing.get_sequencing_groups, mock_get_sequencing_groups + ) + + @run_as_sync + @patch('db.python.layers.billing.BillingLayer.get_compute_categories') + async def test_get_compute_categories(self, mock_get_compute_categories): + """ + Test get_compute_categories function + """ + await self.call_api_function( + billing.get_compute_categories, mock_get_compute_categories + ) + + @run_as_sync + @patch('db.python.layers.billing.BillingLayer.get_cromwell_sub_workflow_names') + async def test_get_cromwell_sub_workflow_names( + self, mock_get_cromwell_sub_workflow_names + ): + """ + Test get_cromwell_sub_workflow_names function + """ + await self.call_api_function( + billing.get_cromwell_sub_workflow_names, + mock_get_cromwell_sub_workflow_names, + ) + + @run_as_sync + @patch('db.python.layers.billing.BillingLayer.get_wdl_task_names') + async def test_get_wdl_task_names(self, mock_get_wdl_task_names): + """ + Test get_wdl_task_names function + """ + await self.call_api_function( + billing.get_wdl_task_names, mock_get_wdl_task_names + ) + + @run_as_sync + @patch('db.python.layers.billing.BillingLayer.get_invoice_months') + async def test_get_invoice_months(self, mock_get_invoice_months): + """ + Test get_invoice_months function + """ + await self.call_api_function( + billing.get_invoice_months, mock_get_invoice_months + ) + + @run_as_sync + @patch('db.python.layers.billing.BillingLayer.get_namespaces') + async def test_get_namespaces(self, mock_get_namespaces): + """ + Test get_namespaces function + """ + await self.call_api_function(billing.get_namespaces, mock_get_namespaces) diff --git a/test/test_api_utils.py b/test/test_api_utils.py new file mode 100644 index 000000000..6f02a913c --- /dev/null +++ b/test/test_api_utils.py @@ -0,0 +1,77 @@ +import unittest +from datetime import date, datetime + +from api.utils.dates import ( + get_invoice_month_range, + parse_date_only_string, + reformat_datetime, +) + + +class TestApiUtils(unittest.TestCase): + """Test API utils functions""" + + def test_parse_date_only_string(self): + """ + Test parse_date_only_string function + """ + result_none = parse_date_only_string(None) + self.assertEqual(None, result_none) + + result_date = parse_date_only_string('2021-01-10') + self.assertEqual(2021, result_date.year) + self.assertEqual(1, result_date.month) + self.assertEqual(10, result_date.day) + + # test exception + invalid_date_str = '123456789' + with self.assertRaises(ValueError) as context: + parse_date_only_string(invalid_date_str) + + self.assertTrue( + f'Date could not be converted: {invalid_date_str}' in str(context.exception) + ) + + def test_get_invoice_month_range(self): + """ + Test get_invoice_month_range function + """ + jan_2021 = datetime.strptime('2021-01-10', '%Y-%m-%d').date() + res_jan_2021 = get_invoice_month_range(jan_2021) + + # there is 3 (INVOICE_DAY_DIFF) days difference between invoice month st and end + self.assertEqual( + (date(2020, 12, 29), date(2021, 2, 3)), + res_jan_2021, + ) + + dec_2021 = datetime.strptime('2021-12-10', '%Y-%m-%d').date() + res_dec_2021 = get_invoice_month_range(dec_2021) + + # there is 3 (INVOICE_DAY_DIFF) days difference between invoice month st and end + self.assertEqual( + (date(2021, 11, 28), date(2022, 1, 3)), + res_dec_2021, + ) + + def test_reformat_datetime(self): + """ + Test reformat_datetime function + """ + in_format = '%Y-%m-%d' + out_format = '%d/%m/%Y' + + result_none = reformat_datetime(None, in_format, out_format) + self.assertEqual(None, result_none) + + result_formatted = reformat_datetime('2021-11-09', in_format, out_format) + self.assertEqual('09/11/2021', result_formatted) + + # test exception + invalid_date_str = '123456789' + with self.assertRaises(ValueError) as context: + reformat_datetime(invalid_date_str, in_format, out_format) + + self.assertTrue( + f'Date could not be converted: {invalid_date_str}' in str(context.exception) + ) diff --git a/test/test_bq_billing_ar_batch.py b/test/test_bq_billing_ar_batch.py new file mode 100644 index 000000000..568f0f1dc --- /dev/null +++ b/test/test_bq_billing_ar_batch.py @@ -0,0 +1,188 @@ +# pylint: disable=protected-access +import datetime +from test.testbase import run_as_sync +from test.testbqbase import BqTest +from typing import Any +from unittest import mock + +import google.cloud.bigquery as bq + +from db.python.tables.bq.billing_ar_batch import BillingArBatchTable +from db.python.tables.bq.billing_filter import BillingFilter +from db.python.tables.bq.generic_bq_filter import GenericBQFilter +from db.python.utils import InternalError +from models.models import BillingColumn, BillingTotalCostQueryModel + + +class TestBillingArBatchTable(BqTest): + """Test BillingArBatchTable and its methods""" + + def setUp(self): + super().setUp() + + # setup table object + self.table_obj = BillingArBatchTable(self.connection) + + def test_query_to_partitioned_filter(self): + """Test query to partitioned filter conversion""" + + # given + start_date = '2023-01-01' + end_date = '2024-01-01' + filters: dict[BillingColumn, str | list[Any] | dict[Any, Any]] = { + BillingColumn.TOPIC: 'TEST_TOPIC' + } + + # expected + expected_filter = BillingFilter( + day=GenericBQFilter( + gte=datetime.datetime(2023, 1, 1, 0, 0), + lte=datetime.datetime(2024, 1, 1, 0, 0), + ), + topic=GenericBQFilter(eq='TEST_TOPIC'), + ) + + query = BillingTotalCostQueryModel( + fields=[], # not relevant for this test, but can't be null generally + start_date=start_date, + end_date=end_date, + filters=filters, + ) + filter_ = BillingArBatchTable._query_to_partitioned_filter(query) + + # BillingFilter has __eq__ method, so we can compare them directly + self.assertEqual(expected_filter, filter_) + + def test_error_no_connection(self): + """Test No connection exception""" + + with self.assertRaises(InternalError) as context: + BillingArBatchTable(None) + + self.assertTrue( + 'No connection was provided to the table \'BillingArBatchTable\'' + in str(context.exception) + ) + + def test_get_table_name(self): + """Test get_table_name""" + + # table name is set in the class + given_table_name = 'TEST_TABLE_NAME' + + # set table name + self.table_obj.table_name = given_table_name + + # test get table name function + table_name = self.table_obj.get_table_name() + + self.assertEqual(given_table_name, table_name) + + @run_as_sync + async def test_get_batches_by_ar_guid_no_data(self): + """Test get_batches_by_ar_guid""" + + ar_guid = '1234567890' + + # mock BigQuery result + self.bq_result.result.return_value = [] + + # test get_batches_by_ar_guid function + (start_day, end_day, batch_ids) = await self.table_obj.get_batches_by_ar_guid( + ar_guid + ) + + self.assertEqual(None, start_day) + self.assertEqual(None, end_day) + self.assertEqual([], batch_ids) + + @run_as_sync + async def test_get_batches_by_ar_guid_one_record(self): + """Test get_batches_by_ar_guid""" + + ar_guid = '1234567890' + given_start_day = datetime.datetime(2023, 1, 1, 0, 0) + given_end_day = datetime.datetime(2023, 1, 1, 2, 3) + + # mock BigQuery result + self.bq_result.result.return_value = [ + mock.MagicMock( + spec=bq.Row, + start_day=given_start_day, + end_day=given_end_day, + batch_id='Batch1234', + ) + ] + + # test get_batches_by_ar_guid function + (start_day, end_day, batch_ids) = await self.table_obj.get_batches_by_ar_guid( + ar_guid + ) + + self.assertEqual(given_start_day, start_day) + # end day is the last day + 1 + self.assertEqual(given_end_day + datetime.timedelta(days=1), end_day) + self.assertEqual(['Batch1234'], batch_ids) + + @run_as_sync + async def test_get_batches_by_ar_guid_two_record(self): + """Test get_batches_by_ar_guid""" + + ar_guid = '1234567890' + given_start_day = datetime.datetime(2023, 1, 1, 0, 0) + given_end_day = datetime.datetime(2023, 1, 1, 2, 3) + + # mock BigQuery result + self.bq_result.result.return_value = [ + mock.MagicMock( + spec=bq.Row, + start_day=given_start_day, + end_day=given_start_day, + batch_id='FirstBatch', + ), + mock.MagicMock( + spec=bq.Row, + start_day=given_start_day, + end_day=given_end_day, + batch_id='SecondBatch', + ), + ] + + # test get_batches_by_ar_guid function + (start_day, end_day, batch_ids) = await self.table_obj.get_batches_by_ar_guid( + ar_guid + ) + + self.assertEqual(given_start_day, start_day) + # end day is the last day + 1 + self.assertEqual(given_end_day + datetime.timedelta(days=1), end_day) + self.assertEqual(['FirstBatch', 'SecondBatch'], batch_ids) + + @run_as_sync + async def test_get_ar_guid_by_batch_id_no_data(self): + """Test get_ar_guid_by_batch_id""" + + batch_id = '1234567890' + + # mock BigQuery result + self.bq_result.result.return_value = [] + + # test get_ar_guid_by_batch_id function + ar_guid = await self.table_obj.get_ar_guid_by_batch_id(batch_id) + + self.assertEqual(None, ar_guid) + + @run_as_sync + async def test_get_ar_guid_by_batch_id_one_rec(self): + """Test get_ar_guid_by_batch_id""" + + batch_id = '1234567890' + expected_ar_guid = 'AR_GUID_1234' + + # mock BigQuery result + self.bq_result.result.return_value = [{'ar_guid': expected_ar_guid}] + + # test get_ar_guid_by_batch_id function + ar_guid = await self.table_obj.get_ar_guid_by_batch_id(batch_id) + + self.assertEqual(expected_ar_guid, ar_guid) diff --git a/test/test_bq_billing_base.py b/test/test_bq_billing_base.py new file mode 100644 index 000000000..655b40486 --- /dev/null +++ b/test/test_bq_billing_base.py @@ -0,0 +1,911 @@ +# pylint: disable=protected-access too-many-public-methods +from datetime import datetime +from test.testbase import run_as_sync +from test.testbqbase import BqTest +from typing import Any +from unittest import mock + +import google.cloud.bigquery as bq + +from db.python.tables.bq.billing_base import ( + BillingBaseTable, + abbrev_cost_category, + prepare_time_periods, +) +from db.python.tables.bq.billing_daily_extended import BillingDailyExtendedTable +from db.python.tables.bq.billing_filter import BillingFilter +from db.python.tables.bq.generic_bq_filter import GenericBQFilter +from models.enums import BillingTimePeriods +from models.models import ( + BillingColumn, + BillingCostBudgetRecord, + BillingCostDetailsRecord, + BillingTotalCostQueryModel, +) + + +def mock_execute_query_running_cost(query, *_args, **_kwargs): + """ + This is a mockup function for _execute_query function + This returns one mockup BQ query result + for 2 different SQL queries used by get_running_cost API point + Those 2 queries are: + 1. query to get last loaded day + 2. query to get aggregated monthly/daily cost + """ + if ' as last_loaded_day' in query: + # This is the 1st query to get last loaded day + # mockup BQ query result for last_loaded_day as list of rows + return [ + mock.MagicMock(spec=bq.Row, last_loaded_day='2024-01-01 00:00:00+00:00') + ] + + # This is the 2nd query to get aggregated monthly cost + # return mockup BQ query result as list of rows + return [ + mock.MagicMock( + spec=bq.Row, + field='TOPIC1', + cost_category='Compute Engine', + daily_cost=123.45, + monthly_cost=2345.67, + ) + ] + + +def mock_execute_query_get_total_cost(_query, *_args, **_kwargs): + """ + This is a mockup function for _execute_query function + This returns one mockup BQ query result + """ + # mockup BQ query result topic cost by invoice month, return row iterator + mock_rows = mock.MagicMock(spec=bq.table.RowIterator) + mock_rows.total_rows = 3 + mock_rows.__iter__.return_value = [ + {'day': '202301', 'topic': 'TOPIC1', 'cost': 123.10}, + {'day': '202302', 'topic': 'TOPIC1', 'cost': 223.20}, + {'day': '202303', 'topic': 'TOPIC1', 'cost': 323.30}, + ] + mock_result = mock.MagicMock(spec=bq.job.QueryJob) + mock_result.result.return_value = mock_rows + return mock_result + + +class TestBillingBaseTable(BqTest): + """Test BillingBaseTable and its methods""" + + def setUp(self): + super().setUp() + + # setup table object + # base is abstract, so we need to use a child class + # DailyExtended is the simplest one, almost no overrides + self.table_obj = BillingDailyExtendedTable(self.connection) + + def test_query_to_partitioned_filter(self): + """Test query to partitioned filter conversion""" + + # given + start_date = '2023-01-01' + end_date = '2024-01-01' + filters: dict[BillingColumn, str | list[Any] | dict[Any, Any]] = { + BillingColumn.TOPIC: 'TEST_TOPIC' + } + + # expected + expected_filter = BillingFilter( + day=GenericBQFilter( + gte=datetime(2023, 1, 1, 0, 0), + lte=datetime(2024, 1, 1, 0, 0), + ), + topic=GenericBQFilter(eq='TEST_TOPIC'), + ) + + query = BillingTotalCostQueryModel( + fields=[], # not relevant for this test, but can't be null generally + start_date=start_date, + end_date=end_date, + filters=filters, + ) + filter_ = BillingBaseTable._query_to_partitioned_filter(query) + + # BillingFilter has __eq__ method, so we can compare them directly + self.assertEqual(expected_filter, filter_) + + def test_abbrev_cost_category(self): + """Test abbrev_cost_category""" + + # table name is set in the class + categories_to_expected = { + 'Cloud Storage': 'S', + 'Compute Engine': 'C', + 'Other': 'C', + } + + # test category to abreveation + for cat, expected_abrev in categories_to_expected.items(): + self.assertEqual(expected_abrev, abbrev_cost_category(cat)) + + def test_prepare_time_periods_by_day(self): + """Test prepare_time_periods""" + + query = BillingTotalCostQueryModel( + fields=[], + start_date='2024-01-01', + end_date='2024-01-01', + time_periods=BillingTimePeriods.DAY, + ) + + time_group = prepare_time_periods(query) + + self.assertEqual('FORMAT_DATE("%Y-%m-%d", day) as day', time_group.field) + self.assertEqual('PARSE_DATE("%Y-%m-%d", day) as day', time_group.formula) + self.assertEqual(',', time_group.separator) + + def test_prepare_time_periods_by_week(self): + """Test prepare_time_periods""" + + query = BillingTotalCostQueryModel( + fields=[], + start_date='2024-01-01', + end_date='2024-01-01', + time_periods=BillingTimePeriods.WEEK, + ) + + time_group = prepare_time_periods(query) + + self.assertEqual('FORMAT_DATE("%Y%W", day) as day', time_group.field) + self.assertEqual('PARSE_DATE("%Y%W", day) as day', time_group.formula) + self.assertEqual(',', time_group.separator) + + def test_prepare_time_periods_by_month(self): + """Test prepare_time_periods""" + + query = BillingTotalCostQueryModel( + fields=[], + start_date='2024-01-01', + end_date='2024-01-01', + time_periods=BillingTimePeriods.MONTH, + ) + + time_group = prepare_time_periods(query) + + self.assertEqual('FORMAT_DATE("%Y%m", day) as day', time_group.field) + self.assertEqual('PARSE_DATE("%Y%m", day) as day', time_group.formula) + self.assertEqual(',', time_group.separator) + + def test_prepare_time_periods_by_invoice_month(self): + """Test prepare_time_periods""" + + query = BillingTotalCostQueryModel( + fields=[], + start_date='2024-01-01', + end_date='2024-01-01', + time_periods=BillingTimePeriods.INVOICE_MONTH, + ) + + time_group = prepare_time_periods(query) + + self.assertEqual('invoice_month as day', time_group.field) + self.assertEqual('PARSE_DATE("%Y%m", day) as day', time_group.formula) + self.assertEqual(',', time_group.separator) + + def test_filter_to_optimise_query(self): + """Test _filter_to_optimise_query""" + + result = BillingBaseTable._filter_to_optimise_query() + self.assertEqual( + 'day >= TIMESTAMP(@start_day) AND day <= TIMESTAMP(@last_day)', result + ) + + def test_last_loaded_day_filter(self): + """Test _last_loaded_day_filter""" + + result = BillingBaseTable._last_loaded_day_filter() + self.assertEqual('day = TIMESTAMP(@last_loaded_day)', result) + + def test_convert_output_empty_results(self): + """Test _convert_output - various empty results""" + + empty_results = BillingBaseTable._convert_output(None) + self.assertEqual([], empty_results) + + query_job_result = mock.MagicMock(spec=bq.job.QueryJob) + query_job_result.result.total_rows = 0 + + empty_list = BillingBaseTable._convert_output(query_job_result) + self.assertEqual([], empty_list) + + query_job_result = mock.MagicMock(spec=bq.job.QueryJob) + query_job_result.result.return_value = mock.MagicMock(spec=bq.table.RowIterator) + + empty_row_iterator = BillingBaseTable._convert_output(query_job_result) + self.assertEqual([], empty_row_iterator) + + def test_convert_output_one_record(self): + """Test _convert_output - one record result""" + + mock_rows = mock.MagicMock(spec=bq.table.RowIterator) + mock_rows.total_rows = 1 + mock_rows.__iter__.return_value = [{}] + + query_job_result = mock.MagicMock(spec=bq.job.QueryJob) + query_job_result.result.return_value = mock_rows + + single_row = BillingBaseTable._convert_output(query_job_result) + self.assertEqual([{}], single_row) + + def test_convert_output_label_record(self): + """Test _convert_output - test with label item""" + mock_rows = mock.MagicMock(spec=bq.table.RowIterator) + mock_rows.total_rows = 1 + mock_rows.__iter__.return_value = [ + {'labels': [{'key': 'test_key', 'value': 'test_value'}]} + ] + + query_job_result = mock.MagicMock(spec=bq.job.QueryJob) + query_job_result.result.return_value = mock_rows + + row_iterator = BillingBaseTable._convert_output(query_job_result) + self.assertEqual( + [ + { + # keep the original labels + 'labels': [{'key': 'test_key', 'value': 'test_value'}], + # append the labels as key-value pairs + 'test_key': 'test_value', + } + ], + row_iterator, + ) + + def test_prepare_order_by_string_empty(self): + """Test _prepare_order_by_string - empty results""" + + self.assertEqual('', BillingBaseTable._prepare_order_by_string(None)) + + def test_prepare_order_by_string_order_by_one_column(self): + """Test _prepare_order_by_string""" + + # DESC order by column + self.assertEqual( + 'ORDER BY cost DESC', + BillingBaseTable._prepare_order_by_string({BillingColumn.COST: True}), + ) + + # ASC order by column + self.assertEqual( + 'ORDER BY cost ASC', + BillingBaseTable._prepare_order_by_string({BillingColumn.COST: False}), + ) + + def test_prepare_order_by_string_order_by_two_columns(self): + """Test _prepare_order_by_string - order by 2 columns""" + self.assertEqual( + 'ORDER BY cost ASC,day DESC', + BillingBaseTable._prepare_order_by_string( + {BillingColumn.COST: False, BillingColumn.DAY: True} + ), + ) + + def test_prepare_aggregation_default_group_by(self): + """Test _prepare_aggregation""" + + query = BillingTotalCostQueryModel( + fields=[], start_date='2024-01-01', end_date='2024-01-01' + ) + + fields_selected, group_by = BillingBaseTable._prepare_aggregation(query) + # no fields selected so it is empty + self.assertEqual('', fields_selected) + # by default results are grouped by day + self.assertEqual('GROUP BY day', group_by) + + def test_prepare_aggregation_default_no_grouping_by(self): + """Test _prepare_aggregation""" + + # test when query is not grouped by + query = BillingTotalCostQueryModel( + fields=[BillingColumn.TOPIC], + start_date='2024-01-01', + end_date='2024-01-01', + group_by=False, + ) + + fields_selected, group_by = BillingBaseTable._prepare_aggregation(query) + # topic field is selected + self.assertEqual('topic', fields_selected) + # group by is switched off + self.assertEqual('', group_by) + + def test_prepare_aggregation_default_group_by_more_columns(self): + """Test _prepare_aggregation""" + + # test when query is grouped by, but column can not be grouped by + # cost can not be grouped by, so it is not present in the result + query = BillingTotalCostQueryModel( + fields=[BillingColumn.TOPIC, BillingColumn.COST], + start_date='2024-01-01', + end_date='2024-01-01', + group_by=True, + ) + + fields_selected, group_by = BillingBaseTable._prepare_aggregation(query) + self.assertEqual('topic', fields_selected) + # always group by day and any field that can be grouped by + self.assertEqual('GROUP BY day,topic', group_by) + + def test_execute_query_results_as_list(self): + """Test _execute_query""" + + # we are not running SQL against real BQ, just a mocking, so we can use any query + sql_query = 'SELECT 1;' + sql_params: list[Any] = [] + + # test results_as_list=True + given_bq_results = [[], [123], ['a', 'b', 'c']] + for bq_result in given_bq_results: + self.bq_result.result.return_value = bq_result + results = self.table_obj._execute_query( + sql_query, sql_params, results_as_list=True + ) + self.assertEqual(bq_result, results) + + def test_execute_query_results_not_as_list(self): + """Test _execute_query""" + + # we are not running SQL against real BQ, just a mocking, so we can use any query + sql_query = 'SELECT 1;' + sql_params: list[Any] = [] + + # now test results_as_list=False + given_bq_results = [[], [123], ['a', 'b', 'c']] + for bq_result in given_bq_results: + # mock BigQuery result + self.bq_client.query.return_value = bq_result + results = self.table_obj._execute_query( + sql_query, sql_params, results_as_list=False + ) + self.assertEqual(bq_result, results) + + def test_execute_query_with_sql_params(self): + """Test _execute_query""" + + # now test results_as_list=False and with some dummy params + sql_query = 'SELECT 1;' + sql_params = [ + bq.ScalarQueryParameter('dummy_not_used', 'STRING', '2021-01-01 00:00:00') + ] + + given_bq_results = [[], [123], ['a', 'b', 'c']] + for bq_result in given_bq_results: + # mock BigQuery result + self.bq_client.query.return_value = bq_result + results = self.table_obj._execute_query( + sql_query, sql_params, results_as_list=False + ) + self.assertEqual(bq_result, results) + + @run_as_sync + async def test_append_total_running_cost_no_topic(self): + """Test _append_total_running_cost""" + + # test _append_total_running_cost function, no topic present + total_record = await BillingBaseTable._append_total_running_cost( + field=BillingColumn.TOPIC, + is_current_month=True, + last_loaded_day=None, + total_monthly={'C': {'ALL': 1000}, 'S': {'ALL': 2000}}, + total_daily={'C': {'ALL': 100}, 'S': {'ALL': 200}}, + total_monthly_category={}, + total_daily_category={}, + results=[], + ) + + self.assertEqual( + [ + BillingCostBudgetRecord( + field='All Topics', + total_monthly=3000.0, + total_daily=300.0, + compute_monthly=1000.0, + compute_daily=100.0, + storage_monthly=2000.0, + storage_daily=200.0, + details=[], + budget_spent=None, + budget=None, + last_loaded_day=None, + ) + ], + total_record, + ) + + @run_as_sync + async def test_append_total_running_cost_not_current_month(self): + """Test _append_total_running_cost""" + + # test _append_total_running_cost function, not current month + total_record = await BillingBaseTable._append_total_running_cost( + field=BillingColumn.TOPIC, + is_current_month=False, + last_loaded_day=None, + total_monthly={'C': {'ALL': 1000}, 'S': {'ALL': 2000}}, + total_daily=None, + total_monthly_category={}, + total_daily_category={}, + results=[], + ) + + self.assertEqual( + [ + BillingCostBudgetRecord( + field='All Topics', + total_monthly=3000.0, + total_daily=None, + compute_monthly=1000.0, + compute_daily=None, + storage_monthly=2000.0, + storage_daily=None, + details=[], + budget_spent=None, + budget=None, + last_loaded_day=None, + ) + ], + total_record, + ) + + @run_as_sync + async def test_append_total_running_cost_current_month(self): + """Test _append_total_running_cost""" + + total_record = await BillingBaseTable._append_total_running_cost( + field=BillingColumn.TOPIC, + is_current_month=True, + last_loaded_day=None, + total_monthly={'C': {'ALL': 1000}, 'S': {'ALL': 2000}}, + total_daily={'C': {'ALL': 100}, 'S': {'ALL': 200}}, + total_monthly_category={ + 'Compute Engine': 900, + 'Cloud Storage': 2000, + 'Other': 100, + }, + total_daily_category={ + 'Compute Engine': 90, + 'Cloud Storage': 200, + 'Other': 10, + }, + results=[], + ) + + self.assertEqual( + [ + BillingCostBudgetRecord( + field='All Topics', + total_monthly=3000.0, + total_daily=300.0, + compute_monthly=1000.0, + compute_daily=100.0, + storage_monthly=2000.0, + storage_daily=200.0, + details=[ + BillingCostDetailsRecord( + cost_group='C', + cost_category='Compute Engine', + daily_cost=90.0, + monthly_cost=900.0, + ), + BillingCostDetailsRecord( + cost_group='S', + cost_category='Cloud Storage', + daily_cost=200.0, + monthly_cost=2000.0, + ), + BillingCostDetailsRecord( + cost_group='C', + cost_category='Other', + daily_cost=10.0, + monthly_cost=100.0, + ), + ], + budget_spent=None, + budget=None, + last_loaded_day=None, + ) + ], + total_record, + ) + + @run_as_sync + async def test_budgets_by_gcp_project_empty_results(self): + """Test _budgets_by_gcp_project""" + + # Only GCP_PROJECT and current month has budget + empty_result = await self.table_obj._budgets_by_gcp_project( + BillingColumn.TOPIC, False + ) + self.assertEqual({}, empty_result) + + # GCP_PROJECT and current month, but BQ mockup setup as empty values + empty_result = await self.table_obj._budgets_by_gcp_project( + BillingColumn.GCP_PROJECT, True + ) + self.assertEqual({}, empty_result) + + @run_as_sync + async def test_budgets_by_gcp_project_with_results(self): + """Test _budgets_by_gcp_project""" + + # GCP_PROJECT and current month and Mockup set as 2 records + self.bq_result.result.return_value = [ + mock.MagicMock(spec=bq.Row, gcp_project='Project1', budget=1000.0), + mock.MagicMock(spec=bq.Row, gcp_project='Project2', budget=2000.0), + ] + + non_empty_result = await self.table_obj._budgets_by_gcp_project( + BillingColumn.GCP_PROJECT, True + ) + self.assertDictEqual({'Project1': 1000.0, 'Project2': 2000.0}, non_empty_result) + + @run_as_sync + async def test_execute_running_cost_query_invalid_months(self): + """Test _execute_running_cost_query""" + + # test invalid inputs + with self.assertRaises(ValueError) as context: + await self.table_obj._execute_running_cost_query(BillingColumn.TOPIC, None) + + self.assertTrue('Invalid invoice month' in str(context.exception)) + + with self.assertRaises(ValueError) as context: + await self.table_obj._execute_running_cost_query( + BillingColumn.TOPIC, '12345678' + ) + + self.assertTrue('Invalid invoice month' in str(context.exception)) + + with self.assertRaises(ValueError) as context: + await self.table_obj._execute_running_cost_query( + BillingColumn.TOPIC, '1024AA' + ) + self.assertTrue('Invalid invoice month' in str(context.exception)) + + @run_as_sync + async def test_execute_running_cost_query_empty_results_old_month(self): + """Test _execute_running_cost_query""" + + # no mocked BQ results, should return as empty + ( + is_current_month, + last_loaded_day, + query_job_result, + ) = await self.table_obj._execute_running_cost_query( + BillingColumn.TOPIC, '202101' + ) + + self.assertEqual(False, is_current_month) + self.assertEqual(None, last_loaded_day) + self.assertEqual([], query_job_result) + + @run_as_sync + async def test_execute_running_cost_query_empty_results_current_month(self): + """Test _execute_running_cost_query""" + + # no mocked BQ results, should return as empty + # use current month to test the current month branch + current_month_as_string = datetime.now().strftime('%Y%m') + ( + is_current_month, + last_loaded_day, + query_job_result, + ) = await self.table_obj._execute_running_cost_query( + BillingColumn.TOPIC, current_month_as_string + ) + + self.assertEqual(True, is_current_month) + self.assertEqual(None, last_loaded_day) + self.assertEqual([], query_job_result) + + @run_as_sync + async def test_append_running_cost_records_empty_results(self): + """Test _append_running_cost_records""" + + # test empty results + empty_results = await self.table_obj._append_running_cost_records( + field=BillingColumn.TOPIC, + is_current_month=False, + last_loaded_day=None, + total_monthly={}, + total_daily={}, + field_details={}, + results=[], + ) + + self.assertEqual([], empty_results) + + @run_as_sync + async def test_append_running_cost_records_simple_data(self): + """Test _append_running_cost_records""" + + # prepare simple input data + field_details: dict[str, Any] = { + 'Project1': [], + } + + simple_result = await self.table_obj._append_running_cost_records( + field=BillingColumn.GCP_PROJECT, + is_current_month=False, + last_loaded_day=None, + total_monthly={'C': {}, 'S': {}}, + total_daily={'C': {}, 'S': {}}, + field_details=field_details, + results=[], + ) + + self.assertEqual( + [ + BillingCostBudgetRecord( + field='Project1', + total_monthly=0.0, + compute_monthly=0.0, + compute_daily=0.0, + storage_monthly=0.0, + storage_daily=0.0, + details=[], + last_loaded_day=None, + total_daily=None, + budget_spent=None, + budget=None, + ) + ], + simple_result, + ) + + @run_as_sync + async def test_append_running_cost_records_with_details(self): + """Test _append_running_cost_records""" + + # prepare input data with more details + field_details = { + 'Project2': [ + { + 'cost_group': 'C', + 'cost_category': 'Compute Engine', + 'daily_cost': 90.0, + 'monthly_cost': 900.0, + } + ], + } + + detailed_result = await self.table_obj._append_running_cost_records( + field=BillingColumn.GCP_PROJECT, + is_current_month=False, + last_loaded_day=None, + total_monthly={'C': {}, 'S': {}}, + total_daily={'C': {}, 'S': {}}, + field_details=field_details, + results=[], + ) + + self.assertEqual( + [ + BillingCostBudgetRecord( + field='Project2', + total_monthly=0.0, + compute_monthly=0.0, + compute_daily=0.0, + storage_monthly=0.0, + storage_daily=0.0, + details=[ + BillingCostDetailsRecord( + cost_group='C', + cost_category='Compute Engine', + daily_cost=90.0, + monthly_cost=900.0, + ) + ], + last_loaded_day=None, + total_daily=None, + budget_spent=None, + budget=None, + ) + ], + detailed_result, + ) + + @run_as_sync + async def test_get_running_cost_invalid_input(self): + """Test get_running_cost""" + + # test invalid outputs + with self.assertRaises(ValueError) as context: + await self.table_obj.get_running_cost( + # not allowed field + field=BillingColumn.SKU, + invoice_month=None, + ) + + self.assertTrue( + ( + 'Invalid field only topic, dataset, gcp-project, compute_category, ' + 'wdl_task_name, cromwell_sub_workflow_name & namespace are allowed' + ) + in str(context.exception) + ) + + @run_as_sync + async def test_get_running_cost_empty_results(self): + """Test get_running_cost""" + + # test empty cost (no BQ mockup data provided) + empty_results = await self.table_obj.get_running_cost( + field=BillingColumn.TOPIC, + invoice_month='202301', + ) + + self.assertEqual([], empty_results) + + @run_as_sync + async def test_get_running_cost_older_month(self): + """Test get_running_cost""" + + # mockup BQ sql query result for _execute_running_cost_query function + self.table_obj._execute_query = mock.MagicMock( + side_effect=mock_execute_query_running_cost + ) + + one_record_result = await self.table_obj.get_running_cost( + field=BillingColumn.TOPIC, + invoice_month='202301', + ) + + self.assertEqual( + [ + BillingCostBudgetRecord( + field='All Topics', + total_monthly=2345.67, + total_daily=None, + compute_monthly=2345.67, + compute_daily=None, + storage_monthly=0.0, + storage_daily=None, + details=[ + BillingCostDetailsRecord( + cost_group='C', + cost_category='Compute Engine', + daily_cost=None, + monthly_cost=2345.67, + ) + ], + budget_spent=None, + budget=None, + last_loaded_day=None, + ), + BillingCostBudgetRecord( + field='TOPIC1', + total_monthly=2345.67, + total_daily=None, + compute_monthly=2345.67, + compute_daily=0.0, + storage_monthly=0.0, + storage_daily=0.0, + details=[ + BillingCostDetailsRecord( + cost_group='C', + cost_category='Compute Engine', + daily_cost=None, + monthly_cost=2345.67, + ) + ], + budget_spent=None, + budget=None, + last_loaded_day=None, + ), + ], + one_record_result, + ) + + @run_as_sync + async def test_get_running_cost_current_month(self): + """Test get_running_cost""" + + # mockup BQ sql query result for _execute_running_cost_query function + self.table_obj._execute_query = mock.MagicMock( + side_effect=mock_execute_query_running_cost + ) + + # use the current month to test the current month branch + current_month_as_string = datetime.now().strftime('%Y%m') + + current_month_result = await self.table_obj.get_running_cost( + field=BillingColumn.TOPIC, + invoice_month=current_month_as_string, + ) + + self.assertEqual( + [ + BillingCostBudgetRecord( + field='All Topics', + total_monthly=2345.67, + total_daily=123.45, + compute_monthly=2345.67, + compute_daily=123.45, + storage_monthly=0.0, + storage_daily=0.0, + details=[ + BillingCostDetailsRecord( + cost_group='C', + cost_category='Compute Engine', + daily_cost=123.45, + monthly_cost=2345.67, + ) + ], + budget_spent=None, + budget=None, + last_loaded_day='Jan 01', + ), + BillingCostBudgetRecord( + field='TOPIC1', + total_monthly=2345.67, + total_daily=123.45, + compute_monthly=2345.67, + compute_daily=123.45, + storage_monthly=0.0, + storage_daily=0.0, + details=[ + BillingCostDetailsRecord( + cost_group='C', + cost_category='Compute Engine', + daily_cost=123.45, + monthly_cost=2345.67, + ) + ], + budget_spent=None, + budget=None, + last_loaded_day='Jan 01', + ), + ], + current_month_result, + ) + + @run_as_sync + async def test_get_total_cost(self): + """Test get_total_cost""" + + # test invalid input + query = BillingTotalCostQueryModel( + fields=[], start_date='2023-01-01', end_date='2024-01-01' + ) + + with self.assertRaises(ValueError) as context: + await self.table_obj.get_total_cost(query) + + self.assertTrue('Date and Fields are required' in str(context.exception)) + + # test empty results + query = BillingTotalCostQueryModel( + fields=[BillingColumn.TOPIC], + start_date='2023-01-01', + end_date='2024-01-01', + time_periods=BillingTimePeriods.INVOICE_MONTH, + ) + + # no BQ mockup data setup, returns empty list + empty_results = await self.table_obj.get_total_cost(query) + self.assertEqual([], empty_results) + + # mockup BQ sql query result for _execute_query to return 3 records. + # implementation is inside mock_execute_query function + self.table_obj._execute_query = mock.MagicMock( + side_effect=mock_execute_query_get_total_cost + ) + results = await self.table_obj.get_total_cost(query) + self.assertEqual( + [ + {'day': '202301', 'topic': 'TOPIC1', 'cost': 123.1}, + {'day': '202302', 'topic': 'TOPIC1', 'cost': 223.2}, + {'day': '202303', 'topic': 'TOPIC1', 'cost': 323.3}, + ], + results, + ) diff --git a/test/test_bq_billing_daily.py b/test/test_bq_billing_daily.py new file mode 100644 index 000000000..969b7c936 --- /dev/null +++ b/test/test_bq_billing_daily.py @@ -0,0 +1,255 @@ +# pylint: disable=protected-access +import datetime +from test.testbase import run_as_sync +from test.testbqbase import BqTest +from typing import Any +from unittest import mock + +import google.cloud.bigquery as bq + +from db.python.tables.bq.billing_daily import BillingDailyTable +from db.python.tables.bq.billing_filter import BillingFilter +from db.python.tables.bq.generic_bq_filter import GenericBQFilter +from db.python.utils import InternalError +from models.models import BillingColumn, BillingTotalCostQueryModel + + +class TestBillingDailyTable(BqTest): + """Test BillingRawTable and its methods""" + + def setUp(self): + super().setUp() + + # setup table object + self.table_obj = BillingDailyTable(self.connection) + + def test_query_to_partitioned_filter(self): + """Test query to partitioned filter conversion""" + + # given + start_date = '2023-01-01' + end_date = '2024-01-01' + filters: dict[BillingColumn, str | list[Any] | dict[Any, Any]] = { + BillingColumn.TOPIC: 'TEST_TOPIC' + } + + # expected + expected_filter = BillingFilter( + day=GenericBQFilter( + gte=datetime.datetime(2023, 1, 1, 0, 0), + lte=datetime.datetime(2024, 1, 1, 0, 0), + ), + topic=GenericBQFilter(eq='TEST_TOPIC'), + ) + + query = BillingTotalCostQueryModel( + fields=[], # not relevant for this test, but can't be null generally + start_date=start_date, + end_date=end_date, + filters=filters, + ) + filter_ = BillingDailyTable._query_to_partitioned_filter(query) + + # BillingFilter has __eq__ method, so we can compare them directly + self.assertEqual(expected_filter, filter_) + + def test_error_no_connection(self): + """Test No connection exception""" + + with self.assertRaises(InternalError) as context: + BillingDailyTable(None) + + self.assertTrue( + 'No connection was provided to the table \'BillingDailyTable\'' + in str(context.exception) + ) + + def test_get_table_name(self): + """Test get_table_name""" + + # table name is set in the class + given_table_name = 'TEST_TABLE_NAME' + + # set table name + self.table_obj.table_name = given_table_name + + # test get table name function + table_name = self.table_obj.get_table_name() + + self.assertEqual(given_table_name, table_name) + + @run_as_sync + async def test_last_loaded_day_return_valid_day(self): + """Test _last_loaded_day""" + + given_last_day = '2021-01-01 00:00:00' + + # mock BigQuery result + + self.bq_result.result.return_value = [ + mock.MagicMock( + spec=bq.Row, + last_loaded_day=datetime.datetime.strptime( + given_last_day, '%Y-%m-%d %H:%M:%S' + ), + ) + ] # 2021-01-01 + + # test get table name function + last_loaded_day = await self.table_obj._last_loaded_day() + + self.assertEqual(given_last_day, last_loaded_day) + + @run_as_sync + async def test_last_loaded_day_return_none(self): + """Test _last_loaded_day as None""" + + # mock BigQuery result as empty list + self.bq_result.result.return_value = [] + + # test get table name function + last_loaded_day = await self.table_obj._last_loaded_day() + + self.assertEqual(None, last_loaded_day) + + def test_prepare_daily_cost_subquery(self): + """Test _prepare_daily_cost_subquery""" + + self.table_obj.table_name = 'TEST_TABLE_NAME' + + # given + given_field = BillingColumn.COST + given_query_params: list[Any] = [] + given_last_loaded_day = '2021-01-01 00:00:00' + + ( + query_params, + daily_cost_field, + daily_cost_join, + ) = self.table_obj._prepare_daily_cost_subquery( + given_field, given_query_params, given_last_loaded_day + ) + + # expected + expected_daily_cost_join = """LEFT JOIN ( + SELECT + cost as field, + cost_category, + SUM(cost) as cost + FROM + `TEST_TABLE_NAME` + WHERE day = TIMESTAMP(@last_loaded_day) + GROUP BY + field, + cost_category + ) day + ON month.field = day.field + AND month.cost_category = day.cost_category + """ + + self.assertEqual( + [ + bq.ScalarQueryParameter( + 'last_loaded_day', 'STRING', '2021-01-01 00:00:00' + ) + ], + query_params, + ) + + self.assertEqual(', day.cost as daily_cost', daily_cost_field) + self.assertEqual(expected_daily_cost_join, daily_cost_join) + + @run_as_sync + async def test_get_entities_as_empty_list(self): + """ + Test get_topics, get_invoice_months, + get_cost_categories and get_skus as empty list + """ + + # mock BigQuery result as empty list + self.bq_result.result.return_value = [] + + # test get_topics function + records = await self.table_obj.get_topics() + self.assertEqual([], records) + + # test get_invoice_months function + records = await self.table_obj.get_invoice_months() + self.assertEqual([], records) + + # test get_cost_categories function + records = await self.table_obj.get_cost_categories() + self.assertEqual([], records) + + # test get_skus function + records = await self.table_obj.get_skus() + self.assertEqual([], records) + + @run_as_sync + async def test_get_topics_return_valid_list(self): + """Test get_topics as empty list""" + + # mock BigQuery result as list of 2 records + self.bq_result.result.return_value = [ + {'topic': 'TOPIC1'}, + {'topic': 'TOPIC2'}, + ] + + # test get_topics function + records = await self.table_obj.get_topics() + + self.assertEqual(['TOPIC1', 'TOPIC2'], records) + + @run_as_sync + async def test_get_invoice_months_return_valid_list(self): + """Test get_invoice_months as empty list""" + + # mock BigQuery result as list of 2 records + self.bq_result.result.return_value = [ + {'invoice_month': '202401'}, + {'invoice_month': '202402'}, + ] + + # test get_invoice_months function + records = await self.table_obj.get_invoice_months() + + self.assertEqual(['202401', '202402'], records) + + @run_as_sync + async def test_get_cost_categories_return_valid_list(self): + """Test get_cost_categories as empty list""" + + # mock BigQuery result as list of 2 records + self.bq_result.result.return_value = [ + {'cost_category': 'CAT1'}, + ] + + # test get_cost_categories function + records = await self.table_obj.get_cost_categories() + + self.assertEqual(['CAT1'], records) + + @run_as_sync + async def test_get_skus_return_valid_list(self): + """Test get_skus as empty list""" + + # mock BigQuery result as list of 3 records + self.bq_result.result.return_value = [ + {'sku': 'SKU1'}, + {'sku': 'SKU2'}, + {'sku': 'SKU3'}, + ] + + # test get_skus function + records = await self.table_obj.get_skus() + self.assertEqual(['SKU1', 'SKU2', 'SKU3'], records) + + # test get_skus function with limit, + # limit is ignored in the test as we already have mockup data + records = await self.table_obj.get_skus(limit=3) + self.assertEqual(['SKU1', 'SKU2', 'SKU3'], records) + + # test get_skus function with limit & offset + # limit & offset are ignored in the test as we already have mockup data + records = await self.table_obj.get_skus(limit=3, offset=1) + self.assertEqual(['SKU1', 'SKU2', 'SKU3'], records) diff --git a/test/test_bq_billing_daily_extended.py b/test/test_bq_billing_daily_extended.py new file mode 100644 index 000000000..183d6c8c5 --- /dev/null +++ b/test/test_bq_billing_daily_extended.py @@ -0,0 +1,116 @@ +# pylint: disable=protected-access +import datetime +from test.testbase import run_as_sync +from test.testbqbase import BqTest +from typing import Any + +from db.python.tables.bq.billing_daily_extended import BillingDailyExtendedTable +from db.python.tables.bq.billing_filter import BillingFilter +from db.python.tables.bq.generic_bq_filter import GenericBQFilter +from db.python.utils import InternalError +from models.models import BillingColumn, BillingTotalCostQueryModel + + +class TestBillingGcpDailyTable(BqTest): + """Test BillingRawTable and its methods""" + + def setUp(self): + super().setUp() + + # setup table object + self.table_obj = BillingDailyExtendedTable(self.connection) + + def test_query_to_partitioned_filter(self): + """Test query to partitioned filter conversion""" + + # given + start_date = '2023-01-01' + end_date = '2024-01-01' + filters: dict[BillingColumn, str | list[Any] | dict[Any, Any]] = { + BillingColumn.TOPIC: 'TEST_TOPIC' + } + + # expected + expected_filter = BillingFilter( + day=GenericBQFilter( + gte=datetime.datetime(2023, 1, 1, 0, 0), + lte=datetime.datetime(2024, 1, 1, 0, 0), + ), + topic=GenericBQFilter(eq='TEST_TOPIC'), + ) + + query = BillingTotalCostQueryModel( + fields=[], # not relevant for this test, but can't be null generally + start_date=start_date, + end_date=end_date, + filters=filters, + ) + filter_ = BillingDailyExtendedTable._query_to_partitioned_filter(query) + + # BillingFilter has __eq__ method, so we can compare them directly + self.assertEqual(expected_filter, filter_) + + def test_error_no_connection(self): + """Test No connection exception""" + + with self.assertRaises(InternalError) as context: + BillingDailyExtendedTable(None) + + self.assertTrue( + 'No connection was provided to the table \'BillingDailyExtendedTable\'' + in str(context.exception) + ) + + def test_get_table_name(self): + """Test get_table_name""" + + # table name is set in the class + given_table_name = 'TEST_TABLE_NAME' + + # set table name + self.table_obj.table_name = given_table_name + + # test get table name function + table_name = self.table_obj.get_table_name() + + self.assertEqual(given_table_name, table_name) + + @run_as_sync + async def test_get_extended_values_return_empty_list(self): + """Test get_extended_values as empty list""" + + # mock BigQuery result as empty list + self.bq_result.result.return_value = [] + + # test get table name function + records = await self.table_obj.get_extended_values('dataset') + + self.assertEqual([], records) + + @run_as_sync + async def test_get_extended_values_return_valid_list(self): + """Test get_extended_values as list of 2 records""" + + # mock BigQuery result as list of 2 records + self.bq_result.result.return_value = [ + {'dataset': 'DATA1'}, + {'dataset': 'DATA2'}, + ] + + # test get table name function + records = await self.table_obj.get_extended_values('dataset') + + self.assertEqual(['DATA1', 'DATA2'], records) + + @run_as_sync + async def test_get_extended_values_error(self): + """Test get_extended_values return exception as invalid ext column""" + + # mock BigQuery result as empty list + self.bq_result.result.return_value = [] + + # test get table name function + with self.assertRaises(ValueError) as context: + await self.table_obj.get_extended_values('rubish') + + self.assertTrue('Invalid field value' in str(context.exception)) diff --git a/test/test_bq_billing_gcp_daily.py b/test/test_bq_billing_gcp_daily.py new file mode 100644 index 000000000..99e3caa10 --- /dev/null +++ b/test/test_bq_billing_gcp_daily.py @@ -0,0 +1,197 @@ +# pylint: disable=protected-access +import datetime +from test.testbase import run_as_sync +from test.testbqbase import BqTest +from textwrap import dedent +from typing import Any +from unittest import mock + +import google.cloud.bigquery as bq + +from db.python.tables.bq.billing_filter import BillingFilter +from db.python.tables.bq.billing_gcp_daily import BillingGcpDailyTable +from db.python.tables.bq.generic_bq_filter import GenericBQFilter +from db.python.utils import InternalError +from models.models import BillingColumn, BillingTotalCostQueryModel + + +class TestBillingGcpDailyTable(BqTest): + """Test BillingRawTable and its methods""" + + def setUp(self): + super().setUp() + + # setup table object + self.table_obj = BillingGcpDailyTable(self.connection) + + def test_query_to_partitioned_filter(self): + """Test query to partitioned filter conversion""" + + # given + start_date = '2023-01-01' + end_date = '2024-01-01' + filters: dict[BillingColumn, str | list[Any] | dict[Any, Any]] = { + BillingColumn.TOPIC: 'TEST_TOPIC' + } + + # expected + expected_filter = BillingFilter( + day=GenericBQFilter( + gte=datetime.datetime(2023, 1, 1, 0, 0), + lte=datetime.datetime(2024, 1, 1, 0, 0), + ), + part_time=GenericBQFilter( + gte=datetime.datetime(2023, 1, 1, 0, 0), + lte=datetime.datetime(2024, 1, 8, 0, 0), # 7 days added + ), + topic=GenericBQFilter(eq='TEST_TOPIC'), + ) + + query = BillingTotalCostQueryModel( + fields=[], # not relevant for this test, but can't be null generally + start_date=start_date, + end_date=end_date, + filters=filters, + ) + filter_ = BillingGcpDailyTable._query_to_partitioned_filter(query) + + # BillingFilter has __eq__ method, so we can compare them directly + self.assertEqual(expected_filter, filter_) + + def test_error_no_connection(self): + """Test No connection exception""" + + with self.assertRaises(InternalError) as context: + BillingGcpDailyTable(None) + + self.assertTrue( + 'No connection was provided to the table \'BillingGcpDailyTable\'' + in str(context.exception) + ) + + def test_get_table_name(self): + """Test get_table_name""" + + # table name is set in the class + given_table_name = 'TEST_TABLE_NAME' + + # set table name + self.table_obj.table_name = given_table_name + + # test get table name function + table_name = self.table_obj.get_table_name() + + self.assertEqual(given_table_name, table_name) + + @run_as_sync + async def test_last_loaded_day_return_valid_day(self): + """Test _last_loaded_day""" + + given_last_day = '2021-01-01 00:00:00' + + # mock BigQuery result + + self.bq_result.result.return_value = [ + mock.MagicMock( + spec=bq.Row, + last_loaded_day=datetime.datetime.strptime( + given_last_day, '%Y-%m-%d %H:%M:%S' + ), + ) + ] # 2021-01-01 + + # test get table name function + last_loaded_day = await self.table_obj._last_loaded_day() + + self.assertEqual(given_last_day, last_loaded_day) + + @run_as_sync + async def test_last_loaded_day_return_none(self): + """Test _last_loaded_day as None""" + + # mock BigQuery result as empty list + self.bq_result.result.return_value = [] + + # test get table name function + last_loaded_day = await self.table_obj._last_loaded_day() + + self.assertEqual(None, last_loaded_day) + + def test_prepare_daily_cost_subquery(self): + """Test _prepare_daily_cost_subquery""" + + self.table_obj.table_name = 'TEST_TABLE_NAME' + + # given + given_field = BillingColumn.COST + given_query_params: list[Any] = [] + given_last_loaded_day = '2021-01-01 00:00:00' + + ( + query_params, + daily_cost_field, + daily_cost_join, + ) = self.table_obj._prepare_daily_cost_subquery( + given_field, given_query_params, given_last_loaded_day + ) + + # expected + expected_daily_cost_join = """LEFT JOIN ( + SELECT + cost as field, + cost_category, + SUM(cost) as cost + FROM + `TEST_TABLE_NAME` + WHERE day = TIMESTAMP(@last_loaded_day) + + AND part_time >= TIMESTAMP(@last_loaded_day) + AND part_time <= TIMESTAMP_ADD( + TIMESTAMP(@last_loaded_day), INTERVAL 7 DAY + ) + + GROUP BY + field, + cost_category + ) day + ON month.field = day.field + AND month.cost_category = day.cost_category + """ + + self.assertEqual( + [ + bq.ScalarQueryParameter( + 'last_loaded_day', 'STRING', '2021-01-01 00:00:00' + ) + ], + query_params, + ) + self.assertEqual(', day.cost as daily_cost', daily_cost_field) + self.assertEqual(dedent(expected_daily_cost_join), dedent(daily_cost_join)) + + @run_as_sync + async def test_get_gcp_projects_return_empty_list(self): + """Test get_gcp_projects as empty list""" + + # mock BigQuery result as empty list + self.bq_result.result.return_value = [] + + # test get table name function + gcp_projects = await self.table_obj.get_gcp_projects() + + self.assertEqual([], gcp_projects) + + @run_as_sync + async def test_get_gcp_projects_return_valid_list(self): + """Test get_gcp_projects as empty list""" + + # mock BigQuery result as list of 2 records + self.bq_result.result.return_value = [ + {'gcp_project': 'PROJECT1'}, + {'gcp_project': 'PROJECT2'}, + ] + + # test get table name function + gcp_projects = await self.table_obj.get_gcp_projects() + + self.assertEqual(['PROJECT1', 'PROJECT2'], gcp_projects) diff --git a/test/test_bq_billing_raw.py b/test/test_bq_billing_raw.py new file mode 100644 index 000000000..e1a5af50d --- /dev/null +++ b/test/test_bq_billing_raw.py @@ -0,0 +1,75 @@ +# pylint: disable=protected-access +import datetime +from test.testbqbase import BqTest +from typing import Any + +from db.python.tables.bq.billing_filter import BillingFilter +from db.python.tables.bq.billing_raw import BillingRawTable +from db.python.tables.bq.generic_bq_filter import GenericBQFilter +from db.python.utils import InternalError +from models.models import BillingColumn, BillingTotalCostQueryModel + + +class TestBillingRawTable(BqTest): + """Test BillingRawTable and its methods""" + + def setUp(self): + super().setUp() + + # setup table object + self.table_obj = BillingRawTable(self.connection) + + def test_query_to_partitioned_filter(self): + """Test query to partitioned filter conversion""" + + # given + start_date = '2023-01-01' + end_date = '2024-01-01' + filters: dict[BillingColumn, str | list[Any] | dict[Any, Any]] = { + BillingColumn.TOPIC: 'TEST_TOPIC' + } + + # expected + expected_filter = BillingFilter( + usage_end_time=GenericBQFilter( + gte=datetime.datetime(2023, 1, 1, 0, 0), + lte=datetime.datetime(2024, 1, 1, 0, 0), + ), + topic=GenericBQFilter(eq='TEST_TOPIC'), + ) + + query = BillingTotalCostQueryModel( + fields=[], # not relevant for this test, but can't be null generally + start_date=start_date, + end_date=end_date, + filters=filters, + ) + filter_ = BillingRawTable._query_to_partitioned_filter(query) + + # BillingFilter has __eq__ method, so we can compare them directly + self.assertEqual(expected_filter, filter_) + + def test_error_no_connection(self): + """Test No connection exception""" + + with self.assertRaises(InternalError) as context: + BillingRawTable(None) + + self.assertTrue( + 'No connection was provided to the table \'BillingRawTable\'' + in str(context.exception) + ) + + def test_get_table_name(self): + """Test get_table_name""" + + # table name is set in the class + given_table_name = 'TEST_BQ_AGGREG_RAW' + + # set table name + self.table_obj.table_name = given_table_name + + # test get table name function + table_name = self.table_obj.get_table_name() + + self.assertEqual(given_table_name, table_name) diff --git a/test/test_bq_function_filter.py b/test/test_bq_function_filter.py new file mode 100644 index 000000000..8c84f9af9 --- /dev/null +++ b/test/test_bq_function_filter.py @@ -0,0 +1,155 @@ +import unittest +from datetime import datetime +from enum import Enum + +import google.cloud.bigquery as bq +import pytz + +from db.python.tables.bq.function_bq_filter import FunctionBQFilter +from models.models import BillingColumn + + +class BGFunFilterTestEnum(str, Enum): + """Simple Enum classs""" + + ID = 'id' + VALUE = 'value' + + +class TestFunctionBQFilter(unittest.TestCase): + """Test function filters SQL generation""" + + def test_empty_output(self): + """Test that function filter converts to SQL as expected""" + filter_ = FunctionBQFilter( + name='test_sql_func', + implementation='CREATE TEMP FUNCTION test_sql_func() AS (SELECT 1);', + ) + + # no params are present, should return empty SQL and Param list + filter_.to_sql(BillingColumn.LABELS) + + self.assertEqual([], filter_.func_sql_parameters) + self.assertEqual('', filter_.func_where) + + def test_one_param_output(self): + """Test that function filter converts to SQL as expected""" + filter_ = FunctionBQFilter( + name='test_sql_func', + implementation='CREATE TEMP FUNCTION test_sql_func() AS (SELECT 1);', + ) + + # no params are present, should return empty SQL and Param list + filter_.to_sql(BillingColumn.LABELS, {'label_name': 'Some Value'}) + + self.assertEqual( + [ + bq.ScalarQueryParameter('param1', 'STRING', 'label_name'), + bq.ScalarQueryParameter('value1', 'STRING', 'Some Value'), + ], + filter_.func_sql_parameters, + ) + self.assertEqual( + '(test_sql_func(labels,@param1) = @value1)', filter_.func_where + ) + + def test_two_param_default_operator_output(self): + """Test that function filter converts to SQL as expected""" + filter_ = FunctionBQFilter( + name='test_sql_func', + implementation='CREATE TEMP FUNCTION test_sql_func() AS (SELECT 1);', + ) + + # no params are present, should return empty SQL and Param list + filter_.to_sql(BillingColumn.LABELS, {'label_name1': 10, 'label_name2': 123.45}) + + self.assertEqual( + [ + bq.ScalarQueryParameter('param1', 'STRING', 'label_name1'), + bq.ScalarQueryParameter('value1', 'INT64', 10), + bq.ScalarQueryParameter('param2', 'STRING', 'label_name2'), + bq.ScalarQueryParameter('value2', 'FLOAT64', 123.45), + ], + filter_.func_sql_parameters, + ) + self.assertEqual( + ( + '(test_sql_func(labels,@param1) = @value1 ' + 'AND test_sql_func(labels,@param2) = @value2)' + ), + filter_.func_where, + ) + + def test_two_param_operator_or_output(self): + """Test that function filter converts to SQL as expected""" + filter_ = FunctionBQFilter( + name='test_sql_func', + implementation='CREATE TEMP FUNCTION test_sql_func() AS (SELECT 1);', + ) + + # no params are present, should return empty SQL and Param list + filter_.to_sql( + BillingColumn.LABELS, + { + 'label_name1': BGFunFilterTestEnum.ID, + 'label_name2': datetime(2024, 1, 1), + }, + 'OR', + ) + + self.assertEqual( + [ + bq.ScalarQueryParameter('param1', 'STRING', 'label_name1'), + bq.ScalarQueryParameter('value1', 'STRING', 'id'), + bq.ScalarQueryParameter('param2', 'STRING', 'label_name2'), + bq.ScalarQueryParameter('value2', 'STRING', '2024-01-01T00:00:00'), + ], + filter_.func_sql_parameters, + ) + self.assertEqual( + ( + '(test_sql_func(labels,@param1) = @value1 ' + 'OR test_sql_func(labels,@param2) = TIMESTAMP(@value2))' + ), + filter_.func_where, + ) + + def test_two_param_datime_with_zone_output(self): + """Test that function filter converts to SQL as expected""" + filter_ = FunctionBQFilter( + name='test_sql_func', + implementation='CREATE TEMP FUNCTION test_sql_func() AS (SELECT 1);', + ) + + # no params are present, should return empty SQL and Param list + NYC = pytz.timezone('America/New_York') + SYD = pytz.timezone('Australia/Sydney') + filter_.to_sql( + BillingColumn.LABELS, + { + 'label_name1': datetime(2024, 1, 1, tzinfo=pytz.UTC).astimezone(NYC), + 'label_name2': datetime(2024, 1, 1, tzinfo=pytz.UTC).astimezone(SYD), + }, + 'AND', + ) + + self.assertEqual( + [ + bq.ScalarQueryParameter('param1', 'STRING', 'label_name1'), + bq.ScalarQueryParameter( + 'value1', 'STRING', '2023-12-31T19:00:00-05:00' + ), + bq.ScalarQueryParameter('param2', 'STRING', 'label_name2'), + bq.ScalarQueryParameter( + 'value2', 'STRING', '2024-01-01T11:00:00+11:00' + ), + ], + filter_.func_sql_parameters, + ) + self.assertEqual( + ( + '(test_sql_func(labels,@param1) = TIMESTAMP(@value1) ' + 'AND test_sql_func(labels,@param2) = TIMESTAMP(@value2))' + ), + filter_.func_where, + ) diff --git a/test/test_bq_generic_filter.py b/test/test_bq_generic_filter.py new file mode 100644 index 000000000..5e5600d1f --- /dev/null +++ b/test/test_bq_generic_filter.py @@ -0,0 +1,270 @@ +import dataclasses +import unittest +from datetime import datetime +from enum import Enum +from typing import Any + +from google.cloud import bigquery + +from db.python.tables.bq.generic_bq_filter import GenericBQFilter +from db.python.tables.bq.generic_bq_filter_model import GenericBQFilterModel + + +@dataclasses.dataclass(kw_only=True) +class GenericBQFilterTest(GenericBQFilterModel): + """Test model for GenericBQFilter""" + + test_string: GenericBQFilter[str] | None = None + test_int: GenericBQFilter[int] | None = None + test_float: GenericBQFilter[float] | None = None + test_dt: GenericBQFilter[datetime] | None = None + test_dict: dict[str, GenericBQFilter[str]] | None = None + test_enum: GenericBQFilter[Enum] | None = None + test_any: Any | None = None + + +class BGFilterTestEnum(str, Enum): + """Simple Enum classs""" + + ID = 'id' + VALUE = 'value' + + +class TestGenericBQFilter(unittest.TestCase): + """Test generic filter SQL generation""" + + def test_basic_no_override(self): + """Test that the basic filter converts to SQL as expected""" + filter_ = GenericBQFilterTest(test_string=GenericBQFilter(eq='test')) + sql, values = filter_.to_sql() + + self.assertEqual('test_string = @test_string_eq', sql) + self.assertDictEqual( + { + 'test_string_eq': bigquery.ScalarQueryParameter( + 'test_string_eq', 'STRING', 'test' + ) + }, + values, + ) + + def test_basic_override(self): + """Test that the basic filter with an override converts to SQL as expected""" + filter_ = GenericBQFilterTest(test_string=GenericBQFilter(eq='test')) + sql, values = filter_.to_sql({'test_string': 't.string'}) + + self.assertEqual('t.string = @t_string_eq', sql) + self.assertDictEqual( + { + 't_string_eq': bigquery.ScalarQueryParameter( + 't_string_eq', 'STRING', 'test' + ) + }, + values, + ) + + def test_single_string(self): + """ + Test that a single value filtered using the "in" operator + gets converted to an eq operation + """ + filter_ = GenericBQFilterTest(test_string=GenericBQFilter(in_=['test'])) + sql, values = filter_.to_sql() + + self.assertEqual('test_string = @test_string_in_eq', sql) + self.assertDictEqual( + { + 'test_string_in_eq': bigquery.ScalarQueryParameter( + 'test_string_in_eq', 'STRING', 'test' + ) + }, + values, + ) + + def test_single_int(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = 123 + filter_ = GenericBQFilterTest(test_int=GenericBQFilter(gt=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int > @test_int_gt', sql) + self.assertDictEqual( + { + 'test_int_gt': bigquery.ScalarQueryParameter( + 'test_int_gt', 'INT64', value + ) + }, + values, + ) + + def test_single_float(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = 123.456 + filter_ = GenericBQFilterTest(test_float=GenericBQFilter(gte=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_float >= @test_float_gte', sql) + self.assertDictEqual( + { + 'test_float_gte': bigquery.ScalarQueryParameter( + 'test_float_gte', 'FLOAT64', value + ) + }, + values, + ) + + def test_single_datetime(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + datetime_str = '2021-10-08 01:02:03' + value = datetime.strptime(datetime_str, '%Y-%m-%d %H:%M:%S') + filter_ = GenericBQFilterTest(test_dt=GenericBQFilter(lt=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_dt < TIMESTAMP(@test_dt_lt)', sql) + self.assertDictEqual( + { + 'test_dt_lt': bigquery.ScalarQueryParameter( + 'test_dt_lt', 'STRING', datetime_str + ) + }, + values, + ) + + def test_single_enum(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = BGFilterTestEnum.ID + filter_ = GenericBQFilterTest(test_enum=GenericBQFilter(lte=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_enum <= @test_enum_lte', sql) + self.assertDictEqual( + { + 'test_enum_lte': bigquery.ScalarQueryParameter( + 'test_enum_lte', 'STRING', value.value + ) + }, + values, + ) + + def test_in_multiple_int(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = [1, 2] + filter_ = GenericBQFilterTest(test_int=GenericBQFilter(in_=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int IN UNNEST(@test_int_in)', sql) + self.assertDictEqual( + { + 'test_int_in': bigquery.ArrayQueryParameter( + 'test_int_in', 'INT64', value + ) + }, + values, + ) + + def test_in_multiple_float(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = [1.0, 2.0] + filter_ = GenericBQFilterTest(test_float=GenericBQFilter(in_=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_float IN UNNEST(@test_float_in)', sql) + self.assertDictEqual( + { + 'test_float_in': bigquery.ArrayQueryParameter( + 'test_float_in', 'FLOAT64', value + ) + }, + values, + ) + + def test_in_multiple_str(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = ['A', 'B'] + filter_ = GenericBQFilterTest(test_string=GenericBQFilter(in_=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_string IN UNNEST(@test_string_in)', sql) + self.assertDictEqual( + { + 'test_string_in': bigquery.ArrayQueryParameter( + 'test_string_in', 'STRING', value + ) + }, + values, + ) + + def test_nin_multiple_str(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = ['A', 'B'] + filter_ = GenericBQFilterTest(test_string=GenericBQFilter(nin=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_string NOT IN UNNEST(@test_string_nin)', sql) + self.assertDictEqual( + { + 'test_string_nin': bigquery.ArrayQueryParameter( + 'test_string_nin', 'STRING', value + ) + }, + values, + ) + + def test_in_and_eq_multiple_str(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = ['A'] + filter_ = GenericBQFilterTest(test_string=GenericBQFilter(in_=value, eq='B')) + sql, values = filter_.to_sql() + + self.assertEqual( + 'test_string = @test_string_eq AND test_string = @test_string_in_eq', + sql, + ) + self.assertDictEqual( + { + 'test_string_eq': bigquery.ScalarQueryParameter( + 'test_string_eq', 'STRING', 'B' + ), + 'test_string_in_eq': bigquery.ScalarQueryParameter( + 'test_string_in_eq', 'STRING', 'A' + ), + }, + values, + ) + + def test_fail_none_in_tuple(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = (None,) + + # check if ValueError is raised + with self.assertRaises(ValueError) as context: + filter_ = GenericBQFilterTest(test_any=value) + filter_.to_sql() + + self.assertTrue( + 'There is very likely a trailing comma on the end of ' + 'GenericBQFilterTest.test_any. ' + 'If you actually want a tuple of length one with the value = (None,), ' + 'then use dataclasses.field(default_factory=lambda: (None,))' + in str(context.exception) + ) diff --git a/test/test_generic_filters.py b/test/test_generic_filters.py index 2c1348076..b5598be54 100644 --- a/test/test_generic_filters.py +++ b/test/test_generic_filters.py @@ -53,3 +53,54 @@ def test_in_multiple(self): self.assertEqual('test_int IN :test_int_in', sql) self.assertDictEqual({'test_int_in': value}, values) + + def test_gt_single(self): + """ + Test that a single value filtered using the "gt" operator + """ + filter_ = GenericFilterTest(test_int=GenericFilter(gt=123)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int > :test_int_gt', sql) + self.assertDictEqual({'test_int_gt': 123}, values) + + def test_gte_single(self): + """ + Test that a single value filtered using the "gte" operator + """ + filter_ = GenericFilterTest(test_int=GenericFilter(gte=123)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int >= :test_int_gte', sql) + self.assertDictEqual({'test_int_gte': 123}, values) + + def test_lt_single(self): + """ + Test that a single value filtered using the "lt" operator + """ + filter_ = GenericFilterTest(test_int=GenericFilter(lt=123)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int < :test_int_lt', sql) + self.assertDictEqual({'test_int_lt': 123}, values) + + def test_lte_single(self): + """ + Test that a single value filtered using the "lte" operator + """ + filter_ = GenericFilterTest(test_int=GenericFilter(lte=123)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int <= :test_int_lte', sql) + self.assertDictEqual({'test_int_lte': 123}, values) + + def test_not_in_multiple(self): + """ + Test that values filtered using the "nin" operator convert as expected + """ + value = [1, 2] + filter_ = GenericFilterTest(test_int=GenericFilter(nin=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int NOT IN :test_int_nin', sql) + self.assertDictEqual({'test_int_nin': value}, values) diff --git a/test/test_layers_billing.py b/test/test_layers_billing.py new file mode 100644 index 000000000..be13f3fa3 --- /dev/null +++ b/test/test_layers_billing.py @@ -0,0 +1,470 @@ +# pylint: disable=protected-access +import datetime +from test.testbase import run_as_sync +from test.testbqbase import BqTest +from unittest import mock + +import google.cloud.bigquery as bq + +from db.python.layers.billing import BillingLayer +from models.enums import BillingSource +from models.models import ( + BillingColumn, + BillingHailBatchCostRecord, + BillingTotalCostQueryModel, +) + + +class TestBillingLayer(BqTest): + """Test BillingLayer and its methods""" + + def test_table_factory(self): + """Test table_factory""" + + layer = BillingLayer(self.connection) + + # test BillingSource types + table_obj = layer.table_factory() + self.assertEqual('BillingDailyTable', table_obj.__class__.__name__) + + table_obj = layer.table_factory(source=BillingSource.GCP_BILLING) + self.assertEqual('BillingGcpDailyTable', table_obj.__class__.__name__) + + table_obj = layer.table_factory(source=BillingSource.RAW) + self.assertEqual('BillingRawTable', table_obj.__class__.__name__) + + table_obj = layer.table_factory(source=BillingSource.AGGREGATE) + self.assertEqual('BillingDailyTable', table_obj.__class__.__name__) + + # base columns + table_obj = layer.table_factory( + source=BillingSource.AGGREGATE, fields=[BillingColumn.TOPIC] + ) + self.assertEqual('BillingDailyTable', table_obj.__class__.__name__) + + table_obj = layer.table_factory( + source=BillingSource.AGGREGATE, filters={BillingColumn.TOPIC: 'TOPIC1'} + ) + self.assertEqual('BillingDailyTable', table_obj.__class__.__name__) + + # columns from extended view + table_obj = layer.table_factory( + source=BillingSource.AGGREGATE, fields=[BillingColumn.AR_GUID] + ) + self.assertEqual('BillingDailyExtendedTable', table_obj.__class__.__name__) + + table_obj = layer.table_factory( + source=BillingSource.AGGREGATE, filters={BillingColumn.AR_GUID: 'AR_GUID1'} + ) + self.assertEqual('BillingDailyExtendedTable', table_obj.__class__.__name__) + + @run_as_sync + async def test_get_gcp_projects(self): + """Test get_gcp_projects""" + + layer = BillingLayer(self.connection) + + # test with no muckup data, should be empty + records = await layer.get_gcp_projects() + self.assertEqual([], records) + + # mockup BQ results + self.bq_result.result.return_value = [ + {'gcp_project': 'PROJECT1'}, + {'gcp_project': 'PROJECT2'}, + ] + + records = await layer.get_gcp_projects() + self.assertEqual(['PROJECT1', 'PROJECT2'], records) + + @run_as_sync + async def test_get_topics(self): + """Test get_topics""" + + layer = BillingLayer(self.connection) + + # test with no muckup data, should be empty + records = await layer.get_topics() + self.assertEqual([], records) + + # mockup BQ results + self.bq_result.result.return_value = [ + {'topic': 'TOPIC1'}, + {'topic': 'TOPIC2'}, + ] + + records = await layer.get_topics() + self.assertEqual(['TOPIC1', 'TOPIC2'], records) + + @run_as_sync + async def test_get_cost_categories(self): + """Test get_cost_categories""" + + layer = BillingLayer(self.connection) + + # test with no muckup data, should be empty + records = await layer.get_cost_categories() + self.assertEqual([], records) + + # mockup BQ results + self.bq_result.result.return_value = [ + {'cost_category': 'CAT1'}, + {'cost_category': 'CAT2'}, + ] + + records = await layer.get_cost_categories() + self.assertEqual(['CAT1', 'CAT2'], records) + + @run_as_sync + async def test_get_skus(self): + """Test get_skus""" + + layer = BillingLayer(self.connection) + + # test with no muckup data, should be empty + records = await layer.get_skus() + self.assertEqual([], records) + + # mockup BQ results + self.bq_result.result.return_value = [ + {'sku': 'SKU1'}, + {'sku': 'SKU2'}, + ] + + records = await layer.get_skus() + self.assertEqual(['SKU1', 'SKU2'], records) + + @run_as_sync + async def test_get_datasets(self): + """Test get_datasets""" + + layer = BillingLayer(self.connection) + + # test with no muckup data, should be empty + records = await layer.get_datasets() + self.assertEqual([], records) + + # mockup BQ results + self.bq_result.result.return_value = [ + {'dataset': 'DATA1'}, + {'dataset': 'DATA2'}, + ] + + records = await layer.get_datasets() + self.assertEqual(['DATA1', 'DATA2'], records) + + @run_as_sync + async def test_get_stages(self): + """Test get_stages""" + + layer = BillingLayer(self.connection) + + # test with no muckup data, should be empty + records = await layer.get_stages() + self.assertEqual([], records) + + # mockup BQ results + self.bq_result.result.return_value = [ + {'stage': 'STAGE1'}, + {'stage': 'STAGE2'}, + ] + + records = await layer.get_stages() + self.assertEqual(['STAGE1', 'STAGE2'], records) + + @run_as_sync + async def test_get_sequencing_types(self): + """Test get_sequencing_types""" + + layer = BillingLayer(self.connection) + + # test with no muckup data, should be empty + records = await layer.get_sequencing_types() + self.assertEqual([], records) + + # mockup BQ results + self.bq_result.result.return_value = [ + {'sequencing_type': 'SEQ1'}, + {'sequencing_type': 'SEQ2'}, + ] + + records = await layer.get_sequencing_types() + self.assertEqual(['SEQ1', 'SEQ2'], records) + + @run_as_sync + async def test_get_sequencing_groups(self): + """Test get_sequencing_groups""" + + layer = BillingLayer(self.connection) + + # test with no muckup data, should be empty + records = await layer.get_sequencing_groups() + self.assertEqual([], records) + + # mockup BQ results + self.bq_result.result.return_value = [ + {'sequencing_group': 'GRP1'}, + {'sequencing_group': 'GRP2'}, + ] + + records = await layer.get_sequencing_groups() + self.assertEqual(['GRP1', 'GRP2'], records) + + @run_as_sync + async def test_get_compute_categories(self): + """Test get_compute_categories""" + + layer = BillingLayer(self.connection) + + # test with no muckup data, should be empty + records = await layer.get_compute_categories() + self.assertEqual([], records) + + # mockup BQ results + self.bq_result.result.return_value = [ + {'compute_category': 'CAT1'}, + {'compute_category': 'CAT2'}, + ] + + records = await layer.get_compute_categories() + self.assertEqual(['CAT1', 'CAT2'], records) + + @run_as_sync + async def test_get_cromwell_sub_workflow_names(self): + """Test get_cromwell_sub_workflow_names""" + + layer = BillingLayer(self.connection) + + # test with no muckup data, should be empty + records = await layer.get_cromwell_sub_workflow_names() + self.assertEqual([], records) + + # mockup BQ results + self.bq_result.result.return_value = [ + {'cromwell_sub_workflow_name': 'CROM1'}, + {'cromwell_sub_workflow_name': 'CROM2'}, + ] + + records = await layer.get_cromwell_sub_workflow_names() + self.assertEqual(['CROM1', 'CROM2'], records) + + @run_as_sync + async def test_get_wdl_task_names(self): + """Test get_wdl_task_names""" + + layer = BillingLayer(self.connection) + + # test with no muckup data, should be empty + records = await layer.get_wdl_task_names() + self.assertEqual([], records) + + # mockup BQ results + self.bq_result.result.return_value = [ + {'wdl_task_name': 'WDL1'}, + {'wdl_task_name': 'WDL2'}, + ] + + records = await layer.get_wdl_task_names() + self.assertEqual(['WDL1', 'WDL2'], records) + + @run_as_sync + async def test_get_invoice_months(self): + """Test get_invoice_months""" + + layer = BillingLayer(self.connection) + + # test with no muckup data, should be empty + records = await layer.get_invoice_months() + self.assertEqual([], records) + + # mockup BQ results + self.bq_result.result.return_value = [ + {'invoice_month': '202301'}, + {'invoice_month': '202302'}, + ] + + records = await layer.get_invoice_months() + self.assertEqual(['202301', '202302'], records) + + @run_as_sync + async def test_get_namespaces(self): + """Test get_namespaces""" + + layer = BillingLayer(self.connection) + + # test with no muckup data, should be empty + records = await layer.get_namespaces() + self.assertEqual([], records) + + # mockup BQ results + self.bq_result.result.return_value = [ + {'namespace': 'NAME1'}, + {'namespace': 'NAME2'}, + ] + + records = await layer.get_namespaces() + self.assertEqual(['NAME1', 'NAME2'], records) + + @run_as_sync + async def test_get_total_cost(self): + """Test get_total_cost""" + + layer = BillingLayer(self.connection) + + # test inparams exceptions: + query = BillingTotalCostQueryModel(fields=[], start_date='', end_date='') + + with self.assertRaises(ValueError) as context: + await layer.get_total_cost(query) + + self.assertTrue('Date and Fields are required' in str(context.exception)) + + # test with no muckup data, should be empty + query = BillingTotalCostQueryModel( + fields=[BillingColumn.TOPIC], start_date='2024-01-01', end_date='2024-01-03' + ) + records = await layer.get_total_cost(query) + self.assertEqual([], records) + + # get_total_cost with mockup data is tested in test/test_bq_billing_base.py + # BillingLayer is just wrapper for BQ tables + + @run_as_sync + async def test_get_running_cost(self): + """Test get_running_cost""" + + layer = BillingLayer(self.connection) + + # test inparams exceptions: + with self.assertRaises(ValueError) as context: + await layer.get_running_cost( + field=BillingColumn.TOPIC, invoice_month=None, source=None + ) + self.assertTrue('Invalid invoice month' in str(context.exception)) + + with self.assertRaises(ValueError) as context: + await layer.get_running_cost( + field=BillingColumn.TOPIC, invoice_month='2024', source=None + ) + self.assertTrue('Invalid invoice month' in str(context.exception)) + + # test with no muckup data, should be empty + records = await layer.get_running_cost( + field=BillingColumn.TOPIC, invoice_month='202401', source=None + ) + self.assertEqual([], records) + + # get_running_cost with mockup data is tested in test/test_bq_billing_base.py + # BillingLayer is just wrapper for BQ tables + + @run_as_sync + async def test_get_cost_by_ar_guid(self): + """Test get_cost_by_ar_guid""" + + layer = BillingLayer(self.connection) + + # ar_guid as None, return empty results + records = await layer.get_cost_by_ar_guid(ar_guid=None) + + # return empty record + self.assertEqual( + BillingHailBatchCostRecord(ar_guid=None, batch_ids=[], costs=[]), records + ) + + # dummy ar_guid, no mockup data, return empty results + dummy_ar_guid = '12345678' + records = await layer.get_cost_by_ar_guid(ar_guid=dummy_ar_guid) + + # return empty record + self.assertEqual( + BillingHailBatchCostRecord(ar_guid=dummy_ar_guid, batch_ids=[], costs=[]), + records, + ) + + # dummy ar_guid, mockup batch_id + + # mock BigQuery result + given_start_day = datetime.datetime(2023, 1, 1, 0, 0) + given_end_day = datetime.datetime(2023, 1, 1, 2, 3) + dummy_batch_id = '12345' + + mock_rows = mock.MagicMock(spec=bq.table.RowIterator) + mock_rows.total_rows = 1 + mock_rows.__iter__.return_value = [ + mock.MagicMock( + spec=bq.Row, + batch_id=dummy_batch_id, + start_day=given_start_day, + end_day=given_end_day, + ), + ] + self.bq_result.result.return_value = mock_rows + + records = await layer.get_cost_by_ar_guid(ar_guid=dummy_ar_guid) + # returns ar_guid, batch_id and empty cost as those were not mocked up + # we do not need to test cost calculation here, + # as those are tested in test/test_bq_billing_base.py + self.assertEqual( + BillingHailBatchCostRecord( + ar_guid=dummy_ar_guid, batch_ids=[dummy_batch_id], costs=[{}] + ), + records, + ) + + @run_as_sync + async def test_get_cost_by_batch_id(self): + """Test get_cost_by_batch_id""" + + layer = BillingLayer(self.connection) + + # ar_guid as None, return empty results + records = await layer.get_cost_by_batch_id(batch_id=None) + + # return empty record + self.assertEqual( + BillingHailBatchCostRecord(ar_guid=None, batch_ids=[], costs=[]), records + ) + + # dummy ar_guid, no mockup data, return empty results + dummy_batch_id = '12345' + records = await layer.get_cost_by_batch_id(batch_id=dummy_batch_id) + + # return empty record + self.assertEqual( + BillingHailBatchCostRecord(ar_guid=None, batch_ids=[], costs=[]), + records, + ) + + # dummy batch_id, mockup ar_guid + + # mock BigQuery result + given_start_day = datetime.datetime(2023, 1, 1, 0, 0) + given_end_day = datetime.datetime(2023, 1, 1, 2, 3) + dummy_batch_id = '12345' + dummy_ar_guid = '12345678' + + mock_rows = mock.MagicMock(spec=bq.table.RowIterator) + mock_rows.total_rows = 1 + mock_rows.__iter__.return_value = [ + mock.MagicMock( + spec=bq.Row, + ar_guid=dummy_ar_guid, + batch_id=dummy_batch_id, + start_day=given_start_day, + end_day=given_end_day, + # mockup __getitem__ to return dummy_ar_guid + __getitem__=mock.MagicMock(return_value=dummy_ar_guid), + ), + ] + self.bq_result.result.return_value = mock_rows + + records = await layer.get_cost_by_batch_id(batch_id=dummy_batch_id) + # returns ar_guid, batch_id and empty cost as those were not mocked up + # we do not need to test cost calculation here, + # as those are tested in test/test_bq_billing_base.py + self.assertEqual( + BillingHailBatchCostRecord( + ar_guid=dummy_ar_guid, batch_ids=[dummy_batch_id], costs=[{}] + ), + records, + ) diff --git a/test/testbqbase.py b/test/testbqbase.py new file mode 100644 index 000000000..6d25e489a --- /dev/null +++ b/test/testbqbase.py @@ -0,0 +1,44 @@ +import unittest +from typing import Any +from unittest import mock + +import google.cloud.bigquery as bq + +from db.python.gcp_connect import BqConnection +from db.python.layers.billing import BillingLayer + + +class BqTest(unittest.TestCase): + """Base class for Big Query integration tests""" + + # author and grp_project are not used in the BQ tests, but are required + # so some dummy values are preset + author: str = 'Author' + gcp_project: str = 'GCP_PROJECT' + + bq_result: bq.job.QueryJob + bq_client: bq.Client + connection: BqConnection + table_obj: Any | None = None + + def setUp(self) -> None: + super().setUp() + + # Mockup BQ results + self.bq_result = mock.MagicMock(spec=bq.job.QueryJob) + + # mock BigQuery client + self.bq_client = mock.MagicMock(spec=bq.Client) + self.bq_client.query.return_value = self.bq_result + + # Mock BqConnection + self.connection = mock.MagicMock(spec=BqConnection) + self.connection.gcp_project = self.gcp_project + self.connection.connection = self.bq_client + self.connection.author = self.author + + # Mockup BillingLayer + self.layer = BillingLayer(self.connection) + + # overwrite table object in inhereted tests: + self.table_obj = None