diff --git a/.github/workflows/deploy.yaml b/.github/workflows/deploy.yaml index 6c8cbbb31..9cdd535e3 100644 --- a/.github/workflows/deploy.yaml +++ b/.github/workflows/deploy.yaml @@ -43,7 +43,7 @@ jobs: - uses: actions/setup-python@v4 with: - python-version: "3.10" + python-version: "3.11" - uses: actions/setup-java@v3 with: diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 71c330c5a..3af8fa07e 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -18,7 +18,7 @@ jobs: - uses: actions/setup-python@v4 with: - python-version: "3.10" + python-version: "3.11" - uses: actions/setup-java@v2 with: diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 54176c391..1c81caa00 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -20,7 +20,7 @@ jobs: - uses: actions/setup-python@v5 with: - python-version: "3.10" + python-version: "3.11" - uses: actions/setup-java@v4 with: diff --git a/api/routes/analysis.py b/api/routes/analysis.py index 6aaa86645..b3298d2c2 100644 --- a/api/routes/analysis.py +++ b/api/routes/analysis.py @@ -79,19 +79,25 @@ class AnalysisQueryModel(BaseModel): def to_filter(self, project_id_map: dict[str, int]) -> AnalysisFilter: """Convert to internal analysis filter""" return AnalysisFilter( - sample_id=GenericFilter( - in_=sample_id_transform_to_raw_list(self.sample_ids) - ) - if self.sample_ids - else None, - sequencing_group_id=GenericFilter( - in_=sequencing_group_id_transform_to_raw_list(self.sequencing_group_ids) - ) - if self.sequencing_group_ids - else None, - project=GenericFilter(in_=[project_id_map.get(p) for p in self.projects]) - if self.projects - else None, + sample_id=( + GenericFilter(in_=sample_id_transform_to_raw_list(self.sample_ids)) + if self.sample_ids + else None + ), + sequencing_group_id=( + GenericFilter( + in_=sequencing_group_id_transform_to_raw_list( + self.sequencing_group_ids + ) + ) + if self.sequencing_group_ids + else None + ), + project=( + GenericFilter(in_=[project_id_map.get(p) for p in self.projects]) + if self.projects + else None + ), type=GenericFilter(eq=self.type) if self.type else None, ) @@ -241,8 +247,9 @@ async def query_analyses( @router.get('/analysis-runner', operation_id='getAnalysisRunnerLog') async def get_analysis_runner_log( project_names: list[str] = Query(None), # type: ignore - author: str = None, + # author: str = None, # not implemented yet, uncomment when we do output_dir: str = None, + ar_guid: str = None, connection: Connection = get_projectless_db_connection, ) -> list[AnalysisInternal]: """ @@ -257,7 +264,10 @@ async def get_analysis_runner_log( ) results = await atable.get_analysis_runner_log( - project_ids=project_ids, author=author, output_dir=output_dir + project_ids=project_ids, + # author=author, + output_dir=output_dir, + ar_guid=ar_guid, ) return [a.to_external() for a in results] diff --git a/db/backup/backup.py b/db/backup/backup.py index 68cb6a9b6..79feb8ec2 100644 --- a/db/backup/backup.py +++ b/db/backup/backup.py @@ -1,9 +1,10 @@ -#!/usr/bin/python3.7 +#!/usr/bin/python3 # pylint: disable=broad-exception-caught,broad-exception-raised """ Daily back up function for databases within a local MariaDB instance """ import json +import os import subprocess from datetime import datetime from typing import Literal @@ -50,7 +51,7 @@ def perform_backup(): tmp_dir = f'backup_{timestamp_str}' subprocess.run(['mkdir', tmp_dir], check=True) # grant permissions, so that mariadb can read ib_logfile0 - subprocess.run(['sudo', 'chmod', '-R', '777', tmp_dir], check=True) + subprocess.run(['sudo', 'chmod', '-R', '770', tmp_dir], check=True) credentials = read_db_credentials() db_username = credentials['username'] @@ -66,10 +67,11 @@ def perform_backup(): '--backup', f'--target-dir={tmp_dir}/', f'--user={db_username}', - f'-p{db_password}', ], check=True, stderr=subprocess.DEVNULL, + # pass the password with stdin to avoid it being visible in the process list + env={'MYSQL_PWD': db_password, **os.environ}, ) except subprocess.CalledProcessError as e: @@ -83,7 +85,7 @@ def perform_backup(): # mariabackup creates awkward permissions for the output files, # so we'll grant appropriate permissions for tmp_dir to later remove it - subprocess.run(['sudo', 'chmod', '-R', '777', tmp_dir], check=True) + subprocess.run(['sudo', 'chmod', '-R', '770', tmp_dir], check=True) # tar the archive to make it easier to upload to GCS tar_archive_path = f'{tmp_dir}.tar.gz' diff --git a/db/python/layers/analysis.py b/db/python/layers/analysis.py index f54c91eaa..878873386 100644 --- a/db/python/layers/analysis.py +++ b/db/python/layers/analysis.py @@ -441,9 +441,9 @@ async def get_cram_sizes_between_range( sample_create_dates = await sglayer.get_samples_create_date_from_sgs( list(crams.keys()) ) - by_date: dict[ - SequencingGroupInternalId, list[tuple[datetime.date, int]] - ] = defaultdict(list) + by_date: dict[SequencingGroupInternalId, list[tuple[datetime.date, int]]] = ( + defaultdict(list) + ) for sg_id, analyses in crams.items(): if len(analyses) == 1: @@ -531,7 +531,9 @@ async def get_sgs_added_by_day_by_es_indices( return by_day - async def get_audit_logs_by_analysis_ids(self, analysis_ids: list[int]) -> dict[int, list[AuditLogInternal]]: + async def get_audit_logs_by_analysis_ids( + self, analysis_ids: list[int] + ) -> dict[int, list[AuditLogInternal]]: """Get audit logs for analysis IDs""" return await self.at.get_audit_log_for_analysis_ids(analysis_ids) @@ -594,12 +596,16 @@ async def update_analysis( async def get_analysis_runner_log( self, project_ids: list[int] = None, - author: str = None, + # author: str = None, output_dir: str = None, + ar_guid: str = None, ) -> list[AnalysisInternal]: """ Get log for the analysis-runner, useful for checking this history of analysis """ return await self.at.get_analysis_runner_log( - project_ids, author=author, output_dir=output_dir + project_ids, + # author=author, + output_dir=output_dir, + ar_guid=ar_guid, ) diff --git a/db/python/layers/billing.py b/db/python/layers/billing.py index 737be6857..1122259ba 100644 --- a/db/python/layers/billing.py +++ b/db/python/layers/billing.py @@ -18,7 +18,7 @@ class BillingLayer(BqBaseLayer): def table_factory( self, - source: BillingSource, + source: BillingSource | None = None, fields: list[BillingColumn] | None = None, filters: dict[BillingColumn, str | list | dict] | None = None, ) -> ( diff --git a/db/python/tables/analysis.py b/db/python/tables/analysis.py index ee445345e..7955cfadb 100644 --- a/db/python/tables/analysis.py +++ b/db/python/tables/analysis.py @@ -486,8 +486,9 @@ async def get_sample_cram_path_map_for_seqr( async def get_analysis_runner_log( self, project_ids: List[int] = None, - author: str = None, + # author: str = None, output_dir: str = None, + ar_guid: str = None, ) -> List[AnalysisInternal]: """ Get log for the analysis-runner, useful for checking this history of analysis @@ -501,15 +502,15 @@ async def get_analysis_runner_log( wheres.append('project in :project_ids') values['project_ids'] = project_ids - if author: - wheres.append('audit_log_id = :audit_log_id') - values['audit_log_id'] = await self.audit_log_id() - if output_dir: wheres.append('(output = :output OR output LIKE :output_like)') values['output'] = output_dir values['output_like'] = f'%{output_dir}' + if ar_guid: + wheres.append('JSON_EXTRACT(meta, "$.ar_guid") = :ar_guid') + values['ar_guid'] = ar_guid + wheres_str = ' AND '.join(wheres) _query = f'SELECT * FROM analysis WHERE {wheres_str}' rows = await self.connection.fetch_all(_query, values) diff --git a/db/python/tables/bq/billing_base.py b/db/python/tables/bq/billing_base.py index 335603c3b..350dfecab 100644 --- a/db/python/tables/bq/billing_base.py +++ b/db/python/tables/bq/billing_base.py @@ -117,8 +117,9 @@ def _execute_query( # otherwise return as BQ iterator return self._connection.connection.query(query, job_config=job_config) + @staticmethod def _query_to_partitioned_filter( - self, query: BillingTotalCostQueryModel + query: BillingTotalCostQueryModel, ) -> BillingFilter: """ By default views are partitioned by 'day', @@ -137,15 +138,18 @@ def _query_to_partitioned_filter( ) return billing_filter - def _filter_to_optimise_query(self) -> str: + @staticmethod + def _filter_to_optimise_query() -> str: """Filter string to optimise BQ query""" return 'day >= TIMESTAMP(@start_day) AND day <= TIMESTAMP(@last_day)' - def _last_loaded_day_filter(self) -> str: + @staticmethod + def _last_loaded_day_filter() -> str: """Last Loaded day filter string""" return 'day = TIMESTAMP(@last_loaded_day)' - def _convert_output(self, query_job_result): + @staticmethod + def _convert_output(query_job_result): """Convert query result to json""" if not query_job_result or query_job_result.result().total_rows == 0: # return empty list if no record found @@ -325,8 +329,8 @@ async def _execute_running_cost_query( self._execute_query(_query, query_params), ) + @staticmethod async def _append_total_running_cost( - self, field: BillingColumn, is_current_month: bool, last_loaded_day: str | None, @@ -437,9 +441,8 @@ async def _append_running_cost_records( return results - def _prepare_order_by_string( - self, order_by: dict[BillingColumn, bool] | None - ) -> str: + @staticmethod + def _prepare_order_by_string(order_by: dict[BillingColumn, bool] | None) -> str: """Prepare order by string""" if not order_by: return '' @@ -452,9 +455,8 @@ def _prepare_order_by_string( return f'ORDER BY {",".join(order_by_cols)}' if order_by_cols else '' - def _prepare_aggregation( - self, query: BillingTotalCostQueryModel - ) -> tuple[str, str]: + @staticmethod + def _prepare_aggregation(query: BillingTotalCostQueryModel) -> tuple[str, str]: """Prepare both fields for aggregation and group by string""" # Get columns to group by @@ -479,7 +481,8 @@ def _prepare_aggregation( return fields_selected, group_by - def _prepare_labels_function(self, query: BillingTotalCostQueryModel): + @staticmethod + def _prepare_labels_function(query: BillingTotalCostQueryModel): if not query.filters: return None @@ -558,7 +561,7 @@ async def get_total_cost( where_str = f'WHERE {where_str}' _query = f""" - {func_filter.fun_implementation if func_filter else ''} + {func_filter.func_implementation if func_filter else ''} WITH t AS ( SELECT {time_group.field}{time_group.separator} {fields_selected}, diff --git a/db/python/tables/bq/billing_filter.py b/db/python/tables/bq/billing_filter.py index 9a379817f..b78a8c960 100644 --- a/db/python/tables/bq/billing_filter.py +++ b/db/python/tables/bq/billing_filter.py @@ -2,6 +2,7 @@ import dataclasses import datetime +from typing import Any from db.python.tables.bq.generic_bq_filter import GenericBQFilter from db.python.tables.bq.generic_bq_filter_model import GenericBQFilterModel @@ -46,3 +47,17 @@ class BillingFilter(GenericBQFilterModel): goog_pipelines_worker: GenericBQFilter[str] = None wdl_task_name: GenericBQFilter[str] = None namespace: GenericBQFilter[str] = None + + def __eq__(self, other: Any) -> bool: + """Equality operator""" + result = super().__eq__(other) + if not result or not isinstance(other, BillingFilter): + return False + + # compare all attributes + for att in self.__dict__: + if getattr(self, att) != getattr(other, att): + return False + + # all attributes are equal + return True diff --git a/db/python/tables/bq/billing_gcp_daily.py b/db/python/tables/bq/billing_gcp_daily.py index b765547c3..0b691bc6f 100644 --- a/db/python/tables/bq/billing_gcp_daily.py +++ b/db/python/tables/bq/billing_gcp_daily.py @@ -1,4 +1,5 @@ from datetime import datetime, timedelta +from typing import Any from google.cloud import bigquery @@ -9,7 +10,7 @@ ) from db.python.tables.bq.billing_filter import BillingFilter from db.python.tables.bq.generic_bq_filter import GenericBQFilter -from models.models import BillingTotalCostQueryModel +from models.models import BillingColumn, BillingTotalCostQueryModel class BillingGcpDailyTable(BillingBaseTable): @@ -21,8 +22,9 @@ def get_table_name(self): """Get table name""" return self.table_name + @staticmethod def _query_to_partitioned_filter( - self, query: BillingTotalCostQueryModel + query: BillingTotalCostQueryModel, ) -> BillingFilter: """ add extra filter to limit materialized view partition @@ -77,7 +79,9 @@ async def _last_loaded_day(self): return None - def _prepare_daily_cost_subquery(self, field, query_params, last_loaded_day): + def _prepare_daily_cost_subquery( + self, field: BillingColumn, query_params: list[Any], last_loaded_day: str + ): """prepare daily cost subquery""" # add extra filter to limit materialized view partition diff --git a/db/python/tables/bq/billing_raw.py b/db/python/tables/bq/billing_raw.py index a82fa4eec..dde90388d 100644 --- a/db/python/tables/bq/billing_raw.py +++ b/db/python/tables/bq/billing_raw.py @@ -16,8 +16,9 @@ def get_table_name(self): """Get table name""" return self.table_name + @staticmethod def _query_to_partitioned_filter( - self, query: BillingTotalCostQueryModel + query: BillingTotalCostQueryModel, ) -> BillingFilter: """ Raw BQ billing table is partitioned by usage_end_time diff --git a/db/python/tables/bq/function_bq_filter.py b/db/python/tables/bq/function_bq_filter.py index f18f60211..5c8ac21d2 100644 --- a/db/python/tables/bq/function_bq_filter.py +++ b/db/python/tables/bq/function_bq_filter.py @@ -27,14 +27,14 @@ class FunctionBQFilter: def __init__(self, name: str, implementation: str): self.func_name = name - self.fun_implementation = implementation + self.func_implementation = implementation # param_id is a counter for parameterised values self._param_id = 0 def to_sql( self, column_name: BillingColumn, - func_params: str | list[Any] | dict[Any, Any], + func_params: str | list[Any] | dict[Any, Any] | None = None, func_operator: str = None, ) -> tuple[str, list[bigquery.ScalarQueryParameter | bigquery.ArrayQueryParameter]]: """ @@ -103,7 +103,9 @@ def _sql_value_prep(key: str, value: Any) -> bigquery.ScalarQueryParameter: if isinstance(value, float): return bigquery.ScalarQueryParameter(key, 'FLOAT64', value) if isinstance(value, datetime): - return bigquery.ScalarQueryParameter(key, 'STRING', value) + return bigquery.ScalarQueryParameter( + key, 'STRING', value.isoformat(timespec='seconds') + ) # otherwise as string parameter return bigquery.ScalarQueryParameter(key, 'STRING', value) diff --git a/db/python/tables/bq/generic_bq_filter.py b/db/python/tables/bq/generic_bq_filter.py index b0bfba973..ce54ac885 100644 --- a/db/python/tables/bq/generic_bq_filter.py +++ b/db/python/tables/bq/generic_bq_filter.py @@ -12,6 +12,19 @@ class GenericBQFilter(GenericFilter[T]): Generic BigQuery filter is BQ specific filter class, based on GenericFilter """ + def __eq__(self, other): + """Equality operator""" + if not isinstance(other, GenericBQFilter): + return False + + keys = ['eq', 'in_', 'nin', 'gt', 'gte', 'lt', 'lte'] + for att in keys: + if getattr(self, att) != getattr(other, att): + return False + + # all attributes are equal + return True + def to_sql( self, column: str, column_name: str = None ) -> tuple[str, dict[str, T | list[T] | Any | list[Any]]]: @@ -38,13 +51,17 @@ def to_sql( values[k] = self._sql_value_prep(k, self.in_[0]) else: k = self.generate_field_name(_column_name + '_in') - conditionals.append(f'{column} IN ({self._sql_cond_prep(k, self.in_)})') + conditionals.append( + f'{column} IN UNNEST({self._sql_cond_prep(k, self.in_)})' + ) values[k] = self._sql_value_prep(k, self.in_) if self.nin is not None: if not isinstance(self.nin, list): raise ValueError('NIN filter must be a list') k = self.generate_field_name(column + '_nin') - conditionals.append(f'{column} NOT IN ({self._sql_cond_prep(k, self.nin)})') + conditionals.append( + f'{column} NOT IN UNNEST({self._sql_cond_prep(k, self.nin)})' + ) values[k] = self._sql_value_prep(k, self.nin) if self.gt is not None: k = self.generate_field_name(column + '_gt') @@ -83,9 +100,14 @@ def _sql_value_prep(key, value): Overrides the default _sql_value_prep to handle BQ parameters """ if isinstance(value, list): - return bigquery.ArrayQueryParameter( - key, 'STRING', ','.join([str(v) for v in value]) - ) + if value and isinstance(value[0], int): + return bigquery.ArrayQueryParameter(key, 'INT64', value) + if value and isinstance(value[0], float): + return bigquery.ArrayQueryParameter(key, 'FLOAT64', value) + + # otherwise all list records as string + return bigquery.ArrayQueryParameter(key, 'STRING', [str(v) for v in value]) + if isinstance(value, Enum): return GenericBQFilter._sql_value_prep(key, value.value) if isinstance(value, int): diff --git a/docs/installation.md b/docs/installation.md index 89d35b38b..75ca1b521 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -63,7 +63,7 @@ choco install mariadb --version=10.8.3 - Additional dev requirements are listed in `requirements-dev.txt`. - Packages for the sever-side code are listed in `requirements.txt`. -We *STRONGLY* encourage the use of `pyenv` for managing Python versions. Debugging and the server will run on a minimum python version of 3.10. Refer to the [team-docs](https://github.com/populationgenomics/team-docs/blob/main/python.md) for more instructions on how to set this up. +We *STRONGLY* encourage the use of `pyenv` for managing Python versions. Debugging and the server will run on a minimum python version of 3.11. Refer to the [team-docs](https://github.com/populationgenomics/team-docs/blob/main/python.md) for more instructions on how to set this up. Use of a virtual environment to contain all requirements is highly recommended: diff --git a/metamist_infrastructure/driver.py b/metamist_infrastructure/driver.py index 2a12e0c3e..ede85d4f7 100644 --- a/metamist_infrastructure/driver.py +++ b/metamist_infrastructure/driver.py @@ -550,7 +550,10 @@ def _setup_etl_pubsub(self): @cached_property def etl_extract_function(self): """etl_extract_function""" - return self._etl_function('extract', self.etl_extract_service_account) + return self._etl_function( + 'extract', + self.etl_extract_service_account, + ) @cached_property def etl_load_function(self): @@ -646,6 +649,16 @@ def _etl_function( opts=pulumi.ResourceOptions(replace_on_changes=['*']), ) + # prepare custom audience_list + custom_audience_list = None + if ( + self.config.metamist.etl.custom_audience_list + and self.config.metamist.etl.custom_audience_list.get(f_name) + ): + custom_audience_list = json.dumps( + self.config.metamist.etl.custom_audience_list.get(f_name) + ) + fxn = gcp.cloudfunctionsv2.Function( f'metamist-etl-{f_name}', name=f'metamist-etl-{f_name}', @@ -694,6 +707,11 @@ def _etl_function( 'SM_ENVIRONMENT': self.config.metamist.etl.environment, 'CONFIGURATION_SECRET': self.etl_configuration_secret_version.id, }, # type: ignore + annotations=( + {'run.googleapis.com/custom-audiences': custom_audience_list} + if custom_audience_list + else None + ), ingress_settings='ALLOW_ALL', all_traffic_on_latest_revision=True, service_account_email=sa.email, diff --git a/models/models/billing.py b/models/models/billing.py index 9587ed44c..6b5a98fbf 100644 --- a/models/models/billing.py +++ b/models/models/billing.py @@ -243,7 +243,11 @@ def to_filter(self) -> BillingFilter: # add filters as attributes for fk, fv in self.filters.items(): # fk is BillColumn, fv is value - setattr(billing_filter, fk.value, GenericBQFilter(eq=fv)) + # if fv is a list, then use IN filter + if isinstance(fv, list): + setattr(billing_filter, fk.value, GenericBQFilter(in_=fv)) + else: + setattr(billing_filter, fk.value, GenericBQFilter(eq=fv)) return billing_filter diff --git a/mypy.ini b/mypy.ini index 418e757ed..774a6c63e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,7 +1,7 @@ # Global options: [mypy] -python_version = 3.10 +python_version = 3.11 ; warn_return_any = True ; warn_unused_configs = True diff --git a/requirements-dev.txt b/requirements-dev.txt index e0bfc432f..3bb86aa87 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,7 +7,7 @@ flake8-bugbear nest-asyncio pre-commit pylint -testcontainers[mariadb] +testcontainers[mariadb]==3.7.1 types-PyMySQL # some strawberry dependency strawberry-graphql[debug-server]==0.206.0 @@ -15,3 +15,5 @@ strawberry-graphql[debug-server]==0.206.0 functions_framework google-cloud-bigquery google-cloud-pubsub +# following required to unit test some billing functions +pytz diff --git a/scripts/create_test_subset.py b/scripts/create_test_subset.py index 229eb5382..9d85acb53 100755 --- a/scripts/create_test_subset.py +++ b/scripts/create_test_subset.py @@ -201,7 +201,9 @@ def main( logger.info(f'Found {len(all_sids)} sample ids in {project}') # 3. Randomly select from the remaining sgs - additional_samples.update(random.sample(all_sids - additional_samples, samples_n)) + additional_samples.update( + random.sample(list(all_sids - additional_samples), samples_n) + ) # 4. Query all the samples from the selected sgs logger.info(f'Transfering {len(additional_samples)} samples. Querying metadata.') diff --git a/scripts/etl_caller_example.py b/scripts/etl_caller_example.py index 8f353bd91..235d03482 100644 --- a/scripts/etl_caller_example.py +++ b/scripts/etl_caller_example.py @@ -6,6 +6,7 @@ pip install requests google-auth requests urllib3 """ + import os import google.auth.transport.requests @@ -14,7 +15,7 @@ from requests.adapters import HTTPAdapter from urllib3 import Retry -URL = 'https://metamist-etl-mnrpw3mdza-ts.a.run.app' +URL = 'https://metamist-extract.popgen.rocks' TYPE = 'NAME_OF_EXTERNAL_PARTY/v1' @@ -60,4 +61,10 @@ def make_request(body: dict | list): if __name__ == '__main__': + # simple payload which would be sent to the server print(make_request({'test-value': 'test'})) + + # This will raise a 400 with an error message 'Potentially detected PII, the following keys were found in the body' + # it won't update anything in the database + payload = {'first_name': 'This is a test to demonstrate PII detection'} + make_request(payload) diff --git a/test/test_api_billing.py b/test/test_api_billing.py new file mode 100644 index 000000000..4811187a7 --- /dev/null +++ b/test/test_api_billing.py @@ -0,0 +1,269 @@ +# pylint: disable=protected-access too-many-public-methods +from test.testbase import run_as_sync +from test.testbqbase import BqTest +from unittest.mock import patch + +from api.routes import billing +from models.models import ( + BillingColumn, + BillingCostBudgetRecord, + BillingHailBatchCostRecord, + BillingTotalCostQueryModel, + BillingTotalCostRecord, +) + +TEST_API_BILLING_USER = 'test_user' + + +class TestApiBilling(BqTest): + """ + Test API Billing routes + Billing routes are only calling layer functions and returning data it received + This set of tests only checks if all routes code has been called + and if data returned is the same as data received from layer functions + It does not check all possible combination of parameters as + Billing Layer and BQ Tables should be testing those + """ + + def setUp(self): + super().setUp() + + # make billing enabled by default for all the calls + patcher = patch('api.routes.billing.is_billing_enabled', return_value=True) + self.mockup_is_billing_enabled = patcher.start() + + @run_as_sync + @patch('api.routes.billing.is_billing_enabled', return_value=False) + async def test_get_gcp_projects_no_billing(self, _mockup_is_billing_enabled): + """ + Test get_gcp_projects function + This function should raise ValueError if billing is not enabled + """ + with self.assertRaises(ValueError) as context: + await billing.get_gcp_projects('test_user') + + self.assertTrue('Billing is not enabled' in str(context.exception)) + + @run_as_sync + @patch('api.routes.billing._get_billing_layer_from') + @patch('db.python.layers.billing.BillingLayer.get_cost_by_ar_guid') + async def test_get_cost_by_ar_guid( + self, mock_get_cost_by_ar_guid, mock_get_billing_layer + ): + """ + Test get_cost_by_ar_guid function + """ + ar_guid = 'test_ar_guid' + mockup_record = BillingHailBatchCostRecord( + ar_guid=ar_guid, batch_ids=None, costs=None + ) + mock_get_billing_layer.return_value = self.layer + mock_get_cost_by_ar_guid.return_value = mockup_record + records = await billing.get_cost_by_ar_guid( + ar_guid, author=TEST_API_BILLING_USER + ) + self.assertEqual(mockup_record, records) + + @run_as_sync + @patch('api.routes.billing._get_billing_layer_from') + @patch('db.python.layers.billing.BillingLayer.get_cost_by_batch_id') + async def test_get_cost_by_batch_id( + self, mock_get_cost_by_batch_id, mock_get_billing_layer + ): + """ + Test get_cost_by_batch_id function + """ + ar_guid = 'test_ar_guid' + batch_id = 'test_batch_id' + mockup_record = BillingHailBatchCostRecord( + ar_guid=ar_guid, batch_ids=[batch_id], costs=None + ) + mock_get_billing_layer.return_value = self.layer + mock_get_cost_by_batch_id.return_value = mockup_record + records = await billing.get_cost_by_batch_id( + batch_id, author=TEST_API_BILLING_USER + ) + self.assertEqual(mockup_record, records) + + @run_as_sync + @patch('api.routes.billing._get_billing_layer_from') + @patch('db.python.layers.billing.BillingLayer.get_total_cost') + async def test_get_total_cost(self, mock_get_total_cost, mock_get_billing_layer): + """ + Test get_total_cost function + """ + query = BillingTotalCostQueryModel(fields=[], start_date='', end_date='') + mockup_record = [{'cost': 123.45}, {'cost': 123}] + expected = [BillingTotalCostRecord.from_json(r) for r in mockup_record] + mock_get_billing_layer.return_value = self.layer + mock_get_total_cost.return_value = mockup_record + records = await billing.get_total_cost(query, author=TEST_API_BILLING_USER) + self.assertEqual(expected, records) + + @run_as_sync + @patch('api.routes.billing._get_billing_layer_from') + @patch('db.python.layers.billing.BillingLayer.get_running_cost') + async def test_get_running_cost( + self, mock_get_running_cost, mock_get_billing_layer + ): + """ + Test get_running_cost function + """ + mockup_record = [ + BillingCostBudgetRecord.from_json( + {'field': 'TOPIC1', 'total_monthly': 999.99, 'details': []} + ), + BillingCostBudgetRecord.from_json( + {'field': 'TOPIC2', 'total_monthly': 123, 'details': []} + ), + ] + mock_get_billing_layer.return_value = self.layer + mock_get_running_cost.return_value = mockup_record + records = await billing.get_running_costs( + field=BillingColumn.TOPIC, + invoice_month=None, + source=None, + author=TEST_API_BILLING_USER, + ) + self.assertEqual(mockup_record, records) + + @patch('api.routes.billing._get_billing_layer_from') + async def call_api_function( + self, + api_function, + mock_layer_function, + mock_get_billing_layer=None, + ): + """ + Common wrapper for all API calls, to avoid code duplication + API function is called with author=TEST_API_BILLING_USER + get_author function will be tested separately + We only testing if routes are calling layer functions and + returning data it received + """ + mock_get_billing_layer.return_value = self.layer + mockup_records = ['RECORD1', 'RECORD2'] + mock_layer_function.return_value = mockup_records + records = await api_function(author=TEST_API_BILLING_USER) + self.assertEqual(mockup_records, records) + + @run_as_sync + @patch('db.python.layers.billing.BillingLayer.get_gcp_projects') + async def test_get_gcp_projects(self, mock_get_gcp_projects): + """ + Test get_gcp_projects function + """ + await self.call_api_function(billing.get_gcp_projects, mock_get_gcp_projects) + + @run_as_sync + @patch('db.python.layers.billing.BillingLayer.get_topics') + async def test_get_topics(self, mock_get_topics): + """ + Test get_topics function + """ + await self.call_api_function(billing.get_topics, mock_get_topics) + + @run_as_sync + @patch('db.python.layers.billing.BillingLayer.get_cost_categories') + async def test_get_cost_categories(self, mock_get_cost_categories): + """ + Test get_cost_categories function + """ + await self.call_api_function( + billing.get_cost_categories, mock_get_cost_categories + ) + + @run_as_sync + @patch('db.python.layers.billing.BillingLayer.get_skus') + async def test_get_skus(self, mock_get_skus): + """ + Test get_skus function + """ + await self.call_api_function(billing.get_skus, mock_get_skus) + + @run_as_sync + @patch('db.python.layers.billing.BillingLayer.get_datasets') + async def test_get_datasets(self, mock_get_datasets): + """ + Test get_datasets function + """ + await self.call_api_function(billing.get_datasets, mock_get_datasets) + + @run_as_sync + @patch('db.python.layers.billing.BillingLayer.get_sequencing_types') + async def test_get_sequencing_types(self, mock_get_sequencing_types): + """ + Test get_sequencing_types function + """ + await self.call_api_function( + billing.get_sequencing_types, mock_get_sequencing_types + ) + + @run_as_sync + @patch('db.python.layers.billing.BillingLayer.get_stages') + async def test_get_stages(self, mock_get_stages): + """ + Test get_stages function + """ + await self.call_api_function(billing.get_stages, mock_get_stages) + + @run_as_sync + @patch('db.python.layers.billing.BillingLayer.get_sequencing_groups') + async def test_get_sequencing_groups(self, mock_get_sequencing_groups): + """ + Test get_sequencing_groups function + """ + await self.call_api_function( + billing.get_sequencing_groups, mock_get_sequencing_groups + ) + + @run_as_sync + @patch('db.python.layers.billing.BillingLayer.get_compute_categories') + async def test_get_compute_categories(self, mock_get_compute_categories): + """ + Test get_compute_categories function + """ + await self.call_api_function( + billing.get_compute_categories, mock_get_compute_categories + ) + + @run_as_sync + @patch('db.python.layers.billing.BillingLayer.get_cromwell_sub_workflow_names') + async def test_get_cromwell_sub_workflow_names( + self, mock_get_cromwell_sub_workflow_names + ): + """ + Test get_cromwell_sub_workflow_names function + """ + await self.call_api_function( + billing.get_cromwell_sub_workflow_names, + mock_get_cromwell_sub_workflow_names, + ) + + @run_as_sync + @patch('db.python.layers.billing.BillingLayer.get_wdl_task_names') + async def test_get_wdl_task_names(self, mock_get_wdl_task_names): + """ + Test get_wdl_task_names function + """ + await self.call_api_function( + billing.get_wdl_task_names, mock_get_wdl_task_names + ) + + @run_as_sync + @patch('db.python.layers.billing.BillingLayer.get_invoice_months') + async def test_get_invoice_months(self, mock_get_invoice_months): + """ + Test get_invoice_months function + """ + await self.call_api_function( + billing.get_invoice_months, mock_get_invoice_months + ) + + @run_as_sync + @patch('db.python.layers.billing.BillingLayer.get_namespaces') + async def test_get_namespaces(self, mock_get_namespaces): + """ + Test get_namespaces function + """ + await self.call_api_function(billing.get_namespaces, mock_get_namespaces) diff --git a/test/test_api_utils.py b/test/test_api_utils.py new file mode 100644 index 000000000..6f02a913c --- /dev/null +++ b/test/test_api_utils.py @@ -0,0 +1,77 @@ +import unittest +from datetime import date, datetime + +from api.utils.dates import ( + get_invoice_month_range, + parse_date_only_string, + reformat_datetime, +) + + +class TestApiUtils(unittest.TestCase): + """Test API utils functions""" + + def test_parse_date_only_string(self): + """ + Test parse_date_only_string function + """ + result_none = parse_date_only_string(None) + self.assertEqual(None, result_none) + + result_date = parse_date_only_string('2021-01-10') + self.assertEqual(2021, result_date.year) + self.assertEqual(1, result_date.month) + self.assertEqual(10, result_date.day) + + # test exception + invalid_date_str = '123456789' + with self.assertRaises(ValueError) as context: + parse_date_only_string(invalid_date_str) + + self.assertTrue( + f'Date could not be converted: {invalid_date_str}' in str(context.exception) + ) + + def test_get_invoice_month_range(self): + """ + Test get_invoice_month_range function + """ + jan_2021 = datetime.strptime('2021-01-10', '%Y-%m-%d').date() + res_jan_2021 = get_invoice_month_range(jan_2021) + + # there is 3 (INVOICE_DAY_DIFF) days difference between invoice month st and end + self.assertEqual( + (date(2020, 12, 29), date(2021, 2, 3)), + res_jan_2021, + ) + + dec_2021 = datetime.strptime('2021-12-10', '%Y-%m-%d').date() + res_dec_2021 = get_invoice_month_range(dec_2021) + + # there is 3 (INVOICE_DAY_DIFF) days difference between invoice month st and end + self.assertEqual( + (date(2021, 11, 28), date(2022, 1, 3)), + res_dec_2021, + ) + + def test_reformat_datetime(self): + """ + Test reformat_datetime function + """ + in_format = '%Y-%m-%d' + out_format = '%d/%m/%Y' + + result_none = reformat_datetime(None, in_format, out_format) + self.assertEqual(None, result_none) + + result_formatted = reformat_datetime('2021-11-09', in_format, out_format) + self.assertEqual('09/11/2021', result_formatted) + + # test exception + invalid_date_str = '123456789' + with self.assertRaises(ValueError) as context: + reformat_datetime(invalid_date_str, in_format, out_format) + + self.assertTrue( + f'Date could not be converted: {invalid_date_str}' in str(context.exception) + ) diff --git a/test/test_bq_billing_ar_batch.py b/test/test_bq_billing_ar_batch.py new file mode 100644 index 000000000..568f0f1dc --- /dev/null +++ b/test/test_bq_billing_ar_batch.py @@ -0,0 +1,188 @@ +# pylint: disable=protected-access +import datetime +from test.testbase import run_as_sync +from test.testbqbase import BqTest +from typing import Any +from unittest import mock + +import google.cloud.bigquery as bq + +from db.python.tables.bq.billing_ar_batch import BillingArBatchTable +from db.python.tables.bq.billing_filter import BillingFilter +from db.python.tables.bq.generic_bq_filter import GenericBQFilter +from db.python.utils import InternalError +from models.models import BillingColumn, BillingTotalCostQueryModel + + +class TestBillingArBatchTable(BqTest): + """Test BillingArBatchTable and its methods""" + + def setUp(self): + super().setUp() + + # setup table object + self.table_obj = BillingArBatchTable(self.connection) + + def test_query_to_partitioned_filter(self): + """Test query to partitioned filter conversion""" + + # given + start_date = '2023-01-01' + end_date = '2024-01-01' + filters: dict[BillingColumn, str | list[Any] | dict[Any, Any]] = { + BillingColumn.TOPIC: 'TEST_TOPIC' + } + + # expected + expected_filter = BillingFilter( + day=GenericBQFilter( + gte=datetime.datetime(2023, 1, 1, 0, 0), + lte=datetime.datetime(2024, 1, 1, 0, 0), + ), + topic=GenericBQFilter(eq='TEST_TOPIC'), + ) + + query = BillingTotalCostQueryModel( + fields=[], # not relevant for this test, but can't be null generally + start_date=start_date, + end_date=end_date, + filters=filters, + ) + filter_ = BillingArBatchTable._query_to_partitioned_filter(query) + + # BillingFilter has __eq__ method, so we can compare them directly + self.assertEqual(expected_filter, filter_) + + def test_error_no_connection(self): + """Test No connection exception""" + + with self.assertRaises(InternalError) as context: + BillingArBatchTable(None) + + self.assertTrue( + 'No connection was provided to the table \'BillingArBatchTable\'' + in str(context.exception) + ) + + def test_get_table_name(self): + """Test get_table_name""" + + # table name is set in the class + given_table_name = 'TEST_TABLE_NAME' + + # set table name + self.table_obj.table_name = given_table_name + + # test get table name function + table_name = self.table_obj.get_table_name() + + self.assertEqual(given_table_name, table_name) + + @run_as_sync + async def test_get_batches_by_ar_guid_no_data(self): + """Test get_batches_by_ar_guid""" + + ar_guid = '1234567890' + + # mock BigQuery result + self.bq_result.result.return_value = [] + + # test get_batches_by_ar_guid function + (start_day, end_day, batch_ids) = await self.table_obj.get_batches_by_ar_guid( + ar_guid + ) + + self.assertEqual(None, start_day) + self.assertEqual(None, end_day) + self.assertEqual([], batch_ids) + + @run_as_sync + async def test_get_batches_by_ar_guid_one_record(self): + """Test get_batches_by_ar_guid""" + + ar_guid = '1234567890' + given_start_day = datetime.datetime(2023, 1, 1, 0, 0) + given_end_day = datetime.datetime(2023, 1, 1, 2, 3) + + # mock BigQuery result + self.bq_result.result.return_value = [ + mock.MagicMock( + spec=bq.Row, + start_day=given_start_day, + end_day=given_end_day, + batch_id='Batch1234', + ) + ] + + # test get_batches_by_ar_guid function + (start_day, end_day, batch_ids) = await self.table_obj.get_batches_by_ar_guid( + ar_guid + ) + + self.assertEqual(given_start_day, start_day) + # end day is the last day + 1 + self.assertEqual(given_end_day + datetime.timedelta(days=1), end_day) + self.assertEqual(['Batch1234'], batch_ids) + + @run_as_sync + async def test_get_batches_by_ar_guid_two_record(self): + """Test get_batches_by_ar_guid""" + + ar_guid = '1234567890' + given_start_day = datetime.datetime(2023, 1, 1, 0, 0) + given_end_day = datetime.datetime(2023, 1, 1, 2, 3) + + # mock BigQuery result + self.bq_result.result.return_value = [ + mock.MagicMock( + spec=bq.Row, + start_day=given_start_day, + end_day=given_start_day, + batch_id='FirstBatch', + ), + mock.MagicMock( + spec=bq.Row, + start_day=given_start_day, + end_day=given_end_day, + batch_id='SecondBatch', + ), + ] + + # test get_batches_by_ar_guid function + (start_day, end_day, batch_ids) = await self.table_obj.get_batches_by_ar_guid( + ar_guid + ) + + self.assertEqual(given_start_day, start_day) + # end day is the last day + 1 + self.assertEqual(given_end_day + datetime.timedelta(days=1), end_day) + self.assertEqual(['FirstBatch', 'SecondBatch'], batch_ids) + + @run_as_sync + async def test_get_ar_guid_by_batch_id_no_data(self): + """Test get_ar_guid_by_batch_id""" + + batch_id = '1234567890' + + # mock BigQuery result + self.bq_result.result.return_value = [] + + # test get_ar_guid_by_batch_id function + ar_guid = await self.table_obj.get_ar_guid_by_batch_id(batch_id) + + self.assertEqual(None, ar_guid) + + @run_as_sync + async def test_get_ar_guid_by_batch_id_one_rec(self): + """Test get_ar_guid_by_batch_id""" + + batch_id = '1234567890' + expected_ar_guid = 'AR_GUID_1234' + + # mock BigQuery result + self.bq_result.result.return_value = [{'ar_guid': expected_ar_guid}] + + # test get_ar_guid_by_batch_id function + ar_guid = await self.table_obj.get_ar_guid_by_batch_id(batch_id) + + self.assertEqual(expected_ar_guid, ar_guid) diff --git a/test/test_bq_billing_base.py b/test/test_bq_billing_base.py new file mode 100644 index 000000000..655b40486 --- /dev/null +++ b/test/test_bq_billing_base.py @@ -0,0 +1,911 @@ +# pylint: disable=protected-access too-many-public-methods +from datetime import datetime +from test.testbase import run_as_sync +from test.testbqbase import BqTest +from typing import Any +from unittest import mock + +import google.cloud.bigquery as bq + +from db.python.tables.bq.billing_base import ( + BillingBaseTable, + abbrev_cost_category, + prepare_time_periods, +) +from db.python.tables.bq.billing_daily_extended import BillingDailyExtendedTable +from db.python.tables.bq.billing_filter import BillingFilter +from db.python.tables.bq.generic_bq_filter import GenericBQFilter +from models.enums import BillingTimePeriods +from models.models import ( + BillingColumn, + BillingCostBudgetRecord, + BillingCostDetailsRecord, + BillingTotalCostQueryModel, +) + + +def mock_execute_query_running_cost(query, *_args, **_kwargs): + """ + This is a mockup function for _execute_query function + This returns one mockup BQ query result + for 2 different SQL queries used by get_running_cost API point + Those 2 queries are: + 1. query to get last loaded day + 2. query to get aggregated monthly/daily cost + """ + if ' as last_loaded_day' in query: + # This is the 1st query to get last loaded day + # mockup BQ query result for last_loaded_day as list of rows + return [ + mock.MagicMock(spec=bq.Row, last_loaded_day='2024-01-01 00:00:00+00:00') + ] + + # This is the 2nd query to get aggregated monthly cost + # return mockup BQ query result as list of rows + return [ + mock.MagicMock( + spec=bq.Row, + field='TOPIC1', + cost_category='Compute Engine', + daily_cost=123.45, + monthly_cost=2345.67, + ) + ] + + +def mock_execute_query_get_total_cost(_query, *_args, **_kwargs): + """ + This is a mockup function for _execute_query function + This returns one mockup BQ query result + """ + # mockup BQ query result topic cost by invoice month, return row iterator + mock_rows = mock.MagicMock(spec=bq.table.RowIterator) + mock_rows.total_rows = 3 + mock_rows.__iter__.return_value = [ + {'day': '202301', 'topic': 'TOPIC1', 'cost': 123.10}, + {'day': '202302', 'topic': 'TOPIC1', 'cost': 223.20}, + {'day': '202303', 'topic': 'TOPIC1', 'cost': 323.30}, + ] + mock_result = mock.MagicMock(spec=bq.job.QueryJob) + mock_result.result.return_value = mock_rows + return mock_result + + +class TestBillingBaseTable(BqTest): + """Test BillingBaseTable and its methods""" + + def setUp(self): + super().setUp() + + # setup table object + # base is abstract, so we need to use a child class + # DailyExtended is the simplest one, almost no overrides + self.table_obj = BillingDailyExtendedTable(self.connection) + + def test_query_to_partitioned_filter(self): + """Test query to partitioned filter conversion""" + + # given + start_date = '2023-01-01' + end_date = '2024-01-01' + filters: dict[BillingColumn, str | list[Any] | dict[Any, Any]] = { + BillingColumn.TOPIC: 'TEST_TOPIC' + } + + # expected + expected_filter = BillingFilter( + day=GenericBQFilter( + gte=datetime(2023, 1, 1, 0, 0), + lte=datetime(2024, 1, 1, 0, 0), + ), + topic=GenericBQFilter(eq='TEST_TOPIC'), + ) + + query = BillingTotalCostQueryModel( + fields=[], # not relevant for this test, but can't be null generally + start_date=start_date, + end_date=end_date, + filters=filters, + ) + filter_ = BillingBaseTable._query_to_partitioned_filter(query) + + # BillingFilter has __eq__ method, so we can compare them directly + self.assertEqual(expected_filter, filter_) + + def test_abbrev_cost_category(self): + """Test abbrev_cost_category""" + + # table name is set in the class + categories_to_expected = { + 'Cloud Storage': 'S', + 'Compute Engine': 'C', + 'Other': 'C', + } + + # test category to abreveation + for cat, expected_abrev in categories_to_expected.items(): + self.assertEqual(expected_abrev, abbrev_cost_category(cat)) + + def test_prepare_time_periods_by_day(self): + """Test prepare_time_periods""" + + query = BillingTotalCostQueryModel( + fields=[], + start_date='2024-01-01', + end_date='2024-01-01', + time_periods=BillingTimePeriods.DAY, + ) + + time_group = prepare_time_periods(query) + + self.assertEqual('FORMAT_DATE("%Y-%m-%d", day) as day', time_group.field) + self.assertEqual('PARSE_DATE("%Y-%m-%d", day) as day', time_group.formula) + self.assertEqual(',', time_group.separator) + + def test_prepare_time_periods_by_week(self): + """Test prepare_time_periods""" + + query = BillingTotalCostQueryModel( + fields=[], + start_date='2024-01-01', + end_date='2024-01-01', + time_periods=BillingTimePeriods.WEEK, + ) + + time_group = prepare_time_periods(query) + + self.assertEqual('FORMAT_DATE("%Y%W", day) as day', time_group.field) + self.assertEqual('PARSE_DATE("%Y%W", day) as day', time_group.formula) + self.assertEqual(',', time_group.separator) + + def test_prepare_time_periods_by_month(self): + """Test prepare_time_periods""" + + query = BillingTotalCostQueryModel( + fields=[], + start_date='2024-01-01', + end_date='2024-01-01', + time_periods=BillingTimePeriods.MONTH, + ) + + time_group = prepare_time_periods(query) + + self.assertEqual('FORMAT_DATE("%Y%m", day) as day', time_group.field) + self.assertEqual('PARSE_DATE("%Y%m", day) as day', time_group.formula) + self.assertEqual(',', time_group.separator) + + def test_prepare_time_periods_by_invoice_month(self): + """Test prepare_time_periods""" + + query = BillingTotalCostQueryModel( + fields=[], + start_date='2024-01-01', + end_date='2024-01-01', + time_periods=BillingTimePeriods.INVOICE_MONTH, + ) + + time_group = prepare_time_periods(query) + + self.assertEqual('invoice_month as day', time_group.field) + self.assertEqual('PARSE_DATE("%Y%m", day) as day', time_group.formula) + self.assertEqual(',', time_group.separator) + + def test_filter_to_optimise_query(self): + """Test _filter_to_optimise_query""" + + result = BillingBaseTable._filter_to_optimise_query() + self.assertEqual( + 'day >= TIMESTAMP(@start_day) AND day <= TIMESTAMP(@last_day)', result + ) + + def test_last_loaded_day_filter(self): + """Test _last_loaded_day_filter""" + + result = BillingBaseTable._last_loaded_day_filter() + self.assertEqual('day = TIMESTAMP(@last_loaded_day)', result) + + def test_convert_output_empty_results(self): + """Test _convert_output - various empty results""" + + empty_results = BillingBaseTable._convert_output(None) + self.assertEqual([], empty_results) + + query_job_result = mock.MagicMock(spec=bq.job.QueryJob) + query_job_result.result.total_rows = 0 + + empty_list = BillingBaseTable._convert_output(query_job_result) + self.assertEqual([], empty_list) + + query_job_result = mock.MagicMock(spec=bq.job.QueryJob) + query_job_result.result.return_value = mock.MagicMock(spec=bq.table.RowIterator) + + empty_row_iterator = BillingBaseTable._convert_output(query_job_result) + self.assertEqual([], empty_row_iterator) + + def test_convert_output_one_record(self): + """Test _convert_output - one record result""" + + mock_rows = mock.MagicMock(spec=bq.table.RowIterator) + mock_rows.total_rows = 1 + mock_rows.__iter__.return_value = [{}] + + query_job_result = mock.MagicMock(spec=bq.job.QueryJob) + query_job_result.result.return_value = mock_rows + + single_row = BillingBaseTable._convert_output(query_job_result) + self.assertEqual([{}], single_row) + + def test_convert_output_label_record(self): + """Test _convert_output - test with label item""" + mock_rows = mock.MagicMock(spec=bq.table.RowIterator) + mock_rows.total_rows = 1 + mock_rows.__iter__.return_value = [ + {'labels': [{'key': 'test_key', 'value': 'test_value'}]} + ] + + query_job_result = mock.MagicMock(spec=bq.job.QueryJob) + query_job_result.result.return_value = mock_rows + + row_iterator = BillingBaseTable._convert_output(query_job_result) + self.assertEqual( + [ + { + # keep the original labels + 'labels': [{'key': 'test_key', 'value': 'test_value'}], + # append the labels as key-value pairs + 'test_key': 'test_value', + } + ], + row_iterator, + ) + + def test_prepare_order_by_string_empty(self): + """Test _prepare_order_by_string - empty results""" + + self.assertEqual('', BillingBaseTable._prepare_order_by_string(None)) + + def test_prepare_order_by_string_order_by_one_column(self): + """Test _prepare_order_by_string""" + + # DESC order by column + self.assertEqual( + 'ORDER BY cost DESC', + BillingBaseTable._prepare_order_by_string({BillingColumn.COST: True}), + ) + + # ASC order by column + self.assertEqual( + 'ORDER BY cost ASC', + BillingBaseTable._prepare_order_by_string({BillingColumn.COST: False}), + ) + + def test_prepare_order_by_string_order_by_two_columns(self): + """Test _prepare_order_by_string - order by 2 columns""" + self.assertEqual( + 'ORDER BY cost ASC,day DESC', + BillingBaseTable._prepare_order_by_string( + {BillingColumn.COST: False, BillingColumn.DAY: True} + ), + ) + + def test_prepare_aggregation_default_group_by(self): + """Test _prepare_aggregation""" + + query = BillingTotalCostQueryModel( + fields=[], start_date='2024-01-01', end_date='2024-01-01' + ) + + fields_selected, group_by = BillingBaseTable._prepare_aggregation(query) + # no fields selected so it is empty + self.assertEqual('', fields_selected) + # by default results are grouped by day + self.assertEqual('GROUP BY day', group_by) + + def test_prepare_aggregation_default_no_grouping_by(self): + """Test _prepare_aggregation""" + + # test when query is not grouped by + query = BillingTotalCostQueryModel( + fields=[BillingColumn.TOPIC], + start_date='2024-01-01', + end_date='2024-01-01', + group_by=False, + ) + + fields_selected, group_by = BillingBaseTable._prepare_aggregation(query) + # topic field is selected + self.assertEqual('topic', fields_selected) + # group by is switched off + self.assertEqual('', group_by) + + def test_prepare_aggregation_default_group_by_more_columns(self): + """Test _prepare_aggregation""" + + # test when query is grouped by, but column can not be grouped by + # cost can not be grouped by, so it is not present in the result + query = BillingTotalCostQueryModel( + fields=[BillingColumn.TOPIC, BillingColumn.COST], + start_date='2024-01-01', + end_date='2024-01-01', + group_by=True, + ) + + fields_selected, group_by = BillingBaseTable._prepare_aggregation(query) + self.assertEqual('topic', fields_selected) + # always group by day and any field that can be grouped by + self.assertEqual('GROUP BY day,topic', group_by) + + def test_execute_query_results_as_list(self): + """Test _execute_query""" + + # we are not running SQL against real BQ, just a mocking, so we can use any query + sql_query = 'SELECT 1;' + sql_params: list[Any] = [] + + # test results_as_list=True + given_bq_results = [[], [123], ['a', 'b', 'c']] + for bq_result in given_bq_results: + self.bq_result.result.return_value = bq_result + results = self.table_obj._execute_query( + sql_query, sql_params, results_as_list=True + ) + self.assertEqual(bq_result, results) + + def test_execute_query_results_not_as_list(self): + """Test _execute_query""" + + # we are not running SQL against real BQ, just a mocking, so we can use any query + sql_query = 'SELECT 1;' + sql_params: list[Any] = [] + + # now test results_as_list=False + given_bq_results = [[], [123], ['a', 'b', 'c']] + for bq_result in given_bq_results: + # mock BigQuery result + self.bq_client.query.return_value = bq_result + results = self.table_obj._execute_query( + sql_query, sql_params, results_as_list=False + ) + self.assertEqual(bq_result, results) + + def test_execute_query_with_sql_params(self): + """Test _execute_query""" + + # now test results_as_list=False and with some dummy params + sql_query = 'SELECT 1;' + sql_params = [ + bq.ScalarQueryParameter('dummy_not_used', 'STRING', '2021-01-01 00:00:00') + ] + + given_bq_results = [[], [123], ['a', 'b', 'c']] + for bq_result in given_bq_results: + # mock BigQuery result + self.bq_client.query.return_value = bq_result + results = self.table_obj._execute_query( + sql_query, sql_params, results_as_list=False + ) + self.assertEqual(bq_result, results) + + @run_as_sync + async def test_append_total_running_cost_no_topic(self): + """Test _append_total_running_cost""" + + # test _append_total_running_cost function, no topic present + total_record = await BillingBaseTable._append_total_running_cost( + field=BillingColumn.TOPIC, + is_current_month=True, + last_loaded_day=None, + total_monthly={'C': {'ALL': 1000}, 'S': {'ALL': 2000}}, + total_daily={'C': {'ALL': 100}, 'S': {'ALL': 200}}, + total_monthly_category={}, + total_daily_category={}, + results=[], + ) + + self.assertEqual( + [ + BillingCostBudgetRecord( + field='All Topics', + total_monthly=3000.0, + total_daily=300.0, + compute_monthly=1000.0, + compute_daily=100.0, + storage_monthly=2000.0, + storage_daily=200.0, + details=[], + budget_spent=None, + budget=None, + last_loaded_day=None, + ) + ], + total_record, + ) + + @run_as_sync + async def test_append_total_running_cost_not_current_month(self): + """Test _append_total_running_cost""" + + # test _append_total_running_cost function, not current month + total_record = await BillingBaseTable._append_total_running_cost( + field=BillingColumn.TOPIC, + is_current_month=False, + last_loaded_day=None, + total_monthly={'C': {'ALL': 1000}, 'S': {'ALL': 2000}}, + total_daily=None, + total_monthly_category={}, + total_daily_category={}, + results=[], + ) + + self.assertEqual( + [ + BillingCostBudgetRecord( + field='All Topics', + total_monthly=3000.0, + total_daily=None, + compute_monthly=1000.0, + compute_daily=None, + storage_monthly=2000.0, + storage_daily=None, + details=[], + budget_spent=None, + budget=None, + last_loaded_day=None, + ) + ], + total_record, + ) + + @run_as_sync + async def test_append_total_running_cost_current_month(self): + """Test _append_total_running_cost""" + + total_record = await BillingBaseTable._append_total_running_cost( + field=BillingColumn.TOPIC, + is_current_month=True, + last_loaded_day=None, + total_monthly={'C': {'ALL': 1000}, 'S': {'ALL': 2000}}, + total_daily={'C': {'ALL': 100}, 'S': {'ALL': 200}}, + total_monthly_category={ + 'Compute Engine': 900, + 'Cloud Storage': 2000, + 'Other': 100, + }, + total_daily_category={ + 'Compute Engine': 90, + 'Cloud Storage': 200, + 'Other': 10, + }, + results=[], + ) + + self.assertEqual( + [ + BillingCostBudgetRecord( + field='All Topics', + total_monthly=3000.0, + total_daily=300.0, + compute_monthly=1000.0, + compute_daily=100.0, + storage_monthly=2000.0, + storage_daily=200.0, + details=[ + BillingCostDetailsRecord( + cost_group='C', + cost_category='Compute Engine', + daily_cost=90.0, + monthly_cost=900.0, + ), + BillingCostDetailsRecord( + cost_group='S', + cost_category='Cloud Storage', + daily_cost=200.0, + monthly_cost=2000.0, + ), + BillingCostDetailsRecord( + cost_group='C', + cost_category='Other', + daily_cost=10.0, + monthly_cost=100.0, + ), + ], + budget_spent=None, + budget=None, + last_loaded_day=None, + ) + ], + total_record, + ) + + @run_as_sync + async def test_budgets_by_gcp_project_empty_results(self): + """Test _budgets_by_gcp_project""" + + # Only GCP_PROJECT and current month has budget + empty_result = await self.table_obj._budgets_by_gcp_project( + BillingColumn.TOPIC, False + ) + self.assertEqual({}, empty_result) + + # GCP_PROJECT and current month, but BQ mockup setup as empty values + empty_result = await self.table_obj._budgets_by_gcp_project( + BillingColumn.GCP_PROJECT, True + ) + self.assertEqual({}, empty_result) + + @run_as_sync + async def test_budgets_by_gcp_project_with_results(self): + """Test _budgets_by_gcp_project""" + + # GCP_PROJECT and current month and Mockup set as 2 records + self.bq_result.result.return_value = [ + mock.MagicMock(spec=bq.Row, gcp_project='Project1', budget=1000.0), + mock.MagicMock(spec=bq.Row, gcp_project='Project2', budget=2000.0), + ] + + non_empty_result = await self.table_obj._budgets_by_gcp_project( + BillingColumn.GCP_PROJECT, True + ) + self.assertDictEqual({'Project1': 1000.0, 'Project2': 2000.0}, non_empty_result) + + @run_as_sync + async def test_execute_running_cost_query_invalid_months(self): + """Test _execute_running_cost_query""" + + # test invalid inputs + with self.assertRaises(ValueError) as context: + await self.table_obj._execute_running_cost_query(BillingColumn.TOPIC, None) + + self.assertTrue('Invalid invoice month' in str(context.exception)) + + with self.assertRaises(ValueError) as context: + await self.table_obj._execute_running_cost_query( + BillingColumn.TOPIC, '12345678' + ) + + self.assertTrue('Invalid invoice month' in str(context.exception)) + + with self.assertRaises(ValueError) as context: + await self.table_obj._execute_running_cost_query( + BillingColumn.TOPIC, '1024AA' + ) + self.assertTrue('Invalid invoice month' in str(context.exception)) + + @run_as_sync + async def test_execute_running_cost_query_empty_results_old_month(self): + """Test _execute_running_cost_query""" + + # no mocked BQ results, should return as empty + ( + is_current_month, + last_loaded_day, + query_job_result, + ) = await self.table_obj._execute_running_cost_query( + BillingColumn.TOPIC, '202101' + ) + + self.assertEqual(False, is_current_month) + self.assertEqual(None, last_loaded_day) + self.assertEqual([], query_job_result) + + @run_as_sync + async def test_execute_running_cost_query_empty_results_current_month(self): + """Test _execute_running_cost_query""" + + # no mocked BQ results, should return as empty + # use current month to test the current month branch + current_month_as_string = datetime.now().strftime('%Y%m') + ( + is_current_month, + last_loaded_day, + query_job_result, + ) = await self.table_obj._execute_running_cost_query( + BillingColumn.TOPIC, current_month_as_string + ) + + self.assertEqual(True, is_current_month) + self.assertEqual(None, last_loaded_day) + self.assertEqual([], query_job_result) + + @run_as_sync + async def test_append_running_cost_records_empty_results(self): + """Test _append_running_cost_records""" + + # test empty results + empty_results = await self.table_obj._append_running_cost_records( + field=BillingColumn.TOPIC, + is_current_month=False, + last_loaded_day=None, + total_monthly={}, + total_daily={}, + field_details={}, + results=[], + ) + + self.assertEqual([], empty_results) + + @run_as_sync + async def test_append_running_cost_records_simple_data(self): + """Test _append_running_cost_records""" + + # prepare simple input data + field_details: dict[str, Any] = { + 'Project1': [], + } + + simple_result = await self.table_obj._append_running_cost_records( + field=BillingColumn.GCP_PROJECT, + is_current_month=False, + last_loaded_day=None, + total_monthly={'C': {}, 'S': {}}, + total_daily={'C': {}, 'S': {}}, + field_details=field_details, + results=[], + ) + + self.assertEqual( + [ + BillingCostBudgetRecord( + field='Project1', + total_monthly=0.0, + compute_monthly=0.0, + compute_daily=0.0, + storage_monthly=0.0, + storage_daily=0.0, + details=[], + last_loaded_day=None, + total_daily=None, + budget_spent=None, + budget=None, + ) + ], + simple_result, + ) + + @run_as_sync + async def test_append_running_cost_records_with_details(self): + """Test _append_running_cost_records""" + + # prepare input data with more details + field_details = { + 'Project2': [ + { + 'cost_group': 'C', + 'cost_category': 'Compute Engine', + 'daily_cost': 90.0, + 'monthly_cost': 900.0, + } + ], + } + + detailed_result = await self.table_obj._append_running_cost_records( + field=BillingColumn.GCP_PROJECT, + is_current_month=False, + last_loaded_day=None, + total_monthly={'C': {}, 'S': {}}, + total_daily={'C': {}, 'S': {}}, + field_details=field_details, + results=[], + ) + + self.assertEqual( + [ + BillingCostBudgetRecord( + field='Project2', + total_monthly=0.0, + compute_monthly=0.0, + compute_daily=0.0, + storage_monthly=0.0, + storage_daily=0.0, + details=[ + BillingCostDetailsRecord( + cost_group='C', + cost_category='Compute Engine', + daily_cost=90.0, + monthly_cost=900.0, + ) + ], + last_loaded_day=None, + total_daily=None, + budget_spent=None, + budget=None, + ) + ], + detailed_result, + ) + + @run_as_sync + async def test_get_running_cost_invalid_input(self): + """Test get_running_cost""" + + # test invalid outputs + with self.assertRaises(ValueError) as context: + await self.table_obj.get_running_cost( + # not allowed field + field=BillingColumn.SKU, + invoice_month=None, + ) + + self.assertTrue( + ( + 'Invalid field only topic, dataset, gcp-project, compute_category, ' + 'wdl_task_name, cromwell_sub_workflow_name & namespace are allowed' + ) + in str(context.exception) + ) + + @run_as_sync + async def test_get_running_cost_empty_results(self): + """Test get_running_cost""" + + # test empty cost (no BQ mockup data provided) + empty_results = await self.table_obj.get_running_cost( + field=BillingColumn.TOPIC, + invoice_month='202301', + ) + + self.assertEqual([], empty_results) + + @run_as_sync + async def test_get_running_cost_older_month(self): + """Test get_running_cost""" + + # mockup BQ sql query result for _execute_running_cost_query function + self.table_obj._execute_query = mock.MagicMock( + side_effect=mock_execute_query_running_cost + ) + + one_record_result = await self.table_obj.get_running_cost( + field=BillingColumn.TOPIC, + invoice_month='202301', + ) + + self.assertEqual( + [ + BillingCostBudgetRecord( + field='All Topics', + total_monthly=2345.67, + total_daily=None, + compute_monthly=2345.67, + compute_daily=None, + storage_monthly=0.0, + storage_daily=None, + details=[ + BillingCostDetailsRecord( + cost_group='C', + cost_category='Compute Engine', + daily_cost=None, + monthly_cost=2345.67, + ) + ], + budget_spent=None, + budget=None, + last_loaded_day=None, + ), + BillingCostBudgetRecord( + field='TOPIC1', + total_monthly=2345.67, + total_daily=None, + compute_monthly=2345.67, + compute_daily=0.0, + storage_monthly=0.0, + storage_daily=0.0, + details=[ + BillingCostDetailsRecord( + cost_group='C', + cost_category='Compute Engine', + daily_cost=None, + monthly_cost=2345.67, + ) + ], + budget_spent=None, + budget=None, + last_loaded_day=None, + ), + ], + one_record_result, + ) + + @run_as_sync + async def test_get_running_cost_current_month(self): + """Test get_running_cost""" + + # mockup BQ sql query result for _execute_running_cost_query function + self.table_obj._execute_query = mock.MagicMock( + side_effect=mock_execute_query_running_cost + ) + + # use the current month to test the current month branch + current_month_as_string = datetime.now().strftime('%Y%m') + + current_month_result = await self.table_obj.get_running_cost( + field=BillingColumn.TOPIC, + invoice_month=current_month_as_string, + ) + + self.assertEqual( + [ + BillingCostBudgetRecord( + field='All Topics', + total_monthly=2345.67, + total_daily=123.45, + compute_monthly=2345.67, + compute_daily=123.45, + storage_monthly=0.0, + storage_daily=0.0, + details=[ + BillingCostDetailsRecord( + cost_group='C', + cost_category='Compute Engine', + daily_cost=123.45, + monthly_cost=2345.67, + ) + ], + budget_spent=None, + budget=None, + last_loaded_day='Jan 01', + ), + BillingCostBudgetRecord( + field='TOPIC1', + total_monthly=2345.67, + total_daily=123.45, + compute_monthly=2345.67, + compute_daily=123.45, + storage_monthly=0.0, + storage_daily=0.0, + details=[ + BillingCostDetailsRecord( + cost_group='C', + cost_category='Compute Engine', + daily_cost=123.45, + monthly_cost=2345.67, + ) + ], + budget_spent=None, + budget=None, + last_loaded_day='Jan 01', + ), + ], + current_month_result, + ) + + @run_as_sync + async def test_get_total_cost(self): + """Test get_total_cost""" + + # test invalid input + query = BillingTotalCostQueryModel( + fields=[], start_date='2023-01-01', end_date='2024-01-01' + ) + + with self.assertRaises(ValueError) as context: + await self.table_obj.get_total_cost(query) + + self.assertTrue('Date and Fields are required' in str(context.exception)) + + # test empty results + query = BillingTotalCostQueryModel( + fields=[BillingColumn.TOPIC], + start_date='2023-01-01', + end_date='2024-01-01', + time_periods=BillingTimePeriods.INVOICE_MONTH, + ) + + # no BQ mockup data setup, returns empty list + empty_results = await self.table_obj.get_total_cost(query) + self.assertEqual([], empty_results) + + # mockup BQ sql query result for _execute_query to return 3 records. + # implementation is inside mock_execute_query function + self.table_obj._execute_query = mock.MagicMock( + side_effect=mock_execute_query_get_total_cost + ) + results = await self.table_obj.get_total_cost(query) + self.assertEqual( + [ + {'day': '202301', 'topic': 'TOPIC1', 'cost': 123.1}, + {'day': '202302', 'topic': 'TOPIC1', 'cost': 223.2}, + {'day': '202303', 'topic': 'TOPIC1', 'cost': 323.3}, + ], + results, + ) diff --git a/test/test_bq_billing_daily.py b/test/test_bq_billing_daily.py new file mode 100644 index 000000000..969b7c936 --- /dev/null +++ b/test/test_bq_billing_daily.py @@ -0,0 +1,255 @@ +# pylint: disable=protected-access +import datetime +from test.testbase import run_as_sync +from test.testbqbase import BqTest +from typing import Any +from unittest import mock + +import google.cloud.bigquery as bq + +from db.python.tables.bq.billing_daily import BillingDailyTable +from db.python.tables.bq.billing_filter import BillingFilter +from db.python.tables.bq.generic_bq_filter import GenericBQFilter +from db.python.utils import InternalError +from models.models import BillingColumn, BillingTotalCostQueryModel + + +class TestBillingDailyTable(BqTest): + """Test BillingRawTable and its methods""" + + def setUp(self): + super().setUp() + + # setup table object + self.table_obj = BillingDailyTable(self.connection) + + def test_query_to_partitioned_filter(self): + """Test query to partitioned filter conversion""" + + # given + start_date = '2023-01-01' + end_date = '2024-01-01' + filters: dict[BillingColumn, str | list[Any] | dict[Any, Any]] = { + BillingColumn.TOPIC: 'TEST_TOPIC' + } + + # expected + expected_filter = BillingFilter( + day=GenericBQFilter( + gte=datetime.datetime(2023, 1, 1, 0, 0), + lte=datetime.datetime(2024, 1, 1, 0, 0), + ), + topic=GenericBQFilter(eq='TEST_TOPIC'), + ) + + query = BillingTotalCostQueryModel( + fields=[], # not relevant for this test, but can't be null generally + start_date=start_date, + end_date=end_date, + filters=filters, + ) + filter_ = BillingDailyTable._query_to_partitioned_filter(query) + + # BillingFilter has __eq__ method, so we can compare them directly + self.assertEqual(expected_filter, filter_) + + def test_error_no_connection(self): + """Test No connection exception""" + + with self.assertRaises(InternalError) as context: + BillingDailyTable(None) + + self.assertTrue( + 'No connection was provided to the table \'BillingDailyTable\'' + in str(context.exception) + ) + + def test_get_table_name(self): + """Test get_table_name""" + + # table name is set in the class + given_table_name = 'TEST_TABLE_NAME' + + # set table name + self.table_obj.table_name = given_table_name + + # test get table name function + table_name = self.table_obj.get_table_name() + + self.assertEqual(given_table_name, table_name) + + @run_as_sync + async def test_last_loaded_day_return_valid_day(self): + """Test _last_loaded_day""" + + given_last_day = '2021-01-01 00:00:00' + + # mock BigQuery result + + self.bq_result.result.return_value = [ + mock.MagicMock( + spec=bq.Row, + last_loaded_day=datetime.datetime.strptime( + given_last_day, '%Y-%m-%d %H:%M:%S' + ), + ) + ] # 2021-01-01 + + # test get table name function + last_loaded_day = await self.table_obj._last_loaded_day() + + self.assertEqual(given_last_day, last_loaded_day) + + @run_as_sync + async def test_last_loaded_day_return_none(self): + """Test _last_loaded_day as None""" + + # mock BigQuery result as empty list + self.bq_result.result.return_value = [] + + # test get table name function + last_loaded_day = await self.table_obj._last_loaded_day() + + self.assertEqual(None, last_loaded_day) + + def test_prepare_daily_cost_subquery(self): + """Test _prepare_daily_cost_subquery""" + + self.table_obj.table_name = 'TEST_TABLE_NAME' + + # given + given_field = BillingColumn.COST + given_query_params: list[Any] = [] + given_last_loaded_day = '2021-01-01 00:00:00' + + ( + query_params, + daily_cost_field, + daily_cost_join, + ) = self.table_obj._prepare_daily_cost_subquery( + given_field, given_query_params, given_last_loaded_day + ) + + # expected + expected_daily_cost_join = """LEFT JOIN ( + SELECT + cost as field, + cost_category, + SUM(cost) as cost + FROM + `TEST_TABLE_NAME` + WHERE day = TIMESTAMP(@last_loaded_day) + GROUP BY + field, + cost_category + ) day + ON month.field = day.field + AND month.cost_category = day.cost_category + """ + + self.assertEqual( + [ + bq.ScalarQueryParameter( + 'last_loaded_day', 'STRING', '2021-01-01 00:00:00' + ) + ], + query_params, + ) + + self.assertEqual(', day.cost as daily_cost', daily_cost_field) + self.assertEqual(expected_daily_cost_join, daily_cost_join) + + @run_as_sync + async def test_get_entities_as_empty_list(self): + """ + Test get_topics, get_invoice_months, + get_cost_categories and get_skus as empty list + """ + + # mock BigQuery result as empty list + self.bq_result.result.return_value = [] + + # test get_topics function + records = await self.table_obj.get_topics() + self.assertEqual([], records) + + # test get_invoice_months function + records = await self.table_obj.get_invoice_months() + self.assertEqual([], records) + + # test get_cost_categories function + records = await self.table_obj.get_cost_categories() + self.assertEqual([], records) + + # test get_skus function + records = await self.table_obj.get_skus() + self.assertEqual([], records) + + @run_as_sync + async def test_get_topics_return_valid_list(self): + """Test get_topics as empty list""" + + # mock BigQuery result as list of 2 records + self.bq_result.result.return_value = [ + {'topic': 'TOPIC1'}, + {'topic': 'TOPIC2'}, + ] + + # test get_topics function + records = await self.table_obj.get_topics() + + self.assertEqual(['TOPIC1', 'TOPIC2'], records) + + @run_as_sync + async def test_get_invoice_months_return_valid_list(self): + """Test get_invoice_months as empty list""" + + # mock BigQuery result as list of 2 records + self.bq_result.result.return_value = [ + {'invoice_month': '202401'}, + {'invoice_month': '202402'}, + ] + + # test get_invoice_months function + records = await self.table_obj.get_invoice_months() + + self.assertEqual(['202401', '202402'], records) + + @run_as_sync + async def test_get_cost_categories_return_valid_list(self): + """Test get_cost_categories as empty list""" + + # mock BigQuery result as list of 2 records + self.bq_result.result.return_value = [ + {'cost_category': 'CAT1'}, + ] + + # test get_cost_categories function + records = await self.table_obj.get_cost_categories() + + self.assertEqual(['CAT1'], records) + + @run_as_sync + async def test_get_skus_return_valid_list(self): + """Test get_skus as empty list""" + + # mock BigQuery result as list of 3 records + self.bq_result.result.return_value = [ + {'sku': 'SKU1'}, + {'sku': 'SKU2'}, + {'sku': 'SKU3'}, + ] + + # test get_skus function + records = await self.table_obj.get_skus() + self.assertEqual(['SKU1', 'SKU2', 'SKU3'], records) + + # test get_skus function with limit, + # limit is ignored in the test as we already have mockup data + records = await self.table_obj.get_skus(limit=3) + self.assertEqual(['SKU1', 'SKU2', 'SKU3'], records) + + # test get_skus function with limit & offset + # limit & offset are ignored in the test as we already have mockup data + records = await self.table_obj.get_skus(limit=3, offset=1) + self.assertEqual(['SKU1', 'SKU2', 'SKU3'], records) diff --git a/test/test_bq_billing_daily_extended.py b/test/test_bq_billing_daily_extended.py new file mode 100644 index 000000000..183d6c8c5 --- /dev/null +++ b/test/test_bq_billing_daily_extended.py @@ -0,0 +1,116 @@ +# pylint: disable=protected-access +import datetime +from test.testbase import run_as_sync +from test.testbqbase import BqTest +from typing import Any + +from db.python.tables.bq.billing_daily_extended import BillingDailyExtendedTable +from db.python.tables.bq.billing_filter import BillingFilter +from db.python.tables.bq.generic_bq_filter import GenericBQFilter +from db.python.utils import InternalError +from models.models import BillingColumn, BillingTotalCostQueryModel + + +class TestBillingGcpDailyTable(BqTest): + """Test BillingRawTable and its methods""" + + def setUp(self): + super().setUp() + + # setup table object + self.table_obj = BillingDailyExtendedTable(self.connection) + + def test_query_to_partitioned_filter(self): + """Test query to partitioned filter conversion""" + + # given + start_date = '2023-01-01' + end_date = '2024-01-01' + filters: dict[BillingColumn, str | list[Any] | dict[Any, Any]] = { + BillingColumn.TOPIC: 'TEST_TOPIC' + } + + # expected + expected_filter = BillingFilter( + day=GenericBQFilter( + gte=datetime.datetime(2023, 1, 1, 0, 0), + lte=datetime.datetime(2024, 1, 1, 0, 0), + ), + topic=GenericBQFilter(eq='TEST_TOPIC'), + ) + + query = BillingTotalCostQueryModel( + fields=[], # not relevant for this test, but can't be null generally + start_date=start_date, + end_date=end_date, + filters=filters, + ) + filter_ = BillingDailyExtendedTable._query_to_partitioned_filter(query) + + # BillingFilter has __eq__ method, so we can compare them directly + self.assertEqual(expected_filter, filter_) + + def test_error_no_connection(self): + """Test No connection exception""" + + with self.assertRaises(InternalError) as context: + BillingDailyExtendedTable(None) + + self.assertTrue( + 'No connection was provided to the table \'BillingDailyExtendedTable\'' + in str(context.exception) + ) + + def test_get_table_name(self): + """Test get_table_name""" + + # table name is set in the class + given_table_name = 'TEST_TABLE_NAME' + + # set table name + self.table_obj.table_name = given_table_name + + # test get table name function + table_name = self.table_obj.get_table_name() + + self.assertEqual(given_table_name, table_name) + + @run_as_sync + async def test_get_extended_values_return_empty_list(self): + """Test get_extended_values as empty list""" + + # mock BigQuery result as empty list + self.bq_result.result.return_value = [] + + # test get table name function + records = await self.table_obj.get_extended_values('dataset') + + self.assertEqual([], records) + + @run_as_sync + async def test_get_extended_values_return_valid_list(self): + """Test get_extended_values as list of 2 records""" + + # mock BigQuery result as list of 2 records + self.bq_result.result.return_value = [ + {'dataset': 'DATA1'}, + {'dataset': 'DATA2'}, + ] + + # test get table name function + records = await self.table_obj.get_extended_values('dataset') + + self.assertEqual(['DATA1', 'DATA2'], records) + + @run_as_sync + async def test_get_extended_values_error(self): + """Test get_extended_values return exception as invalid ext column""" + + # mock BigQuery result as empty list + self.bq_result.result.return_value = [] + + # test get table name function + with self.assertRaises(ValueError) as context: + await self.table_obj.get_extended_values('rubish') + + self.assertTrue('Invalid field value' in str(context.exception)) diff --git a/test/test_bq_billing_gcp_daily.py b/test/test_bq_billing_gcp_daily.py new file mode 100644 index 000000000..99e3caa10 --- /dev/null +++ b/test/test_bq_billing_gcp_daily.py @@ -0,0 +1,197 @@ +# pylint: disable=protected-access +import datetime +from test.testbase import run_as_sync +from test.testbqbase import BqTest +from textwrap import dedent +from typing import Any +from unittest import mock + +import google.cloud.bigquery as bq + +from db.python.tables.bq.billing_filter import BillingFilter +from db.python.tables.bq.billing_gcp_daily import BillingGcpDailyTable +from db.python.tables.bq.generic_bq_filter import GenericBQFilter +from db.python.utils import InternalError +from models.models import BillingColumn, BillingTotalCostQueryModel + + +class TestBillingGcpDailyTable(BqTest): + """Test BillingRawTable and its methods""" + + def setUp(self): + super().setUp() + + # setup table object + self.table_obj = BillingGcpDailyTable(self.connection) + + def test_query_to_partitioned_filter(self): + """Test query to partitioned filter conversion""" + + # given + start_date = '2023-01-01' + end_date = '2024-01-01' + filters: dict[BillingColumn, str | list[Any] | dict[Any, Any]] = { + BillingColumn.TOPIC: 'TEST_TOPIC' + } + + # expected + expected_filter = BillingFilter( + day=GenericBQFilter( + gte=datetime.datetime(2023, 1, 1, 0, 0), + lte=datetime.datetime(2024, 1, 1, 0, 0), + ), + part_time=GenericBQFilter( + gte=datetime.datetime(2023, 1, 1, 0, 0), + lte=datetime.datetime(2024, 1, 8, 0, 0), # 7 days added + ), + topic=GenericBQFilter(eq='TEST_TOPIC'), + ) + + query = BillingTotalCostQueryModel( + fields=[], # not relevant for this test, but can't be null generally + start_date=start_date, + end_date=end_date, + filters=filters, + ) + filter_ = BillingGcpDailyTable._query_to_partitioned_filter(query) + + # BillingFilter has __eq__ method, so we can compare them directly + self.assertEqual(expected_filter, filter_) + + def test_error_no_connection(self): + """Test No connection exception""" + + with self.assertRaises(InternalError) as context: + BillingGcpDailyTable(None) + + self.assertTrue( + 'No connection was provided to the table \'BillingGcpDailyTable\'' + in str(context.exception) + ) + + def test_get_table_name(self): + """Test get_table_name""" + + # table name is set in the class + given_table_name = 'TEST_TABLE_NAME' + + # set table name + self.table_obj.table_name = given_table_name + + # test get table name function + table_name = self.table_obj.get_table_name() + + self.assertEqual(given_table_name, table_name) + + @run_as_sync + async def test_last_loaded_day_return_valid_day(self): + """Test _last_loaded_day""" + + given_last_day = '2021-01-01 00:00:00' + + # mock BigQuery result + + self.bq_result.result.return_value = [ + mock.MagicMock( + spec=bq.Row, + last_loaded_day=datetime.datetime.strptime( + given_last_day, '%Y-%m-%d %H:%M:%S' + ), + ) + ] # 2021-01-01 + + # test get table name function + last_loaded_day = await self.table_obj._last_loaded_day() + + self.assertEqual(given_last_day, last_loaded_day) + + @run_as_sync + async def test_last_loaded_day_return_none(self): + """Test _last_loaded_day as None""" + + # mock BigQuery result as empty list + self.bq_result.result.return_value = [] + + # test get table name function + last_loaded_day = await self.table_obj._last_loaded_day() + + self.assertEqual(None, last_loaded_day) + + def test_prepare_daily_cost_subquery(self): + """Test _prepare_daily_cost_subquery""" + + self.table_obj.table_name = 'TEST_TABLE_NAME' + + # given + given_field = BillingColumn.COST + given_query_params: list[Any] = [] + given_last_loaded_day = '2021-01-01 00:00:00' + + ( + query_params, + daily_cost_field, + daily_cost_join, + ) = self.table_obj._prepare_daily_cost_subquery( + given_field, given_query_params, given_last_loaded_day + ) + + # expected + expected_daily_cost_join = """LEFT JOIN ( + SELECT + cost as field, + cost_category, + SUM(cost) as cost + FROM + `TEST_TABLE_NAME` + WHERE day = TIMESTAMP(@last_loaded_day) + + AND part_time >= TIMESTAMP(@last_loaded_day) + AND part_time <= TIMESTAMP_ADD( + TIMESTAMP(@last_loaded_day), INTERVAL 7 DAY + ) + + GROUP BY + field, + cost_category + ) day + ON month.field = day.field + AND month.cost_category = day.cost_category + """ + + self.assertEqual( + [ + bq.ScalarQueryParameter( + 'last_loaded_day', 'STRING', '2021-01-01 00:00:00' + ) + ], + query_params, + ) + self.assertEqual(', day.cost as daily_cost', daily_cost_field) + self.assertEqual(dedent(expected_daily_cost_join), dedent(daily_cost_join)) + + @run_as_sync + async def test_get_gcp_projects_return_empty_list(self): + """Test get_gcp_projects as empty list""" + + # mock BigQuery result as empty list + self.bq_result.result.return_value = [] + + # test get table name function + gcp_projects = await self.table_obj.get_gcp_projects() + + self.assertEqual([], gcp_projects) + + @run_as_sync + async def test_get_gcp_projects_return_valid_list(self): + """Test get_gcp_projects as empty list""" + + # mock BigQuery result as list of 2 records + self.bq_result.result.return_value = [ + {'gcp_project': 'PROJECT1'}, + {'gcp_project': 'PROJECT2'}, + ] + + # test get table name function + gcp_projects = await self.table_obj.get_gcp_projects() + + self.assertEqual(['PROJECT1', 'PROJECT2'], gcp_projects) diff --git a/test/test_bq_billing_raw.py b/test/test_bq_billing_raw.py new file mode 100644 index 000000000..e1a5af50d --- /dev/null +++ b/test/test_bq_billing_raw.py @@ -0,0 +1,75 @@ +# pylint: disable=protected-access +import datetime +from test.testbqbase import BqTest +from typing import Any + +from db.python.tables.bq.billing_filter import BillingFilter +from db.python.tables.bq.billing_raw import BillingRawTable +from db.python.tables.bq.generic_bq_filter import GenericBQFilter +from db.python.utils import InternalError +from models.models import BillingColumn, BillingTotalCostQueryModel + + +class TestBillingRawTable(BqTest): + """Test BillingRawTable and its methods""" + + def setUp(self): + super().setUp() + + # setup table object + self.table_obj = BillingRawTable(self.connection) + + def test_query_to_partitioned_filter(self): + """Test query to partitioned filter conversion""" + + # given + start_date = '2023-01-01' + end_date = '2024-01-01' + filters: dict[BillingColumn, str | list[Any] | dict[Any, Any]] = { + BillingColumn.TOPIC: 'TEST_TOPIC' + } + + # expected + expected_filter = BillingFilter( + usage_end_time=GenericBQFilter( + gte=datetime.datetime(2023, 1, 1, 0, 0), + lte=datetime.datetime(2024, 1, 1, 0, 0), + ), + topic=GenericBQFilter(eq='TEST_TOPIC'), + ) + + query = BillingTotalCostQueryModel( + fields=[], # not relevant for this test, but can't be null generally + start_date=start_date, + end_date=end_date, + filters=filters, + ) + filter_ = BillingRawTable._query_to_partitioned_filter(query) + + # BillingFilter has __eq__ method, so we can compare them directly + self.assertEqual(expected_filter, filter_) + + def test_error_no_connection(self): + """Test No connection exception""" + + with self.assertRaises(InternalError) as context: + BillingRawTable(None) + + self.assertTrue( + 'No connection was provided to the table \'BillingRawTable\'' + in str(context.exception) + ) + + def test_get_table_name(self): + """Test get_table_name""" + + # table name is set in the class + given_table_name = 'TEST_BQ_AGGREG_RAW' + + # set table name + self.table_obj.table_name = given_table_name + + # test get table name function + table_name = self.table_obj.get_table_name() + + self.assertEqual(given_table_name, table_name) diff --git a/test/test_bq_function_filter.py b/test/test_bq_function_filter.py new file mode 100644 index 000000000..8c84f9af9 --- /dev/null +++ b/test/test_bq_function_filter.py @@ -0,0 +1,155 @@ +import unittest +from datetime import datetime +from enum import Enum + +import google.cloud.bigquery as bq +import pytz + +from db.python.tables.bq.function_bq_filter import FunctionBQFilter +from models.models import BillingColumn + + +class BGFunFilterTestEnum(str, Enum): + """Simple Enum classs""" + + ID = 'id' + VALUE = 'value' + + +class TestFunctionBQFilter(unittest.TestCase): + """Test function filters SQL generation""" + + def test_empty_output(self): + """Test that function filter converts to SQL as expected""" + filter_ = FunctionBQFilter( + name='test_sql_func', + implementation='CREATE TEMP FUNCTION test_sql_func() AS (SELECT 1);', + ) + + # no params are present, should return empty SQL and Param list + filter_.to_sql(BillingColumn.LABELS) + + self.assertEqual([], filter_.func_sql_parameters) + self.assertEqual('', filter_.func_where) + + def test_one_param_output(self): + """Test that function filter converts to SQL as expected""" + filter_ = FunctionBQFilter( + name='test_sql_func', + implementation='CREATE TEMP FUNCTION test_sql_func() AS (SELECT 1);', + ) + + # no params are present, should return empty SQL and Param list + filter_.to_sql(BillingColumn.LABELS, {'label_name': 'Some Value'}) + + self.assertEqual( + [ + bq.ScalarQueryParameter('param1', 'STRING', 'label_name'), + bq.ScalarQueryParameter('value1', 'STRING', 'Some Value'), + ], + filter_.func_sql_parameters, + ) + self.assertEqual( + '(test_sql_func(labels,@param1) = @value1)', filter_.func_where + ) + + def test_two_param_default_operator_output(self): + """Test that function filter converts to SQL as expected""" + filter_ = FunctionBQFilter( + name='test_sql_func', + implementation='CREATE TEMP FUNCTION test_sql_func() AS (SELECT 1);', + ) + + # no params are present, should return empty SQL and Param list + filter_.to_sql(BillingColumn.LABELS, {'label_name1': 10, 'label_name2': 123.45}) + + self.assertEqual( + [ + bq.ScalarQueryParameter('param1', 'STRING', 'label_name1'), + bq.ScalarQueryParameter('value1', 'INT64', 10), + bq.ScalarQueryParameter('param2', 'STRING', 'label_name2'), + bq.ScalarQueryParameter('value2', 'FLOAT64', 123.45), + ], + filter_.func_sql_parameters, + ) + self.assertEqual( + ( + '(test_sql_func(labels,@param1) = @value1 ' + 'AND test_sql_func(labels,@param2) = @value2)' + ), + filter_.func_where, + ) + + def test_two_param_operator_or_output(self): + """Test that function filter converts to SQL as expected""" + filter_ = FunctionBQFilter( + name='test_sql_func', + implementation='CREATE TEMP FUNCTION test_sql_func() AS (SELECT 1);', + ) + + # no params are present, should return empty SQL and Param list + filter_.to_sql( + BillingColumn.LABELS, + { + 'label_name1': BGFunFilterTestEnum.ID, + 'label_name2': datetime(2024, 1, 1), + }, + 'OR', + ) + + self.assertEqual( + [ + bq.ScalarQueryParameter('param1', 'STRING', 'label_name1'), + bq.ScalarQueryParameter('value1', 'STRING', 'id'), + bq.ScalarQueryParameter('param2', 'STRING', 'label_name2'), + bq.ScalarQueryParameter('value2', 'STRING', '2024-01-01T00:00:00'), + ], + filter_.func_sql_parameters, + ) + self.assertEqual( + ( + '(test_sql_func(labels,@param1) = @value1 ' + 'OR test_sql_func(labels,@param2) = TIMESTAMP(@value2))' + ), + filter_.func_where, + ) + + def test_two_param_datime_with_zone_output(self): + """Test that function filter converts to SQL as expected""" + filter_ = FunctionBQFilter( + name='test_sql_func', + implementation='CREATE TEMP FUNCTION test_sql_func() AS (SELECT 1);', + ) + + # no params are present, should return empty SQL and Param list + NYC = pytz.timezone('America/New_York') + SYD = pytz.timezone('Australia/Sydney') + filter_.to_sql( + BillingColumn.LABELS, + { + 'label_name1': datetime(2024, 1, 1, tzinfo=pytz.UTC).astimezone(NYC), + 'label_name2': datetime(2024, 1, 1, tzinfo=pytz.UTC).astimezone(SYD), + }, + 'AND', + ) + + self.assertEqual( + [ + bq.ScalarQueryParameter('param1', 'STRING', 'label_name1'), + bq.ScalarQueryParameter( + 'value1', 'STRING', '2023-12-31T19:00:00-05:00' + ), + bq.ScalarQueryParameter('param2', 'STRING', 'label_name2'), + bq.ScalarQueryParameter( + 'value2', 'STRING', '2024-01-01T11:00:00+11:00' + ), + ], + filter_.func_sql_parameters, + ) + self.assertEqual( + ( + '(test_sql_func(labels,@param1) = TIMESTAMP(@value1) ' + 'AND test_sql_func(labels,@param2) = TIMESTAMP(@value2))' + ), + filter_.func_where, + ) diff --git a/test/test_bq_generic_filter.py b/test/test_bq_generic_filter.py new file mode 100644 index 000000000..5e5600d1f --- /dev/null +++ b/test/test_bq_generic_filter.py @@ -0,0 +1,270 @@ +import dataclasses +import unittest +from datetime import datetime +from enum import Enum +from typing import Any + +from google.cloud import bigquery + +from db.python.tables.bq.generic_bq_filter import GenericBQFilter +from db.python.tables.bq.generic_bq_filter_model import GenericBQFilterModel + + +@dataclasses.dataclass(kw_only=True) +class GenericBQFilterTest(GenericBQFilterModel): + """Test model for GenericBQFilter""" + + test_string: GenericBQFilter[str] | None = None + test_int: GenericBQFilter[int] | None = None + test_float: GenericBQFilter[float] | None = None + test_dt: GenericBQFilter[datetime] | None = None + test_dict: dict[str, GenericBQFilter[str]] | None = None + test_enum: GenericBQFilter[Enum] | None = None + test_any: Any | None = None + + +class BGFilterTestEnum(str, Enum): + """Simple Enum classs""" + + ID = 'id' + VALUE = 'value' + + +class TestGenericBQFilter(unittest.TestCase): + """Test generic filter SQL generation""" + + def test_basic_no_override(self): + """Test that the basic filter converts to SQL as expected""" + filter_ = GenericBQFilterTest(test_string=GenericBQFilter(eq='test')) + sql, values = filter_.to_sql() + + self.assertEqual('test_string = @test_string_eq', sql) + self.assertDictEqual( + { + 'test_string_eq': bigquery.ScalarQueryParameter( + 'test_string_eq', 'STRING', 'test' + ) + }, + values, + ) + + def test_basic_override(self): + """Test that the basic filter with an override converts to SQL as expected""" + filter_ = GenericBQFilterTest(test_string=GenericBQFilter(eq='test')) + sql, values = filter_.to_sql({'test_string': 't.string'}) + + self.assertEqual('t.string = @t_string_eq', sql) + self.assertDictEqual( + { + 't_string_eq': bigquery.ScalarQueryParameter( + 't_string_eq', 'STRING', 'test' + ) + }, + values, + ) + + def test_single_string(self): + """ + Test that a single value filtered using the "in" operator + gets converted to an eq operation + """ + filter_ = GenericBQFilterTest(test_string=GenericBQFilter(in_=['test'])) + sql, values = filter_.to_sql() + + self.assertEqual('test_string = @test_string_in_eq', sql) + self.assertDictEqual( + { + 'test_string_in_eq': bigquery.ScalarQueryParameter( + 'test_string_in_eq', 'STRING', 'test' + ) + }, + values, + ) + + def test_single_int(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = 123 + filter_ = GenericBQFilterTest(test_int=GenericBQFilter(gt=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int > @test_int_gt', sql) + self.assertDictEqual( + { + 'test_int_gt': bigquery.ScalarQueryParameter( + 'test_int_gt', 'INT64', value + ) + }, + values, + ) + + def test_single_float(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = 123.456 + filter_ = GenericBQFilterTest(test_float=GenericBQFilter(gte=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_float >= @test_float_gte', sql) + self.assertDictEqual( + { + 'test_float_gte': bigquery.ScalarQueryParameter( + 'test_float_gte', 'FLOAT64', value + ) + }, + values, + ) + + def test_single_datetime(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + datetime_str = '2021-10-08 01:02:03' + value = datetime.strptime(datetime_str, '%Y-%m-%d %H:%M:%S') + filter_ = GenericBQFilterTest(test_dt=GenericBQFilter(lt=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_dt < TIMESTAMP(@test_dt_lt)', sql) + self.assertDictEqual( + { + 'test_dt_lt': bigquery.ScalarQueryParameter( + 'test_dt_lt', 'STRING', datetime_str + ) + }, + values, + ) + + def test_single_enum(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = BGFilterTestEnum.ID + filter_ = GenericBQFilterTest(test_enum=GenericBQFilter(lte=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_enum <= @test_enum_lte', sql) + self.assertDictEqual( + { + 'test_enum_lte': bigquery.ScalarQueryParameter( + 'test_enum_lte', 'STRING', value.value + ) + }, + values, + ) + + def test_in_multiple_int(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = [1, 2] + filter_ = GenericBQFilterTest(test_int=GenericBQFilter(in_=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int IN UNNEST(@test_int_in)', sql) + self.assertDictEqual( + { + 'test_int_in': bigquery.ArrayQueryParameter( + 'test_int_in', 'INT64', value + ) + }, + values, + ) + + def test_in_multiple_float(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = [1.0, 2.0] + filter_ = GenericBQFilterTest(test_float=GenericBQFilter(in_=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_float IN UNNEST(@test_float_in)', sql) + self.assertDictEqual( + { + 'test_float_in': bigquery.ArrayQueryParameter( + 'test_float_in', 'FLOAT64', value + ) + }, + values, + ) + + def test_in_multiple_str(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = ['A', 'B'] + filter_ = GenericBQFilterTest(test_string=GenericBQFilter(in_=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_string IN UNNEST(@test_string_in)', sql) + self.assertDictEqual( + { + 'test_string_in': bigquery.ArrayQueryParameter( + 'test_string_in', 'STRING', value + ) + }, + values, + ) + + def test_nin_multiple_str(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = ['A', 'B'] + filter_ = GenericBQFilterTest(test_string=GenericBQFilter(nin=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_string NOT IN UNNEST(@test_string_nin)', sql) + self.assertDictEqual( + { + 'test_string_nin': bigquery.ArrayQueryParameter( + 'test_string_nin', 'STRING', value + ) + }, + values, + ) + + def test_in_and_eq_multiple_str(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = ['A'] + filter_ = GenericBQFilterTest(test_string=GenericBQFilter(in_=value, eq='B')) + sql, values = filter_.to_sql() + + self.assertEqual( + 'test_string = @test_string_eq AND test_string = @test_string_in_eq', + sql, + ) + self.assertDictEqual( + { + 'test_string_eq': bigquery.ScalarQueryParameter( + 'test_string_eq', 'STRING', 'B' + ), + 'test_string_in_eq': bigquery.ScalarQueryParameter( + 'test_string_in_eq', 'STRING', 'A' + ), + }, + values, + ) + + def test_fail_none_in_tuple(self): + """ + Test that values filtered using the "in" operator convert as expected + """ + value = (None,) + + # check if ValueError is raised + with self.assertRaises(ValueError) as context: + filter_ = GenericBQFilterTest(test_any=value) + filter_.to_sql() + + self.assertTrue( + 'There is very likely a trailing comma on the end of ' + 'GenericBQFilterTest.test_any. ' + 'If you actually want a tuple of length one with the value = (None,), ' + 'then use dataclasses.field(default_factory=lambda: (None,))' + in str(context.exception) + ) diff --git a/test/test_generic_filters.py b/test/test_generic_filters.py index 2c1348076..b5598be54 100644 --- a/test/test_generic_filters.py +++ b/test/test_generic_filters.py @@ -53,3 +53,54 @@ def test_in_multiple(self): self.assertEqual('test_int IN :test_int_in', sql) self.assertDictEqual({'test_int_in': value}, values) + + def test_gt_single(self): + """ + Test that a single value filtered using the "gt" operator + """ + filter_ = GenericFilterTest(test_int=GenericFilter(gt=123)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int > :test_int_gt', sql) + self.assertDictEqual({'test_int_gt': 123}, values) + + def test_gte_single(self): + """ + Test that a single value filtered using the "gte" operator + """ + filter_ = GenericFilterTest(test_int=GenericFilter(gte=123)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int >= :test_int_gte', sql) + self.assertDictEqual({'test_int_gte': 123}, values) + + def test_lt_single(self): + """ + Test that a single value filtered using the "lt" operator + """ + filter_ = GenericFilterTest(test_int=GenericFilter(lt=123)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int < :test_int_lt', sql) + self.assertDictEqual({'test_int_lt': 123}, values) + + def test_lte_single(self): + """ + Test that a single value filtered using the "lte" operator + """ + filter_ = GenericFilterTest(test_int=GenericFilter(lte=123)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int <= :test_int_lte', sql) + self.assertDictEqual({'test_int_lte': 123}, values) + + def test_not_in_multiple(self): + """ + Test that values filtered using the "nin" operator convert as expected + """ + value = [1, 2] + filter_ = GenericFilterTest(test_int=GenericFilter(nin=value)) + sql, values = filter_.to_sql() + + self.assertEqual('test_int NOT IN :test_int_nin', sql) + self.assertDictEqual({'test_int_nin': value}, values) diff --git a/test/test_layers_billing.py b/test/test_layers_billing.py new file mode 100644 index 000000000..be13f3fa3 --- /dev/null +++ b/test/test_layers_billing.py @@ -0,0 +1,470 @@ +# pylint: disable=protected-access +import datetime +from test.testbase import run_as_sync +from test.testbqbase import BqTest +from unittest import mock + +import google.cloud.bigquery as bq + +from db.python.layers.billing import BillingLayer +from models.enums import BillingSource +from models.models import ( + BillingColumn, + BillingHailBatchCostRecord, + BillingTotalCostQueryModel, +) + + +class TestBillingLayer(BqTest): + """Test BillingLayer and its methods""" + + def test_table_factory(self): + """Test table_factory""" + + layer = BillingLayer(self.connection) + + # test BillingSource types + table_obj = layer.table_factory() + self.assertEqual('BillingDailyTable', table_obj.__class__.__name__) + + table_obj = layer.table_factory(source=BillingSource.GCP_BILLING) + self.assertEqual('BillingGcpDailyTable', table_obj.__class__.__name__) + + table_obj = layer.table_factory(source=BillingSource.RAW) + self.assertEqual('BillingRawTable', table_obj.__class__.__name__) + + table_obj = layer.table_factory(source=BillingSource.AGGREGATE) + self.assertEqual('BillingDailyTable', table_obj.__class__.__name__) + + # base columns + table_obj = layer.table_factory( + source=BillingSource.AGGREGATE, fields=[BillingColumn.TOPIC] + ) + self.assertEqual('BillingDailyTable', table_obj.__class__.__name__) + + table_obj = layer.table_factory( + source=BillingSource.AGGREGATE, filters={BillingColumn.TOPIC: 'TOPIC1'} + ) + self.assertEqual('BillingDailyTable', table_obj.__class__.__name__) + + # columns from extended view + table_obj = layer.table_factory( + source=BillingSource.AGGREGATE, fields=[BillingColumn.AR_GUID] + ) + self.assertEqual('BillingDailyExtendedTable', table_obj.__class__.__name__) + + table_obj = layer.table_factory( + source=BillingSource.AGGREGATE, filters={BillingColumn.AR_GUID: 'AR_GUID1'} + ) + self.assertEqual('BillingDailyExtendedTable', table_obj.__class__.__name__) + + @run_as_sync + async def test_get_gcp_projects(self): + """Test get_gcp_projects""" + + layer = BillingLayer(self.connection) + + # test with no muckup data, should be empty + records = await layer.get_gcp_projects() + self.assertEqual([], records) + + # mockup BQ results + self.bq_result.result.return_value = [ + {'gcp_project': 'PROJECT1'}, + {'gcp_project': 'PROJECT2'}, + ] + + records = await layer.get_gcp_projects() + self.assertEqual(['PROJECT1', 'PROJECT2'], records) + + @run_as_sync + async def test_get_topics(self): + """Test get_topics""" + + layer = BillingLayer(self.connection) + + # test with no muckup data, should be empty + records = await layer.get_topics() + self.assertEqual([], records) + + # mockup BQ results + self.bq_result.result.return_value = [ + {'topic': 'TOPIC1'}, + {'topic': 'TOPIC2'}, + ] + + records = await layer.get_topics() + self.assertEqual(['TOPIC1', 'TOPIC2'], records) + + @run_as_sync + async def test_get_cost_categories(self): + """Test get_cost_categories""" + + layer = BillingLayer(self.connection) + + # test with no muckup data, should be empty + records = await layer.get_cost_categories() + self.assertEqual([], records) + + # mockup BQ results + self.bq_result.result.return_value = [ + {'cost_category': 'CAT1'}, + {'cost_category': 'CAT2'}, + ] + + records = await layer.get_cost_categories() + self.assertEqual(['CAT1', 'CAT2'], records) + + @run_as_sync + async def test_get_skus(self): + """Test get_skus""" + + layer = BillingLayer(self.connection) + + # test with no muckup data, should be empty + records = await layer.get_skus() + self.assertEqual([], records) + + # mockup BQ results + self.bq_result.result.return_value = [ + {'sku': 'SKU1'}, + {'sku': 'SKU2'}, + ] + + records = await layer.get_skus() + self.assertEqual(['SKU1', 'SKU2'], records) + + @run_as_sync + async def test_get_datasets(self): + """Test get_datasets""" + + layer = BillingLayer(self.connection) + + # test with no muckup data, should be empty + records = await layer.get_datasets() + self.assertEqual([], records) + + # mockup BQ results + self.bq_result.result.return_value = [ + {'dataset': 'DATA1'}, + {'dataset': 'DATA2'}, + ] + + records = await layer.get_datasets() + self.assertEqual(['DATA1', 'DATA2'], records) + + @run_as_sync + async def test_get_stages(self): + """Test get_stages""" + + layer = BillingLayer(self.connection) + + # test with no muckup data, should be empty + records = await layer.get_stages() + self.assertEqual([], records) + + # mockup BQ results + self.bq_result.result.return_value = [ + {'stage': 'STAGE1'}, + {'stage': 'STAGE2'}, + ] + + records = await layer.get_stages() + self.assertEqual(['STAGE1', 'STAGE2'], records) + + @run_as_sync + async def test_get_sequencing_types(self): + """Test get_sequencing_types""" + + layer = BillingLayer(self.connection) + + # test with no muckup data, should be empty + records = await layer.get_sequencing_types() + self.assertEqual([], records) + + # mockup BQ results + self.bq_result.result.return_value = [ + {'sequencing_type': 'SEQ1'}, + {'sequencing_type': 'SEQ2'}, + ] + + records = await layer.get_sequencing_types() + self.assertEqual(['SEQ1', 'SEQ2'], records) + + @run_as_sync + async def test_get_sequencing_groups(self): + """Test get_sequencing_groups""" + + layer = BillingLayer(self.connection) + + # test with no muckup data, should be empty + records = await layer.get_sequencing_groups() + self.assertEqual([], records) + + # mockup BQ results + self.bq_result.result.return_value = [ + {'sequencing_group': 'GRP1'}, + {'sequencing_group': 'GRP2'}, + ] + + records = await layer.get_sequencing_groups() + self.assertEqual(['GRP1', 'GRP2'], records) + + @run_as_sync + async def test_get_compute_categories(self): + """Test get_compute_categories""" + + layer = BillingLayer(self.connection) + + # test with no muckup data, should be empty + records = await layer.get_compute_categories() + self.assertEqual([], records) + + # mockup BQ results + self.bq_result.result.return_value = [ + {'compute_category': 'CAT1'}, + {'compute_category': 'CAT2'}, + ] + + records = await layer.get_compute_categories() + self.assertEqual(['CAT1', 'CAT2'], records) + + @run_as_sync + async def test_get_cromwell_sub_workflow_names(self): + """Test get_cromwell_sub_workflow_names""" + + layer = BillingLayer(self.connection) + + # test with no muckup data, should be empty + records = await layer.get_cromwell_sub_workflow_names() + self.assertEqual([], records) + + # mockup BQ results + self.bq_result.result.return_value = [ + {'cromwell_sub_workflow_name': 'CROM1'}, + {'cromwell_sub_workflow_name': 'CROM2'}, + ] + + records = await layer.get_cromwell_sub_workflow_names() + self.assertEqual(['CROM1', 'CROM2'], records) + + @run_as_sync + async def test_get_wdl_task_names(self): + """Test get_wdl_task_names""" + + layer = BillingLayer(self.connection) + + # test with no muckup data, should be empty + records = await layer.get_wdl_task_names() + self.assertEqual([], records) + + # mockup BQ results + self.bq_result.result.return_value = [ + {'wdl_task_name': 'WDL1'}, + {'wdl_task_name': 'WDL2'}, + ] + + records = await layer.get_wdl_task_names() + self.assertEqual(['WDL1', 'WDL2'], records) + + @run_as_sync + async def test_get_invoice_months(self): + """Test get_invoice_months""" + + layer = BillingLayer(self.connection) + + # test with no muckup data, should be empty + records = await layer.get_invoice_months() + self.assertEqual([], records) + + # mockup BQ results + self.bq_result.result.return_value = [ + {'invoice_month': '202301'}, + {'invoice_month': '202302'}, + ] + + records = await layer.get_invoice_months() + self.assertEqual(['202301', '202302'], records) + + @run_as_sync + async def test_get_namespaces(self): + """Test get_namespaces""" + + layer = BillingLayer(self.connection) + + # test with no muckup data, should be empty + records = await layer.get_namespaces() + self.assertEqual([], records) + + # mockup BQ results + self.bq_result.result.return_value = [ + {'namespace': 'NAME1'}, + {'namespace': 'NAME2'}, + ] + + records = await layer.get_namespaces() + self.assertEqual(['NAME1', 'NAME2'], records) + + @run_as_sync + async def test_get_total_cost(self): + """Test get_total_cost""" + + layer = BillingLayer(self.connection) + + # test inparams exceptions: + query = BillingTotalCostQueryModel(fields=[], start_date='', end_date='') + + with self.assertRaises(ValueError) as context: + await layer.get_total_cost(query) + + self.assertTrue('Date and Fields are required' in str(context.exception)) + + # test with no muckup data, should be empty + query = BillingTotalCostQueryModel( + fields=[BillingColumn.TOPIC], start_date='2024-01-01', end_date='2024-01-03' + ) + records = await layer.get_total_cost(query) + self.assertEqual([], records) + + # get_total_cost with mockup data is tested in test/test_bq_billing_base.py + # BillingLayer is just wrapper for BQ tables + + @run_as_sync + async def test_get_running_cost(self): + """Test get_running_cost""" + + layer = BillingLayer(self.connection) + + # test inparams exceptions: + with self.assertRaises(ValueError) as context: + await layer.get_running_cost( + field=BillingColumn.TOPIC, invoice_month=None, source=None + ) + self.assertTrue('Invalid invoice month' in str(context.exception)) + + with self.assertRaises(ValueError) as context: + await layer.get_running_cost( + field=BillingColumn.TOPIC, invoice_month='2024', source=None + ) + self.assertTrue('Invalid invoice month' in str(context.exception)) + + # test with no muckup data, should be empty + records = await layer.get_running_cost( + field=BillingColumn.TOPIC, invoice_month='202401', source=None + ) + self.assertEqual([], records) + + # get_running_cost with mockup data is tested in test/test_bq_billing_base.py + # BillingLayer is just wrapper for BQ tables + + @run_as_sync + async def test_get_cost_by_ar_guid(self): + """Test get_cost_by_ar_guid""" + + layer = BillingLayer(self.connection) + + # ar_guid as None, return empty results + records = await layer.get_cost_by_ar_guid(ar_guid=None) + + # return empty record + self.assertEqual( + BillingHailBatchCostRecord(ar_guid=None, batch_ids=[], costs=[]), records + ) + + # dummy ar_guid, no mockup data, return empty results + dummy_ar_guid = '12345678' + records = await layer.get_cost_by_ar_guid(ar_guid=dummy_ar_guid) + + # return empty record + self.assertEqual( + BillingHailBatchCostRecord(ar_guid=dummy_ar_guid, batch_ids=[], costs=[]), + records, + ) + + # dummy ar_guid, mockup batch_id + + # mock BigQuery result + given_start_day = datetime.datetime(2023, 1, 1, 0, 0) + given_end_day = datetime.datetime(2023, 1, 1, 2, 3) + dummy_batch_id = '12345' + + mock_rows = mock.MagicMock(spec=bq.table.RowIterator) + mock_rows.total_rows = 1 + mock_rows.__iter__.return_value = [ + mock.MagicMock( + spec=bq.Row, + batch_id=dummy_batch_id, + start_day=given_start_day, + end_day=given_end_day, + ), + ] + self.bq_result.result.return_value = mock_rows + + records = await layer.get_cost_by_ar_guid(ar_guid=dummy_ar_guid) + # returns ar_guid, batch_id and empty cost as those were not mocked up + # we do not need to test cost calculation here, + # as those are tested in test/test_bq_billing_base.py + self.assertEqual( + BillingHailBatchCostRecord( + ar_guid=dummy_ar_guid, batch_ids=[dummy_batch_id], costs=[{}] + ), + records, + ) + + @run_as_sync + async def test_get_cost_by_batch_id(self): + """Test get_cost_by_batch_id""" + + layer = BillingLayer(self.connection) + + # ar_guid as None, return empty results + records = await layer.get_cost_by_batch_id(batch_id=None) + + # return empty record + self.assertEqual( + BillingHailBatchCostRecord(ar_guid=None, batch_ids=[], costs=[]), records + ) + + # dummy ar_guid, no mockup data, return empty results + dummy_batch_id = '12345' + records = await layer.get_cost_by_batch_id(batch_id=dummy_batch_id) + + # return empty record + self.assertEqual( + BillingHailBatchCostRecord(ar_guid=None, batch_ids=[], costs=[]), + records, + ) + + # dummy batch_id, mockup ar_guid + + # mock BigQuery result + given_start_day = datetime.datetime(2023, 1, 1, 0, 0) + given_end_day = datetime.datetime(2023, 1, 1, 2, 3) + dummy_batch_id = '12345' + dummy_ar_guid = '12345678' + + mock_rows = mock.MagicMock(spec=bq.table.RowIterator) + mock_rows.total_rows = 1 + mock_rows.__iter__.return_value = [ + mock.MagicMock( + spec=bq.Row, + ar_guid=dummy_ar_guid, + batch_id=dummy_batch_id, + start_day=given_start_day, + end_day=given_end_day, + # mockup __getitem__ to return dummy_ar_guid + __getitem__=mock.MagicMock(return_value=dummy_ar_guid), + ), + ] + self.bq_result.result.return_value = mock_rows + + records = await layer.get_cost_by_batch_id(batch_id=dummy_batch_id) + # returns ar_guid, batch_id and empty cost as those were not mocked up + # we do not need to test cost calculation here, + # as those are tested in test/test_bq_billing_base.py + self.assertEqual( + BillingHailBatchCostRecord( + ar_guid=dummy_ar_guid, batch_ids=[dummy_batch_id], costs=[{}] + ), + records, + ) diff --git a/test/testbqbase.py b/test/testbqbase.py new file mode 100644 index 000000000..6d25e489a --- /dev/null +++ b/test/testbqbase.py @@ -0,0 +1,44 @@ +import unittest +from typing import Any +from unittest import mock + +import google.cloud.bigquery as bq + +from db.python.gcp_connect import BqConnection +from db.python.layers.billing import BillingLayer + + +class BqTest(unittest.TestCase): + """Base class for Big Query integration tests""" + + # author and grp_project are not used in the BQ tests, but are required + # so some dummy values are preset + author: str = 'Author' + gcp_project: str = 'GCP_PROJECT' + + bq_result: bq.job.QueryJob + bq_client: bq.Client + connection: BqConnection + table_obj: Any | None = None + + def setUp(self) -> None: + super().setUp() + + # Mockup BQ results + self.bq_result = mock.MagicMock(spec=bq.job.QueryJob) + + # mock BigQuery client + self.bq_client = mock.MagicMock(spec=bq.Client) + self.bq_client.query.return_value = self.bq_result + + # Mock BqConnection + self.connection = mock.MagicMock(spec=BqConnection) + self.connection.gcp_project = self.gcp_project + self.connection.connection = self.bq_client + self.connection.author = self.author + + # Mockup BillingLayer + self.layer = BillingLayer(self.connection) + + # overwrite table object in inhereted tests: + self.table_obj = None diff --git a/web/src/Routes.tsx b/web/src/Routes.tsx index a307355e6..ed375b347 100644 --- a/web/src/Routes.tsx +++ b/web/src/Routes.tsx @@ -9,6 +9,7 @@ import { BillingCostByAnalysis, BillingInvoiceMonthCost, BillingCostByCategory, + BillingCostByMonth, } from './pages/billing' import DocumentationArticle from './pages/docs/Documentation' import SampleView from './pages/sample/SampleView' @@ -49,6 +50,14 @@ const Routes: React.FunctionComponent = () => ( /> } /> + + + + } + /> } /> { + const [searchParams] = useSearchParams() + + const [start, setStart] = React.useState( + searchParams.get('start') ?? getCurrentInvoiceYearStart() + ) + const [end, setEnd] = React.useState( + searchParams.get('end') ?? getCurrentInvoiceMonth() + ) + + // Data loading + const [isLoading, setIsLoading] = React.useState(true) + const [error, setError] = React.useState() + const [message, setMessage] = React.useState() + const [months, setMonths] = React.useState([]) + const [data, setData] = React.useState([]) + + // use navigate and update url params + const location = useLocation() + const navigate = useNavigate() + + const updateNav = (st: string, ed: string) => { + const url = generateUrl(location, { + start: st, + end: ed, + }) + navigate(url) + } + + const changeDate = (name: string, value: string) => { + let start_update = start + let end_update = end + if (name === 'start') start_update = value + if (name === 'end') end_update = value + setStart(start_update) + setEnd(end_update) + updateNav(start_update, end_update) + } + + const convertInvoiceMonth = (invoiceMonth: string, start: Boolean) => { + const year = invoiceMonth.substring(0, 4) + const month = invoiceMonth.substring(4, 6) + if (start) return `${year}-${month}-01` + // get last day of month + const lastDay = new Date(parseInt(year), parseInt(month), 0).getDate() + return `${year}-${month}-${lastDay}` + } + + const convertCostCategory = (costCategory: string) => { + if (costCategory === 'Cloud Storage') { + return 'Storage Cost' + } + return 'Compute Cost' + } + + const getData = (query: BillingTotalCostQueryModel) => { + setIsLoading(true) + setError(undefined) + setMessage(undefined) + new BillingApi() + .getTotalCost(query) + .then((response) => { + setIsLoading(false) + + // calc totals per topic, month and category + const recTotals: { [key: string]: { [key: string]: number } } = {} + const recMonths: string[] = [] + + response.data.forEach((item: BillingTotalCostRecord) => { + const { day, cost_category, topic, cost } = item + const ccat = convertCostCategory(cost_category) + if (recMonths.indexOf(day) === -1) { + recMonths.push(day) + } + if (!recTotals[topic]) { + recTotals[topic] = {} + } + if (!recTotals[topic][day]) { + recTotals[topic][day] = {} + } + if (!recTotals[topic][day][ccat]) { + recTotals[topic][day][ccat] = 0 + } + recTotals[topic][day][ccat] += cost + }) + + setMonths(recMonths) + setData(recTotals) + }) + .catch((er) => setError(er.message)) + } + + const messageComponent = () => { + if (message) { + return ( + setError(undefined)}> + {message} + + ) + } + if (error) { + return ( + setError(undefined)}> + {error} +
+ +
+ ) + } + if (isLoading) { + return ( +
+ +

+ This query takes a while... +

+
+ ) + } + return null + } + + const dataComponent = () => { + if (message || error || isLoading) { + return null + } + + if (!message && !error && !isLoading && (!data || data.length === 0)) { + return ( + + No Data + + ) + } + + return ( + <> + + + + + ) + } + + const onMonthStart = (event: any, data: any) => { + changeDate('start', data.value) + } + + const onMonthEnd = (event: any, data: any) => { + changeDate('end', data.value) + } + + React.useEffect(() => { + if (Boolean(start) && Boolean(end)) { + // valid selection, retrieve data + getData({ + fields: [BillingColumn.Topic, BillingColumn.CostCategory], + start_date: getAdjustedDay(convertInvoiceMonth(start, true), -2), + end_date: getAdjustedDay(convertInvoiceMonth(end, false), 3), + order_by: { day: false }, + source: BillingSource.Aggregate, + time_periods: 'invoice_month', + filters: { + invoice_month: generateInvoiceMonths(start, end), + }, + }) + } else { + // invalid selection, + setIsLoading(false) + setError(undefined) + + if (start === undefined || start === null || start === '') { + setMessage('Please select Start date') + } else if (end === undefined || end === null || end === '') { + setMessage('Please select End date') + } + } + }, [start, end]) + + return ( + <> + +

+ Cost Across Invoice Months (Topic only) +

+ + + + + + + + + + +
+ + {messageComponent()} + + {dataComponent()} + + ) +} + +export default BillingCostByTime diff --git a/web/src/pages/billing/BillingInvoiceMonthCost.tsx b/web/src/pages/billing/BillingInvoiceMonthCost.tsx index e06324670..ee66d77d5 100644 --- a/web/src/pages/billing/BillingInvoiceMonthCost.tsx +++ b/web/src/pages/billing/BillingInvoiceMonthCost.tsx @@ -183,7 +183,7 @@ const BillingCurrentCost = () => { return ( <> -

Billing By Invoice Month

+

Cost By Invoice Month

diff --git a/web/src/pages/billing/components/BillingCostByMonthTable.tsx b/web/src/pages/billing/components/BillingCostByMonthTable.tsx new file mode 100644 index 000000000..008906b47 --- /dev/null +++ b/web/src/pages/billing/components/BillingCostByMonthTable.tsx @@ -0,0 +1,92 @@ +import { Header, Table as SUITable } from 'semantic-ui-react' +import Table from '../../../shared/components/Table' +import React from 'react' +import formatMoney from '../../../shared/utilities/formatMoney' +import LoadingDucks from '../../../shared/components/LoadingDucks/LoadingDucks' + +const date2Month = (dt: string): string => { + if (dt === undefined || dt === null) { + return '' + } + const date = new Date(dt) + return `${date.getFullYear()}${(date.getMonth() + 1).toString().padStart(2, '0')}` +} + +interface IBillingCostByMonthTableProps { + start: string + end: string + isLoading: boolean + data: any + months: string[] +} + +const BillingCostByMonthTable: React.FC = ({ + start, + end, + isLoading, + data, + months, +}) => { + if (isLoading) { + return ( +
+ +
+ ) + } + const compTypes = ['Compute Cost', 'Storage Cost'] + + const dataToBody = (data: any) => { + const sortedKeys = Object.keys(data).sort() + return sortedKeys.map((key) => ( + <> + {compTypes.map((compType, index) => ( + + + {index === 0 && {key}} + + {compType} + {months.map((month) => ( + + {data[key] && data[key][month] && data[key][month][compType] + ? formatMoney(data[key][month][compType]) + : null} + + ))} + + ))} + + )) + } + + return ( + <> +
+ SUM of Cost (AUD) By Topic from {start} to {end} +
+ + + + + + + Invoice Month + + + + Topic + Compute Type + {months.map((month) => ( + + {date2Month(month)} + + ))} + + + {dataToBody(data)} +
+ + ) +} + +export default BillingCostByMonthTable diff --git a/web/src/pages/billing/index.ts b/web/src/pages/billing/index.ts index 037d0b871..0e6dc599e 100644 --- a/web/src/pages/billing/index.ts +++ b/web/src/pages/billing/index.ts @@ -4,3 +4,4 @@ export { default as BillingCostByTime } from "./BillingCostByTime"; export { default as BillingCostByAnalysis } from "./BillingCostByAnalysis"; export { default as BillingCostByCategory } from "./BillingCostByCategory"; export { default as BillingInvoiceMonthCost } from "./BillingInvoiceMonthCost"; +export { default as BillingCostByMonth } from "./BillingCostByMonth"; diff --git a/web/src/shared/components/Header/NavBar.tsx b/web/src/shared/components/Header/NavBar.tsx index 8ee0f52b4..8a0efabd1 100644 --- a/web/src/shared/components/Header/NavBar.tsx +++ b/web/src/shared/components/Header/NavBar.tsx @@ -33,10 +33,15 @@ const billingPages = { icon: , }, { - title: 'Invoice Month Cost', + title: 'Cost By Invoice Month', url: '/billing/invoiceMonthCost', icon: , }, + { + title: 'Cost Across Invoice Months (Topics only)', + url: '/billing/costByMonth', + icon: , + }, { title: 'Cost By Time', url: '/billing/costByTime', diff --git a/web/src/shared/utilities/formatDates.ts b/web/src/shared/utilities/formatDates.ts new file mode 100644 index 000000000..f1ade6988 --- /dev/null +++ b/web/src/shared/utilities/formatDates.ts @@ -0,0 +1,54 @@ +const getAdjustedDay = (value: string, days: number): string => { + const date = new Date(value) + date.setDate(date.getDate() + days) + return [ + date.getFullYear(), + (date.getMonth() + 1).toString().padStart(2, '0'), + date.getDate().toString().padStart(2, '0') + ].join('-') +} + +const getCurrentInvoiceMonth = () => { + // get current month and year in the format YYYYMM + const date = new Date() + return [ + date.getFullYear(), + (date.getMonth() + 1).toString().padStart(2, '0') + ].join('') +} + +const getCurrentInvoiceYearStart = () => { + const date = new Date() + return [ + date.getFullYear(), + '01' + ].join('') +} + +const generateInvoiceMonths = (start: string, end: string): string[] => { + + const invoiceMonths = [] + const yearStart = start.substring(0, 4) + const yearEnd = end.substring(0, 4) + + const mthStart = start.substring(4, 6) + const mthEnd = end.substring(4, 6) + + const dateStart = new Date(yearStart + '-' + mthStart + '-01') + const dateEnd = new Date(yearEnd + '-' + mthEnd + '-01') + + for (let i = yearStart; i <= yearEnd; i++) { + const startMonth = i === yearStart ? dateStart.getMonth() : 0 + const endMonth = i === yearEnd ? dateEnd.getMonth() : 11 + for (let j = startMonth; j <= endMonth; j++) { + const month = j + 1 + const monthString = month.toString().padStart(2, '0') + const yearString = i.toString() + const dateString = `${yearString}${monthString}` + invoiceMonths.push(dateString) + } + } + return invoiceMonths; +} + +export {getAdjustedDay, generateInvoiceMonths, getCurrentInvoiceMonth, getCurrentInvoiceYearStart}